diff --git a/flatdata-py/flatdata/lib/data_access.py b/flatdata-py/flatdata/lib/data_access.py index 6ffd021a..025e817e 100644 --- a/flatdata-py/flatdata/lib/data_access.py +++ b/flatdata-py/flatdata/lib/data_access.py @@ -9,61 +9,71 @@ _SIGN_BITS = [0] + [(1 << (bits - 1)) for bits in range(1, 65)] -def read_value(data, offset_bits, num_bits, is_signed): - offset_bytes, offset_extra_bits = divmod(offset_bits, 8) - total_bytes = (num_bits + 7) // 8 - - if num_bits == 1: - return int((data[offset_bytes] & (1 << offset_extra_bits)) != 0) - - result = int.from_bytes(data[offset_bytes: offset_bytes + total_bytes], byteorder="little") - result >>= offset_extra_bits - if (total_bytes * 8 - offset_extra_bits) < num_bits: - remainder = data[offset_bytes + total_bytes] - result |= remainder << (total_bytes * 8 - offset_extra_bits) +def make_field_reader(offset_bits, num_bits, is_signed): + """Build a specialized closure for reading a single field from a structure. - if num_bits < 64 or offset_extra_bits > 0: - result = result & ((1 << num_bits) - 1) - - if not is_signed: - return result - - return (result & (_SIGN_BITS[num_bits] - 1)) - (result & _SIGN_BITS[num_bits]) - - -def write_value(data, offset_bits, num_bits, is_signed, value): - assert num_bits <= 64, f'Number of bits to write is greater than 64' - - offset_bytes, offset_extra_bits = divmod(offset_bits, 8) + Returns a function reader(data, pos_bytes) that reads the field value + from ``data`` at byte position ``pos_bytes``. All constants (byte offset, + bit shift, mask, sign handling) are pre-computed and captured by the + closure so the hot path does minimal work. + """ + offset_bytes, offset_extra = divmod(offset_bits, 8) total_bytes = (num_bits + 7) // 8 - - if num_bits == 1: - if value == 1: - data[offset_bytes] |= 1 << offset_extra_bits - else: - data[offset_bytes] &= ~(1 << offset_extra_bits) - return - + end_byte = offset_bytes + total_bytes mask = (1 << num_bits) - 1 - value <<= offset_extra_bits - value &= mask << offset_extra_bits - value_in_little_endian = value.to_bytes(total_bytes + 1, byteorder="little", signed=is_signed) - surrounding_bits = data[offset_bytes] & ((1 << offset_bits) - 1) + needs_extra = (total_bytes * 8 - offset_extra) < num_bits + extra_shift = total_bytes * 8 - offset_extra - byte_idx = 0 - data[offset_bytes] = value_in_little_endian[byte_idx] - data[offset_bytes] |= surrounding_bits - - byte_idx += 1 - while byte_idx < total_bytes: - data[offset_bytes + byte_idx] = value_in_little_endian[byte_idx] - byte_idx += 1 + if num_bits == 1: + bit_mask = 1 << offset_extra + def reader(data, pos): + return int((data[pos + offset_bytes] & bit_mask) != 0) + return reader - bits_written = total_bytes * 8 - offset_extra_bits - if bits_written < num_bits: - surrounding_bits = data[offset_bytes + byte_idx] & ~((1 << offset_bits) - 1) - data[offset_bytes + byte_idx] = value_in_little_endian[byte_idx] & ((1 << (8 - (bits_written % 8))) - 1) - data[offset_bytes + byte_idx] |= surrounding_bits + if is_signed: + sign_bit = _SIGN_BITS[num_bits] + sign_mask = sign_bit - 1 + if needs_extra: + def reader(data, pos): + result = int.from_bytes( + data[pos + offset_bytes: pos + end_byte], byteorder="little") + result >>= offset_extra + result |= data[pos + end_byte] << extra_shift + result &= mask + return (result & sign_mask) - (result & sign_bit) + elif offset_extra: + def reader(data, pos): + result = (int.from_bytes( + data[pos + offset_bytes: pos + end_byte], + byteorder="little") >> offset_extra) & mask + return (result & sign_mask) - (result & sign_bit) + else: + def reader(data, pos): + result = int.from_bytes( + data[pos + offset_bytes: pos + end_byte], + byteorder="little") & mask + return (result & sign_mask) - (result & sign_bit) + return reader + + # Unsigned paths + if needs_extra: + def reader(data, pos): + result = int.from_bytes( + data[pos + offset_bytes: pos + end_byte], byteorder="little") + result >>= offset_extra + result |= data[pos + end_byte] << extra_shift + return result & mask + elif offset_extra: + def reader(data, pos): + return (int.from_bytes( + data[pos + offset_bytes: pos + end_byte], + byteorder="little") >> offset_extra) & mask + else: + def reader(data, pos): + return int.from_bytes( + data[pos + offset_bytes: pos + end_byte], + byteorder="little") & mask + return reader def read_field_vectorized(raw_bytes_2d, field_offset_bits, field_width_bits, is_signed): @@ -110,3 +120,49 @@ def read_field_vectorized(raw_bytes_2d, field_offset_bits, field_width_bits, is_ result = np.where(result & sign_bit, signed, result.astype(np.int64)) return result + + +def read_value(data, offset_bits, num_bits, is_signed): + """Read a bit-packed value from data at the given bit offset. + + This is a convenience wrapper around :func:`make_field_reader` for one-off + reads. For repeated reads of the same field, prefer building a reader once + with ``make_field_reader`` and reusing it. + """ + reader = make_field_reader(offset_bits, num_bits, is_signed) + return reader(data, 0) + + +def write_value(data, offset_bits, num_bits, is_signed, value): + assert num_bits <= 64, f'Number of bits to write is greater than 64' + + offset_bytes, offset_extra_bits = divmod(offset_bits, 8) + total_bytes = (num_bits + 7) // 8 + + if num_bits == 1: + if value == 1: + data[offset_bytes] |= 1 << offset_extra_bits + else: + data[offset_bytes] &= ~(1 << offset_extra_bits) + return + + mask = (1 << num_bits) - 1 + value <<= offset_extra_bits + value &= mask << offset_extra_bits + value_in_little_endian = value.to_bytes(total_bytes + 1, byteorder="little", signed=is_signed) + surrounding_bits = data[offset_bytes] & ((1 << offset_bits) - 1) + + byte_idx = 0 + data[offset_bytes] = value_in_little_endian[byte_idx] + data[offset_bytes] |= surrounding_bits + + byte_idx += 1 + while byte_idx < total_bytes: + data[offset_bytes + byte_idx] = value_in_little_endian[byte_idx] + byte_idx += 1 + + bits_written = total_bytes * 8 - offset_extra_bits + if bits_written < num_bits: + surrounding_bits = data[offset_bytes + byte_idx] & ~((1 << offset_bits) - 1) + data[offset_bytes + byte_idx] = value_in_little_endian[byte_idx] & ((1 << (8 - (bits_written % 8))) - 1) + data[offset_bytes + byte_idx] |= surrounding_bits diff --git a/flatdata-py/flatdata/lib/structure.py b/flatdata-py/flatdata/lib/structure.py index 4b19d900..e0bdcc42 100644 --- a/flatdata-py/flatdata/lib/structure.py +++ b/flatdata-py/flatdata/lib/structure.py @@ -2,7 +2,7 @@ import json import numpy as np -from .data_access import read_value +from .data_access import make_field_reader FieldSignature = namedtuple( "FieldSignature", ["offset", "width", "is_signed", "dtype"]) @@ -10,6 +10,14 @@ class Structure: __slots__ = ('_mem', '_pos') + _READERS = {} + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + fields = cls.__dict__.get('_FIELDS') + if fields is not None: + cls._READERS = {name: make_field_reader(f.offset, f.width, f.is_signed) + for name, f in fields.items()} def __init__(self, mem, pos): self._mem = mem @@ -17,13 +25,10 @@ def __init__(self, mem, pos): def __getattr__(self, name): try: - field = self._FIELDS[name] + reader = self._READERS[name] except KeyError: raise AttributeError("Field %s not found in structure" % name) - return self._get_value(field) - - def _get_value(self, field): - return read_value(self._mem, self._pos * 8 + field.offset, field.width, field.is_signed) + return reader(self._mem, self._pos) def __dir__(self): return self._FIELD_KEYS @@ -33,20 +38,24 @@ def __iter__(self): yield getattr(self, name) def as_dict(self): - return {name: self._get_value(field) for name, field in self._FIELDS.items()} + mem, pos = self._mem, self._pos + return {name: reader(mem, pos) for name, reader in self._READERS.items()} def as_list(self): - return [self._get_value(field) for field in self._FIELDS.values()] + mem, pos = self._mem, self._pos + return [reader(mem, pos) for reader in self._READERS.values()] def as_tuple(self): - return tuple(self._get_value(field) for field in self._FIELDS.values()) + mem, pos = self._mem, self._pos + return tuple(reader(mem, pos) for reader in self._READERS.values()) @classmethod def dtype(cls): return [(name, np.dtype(field.dtype)) for name, field in cls._FIELDS.items()] def as_nparray(self): - return np.array([tuple(self._get_value(field) for name, field in self._FIELDS.items())], + mem, pos = self._mem, self._pos + return np.array([tuple(reader(mem, pos) for reader in self._READERS.values())], dtype=self.dtype()) def schema(self): diff --git a/flatdata-py/tests/test_data_access.py b/flatdata-py/tests/test_data_access.py index fa461858..356e898b 100644 --- a/flatdata-py/tests/test_data_access.py +++ b/flatdata-py/tests/test_data_access.py @@ -1,5 +1,5 @@ import pytest -from flatdata.lib.data_access import read_value, write_value +from flatdata.lib.data_access import read_value, write_value, make_field_reader def test_reader(): @@ -2264,3 +2264,21 @@ def _test_writer(buffer, offset, num_bits, is_signed, expected): _test_writer(b"\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", 3, 16, True, 8192) _test_writer(b"\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", 3, 16, True, 16384) _test_writer(b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", 3, 2, True, 0) + + +def test_make_field_reader_with_nonzero_pos(): + """Reader closures must produce correct results at arbitrary byte positions.""" + data = bytearray(20) + struct_bytes = b'\xab\xcd\xef\x12\x98\x76\x54\x32\x10' + data[0:9] = struct_bytes + data[10:19] = struct_bytes + + for offset_bits in [0, 3, 8, 13]: + for num_bits in [1, 5, 8, 16, 32, 64]: + for is_signed in [False, True]: + if offset_bits + num_bits > len(struct_bytes) * 8: + continue + reader = make_field_reader(offset_bits, num_bits, is_signed) + assert reader(data, 0) == reader(data, 10), ( + f"offset={offset_bits}, width={num_bits}, signed={is_signed}: " + f"pos=0 got {reader(data, 0)}, pos=10 got {reader(data, 10)}")