Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pylib/cqlshlib/copyutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -1752,6 +1752,7 @@ def format_value(self, val, cqltype):
formatted = formatter(val, cqltype=cqltype,
encoding=self.encoding, colormap=NO_COLOR_MAP, date_time_format=self.date_time_format,
float_precision=cqltype.precision, nullval=self.nullval, quote=False,
escape_backslash=False,
Copy link
Copy Markdown
Contributor

@bschoening bschoening May 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider an alternative approach where formatted_value() bypasses display formatting for text.

format_value():
...
if cqltype.type_name in ('text', 'varchar', 'ascii'):
    return val if val.isprintable() else None

decimal_sep=self.decimal_sep, thousands_sep=self.thousands_sep,
boolean_styles=self.boolean_styles)
return formatted
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add unit tests

Expand Down
41 changes: 24 additions & 17 deletions pylib/cqlshlib/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,8 +477,8 @@ def decode_zig_zag_64(n):


@formatter_for('str')
def format_value_text(val, encoding, colormap, quote=False, **_):
escapedval = val.replace('\\', '\\\\')
def format_value_text(val, encoding, colormap, quote=False, escape_backslash=True, **_):
escapedval = val.replace('\\', '\\\\') if escape_backslash else val
if quote:
escapedval = escapedval.replace("'", "''")
escapedval = UNICODE_CONTROLCHARS_RE.sub(_show_control_chars, escapedval)
Expand All @@ -496,11 +496,13 @@ def format_value_text(val, encoding, colormap, quote=False, **_):

def format_simple_collection(val, cqltype, lbracket, rbracket, encoding,
colormap, date_time_format, float_precision, nullval,
decimal_sep, thousands_sep, boolean_styles):
decimal_sep, thousands_sep, boolean_styles,
escape_backslash=True):
subs = [format_value(sval, cqltype=stype, encoding=encoding, colormap=colormap,
date_time_format=date_time_format, float_precision=float_precision,
nullval=nullval, quote=True, decimal_sep=decimal_sep,
thousands_sep=thousands_sep, boolean_styles=boolean_styles)
nullval=nullval, quote=True, escape_backslash=escape_backslash,
decimal_sep=decimal_sep, thousands_sep=thousands_sep,
boolean_styles=boolean_styles)
for sval, stype in zip(val, cqltype.get_n_sub_types(len(val)))]
bval = lbracket + ', '.join(get_str(sval) for sval in subs) + rbracket
if colormap is NO_COLOR_MAP:
Expand All @@ -515,26 +517,29 @@ def format_simple_collection(val, cqltype, lbracket, rbracket, encoding,

@formatter_for('list')
def format_value_list(val, cqltype, encoding, colormap, date_time_format, float_precision, nullval,
decimal_sep, thousands_sep, boolean_styles, **_):
decimal_sep, thousands_sep, boolean_styles, escape_backslash=True, **_):
return format_simple_collection(val, cqltype, '[', ']', encoding, colormap,
date_time_format, float_precision, nullval,
decimal_sep, thousands_sep, boolean_styles)
decimal_sep, thousands_sep, boolean_styles,
escape_backslash=escape_backslash)


@formatter_for('tuple')
def format_value_tuple(val, cqltype, encoding, colormap, date_time_format, float_precision, nullval,
decimal_sep, thousands_sep, boolean_styles, **_):
decimal_sep, thousands_sep, boolean_styles, escape_backslash=True, **_):
return format_simple_collection(val, cqltype, '(', ')', encoding, colormap,
date_time_format, float_precision, nullval,
decimal_sep, thousands_sep, boolean_styles)
decimal_sep, thousands_sep, boolean_styles,
escape_backslash=escape_backslash)


@formatter_for('set')
def format_value_set(val, cqltype, encoding, colormap, date_time_format, float_precision, nullval,
decimal_sep, thousands_sep, boolean_styles, **_):
decimal_sep, thousands_sep, boolean_styles, escape_backslash=True, **_):
return format_simple_collection(val, cqltype, '{', '}', encoding, colormap,
date_time_format, float_precision, nullval,
decimal_sep, thousands_sep, boolean_styles)
decimal_sep, thousands_sep, boolean_styles,
escape_backslash=escape_backslash)


formatter_for('frozenset')(format_value_set)
Expand All @@ -544,12 +549,13 @@ def format_value_set(val, cqltype, encoding, colormap, date_time_format, float_p

@formatter_for('dict')
def format_value_map(val, cqltype, encoding, colormap, date_time_format, float_precision, nullval,
decimal_sep, thousands_sep, boolean_styles, **_):
decimal_sep, thousands_sep, boolean_styles, escape_backslash=True, **_):
def subformat(v, t):
return format_value(v, cqltype=t, encoding=encoding, colormap=colormap,
date_time_format=date_time_format, float_precision=float_precision,
nullval=nullval, quote=True, decimal_sep=decimal_sep,
thousands_sep=thousands_sep, boolean_styles=boolean_styles)
nullval=nullval, quote=True, escape_backslash=escape_backslash,
decimal_sep=decimal_sep, thousands_sep=thousands_sep,
boolean_styles=boolean_styles)

subs = [(subformat(k, cqltype.sub_types[0]), subformat(v, cqltype.sub_types[1])) for (k, v) in sorted(val.items())]
bval = '{' + ', '.join(get_str(k) + ': ' + get_str(v) for (k, v) in subs) + '}'
Expand All @@ -572,14 +578,15 @@ def subformat(v, t):


