diff --git a/pylib/cqlshlib/copyutil.py b/pylib/cqlshlib/copyutil.py index 9586486af1cd..6b9d77bd9009 100644 --- a/pylib/cqlshlib/copyutil.py +++ b/pylib/cqlshlib/copyutil.py @@ -1752,6 +1752,7 @@ def format_value(self, val, cqltype): formatted = formatter(val, cqltype=cqltype, encoding=self.encoding, colormap=NO_COLOR_MAP, date_time_format=self.date_time_format, float_precision=cqltype.precision, nullval=self.nullval, quote=False, + escape_backslash=False, decimal_sep=self.decimal_sep, thousands_sep=self.thousands_sep, boolean_styles=self.boolean_styles) return formatted diff --git a/pylib/cqlshlib/formatting.py b/pylib/cqlshlib/formatting.py index cdf36e0c5308..436b8882e965 100644 --- a/pylib/cqlshlib/formatting.py +++ b/pylib/cqlshlib/formatting.py @@ -477,8 +477,8 @@ def decode_zig_zag_64(n): @formatter_for('str') -def format_value_text(val, encoding, colormap, quote=False, **_): - escapedval = val.replace('\\', '\\\\') +def format_value_text(val, encoding, colormap, quote=False, escape_backslash=True, **_): + escapedval = val.replace('\\', '\\\\') if escape_backslash else val if quote: escapedval = escapedval.replace("'", "''") escapedval = UNICODE_CONTROLCHARS_RE.sub(_show_control_chars, escapedval) @@ -496,11 +496,13 @@ def format_value_text(val, encoding, colormap, quote=False, **_): def format_simple_collection(val, cqltype, lbracket, rbracket, encoding, colormap, date_time_format, float_precision, nullval, - decimal_sep, thousands_sep, boolean_styles): + decimal_sep, thousands_sep, boolean_styles, + escape_backslash=True): subs = [format_value(sval, cqltype=stype, encoding=encoding, colormap=colormap, date_time_format=date_time_format, float_precision=float_precision, - nullval=nullval, quote=True, decimal_sep=decimal_sep, - thousands_sep=thousands_sep, boolean_styles=boolean_styles) + nullval=nullval, quote=True, escape_backslash=escape_backslash, + decimal_sep=decimal_sep, thousands_sep=thousands_sep, + boolean_styles=boolean_styles) for sval, stype in zip(val, cqltype.get_n_sub_types(len(val)))] bval = lbracket + ', '.join(get_str(sval) for sval in subs) + rbracket if colormap is NO_COLOR_MAP: @@ -515,26 +517,29 @@ def format_simple_collection(val, cqltype, lbracket, rbracket, encoding, @formatter_for('list') def format_value_list(val, cqltype, encoding, colormap, date_time_format, float_precision, nullval, - decimal_sep, thousands_sep, boolean_styles, **_): + decimal_sep, thousands_sep, boolean_styles, escape_backslash=True, **_): return format_simple_collection(val, cqltype, '[', ']', encoding, colormap, date_time_format, float_precision, nullval, - decimal_sep, thousands_sep, boolean_styles) + decimal_sep, thousands_sep, boolean_styles, + escape_backslash=escape_backslash) @formatter_for('tuple') def format_value_tuple(val, cqltype, encoding, colormap, date_time_format, float_precision, nullval, - decimal_sep, thousands_sep, boolean_styles, **_): + decimal_sep, thousands_sep, boolean_styles, escape_backslash=True, **_): return format_simple_collection(val, cqltype, '(', ')', encoding, colormap, date_time_format, float_precision, nullval, - decimal_sep, thousands_sep, boolean_styles) + decimal_sep, thousands_sep, boolean_styles, + escape_backslash=escape_backslash) @formatter_for('set') def format_value_set(val, cqltype, encoding, colormap, date_time_format, float_precision, nullval, - decimal_sep, thousands_sep, boolean_styles, **_): + decimal_sep, thousands_sep, boolean_styles, escape_backslash=True, **_): return format_simple_collection(val, cqltype, '{', '}', encoding, colormap, date_time_format, float_precision, nullval, - decimal_sep, thousands_sep, boolean_styles) + decimal_sep, thousands_sep, boolean_styles, + escape_backslash=escape_backslash) formatter_for('frozenset')(format_value_set) @@ -544,12 +549,13 @@ def format_value_set(val, cqltype, encoding, colormap, date_time_format, float_p @formatter_for('dict') def format_value_map(val, cqltype, encoding, colormap, date_time_format, float_precision, nullval, - decimal_sep, thousands_sep, boolean_styles, **_): + decimal_sep, thousands_sep, boolean_styles, escape_backslash=True, **_): def subformat(v, t): return format_value(v, cqltype=t, encoding=encoding, colormap=colormap, date_time_format=date_time_format, float_precision=float_precision, - nullval=nullval, quote=True, decimal_sep=decimal_sep, - thousands_sep=thousands_sep, boolean_styles=boolean_styles) + nullval=nullval, quote=True, escape_backslash=escape_backslash, + decimal_sep=decimal_sep, thousands_sep=thousands_sep, + boolean_styles=boolean_styles) subs = [(subformat(k, cqltype.sub_types[0]), subformat(v, cqltype.sub_types[1])) for (k, v) in sorted(val.items())] bval = '{' + ', '.join(get_str(k) + ': ' + get_str(v) for (k, v) in subs) + '}' @@ -572,14 +578,15 @@ def subformat(v, t): def format_value_utype(val, cqltype, encoding, colormap, date_time_format, float_precision, nullval, - decimal_sep, thousands_sep, boolean_styles, **_): + decimal_sep, thousands_sep, boolean_styles, escape_backslash=True, **_): def format_field_value(v, t): if v is None: return colorme(nullval, colormap, 'error') return format_value(v, cqltype=t, encoding=encoding, colormap=colormap, date_time_format=date_time_format, float_precision=float_precision, - nullval=nullval, quote=True, decimal_sep=decimal_sep, - thousands_sep=thousands_sep, boolean_styles=boolean_styles) + nullval=nullval, quote=True, escape_backslash=escape_backslash, + decimal_sep=decimal_sep, thousands_sep=thousands_sep, + boolean_styles=boolean_styles) def format_field_name(name): return format_value_text(name, encoding=encoding, colormap=colormap, quote=False) diff --git a/pylib/cqlshlib/test/test_copyutil.py b/pylib/cqlshlib/test/test_copyutil.py index 9b30980fdb38..bc65b4e76d95 100644 --- a/pylib/cqlshlib/test/test_copyutil.py +++ b/pylib/cqlshlib/test/test_copyutil.py @@ -18,6 +18,8 @@ # and $CQL_TEST_PORT to the associated port. +import csv +import io import unittest from cassandra.metadata import MIN_LONG, Murmur3Token @@ -25,7 +27,9 @@ from cassandra.pool import Host from unittest.mock import Mock -from cqlshlib.copyutil import ExportTask +from cqlshlib.copyutil import ExportProcess, ExportTask +from cqlshlib.displaying import NO_COLOR_MAP +from cqlshlib.formatting import CqlType, DateTimeFormat, format_value_text class CopyTaskTest(unittest.TestCase): @@ -114,3 +118,70 @@ def test_get_ranges_murmur3(self): (None, MIN_LONG + 1): {'hosts': ('10.0.0.2', '10.0.0.3', '10.0.0.4'), 'attempts': 0, 'rows': 0, 'workerno': -1} } self._test_get_ranges_murmur3_base({'endtoken': MIN_LONG + 1}, expected_ranges) + + +class TestExportFormatValue(unittest.TestCase): + """ + Unit tests for ExportProcess.format_value, the COPY TO serializer. + + Regression tests for CASSANDRA-21131: text values - including text nested in + collections - must not be backslash-escaped by the display formatter on export. + The csv.writer already performs CSV-level escaping with the dialect escapechar, + so pre-escaping in format_value_text doubled every backslash and corrupted the + data on each COPY TO / COPY FROM round-trip. + """ + + # CSV dialect produced from the default COPY ESCAPE / QUOTE / DELIMITER options. + # No explicit quoting is configured, so csv defaults to QUOTE_MINIMAL. + DIALECT = dict(quotechar='"', escapechar='\\', delimiter=',', doublequote=False) + + def _format_value(self, val, typestring): + # Build an ExportProcess without running __init__ (which starts a + # multiprocessing.Process and opens cluster connections); set only the + # attributes that format_value reads. + proc = ExportProcess.__new__(ExportProcess) + proc.formatters = {} + proc.encoding = 'utf-8' + proc.date_time_format = DateTimeFormat() + proc.float_precision = 5 + proc.double_precision = 12 + proc.nullval = '' + proc.decimal_sep = '.' + proc.thousands_sep = '' + proc.boolean_styles = ['True', 'False'] + return proc.format_value(val, CqlType(typestring)) + + def _csv_round_trip(self, formatted): + buf = io.StringIO() + csv.writer(buf, **self.DIALECT).writerow([formatted]) + return next(csv.reader(io.StringIO(buf.getvalue()), **self.DIALECT))[0] + + def test_scalar_text_is_not_backslash_escaped(self): + for typestring in ('text', 'varchar', 'ascii'): + self.assertEqual(self._format_value('V\\S', typestring), 'V\\S') + self.assertEqual(self._format_value('C:\\tmp\\f', typestring), 'C:\\tmp\\f') + self.assertEqual(self._format_value('\\"Marianne"\\', typestring), '\\"Marianne"\\') + + def test_collection_text_is_not_backslash_escaped(self): + # The type_name of these is list/set/map/tuple, so a scalar-only check in + # format_value would miss them; the fix propagates escape_backslash through + # the collection formatters down to each text element. + self.assertEqual(self._format_value(['V\\S', 'a\\b'], 'list'), "['V\\S', 'a\\b']") + self.assertEqual(self._format_value({'x\\y'}, 'set'), "{'x\\y'}") + self.assertEqual(self._format_value({'k\\1': 'v\\2'}, 'map'), "{'k\\1': 'v\\2'}") + self.assertEqual(self._format_value(('a\\b', 'c\\d'), 'tuple'), "('a\\b', 'c\\d')") + + def test_backslashes_survive_csv_round_trip(self): + # csv.writer adds exactly one layer of escaping that csv.reader removes on + # COPY FROM, so a value written by format_value comes back unchanged. + for stored in ('V\\S', 'C:\\path\\to\\file', '\\"Marianne"\\', 'plain'): + formatted = self._format_value(stored, 'text') + self.assertEqual(self._csv_round_trip(formatted), stored) + + def test_display_formatting_still_escapes_backslashes(self): + # The terminal-display path must keep doubling backslashes so SELECT output + # renders them visibly; only the CSV export path opts out via + # escape_backslash=False. This is why the parameter is retained. + self.assertEqual( + format_value_text('V\\S', encoding='utf-8', colormap=NO_COLOR_MAP), + 'V\\\\S')