def format_value_utype(val, cqltype, encoding, colormap, date_time_format, float_precision, nullval,
decimal_sep, thousands_sep, boolean_styles, **_):
decimal_sep, thousands_sep, boolean_styles, escape_backslash=True, **_):
def format_field_value(v, t):
if v is None:
return colorme(nullval, colormap, 'error')
return format_value(v, cqltype=t, encoding=encoding, colormap=colormap,
date_time_format=date_time_format, float_precision=float_precision,
nullval=nullval, quote=True, decimal_sep=decimal_sep,
thousands_sep=thousands_sep, boolean_styles=boolean_styles)
nullval=nullval, quote=True, escape_backslash=escape_backslash,
decimal_sep=decimal_sep, thousands_sep=thousands_sep,
boolean_styles=boolean_styles)

def format_field_name(name):
return format_value_text(name, encoding=encoding, colormap=colormap, quote=False)
Expand Down
73 changes: 72 additions & 1 deletion pylib/cqlshlib/test/test_copyutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,18 @@
# and $CQL_TEST_PORT to the associated port.


import csv
import io
import unittest

from cassandra.metadata import MIN_LONG, Murmur3Token
from cassandra.policies import SimpleConvictionPolicy
from cassandra.pool import Host
from unittest.mock import Mock

from cqlshlib.copyutil import ExportTask
from cqlshlib.copyutil import ExportProcess, ExportTask
from cqlshlib.displaying import NO_COLOR_MAP
from cqlshlib.formatting import CqlType, DateTimeFormat, format_value_text


class CopyTaskTest(unittest.TestCase):
Expand Down Expand Up @@ -114,3 +118,70 @@ def test_get_ranges_murmur3(self):
(None, MIN_LONG + 1): {'hosts': ('10.0.0.2', '10.0.0.3', '10.0.0.4'), 'attempts': 0, 'rows': 0, 'workerno': -1}
}
self._test_get_ranges_murmur3_base({'endtoken': MIN_LONG + 1}, expected_ranges)


class TestExportFormatValue(unittest.TestCase):
"""
Unit tests for ExportProcess.format_value, the COPY TO serializer.

Regression tests for CASSANDRA-21131: text values - including text nested in
collections - must not be backslash-escaped by the display formatter on export.
The csv.writer already performs CSV-level escaping with the dialect escapechar,
so pre-escaping in format_value_text doubled every backslash and corrupted the
data on each COPY TO / COPY FROM round-trip.
"""

# CSV dialect produced from the default COPY ESCAPE / QUOTE / DELIMITER options.
# No explicit quoting is configured, so csv defaults to QUOTE_MINIMAL.
DIALECT = dict(quotechar='"', escapechar='\\', delimiter=',', doublequote=False)

def _format_value(self, val, typestring):
# Build an ExportProcess without running __init__ (which starts a
# multiprocessing.Process and opens cluster connections); set only the
# attributes that format_value reads.
proc = ExportProcess.__new__(ExportProcess)
proc.formatters = {}
proc.encoding = 'utf-8'
proc.date_time_format = DateTimeFormat()
proc.float_precision = 5
proc.double_precision = 12
proc.nullval = ''
proc.decimal_sep = '.'
proc.thousands_sep = ''
proc.boolean_styles = ['True', 'False']
return proc.format_value(val, CqlType(typestring))

def _csv_round_trip(self, formatted):
buf = io.StringIO()
csv.writer(buf, **self.DIALECT).writerow([formatted])
return next(csv.reader(io.StringIO(buf.getvalue()), **self.DIALECT))[0]

def test_scalar_text_is_not_backslash_escaped(self):
for typestring in ('text', 'varchar', 'ascii'):
self.assertEqual(self._format_value('V\\S', typestring), 'V\\S')
self.assertEqual(self._format_value('C:\\tmp\\f', typestring), 'C:\\tmp\\f')
self.assertEqual(self._format_value('\\"Marianne"\\', typestring), '\\"Marianne"\\')

def test_collection_text_is_not_backslash_escaped(self):
# The type_name of these is list/set/map/tuple, so a scalar-only check in
# format_value would miss them; the fix propagates escape_backslash through
# the collection formatters down to each text element.
self.assertEqual(self._format_value(['V\\S', 'a\\b'], 'list<text>'), "['V\\S', 'a\\b']")
self.assertEqual(self._format_value({'x\\y'}, 'set<text>'), "{'x\\y'}")
self.assertEqual(self._format_value({'k\\1': 'v\\2'}, 'map<text, text>'), "{'k\\1': 'v\\2'}")
self.assertEqual(self._format_value(('a\\b', 'c\\d'), 'tuple<text, text>'), "('a\\b', 'c\\d')")

def test_backslashes_survive_csv_round_trip(self):
# csv.writer adds exactly one layer of escaping that csv.reader removes on
# COPY FROM, so a value written by format_value comes back unchanged.
for stored in ('V\\S', 'C:\\path\\to\\file', '\\"Marianne"\\', 'plain'):
formatted = self._format_value(stored, 'text')
self.assertEqual(self._csv_round_trip(formatted), stored)

def test_display_formatting_still_escapes_backslashes(self):
# The terminal-display path must keep doubling backslashes so SELECT output
# renders them visibly; only the CSV export path opts out via
# escape_backslash=False. This is why the parameter is retained.
self.assertEqual(
format_value_text('V\\S', encoding='utf-8', colormap=NO_COLOR_MAP),
'V\\\\S')