diff --git a/libdestruct/backing/fake_resolver.py b/libdestruct/backing/fake_resolver.py index 637e67f..94cc8fc 100644 --- a/libdestruct/backing/fake_resolver.py +++ b/libdestruct/backing/fake_resolver.py @@ -8,18 +8,26 @@ from libdestruct.backing.resolver import Resolver +_PAGE_SIZE = 0x1000 +_ZERO_PAGE = b"\x00" * _PAGE_SIZE + class FakeResolver(Resolver): """A class that can resolve elements in a simulated memory storage.""" def __init__(self: FakeResolver, memory: dict | None = None, address: int | None = 0, endianness: str = "little") -> None: """Initializes a basic fake resolver.""" - self.memory = memory if memory is not None else {} + self._memory = memory if memory is not None else {} self.address = address self.parent = None self.offset = None self.endianness = endianness + @property + def memory(self: FakeResolver) -> dict: + """The backing page dict. Read-only — mutate in place instead of reassigning.""" + return self._memory + def resolve_address(self: FakeResolver) -> int: """Resolves self's address, mainly used by children to determine their own address.""" if self.address is not None: @@ -48,11 +56,11 @@ def resolve(self: FakeResolver, size: int, _: int) -> bytes: result = b"" while size: - page = self.memory.get(page_address, b"\x00" * 0x1000) - page_size = min(size, 0x1000 - page_offset) + page = self.memory.get(page_address, _ZERO_PAGE) + page_size = min(size, _PAGE_SIZE - page_offset) result += page[page_offset : page_offset + page_size] size -= page_size - page_address += 0x1000 + page_address += _PAGE_SIZE page_offset = 0 return result @@ -65,11 +73,11 @@ def modify(self: FakeResolver, size: int, _: int, value: bytes) -> None: page_offset = address & 0xFFF while size: - page = self.memory.get(page_address, b"\x00" * 0x1000) - page_size = min(size, 0x1000 - page_offset) + page = self.memory.get(page_address, _ZERO_PAGE) + page_size = min(size, _PAGE_SIZE - page_offset) page = page[:page_offset] + value[:page_size] + page[page_offset + page_size :] self.memory[page_address] = page size -= page_size value = value[page_size:] - page_address += 0x1000 + page_address += _PAGE_SIZE page_offset = 0 diff --git a/libdestruct/backing/memory_resolver.py b/libdestruct/backing/memory_resolver.py index 34c559d..c892d07 100644 --- a/libdestruct/backing/memory_resolver.py +++ b/libdestruct/backing/memory_resolver.py @@ -19,12 +19,17 @@ class MemoryResolver(Resolver): def __init__(self: MemoryResolver, memory: MutableSequence, address: int | None, endianness: str = "little") -> None: """Initializes a basic memory resolver.""" - self.memory = memory + self._memory = memory self.address = address self.parent = None self.offset = None self.endianness = endianness + @property + def memory(self: MemoryResolver) -> MutableSequence: + """The backing memory buffer. Read-only — mutate in place instead of reassigning.""" + return self._memory + def resolve_address(self: MemoryResolver) -> int: """Resolves self's address, mainly used by childs to determine their own address.""" if self.address is not None: diff --git a/libdestruct/c/c_integer_types.py b/libdestruct/c/c_integer_types.py index bf76989..cc6f9d1 100644 --- a/libdestruct/c/c_integer_types.py +++ b/libdestruct/c/c_integer_types.py @@ -127,3 +127,18 @@ class c_ulong(_c_integer): signed: bool = False """Whether the long is signed.""" + + +_SIGNED_INTEGER_BY_SIZE: dict[int, type[_c_integer]] = { + 1: c_char, + 2: c_short, + 4: c_int, + 8: c_long, +} + + +def signed_integer_for_size(size: int) -> type[_c_integer]: + """Return the signed C integer type for the given byte size (1, 2, 4, or 8).""" + if size not in _SIGNED_INTEGER_BY_SIZE: + raise ValueError("The size of the field must be 1, 2, 4, or 8 bytes.") + return _SIGNED_INTEGER_BY_SIZE[size] diff --git a/libdestruct/c/c_str.py b/libdestruct/c/c_str.py index 3895a31..5a9658e 100644 --- a/libdestruct/c/c_str.py +++ b/libdestruct/c/c_str.py @@ -26,11 +26,12 @@ def count(self: c_str) -> int: def get(self: c_str, index: int = -1) -> bytes: """Return the character at the given index.""" - if (index != -1 and index < 0) or index >= self.count(): + length = self.count() + if (index != -1 and index < 0) or index >= length: raise IndexError("String index out of range.") if index == -1: - return self.resolver.resolve(self.count(), 0) + return self.resolver.resolve(length, 0) return bytes([self.resolver.resolve(index + 1, 0)[-1]]) diff --git a/libdestruct/common/enum/int_enum_field.py b/libdestruct/common/enum/int_enum_field.py index a9ead12..a64408d 100644 --- a/libdestruct/common/enum/int_enum_field.py +++ b/libdestruct/common/enum/int_enum_field.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING -from libdestruct.c.c_integer_types import c_char, c_int, c_long, c_short +from libdestruct.c.c_integer_types import signed_integer_for_size from libdestruct.common.enum.enum import enum from libdestruct.common.enum.enum_field import EnumField @@ -43,20 +43,7 @@ def __init__( self.backing_type = backing_type return - if not 0 < size <= 8: - raise ValueError("The size of the field must be between 1 and 8 bytes.") - - match size: - case 1: - self.backing_type = c_char - case 2: - self.backing_type = c_short - case 4: - self.backing_type = c_int - case 8: - self.backing_type = c_long - case _: - raise ValueError("The size of the field must be a power of 2.") + self.backing_type = signed_integer_for_size(size) def inflate(self: IntEnumField, resolver: Resolver) -> int: """Inflate the field. diff --git a/libdestruct/common/flags/int_flag_field.py b/libdestruct/common/flags/int_flag_field.py index b79d5f5..30f0ae5 100644 --- a/libdestruct/common/flags/int_flag_field.py +++ b/libdestruct/common/flags/int_flag_field.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING -from libdestruct.c.c_integer_types import c_char, c_int, c_long, c_short +from libdestruct.c.c_integer_types import signed_integer_for_size from libdestruct.common.flags.flags import flags from libdestruct.common.flags.flags_field import FlagsField @@ -36,20 +36,7 @@ def __init__( self.backing_type = backing_type return - if not 0 < size <= 8: - raise ValueError("The size of the field must be between 1 and 8 bytes.") - - match size: - case 1: - self.backing_type = c_char - case 2: - self.backing_type = c_short - case 4: - self.backing_type = c_int - case 8: - self.backing_type = c_long - case _: - raise ValueError("The size of the field must be a power of 2.") + self.backing_type = signed_integer_for_size(size) def inflate(self: IntFlagField, resolver: Resolver) -> flags: """Inflate the field.""" diff --git a/libdestruct/common/obj.py b/libdestruct/common/obj.py index 4f59bbd..b6ea0e0 100644 --- a/libdestruct/common/obj.py +++ b/libdestruct/common/obj.py @@ -145,6 +145,10 @@ def _compare_value(self: obj, other: object) -> tuple[object, object] | None: return self_val, other return None + # Restore identity hashing — Python blanks __hash__ when __eq__ is defined. + # Equality is value-based but hash is identity, so {a, b} won't dedupe equal values. + __hash__ = object.__hash__ + def __eq__(self: obj, other: object) -> bool: """Return whether the object is equal to the given value.""" pair = self._compare_value(other) diff --git a/libdestruct/common/ptr/ptr.py b/libdestruct/common/ptr/ptr.py index b550905..83aeed1 100644 --- a/libdestruct/common/ptr/ptr.py +++ b/libdestruct/common/ptr/ptr.py @@ -58,8 +58,8 @@ def __init__(self: ptr, resolver: Resolver, wrapper: type | None = None) -> None """ super().__init__(resolver) self.wrapper = wrapper - self._cached_unwrap: obj | bytes | None = None - self._cache_valid: bool = False + self._cached_unwrap: obj | None = None + self._cached_address: int | None = None self._cached_length: int | None = None def get(self: ptr) -> int: @@ -82,7 +82,7 @@ def _set(self: ptr, value: int) -> None: def invalidate(self: ptr) -> None: """Clear the cached unwrap result.""" self._cached_unwrap = None - self._cache_valid = False + self._cached_address = None self._cached_length = None def unwrap(self: ptr, length: int | None = None) -> obj | bytes: @@ -91,22 +91,22 @@ def unwrap(self: ptr, length: int | None = None) -> obj | bytes: Args: length: The length of the object in memory this points to. """ - if self._cache_valid and self._cached_length == length: - return self._cached_unwrap - address = self.get() - if self.wrapper: - if length: - raise ValueError("Length is not supported when unwrapping a pointer to a wrapper object.") - - result = self.wrapper(self.resolver.absolute_from_own(address)) - else: + if not self.wrapper: + # Bytes are a snapshot; never cache — always read live. target_resolver = self.resolver.absolute_from_own(address) - result = target_resolver.resolve(length if length is not None else 1, 0) + return target_resolver.resolve(length if length is not None else 1, 0) + if length: + raise ValueError("Length is not supported when unwrapping a pointer to a wrapper object.") + + if self._cached_unwrap is not None and self._cached_address == address and self._cached_length == length: + return self._cached_unwrap + + result = self.wrapper(self.resolver.absolute_from_own(address)) self._cached_unwrap = result - self._cache_valid = True + self._cached_address = address self._cached_length = length return result @@ -116,9 +116,6 @@ def try_unwrap(self: ptr, length: int | None = None) -> obj | bytes | None: Args: length: The length of the object in memory this points to. """ - if self._cache_valid and self._cached_length == length: - return self._cached_unwrap - address = self.get() try: @@ -152,12 +149,12 @@ def _element_size(self: ptr) -> int: def __add__(self: ptr, n: int) -> ptr: """Return a new pointer advanced by n elements.""" new_addr = self.get() + n * self._element_size - return ptr(_ArithmeticResolver(self.resolver, new_addr), self.wrapper) + return type(self)(_ArithmeticResolver(self.resolver, new_addr), self.wrapper) def __sub__(self: ptr, n: int) -> ptr: """Return a new pointer retreated by n elements.""" new_addr = self.get() - n * self._element_size - return ptr(_ArithmeticResolver(self.resolver, new_addr), self.wrapper) + return type(self)(_ArithmeticResolver(self.resolver, new_addr), self.wrapper) def __getitem__(self: ptr, n: int) -> obj: """Return the object at index n relative to this pointer.""" diff --git a/libdestruct/common/struct/struct_impl.py b/libdestruct/common/struct/struct_impl.py index 7ef8c02..2c1ac17 100644 --- a/libdestruct/common/struct/struct_impl.py +++ b/libdestruct/common/struct/struct_impl.py @@ -39,14 +39,13 @@ class struct_impl(struct): def __init__(self: struct_impl, resolver: Resolver | None = None, **kwargs: ...) -> None: """Initialize the struct implementation.""" - # If we have kwargs and the resolver is None, we provide a fake resolver if kwargs and resolver is None: resolver = FakeResolver() if not isinstance(resolver, Resolver): raise TypeError("The resolver must be a Resolver instance.") - # struct overrides the __init__ method, so we need to call the parent class __init__ method + # struct.__init__ raises by design; bypass it and call obj.__init__ directly. obj.__init__(self, resolver) object.__setattr__(self, "_struct_name", self.__class__.__name__) @@ -60,13 +59,22 @@ def __init__(self: struct_impl, resolver: Resolver | None = None, **kwargs: ...) def __getattribute__(self: struct_impl, name: str) -> object: """Return the attribute, checking struct members first to avoid collisions with obj properties.""" - # Check _members dict directly to avoid infinite recursion try: members = object.__getattribute__(self, "_members") - if name in members: - return members[name] except AttributeError: - pass + return super().__getattribute__(name) + if name in members: + return members[name] + if name == "size": + # VLA structs store _vla_fixed_offset instead of an instance size attr; + # without this, `instance.size` would fall back to the static class size + # (set by compute_own_size), missing the dynamic VLA contribution. + try: + vla_offset = object.__getattribute__(self, "_vla_fixed_offset") + except AttributeError: + pass + else: + return vla_offset + next(reversed(members.values())).size return super().__getattribute__(name) def __setattr__(self: struct_impl, name: str, value: object) -> None: @@ -81,9 +89,7 @@ def __setattr__(self: struct_impl, name: str, value: object) -> None: object.__setattr__(self, name, value) def __new__(cls: struct_impl, *args: ..., **kwargs: ...) -> Self: - """Create a new struct.""" - # Skip the __new__ method of the parent class - # struct_impl -> struct -> obj becomes struct_impl -> obj + """Create a new struct, bypassing struct.__new__ which is for the user-facing factory.""" return obj.__new__(cls) def _inflate_struct_attributes( @@ -152,13 +158,12 @@ def _inflate_struct_attributes( max_alignment = max(max_alignment, aligned) current_offset = _align_offset(current_offset, max_alignment) - # For VLA structs, size must be computed dynamically since the count - # can change at runtime. Detect VLA by duck-typing: vla_impl has a - # _count_member attribute that plain array_impl does not. + # VLA detection uses duck-typing on _count_member to avoid a circular + # import between struct_impl and vla_impl (vla_impl extends array_impl, + # which imports struct). members = object.__getattribute__(self, "_members") - last_member = list(members.values())[-1] if members else None + last_name, last_member = next(reversed(members.items()), (None, None)) if last_member is not None and hasattr(last_member, "_count_member"): - last_name = list(members.keys())[-1] object.__setattr__(self, "_vla_fixed_offset", self._member_offsets[last_name]) else: object.__setattr__(self, "size", current_offset) @@ -178,14 +183,18 @@ def _resolve_field( Either resolved_inflater or bitfield_field will be non-None (not both). explicit_offset is set when an OffsetAttribute is present. """ - # Unwrap Annotated[type, metadata...] — extract the real type and any metadata annotated_offset = None if get_origin(annotation) is Annotated: ann_args = get_args(annotation) annotation = ann_args[0] - for meta in ann_args[1:]: - if isinstance(meta, OffsetAttribute): - annotated_offset = meta.offset + ann_offsets = [m.offset for m in ann_args[1:] if isinstance(m, OffsetAttribute)] + if len(ann_offsets) > 1: + raise ValueError( + f"Field {name!r} has multiple OffsetAttribute entries in its Annotated metadata; " + f"only one is allowed.", + ) + if ann_offsets: + annotated_offset = ann_offsets[0] if name not in reference.__dict__: return inflater.inflater_for(annotation, owner=owner), None, annotated_offset @@ -197,6 +206,13 @@ def _resolve_field( if sum(isinstance(attr, Field) for attr in attrs) > 1: raise ValueError("Only one Field is allowed per attribute.") + attr_offsets = sum(isinstance(a, OffsetAttribute) for a in attrs) + if attr_offsets + (1 if annotated_offset is not None else 0) > 1: + raise ValueError( + f"Field {name!r} has multiple OffsetAttribute entries (across Annotated metadata " + f"and attribute tuple); only one is allowed.", + ) + resolved_type = None bitfield_field = None explicit_offset = annotated_offset @@ -226,6 +242,7 @@ def compute_own_size(cls: type[struct_impl], reference_type: type) -> None: bf_tracker = BitfieldTracker() aligned = getattr(reference_type, "_aligned_", False) seen_vla = False + seen_names: set[str] = set() for name, annotation, reference in iterate_annotation_chain(reference_type, terminate_at=struct): if name == "_aligned_": @@ -240,13 +257,24 @@ def compute_own_size(cls: type[struct_impl], reference_type: type) -> None: # Detect VLA from default value or subscript annotation default = getattr(reference, name, None) if hasattr(reference, name) else None is_vla = isinstance(default, VLAField) - if not is_vla and isinstance(annotation, GenericAlias): + count_field_name: str | None = None + if is_vla: + count_field_name = default.count_field + elif isinstance(annotation, GenericAlias): args = annotation.__args__ if len(args) == 2 and isinstance(args[1], str): is_vla = True + count_field_name = args[1] if is_vla: + if count_field_name is not None and count_field_name not in seen_names: + raise ValueError( + f"VLA field {name!r} references undefined count field {count_field_name!r}. " + f"The count field must be declared before the VLA in the same struct.", + ) seen_vla = True + seen_names.add(name) + resolved_type, bitfield_field, explicit_offset = struct_impl._resolve_field( name, annotation, reference, cls._inflater, owner=(None, cls), ) @@ -346,13 +374,13 @@ def freeze(self: struct_impl) -> None: super().freeze() def reset(self: struct_impl) -> None: - """Reset each member to its frozen value.""" + """Restore the struct's memory region to the bytes captured at freeze time.""" if not object.__getattribute__(self, "_frozen"): raise RuntimeError("Cannot reset a struct that has not been frozen.") - members = object.__getattribute__(self, "_members") - for member in members.values(): - member.reset() + resolver = object.__getattribute__(self, "resolver") + frozen_bytes = object.__getattribute__(self, "_frozen_struct_bytes") + resolver.modify(len(frozen_bytes), 0, frozen_bytes) def to_str(self: struct_impl, indent: int = 0) -> str: """Return a string representation of the struct.""" @@ -379,6 +407,8 @@ def __repr__(self: struct_impl) -> str: }} }}""" + __hash__ = object.__hash__ + def __eq__(self: struct_impl, value: object) -> bool: """Return whether the struct is equal to the given value.""" if not isinstance(value, struct_impl): diff --git a/libdestruct/common/union/tagged_union_field_inflater.py b/libdestruct/common/union/tagged_union_field_inflater.py index afccd4c..43c8c93 100644 --- a/libdestruct/common/union/tagged_union_field_inflater.py +++ b/libdestruct/common/union/tagged_union_field_inflater.py @@ -40,20 +40,32 @@ def tagged_union_field_inflater( struct_instance = owner[0] def inflate_with_discriminator(resolver: Resolver) -> union: - members = object.__getattribute__(struct_instance, "_members") - disc_value = members[field.discriminator].value - - if disc_value not in field.variants: - raise ValueError( - f"Unknown discriminator value {disc_value!r} for field '{field.discriminator}'. " - f"Valid values: {list(field.variants.keys())}" - ) - - variant_type = field.variants[disc_value] - variant_inflater = registry.inflater_for(variant_type) - variant = variant_inflater(resolver) - - return union(resolver, variant, field.get_size()) + # Per-instance variant cache so repeat reads of the same discriminator value + # return the same variant object (stable identity), but a discriminator change + # transparently switches to a different variant. + variant_cache: dict[object, object] = {} + + def dispatcher() -> object: + members = object.__getattribute__(struct_instance, "_members") + disc_value = members[field.discriminator].value + + if disc_value not in field.variants: + raise ValueError( + f"Unknown discriminator value {disc_value!r} for field '{field.discriminator}'. " + f"Valid values: {list(field.variants.keys())}" + ) + + cached = variant_cache.get(disc_value) + if cached is not None: + return cached + + variant_type = field.variants[disc_value] + variant_inflater = registry.inflater_for(variant_type) + variant = variant_inflater(resolver) + variant_cache[disc_value] = variant + return variant + + return union(resolver, None, field.get_size(), dispatcher=dispatcher) return inflate_with_discriminator diff --git a/libdestruct/common/union/union.py b/libdestruct/common/union/union.py index 9e73f04..af43cd2 100644 --- a/libdestruct/common/union/union.py +++ b/libdestruct/common/union/union.py @@ -11,6 +11,8 @@ from libdestruct.common.obj import obj if TYPE_CHECKING: # pragma: no cover + from collections.abc import Callable + from libdestruct.backing.resolver import Resolver @@ -18,11 +20,14 @@ class union(obj): """A union value, supporting both tagged (single active variant) and plain (all variants overlaid) modes.""" _variant: obj | None - """The single active variant (tagged union mode).""" + """The single active variant (tagged union mode, or snapshot after freeze).""" _variants: dict[str, obj] """Named variants (plain union mode).""" + _dispatcher: Callable[[], obj] | None + """For tagged unions: a callable that returns the current variant based on live discriminator value.""" + _frozen_bytes: bytes | None """The frozen bytes of the full union region.""" @@ -32,6 +37,7 @@ def __init__( variant: obj | None, max_size: int, variants: dict[str, obj] | None = None, + dispatcher: Callable[[], obj] | None = None, ) -> None: """Initialize the union. @@ -40,36 +46,49 @@ def __init__( variant: The single active variant (tagged union mode, None for plain unions). max_size: The size of the union (max of all variant sizes). variants: Named variants dict (plain union mode, None for tagged unions). + dispatcher: For tagged unions, a callable returning the current variant based on + the live discriminator value. When set, the variant is re-dispatched on every + access until the union is frozen. """ super().__init__(resolver) self._variant = variant self._variants = variants or {} + self._dispatcher = dispatcher self.size = max_size self._frozen_bytes = None + def _active_variant(self: union) -> obj | None: + """Return the currently-active variant, dispatching live when applicable.""" + if self._frozen or self._dispatcher is None: + return self._variant + return self._dispatcher() + @property def variant(self: union) -> obj | None: """Return the active variant object (tagged union mode).""" - return self._variant + return self._active_variant() def get(self: union) -> object: """Return the value of the active variant.""" - if self._variant is not None: - return self._variant.get() + active = self._active_variant() + if active is not None: + return active.get() if self._variants: return {name: v.get() for name, v in self._variants.items()} return None def _set(self: union, value: object) -> None: """Set the value of the active variant.""" - if self._variant is None: + active = self._active_variant() + if active is None: raise RuntimeError("Cannot set the value of a union without an active variant.") - self._variant._set(value) + active._set(value) def to_dict(self: union) -> object: """Return a JSON-serializable representation of the union.""" - if self._variant is not None: - return self._variant.to_dict() + active = self._active_variant() + if active is not None: + return active.to_dict() if self._variants: return {name: v.to_dict() for name, v in self._variants.items()} return None @@ -83,11 +102,17 @@ def to_bytes(self: union) -> bytes: return self.resolver.resolve(self.size, 0) def freeze(self: union) -> None: - """Freeze the union and all its variants.""" + """Freeze the union and all its variants. + + Snapshots the current variant (if dispatched live) so that frozen reads + return a consistent value even if the discriminator later changes in memory. + """ if self.resolver is not None: self._frozen_bytes = self.resolver.resolve(self.size, 0) else: self._frozen_bytes = b"\x00" * self.size + if self._dispatcher is not None: + self._variant = self._dispatcher() if self._variant is not None: self._variant.freeze() for v in self._variants.values(): @@ -96,8 +121,9 @@ def freeze(self: union) -> None: def diff(self: union) -> tuple[object, object]: """Return the difference between the frozen and current value.""" - if self._variant is not None: - return self._variant.diff() + active = self._active_variant() + if active is not None: + return active.diff() return {name: v.diff() for name, v in self._variants.items()} def reset(self: union) -> None: @@ -109,8 +135,9 @@ def reset(self: union) -> None: def to_str(self: union, indent: int = 0) -> str: """Return a string representation of the union.""" - if self._variant is not None: - return self._variant.to_str(indent) + active = self._active_variant() + if active is not None: + return active.to_str(indent) if self._variants: members = ", ".join(self._variants) return f"union({members})" @@ -126,7 +153,13 @@ def __getattr__(self: union, name: str) -> object: pass try: - variant = object.__getattribute__(self, "_variant") + frozen = object.__getattribute__(self, "_frozen") + dispatcher = object.__getattribute__(self, "_dispatcher") + variant = ( + object.__getattribute__(self, "_variant") + if frozen or dispatcher is None + else dispatcher() + ) if variant is not None: return getattr(variant, name) except AttributeError: diff --git a/test/scripts/alignment_test.py b/test/scripts/alignment_test.py index b714b5e..3e59160 100644 --- a/test/scripts/alignment_test.py +++ b/test/scripts/alignment_test.py @@ -308,3 +308,55 @@ class s_t(struct): s = s_t.from_bytes(memory) self.assertEqual(s.a.value, 0x41) self.assertEqual(s.flags.value, 5) + + +class AlignedStructTailPaddingInstanceTest(unittest.TestCase): + """Instance size must include tail padding for aligned structs.""" + + def test_instance_size_matches_class_size(self): + """size_of(instance) should equal size_of(class) for aligned structs.""" + class aligned_t(struct): + _aligned_ = True + a: c_int + b: c_char + + self.assertEqual(size_of(aligned_t), 8) + + memory = pystruct.pack("", result) + + +# ---------- union.py coverage ---------- + + +class UnionEmptyTest(unittest.TestCase): + """Empty union edge cases.""" + + def test_empty_union_get(self): + u = union(None, None, 4) + self.assertIsNone(u.get()) + + def test_empty_union_to_dict(self): + u = union(None, None, 4) + self.assertIsNone(u.to_dict()) + + def test_empty_union_to_str(self): + u = union(None, None, 4) + self.assertEqual(u.to_str(), "union(empty)") + + +class UnionSetNoVariantTest(unittest.TestCase): + """Setting a union without an active variant raises RuntimeError.""" + + def test_set_raises(self): + u = union(None, None, 4) + with self.assertRaises(RuntimeError): + u._set(42) + + +class UnionResetNoFreezeTest(unittest.TestCase): + """Resetting an unfrozen union raises RuntimeError.""" + + def test_reset_raises(self): + u = union(None, None, 4) + with self.assertRaises(RuntimeError): + u.reset() + + +class UnionNoneResolverTest(unittest.TestCase): + """Union with None resolver returns zero bytes.""" + + def test_to_bytes_none_resolver(self): + u = union(None, None, 8) + self.assertEqual(u.to_bytes(), b"\x00" * 8) + + def test_freeze_none_resolver(self): + u = union(None, None, 4) + u.freeze() + self.assertEqual(u._frozen_bytes, b"\x00" * 4) + + +class PlainUnionDiffTest(unittest.TestCase): + """Plain union diff returns per-variant diffs.""" + + def test_plain_union_diff(self): + class s_t(struct): + data: union = union_of({"i": c_int, "l": c_long}) + + memory = bytearray(8) + pystruct.pack_into(" offset 12 + memory[12:16] = (20).to_bytes(4, "little") + memory[16:24] = (0).to_bytes(8, "little") # next -> null + + node = Node.from_bytes(memory) + self.assertEqual(node.val.value, 10) + next_node = node.next.unwrap() + self.assertEqual(next_node.val.value, 20) + + +class PtrToStrFieldTest(unittest.TestCase): + """ptr to_str with Field-backed wrapper.""" + + def test_ptr_to_field_wrapper(self): + """ptr_to(c_int) creates a Field-backed wrapper that has a qualified name.""" + memory = bytearray(16) + memory[0:8] = (8).to_bytes(8, "little") + memory[8:12] = (42).to_bytes(4, "little") + + lib = inflater(memory) + p = lib.inflate(ptr_to(c_int), 0) + result = p.to_str() + self.assertIn("0x8", result) + + +# ---------- forward_ref_inflater.py coverage ---------- + + +class LazyPtrFieldUnresolvableTest(unittest.TestCase): + """_LazyPtrField when forward ref cannot be resolved → returns raw ptr.""" + + def test_unresolvable_forward_ref_returns_raw_ptr(self): + """ptr['NonExistentType'] should still inflate, just as an untyped ptr.""" + from typing import ForwardRef + from libdestruct.common.forward_ref_inflater import _LazyPtrField + + lazy = _LazyPtrField(ForwardRef("CompletelyBogusTypeThatDoesNotExist"), owner=None) + memory = bytearray(8) + memory[0:8] = (0).to_bytes(8, "little") + result = lazy.inflate(MemoryResolver(memory, 0)) + self.assertIsInstance(result, ptr) + # wrapper should be None since it couldn't resolve + self.assertIsNone(result.wrapper) + + def test_forward_ref_resolves_to_non_type(self): + """Forward ref that eval's to a non-type value returns None from _resolve_forward_ref.""" + from typing import ForwardRef + from libdestruct.common.forward_ref_inflater import _LazyPtrField + + # "42" eval's to int 42, not a type + lazy = _LazyPtrField(ForwardRef("42"), owner=None) + result = lazy._resolve_forward_ref() + self.assertIsNone(result) + + def test_forward_ref_eval_exception(self): + """Forward ref that raises during eval returns None.""" + from typing import ForwardRef + from libdestruct.common.forward_ref_inflater import _LazyPtrField + + # Valid syntax but unresolvable name → NameError during eval + lazy = _LazyPtrField(ForwardRef("NoSuchTypeAnywhere"), owner=None) + result = lazy._resolve_forward_ref() + self.assertIsNone(result) + + +class SubscriptedPtrHandlerEdgeCasesTest(unittest.TestCase): + """_subscripted_ptr_handler edge cases.""" + + def test_ptr_subscript_none_target(self): + """ptr[()] with empty args → untyped ptr field.""" + from libdestruct.common.forward_ref_inflater import _subscripted_ptr_handler + + result = _subscripted_ptr_handler(ptr, (), owner=None) + self.assertIsNotNone(result) + # Should return a PtrField.inflate bound method + memory = bytearray(8) + p = result(MemoryResolver(memory, 0)) + self.assertIsInstance(p, ptr) + self.assertIsNone(p.wrapper) + + def test_ptr_subscript_string_target(self): + """ptr['SomeString'] goes through ForwardRef path.""" + from libdestruct.common.forward_ref_inflater import _subscripted_ptr_handler + + result = _subscripted_ptr_handler(ptr, ("NonExistentType",), owner=None) + self.assertIsNotNone(result) + # Should return a _LazyPtrField.inflate bound method + memory = bytearray(8) + p = result(MemoryResolver(memory, 0)) + self.assertIsInstance(p, ptr) + + def test_ptr_subscript_non_type_non_string_target(self): + """ptr[42] (invalid target) → fallback untyped ptr.""" + from libdestruct.common.forward_ref_inflater import _subscripted_ptr_handler + + result = _subscripted_ptr_handler(ptr, (42,), owner=None) + self.assertIsNotNone(result) + memory = bytearray(8) + p = result(MemoryResolver(memory, 0)) + self.assertIsInstance(p, ptr) + self.assertIsNone(p.wrapper) + + def test_ptr_subscript_concrete_type(self): + """ptr[c_int] resolves immediately to typed ptr.""" + from libdestruct.common.forward_ref_inflater import _subscripted_ptr_handler + + result = _subscripted_ptr_handler(ptr, (c_int,), owner=None) + self.assertIsNotNone(result) + memory = bytearray(16) + memory[0:8] = (8).to_bytes(8, "little") + memory[8:12] = (42).to_bytes(4, "little") + p = result(MemoryResolver(memory, 0)) + self.assertIsInstance(p, ptr) + self.assertIsNotNone(p.wrapper) + + +class BareForwardRefInflaterTest(unittest.TestCase): + """_forward_ref_inflater for bare ForwardRef annotations.""" + + def test_ptr_forward_ref_string_parsing(self): + """ForwardRef('ptr[SomeType]') is parsed and creates a lazy ptr.""" + from typing import ForwardRef + from libdestruct.common.forward_ref_inflater import _forward_ref_inflater + + ref = ForwardRef("ptr['Node']") + result = _forward_ref_inflater(ref, type(None), owner=None) + self.assertIsNotNone(result) + # Should return a _LazyPtrField.inflate bound method + memory = bytearray(8) + p = result(MemoryResolver(memory, 0)) + self.assertIsInstance(p, ptr) + + def test_ptr_forward_ref_double_quoted(self): + """ForwardRef('ptr[\"Node\"]') with double quotes.""" + from typing import ForwardRef + from libdestruct.common.forward_ref_inflater import _forward_ref_inflater + + ref = ForwardRef('ptr["Node"]') + result = _forward_ref_inflater(ref, type(None), owner=None) + self.assertIsNotNone(result) + memory = bytearray(8) + p = result(MemoryResolver(memory, 0)) + self.assertIsInstance(p, ptr) + + def test_ptr_forward_ref_unquoted(self): + """ForwardRef('ptr[c_int]') with unquoted inner type.""" + from typing import ForwardRef + from libdestruct.common.forward_ref_inflater import _forward_ref_inflater + + ref = ForwardRef("ptr[c_int]") + result = _forward_ref_inflater(ref, type(None), owner=None) + self.assertIsNotNone(result) + memory = bytearray(8) + p = result(MemoryResolver(memory, 0)) + self.assertIsInstance(p, ptr) + + def test_non_ptr_forward_ref_raises(self): + """ForwardRef('SomeRandomThing') raises ValueError.""" + from typing import ForwardRef + from libdestruct.common.forward_ref_inflater import _forward_ref_inflater + + ref = ForwardRef("SomeRandomThing") + with self.assertRaises(ValueError) as ctx: + _forward_ref_inflater(ref, type(None), owner=None) + self.assertIn("SomeRandomThing", str(ctx.exception)) + + def test_lazy_ptr_with_owner_resolves(self): + """_LazyPtrField with owner tuple resolves types from owner's module.""" + from typing import ForwardRef + from libdestruct.common.forward_ref_inflater import _LazyPtrField + + # Create a lazy field that references c_int (available in this module's globals) + lazy = _LazyPtrField(ForwardRef("c_int"), owner=(None, type(self))) + resolved = lazy._resolve_forward_ref() + self.assertIs(resolved, c_int) + + +# ---------- flags int_flag_field.py coverage ---------- + + +class FlagsFieldSizesTest(unittest.TestCase): + """IntFlagField with different sizes.""" + + class Perms(IntFlag): + R = 1 + W = 2 + X = 4 + + def test_flags_size_1(self): + from libdestruct import flags_of + + class s_t(struct): + perms: c_int = flags_of(self.Perms, size=1) + + memory = (7).to_bytes(1, "little") + s = s_t.from_bytes(memory) + self.assertEqual(s.perms.value, self.Perms.R | self.Perms.W | self.Perms.X) + + def test_flags_size_2(self): + from libdestruct import flags_of + + class s_t(struct): + perms: c_int = flags_of(self.Perms, size=2) + + memory = (3).to_bytes(2, "little") + s = s_t.from_bytes(memory) + self.assertEqual(s.perms.value, self.Perms.R | self.Perms.W) + + def test_flags_size_8(self): + from libdestruct import flags_of + + class s_t(struct): + perms: c_int = flags_of(self.Perms, size=8) + + memory = (1).to_bytes(8, "little") + s = s_t.from_bytes(memory) + self.assertEqual(s.perms.value, self.Perms.R) + + def test_flags_invalid_size(self): + from libdestruct import flags_of + with self.assertRaises(ValueError): + flags_of(self.Perms, size=3) + + def test_flags_size_too_large(self): + from libdestruct import flags_of + with self.assertRaises(ValueError): + flags_of(self.Perms, size=9) + + +# ---------- c_str.py coverage ---------- + + +class CStrEdgeCasesTest(unittest.TestCase): + """c_str edge cases.""" + + def test_repr(self): + from libdestruct import c_str + + memory = bytearray(b"Hello\x00") + lib = inflater(memory) + s = lib.inflate(c_str, 0) + r = repr(s) + self.assertIsInstance(r, str) + + def test_negative_index_raises(self): + from libdestruct import c_str + + memory = bytearray(b"Hello\x00") + lib = inflater(memory) + s = lib.inflate(c_str, 0) + with self.assertRaises(IndexError): + s.get(-2) + + def test_set_negative_index_raises(self): + from libdestruct import c_str + + memory = bytearray(b"Hello\x00") + lib = inflater(memory) + s = lib.inflate(c_str, 0) + with self.assertRaises(IndexError): + s._set(b"X", -2) + + def test_set_full_string(self): + from libdestruct import c_str + + memory = bytearray(b"Hello\x00") + lib = inflater(memory) + s = lib.inflate(c_str, 0) + s.value = b"World" + self.assertEqual(s.value, b"World") + + +# ---------- struct_parser.py coverage ---------- + + +class StructParserEdgeCasesTest(unittest.TestCase): + """C parser edge cases for coverage.""" + + def test_enum_in_struct_raises(self): + """Enum inside struct is not yet supported, must raise TypeError.""" + from libdestruct.c.struct_parser import definition_to_type + + with self.assertRaises(TypeError): + definition_to_type("struct test { enum { A, B, C } val; };") + + def test_struct_with_pointer_member(self): + from libdestruct.c.struct_parser import definition_to_type + + t = definition_to_type("struct test { int *p; int x; };") + self.assertIn("p", t.__annotations__) + self.assertIn("x", t.__annotations__) + + def test_struct_with_array_read(self): + from libdestruct.c.struct_parser import definition_to_type + + t = definition_to_type("struct test { int arr[3]; int x; };") + memory = b"".join((i).to_bytes(4, "little") for i in [10, 20, 30, 42]) + s = t.from_bytes(memory) + self.assertEqual(s.x.value, 42) + + def test_named_struct_cached(self): + """Named structs are cached and reusable.""" + from libdestruct.c.struct_parser import definition_to_type, clear_parser_cache + + clear_parser_cache() + t = definition_to_type(""" + struct Inner { int x; }; + struct Outer { struct Inner a; int b; }; + """) + memory = (1).to_bytes(4, "little") + (2).to_bytes(4, "little") + s = t.from_bytes(memory) + self.assertEqual(s.b.value, 2) + clear_parser_cache() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/scripts/ctypes_test.py b/test/scripts/ctypes_integration_test.py similarity index 100% rename from test/scripts/ctypes_test.py rename to test/scripts/ctypes_integration_test.py diff --git a/test/scripts/endianness_test.py b/test/scripts/endianness_test.py index 72f53fd..b8aead0 100644 --- a/test/scripts/endianness_test.py +++ b/test/scripts/endianness_test.py @@ -264,5 +264,34 @@ class s_t(struct): self.assertEqual(memory, pystruct.pack(">I", 5)) +class EndiannessValidationTest(unittest.TestCase): + """inflater() should reject invalid endianness strings.""" + + def test_invalid_endianness_raises(self): + """Passing a typo like 'big-endian' must raise ValueError, not silently produce wrong results.""" + with self.assertRaises(ValueError): + inflater(bytearray(4), endianness="big-endian") + + def test_invalid_endianness_typo(self): + """A random typo must raise ValueError.""" + with self.assertRaises(ValueError): + inflater(bytearray(4), endianness="typo") + + def test_valid_endianness_big(self): + """'big' is accepted without error.""" + lib = inflater(bytearray(4), endianness="big") + self.assertIsNotNone(lib) + + def test_valid_endianness_little(self): + """'little' is accepted without error.""" + lib = inflater(bytearray(4), endianness="little") + self.assertIsNotNone(lib) + + def test_from_bytes_invalid_endianness(self): + """from_bytes with invalid endianness must raise ValueError.""" + with self.assertRaises(ValueError): + c_int.from_bytes(b"\x00\x00\x00\x00", endianness="big-endian") + + if __name__ == "__main__": unittest.main() diff --git a/test/scripts/enum_test.py b/test/scripts/enum_test.py index bd283ac..6807a27 100644 --- a/test/scripts/enum_test.py +++ b/test/scripts/enum_test.py @@ -8,6 +8,8 @@ from enum import Enum, IntEnum from libdestruct import inflater, c_int, enum, enum_of, struct +from libdestruct.backing.memory_resolver import MemoryResolver +from libdestruct.common.enum.enum import enum as ld_enum class EnumTest(unittest.TestCase): def test_enum(self): @@ -131,3 +133,51 @@ class test_t(struct): result = bytes(test) self.assertIsInstance(result, bytes) + + +class EnumLenientSetTest(unittest.TestCase): + """enum._set must handle raw ints from lenient mode without crashing.""" + + def test_set_raw_int_from_lenient_get(self): + """Setting back a raw int obtained from lenient get() should work.""" + class Color(IntEnum): + RED = 0 + GREEN = 1 + + memory = bytearray((99).to_bytes(4, "little")) + e = ld_enum(MemoryResolver(memory, 0), Color, c_int, lenient=True) + + val = e.get() + self.assertEqual(val, 99) + self.assertIsInstance(val, int) + self.assertNotIsInstance(val, IntEnum) + + e.value = val + self.assertEqual(e.get(), 99) + + def test_set_enum_member_still_works(self): + """Setting a valid enum member should still work.""" + class Color(IntEnum): + RED = 0 + GREEN = 1 + + memory = bytearray(4) + e = ld_enum(MemoryResolver(memory, 0), Color, c_int, lenient=True) + + e.value = Color.GREEN + self.assertEqual(e.get(), Color.GREEN) + + def test_reset_after_freeze_with_unknown_value(self): + """freeze() + reset() with unknown enum value should not crash.""" + class Color(IntEnum): + RED = 0 + GREEN = 1 + + memory = bytearray((99).to_bytes(4, "little")) + e = ld_enum(MemoryResolver(memory, 0), Color, c_int, lenient=True) + + e.freeze() + memory[0:4] = (0).to_bytes(4, "little") + + e.reset() + self.assertEqual(e.get(), 99) diff --git a/test/scripts/resolver_unit_test.py b/test/scripts/resolver_test.py similarity index 67% rename from test/scripts/resolver_unit_test.py rename to test/scripts/resolver_test.py index 7c3a54e..14625ab 100644 --- a/test/scripts/resolver_unit_test.py +++ b/test/scripts/resolver_test.py @@ -4,11 +4,13 @@ # Licensed under the MIT license. See LICENSE file in the project root for details. # +import inspect import unittest from libdestruct import c_int, c_str, c_uint, inflater, struct, ptr, ptr_to_self from libdestruct.backing.fake_resolver import FakeResolver from libdestruct.backing.memory_resolver import MemoryResolver +from libdestruct.backing.resolver import Resolver class FakeResolverTest(unittest.TestCase): @@ -107,5 +109,43 @@ def test_c_uint_write_to_bytearray(self): self.assertEqual(obj.value, 0xDEADBEEF) +class ResolverParameterNameTest(unittest.TestCase): + """Resolver method parameter names must match documentation.""" + + def test_resolve_parameter_name_is_index(self): + """Resolver.resolve second parameter should be 'index', not 'offset'.""" + sig = inspect.signature(Resolver.resolve) + params = list(sig.parameters.keys()) + self.assertEqual(params[2], "index") + + def test_relative_from_own_parameter_names(self): + """Resolver.relative_from_own parameters should be address_offset, index_offset.""" + sig = inspect.signature(Resolver.relative_from_own) + params = list(sig.parameters.keys()) + self.assertEqual(params[1], "address_offset") + self.assertEqual(params[2], "index_offset") + + +class MemoryReadOnlyTest(unittest.TestCase): + """Reassigning a resolver's memory is not supported and must raise.""" + + def test_memory_resolver_reassignment_raises(self): + resolver = MemoryResolver(bytearray(8), 0) + with self.assertRaises(AttributeError): + resolver.memory = bytearray(8) + + def test_fake_resolver_reassignment_raises(self): + resolver = FakeResolver() + with self.assertRaises(AttributeError): + resolver.memory = {} + + def test_memory_resolver_in_place_mutation_still_works(self): + """Sanity: in-place mutation of the underlying buffer is the supported path.""" + memory = bytearray((1).to_bytes(4, "little")) + resolver = MemoryResolver(memory, 0) + memory[0:4] = (42).to_bytes(4, "little") + self.assertEqual(int.from_bytes(resolver.resolve(4, 0), "little"), 42) + + if __name__ == "__main__": unittest.main() diff --git a/test/scripts/review_fix_test.py b/test/scripts/review_fix_test.py deleted file mode 100644 index 606f1ba..0000000 --- a/test/scripts/review_fix_test.py +++ /dev/null @@ -1,209 +0,0 @@ -# -# This file is part of libdestruct (https://github.com/mrindeciso/libdestruct). -# Copyright (c) 2026 Roberto Alessandro Bertolini. All rights reserved. -# Licensed under the MIT license. See LICENSE file in the project root for details. -# - -"""Tests that expose bugs found during code review of the dev branch.""" - -import struct as pystruct -import unittest - -from libdestruct import c_float, c_double, c_int, c_long, inflater, struct -from libdestruct.common.union import union, union_of - - -class EndiannessValidationTest(unittest.TestCase): - """inflater() should reject invalid endianness strings.""" - - def test_invalid_endianness_raises(self): - """Passing a typo like 'big-endian' must raise ValueError, not silently produce wrong results.""" - with self.assertRaises(ValueError): - inflater(bytearray(4), endianness="big-endian") - - def test_invalid_endianness_typo(self): - """A random typo must raise ValueError.""" - with self.assertRaises(ValueError): - inflater(bytearray(4), endianness="typo") - - def test_valid_endianness_big(self): - """'big' is accepted without error.""" - lib = inflater(bytearray(4), endianness="big") - self.assertIsNotNone(lib) - - def test_valid_endianness_little(self): - """'little' is accepted without error.""" - lib = inflater(bytearray(4), endianness="little") - self.assertIsNotNone(lib) - - def test_from_bytes_invalid_endianness(self): - """from_bytes with invalid endianness must raise ValueError.""" - with self.assertRaises(ValueError): - c_int.from_bytes(b"\x00\x00\x00\x00", endianness="big-endian") - - -class StructAttributeCollisionTest(unittest.TestCase): - """Struct members named after internal attributes must not break core methods.""" - - def test_struct_with_frozen_field_to_bytes(self): - """A struct with a field named '_frozen' must still serialize correctly after freeze.""" - # This field name collides with obj._frozen used in to_bytes() - class s_t(struct): - _frozen: c_int - b: c_int - - memory = b"" - memory += (10).to_bytes(4, "little") - memory += (20).to_bytes(4, "little") - - s = s_t.from_bytes(memory) - # to_bytes must return the correct serialized data, not crash - self.assertEqual(s.to_bytes(), memory) - - def test_struct_with_frozen_field_hexdump(self): - """A struct with a '_frozen' field must still produce a hexdump.""" - class s_t(struct): - _frozen: c_int - - s = s_t.from_bytes((42).to_bytes(4, "little")) - # hexdump must not crash - dump = s.hexdump() - self.assertIn("2a", dump) - - def test_struct_with_members_field_eq(self): - """A struct with a field named '_members' must still support equality.""" - class s_t(struct): - _members: c_int - - a = s_t.from_bytes((1).to_bytes(4, "little")) - b = s_t.from_bytes((1).to_bytes(4, "little")) - self.assertEqual(a, b) - - def test_struct_with_members_field_to_dict(self): - """A struct with a field named '_members' must still support to_dict.""" - class s_t(struct): - _members: c_int - - s = s_t.from_bytes((5).to_bytes(4, "little")) - d = s.to_dict() - self.assertEqual(d["_members"], 5) - - def test_struct_with_frozen_struct_bytes_field(self): - """A field named '_frozen_struct_bytes' must not break freeze/to_bytes.""" - class s_t(struct): - _frozen_struct_bytes: c_int - - memory = (99).to_bytes(4, "little") - s = s_t.from_bytes(memory) - self.assertEqual(s.to_bytes(), memory) - - -class UnionGetAttrSafetyTest(unittest.TestCase): - """union.__getattr__ must produce clear errors, not internal AttributeError.""" - - def test_missing_attribute_error_message(self): - """Accessing a nonexistent attribute on a union should mention the attribute name, not '_variants'.""" - u = union(None, None, 4) - with self.assertRaises(AttributeError) as ctx: - _ = u.nonexistent_attr - # The error message must mention the user's attribute, not internal implementation details - self.assertIn("nonexistent_attr", str(ctx.exception)) - self.assertNotIn("_variants", str(ctx.exception)) - - def test_getattr_after_del_variants(self): - """Even if _variants is somehow missing, __getattr__ should not expose internal details.""" - u = union(None, None, 4) - del u.__dict__["_variants"] - with self.assertRaises(AttributeError) as ctx: - _ = u.something - self.assertIn("something", str(ctx.exception)) - - -class FloatDuplicationRegressionTest(unittest.TestCase): - """After refactoring c_float/c_double to a shared base, core behavior must be preserved.""" - - def test_c_float_read_write(self): - memory = bytearray(4) - lib = inflater(memory) - f = lib.inflate(c_float, 0) - f.value = 3.14 - self.assertAlmostEqual(f.value, 3.14, places=5) - - def test_c_double_read_write(self): - memory = bytearray(8) - lib = inflater(memory) - d = lib.inflate(c_double, 0) - d.value = 2.718281828 - self.assertAlmostEqual(d.value, 2.718281828, places=8) - - def test_c_float_freeze_diff_reset(self): - memory = bytearray(4) - lib = inflater(memory) - f = lib.inflate(c_float, 0) - f.value = 1.5 - f.freeze() - self.assertAlmostEqual(f.value, 1.5, places=5) - with self.assertRaises(ValueError): - f.value = 2.0 - - def test_c_double_freeze_diff_reset(self): - memory = bytearray(8) - lib = inflater(memory) - d = lib.inflate(c_double, 0) - d.value = 1.5 - d.freeze() - self.assertAlmostEqual(d.value, 1.5, places=5) - with self.assertRaises(ValueError): - d.value = 2.0 - - def test_c_float_from_bytes(self): - data = pystruct.pack("f", 3.14) - f = c_float.from_bytes(original, endianness="big") - self.assertAlmostEqual(f.value, 3.14, places=5) - self.assertEqual(f.to_bytes(), original) - - def test_c_double_big_endian(self): - original = pystruct.pack(">d", 2.718) - d = c_double.from_bytes(original, endianness="big") - self.assertAlmostEqual(d.value, 2.718, places=3) - self.assertEqual(d.to_bytes(), original) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/scripts/review_fix_test_2.py b/test/scripts/review_fix_test_2.py deleted file mode 100644 index bf09ce3..0000000 --- a/test/scripts/review_fix_test_2.py +++ /dev/null @@ -1,289 +0,0 @@ -# -# This file is part of libdestruct (https://github.com/mrindeciso/libdestruct). -# Copyright (c) 2026 Roberto Alessandro Bertolini. All rights reserved. -# Licensed under the MIT license. See LICENSE file in the project root for details. -# - -"""Tests that expose bugs found during the second-pass code review.""" - -import unittest - -from libdestruct import ( - array, - bitfield_of, - c_int, - c_uint, - inflater, - size_of, - struct, -) -from libdestruct.c.struct_parser import clear_parser_cache, definition_to_type - - -class ComparisonOperatorSafetyTest(unittest.TestCase): - """Comparison operators must not raise TypeError for incompatible obj types.""" - - def test_lt_primitive_vs_struct_returns_not_implemented(self): - """c_int < struct should return NotImplemented, not raise TypeError.""" - class s_t(struct): - x: c_int - - memory = bytearray(4) - lib = inflater(memory) - val = lib.inflate(c_int, 0) - s = lib.inflate(s_t, 0) - - # Must not raise TypeError - result = val.__lt__(s) - self.assertIs(result, NotImplemented) - - def test_gt_primitive_vs_struct_returns_not_implemented(self): - class s_t(struct): - x: c_int - - memory = bytearray(4) - lib = inflater(memory) - val = lib.inflate(c_int, 0) - s = lib.inflate(s_t, 0) - - result = val.__gt__(s) - self.assertIs(result, NotImplemented) - - def test_le_primitive_vs_struct_returns_not_implemented(self): - class s_t(struct): - x: c_int - - memory = bytearray(4) - lib = inflater(memory) - val = lib.inflate(c_int, 0) - s = lib.inflate(s_t, 0) - - result = val.__le__(s) - self.assertIs(result, NotImplemented) - - def test_ge_primitive_vs_struct_returns_not_implemented(self): - class s_t(struct): - x: c_int - - memory = bytearray(4) - lib = inflater(memory) - val = lib.inflate(c_int, 0) - s = lib.inflate(s_t, 0) - - result = val.__ge__(s) - self.assertIs(result, NotImplemented) - - def test_eq_primitive_vs_struct_returns_not_implemented(self): - class s_t(struct): - x: c_int - - memory = bytearray(4) - lib = inflater(memory) - val = lib.inflate(c_int, 0) - s = lib.inflate(s_t, 0) - - result = val.__eq__(s) - self.assertIs(result, NotImplemented) - - def test_ne_primitive_vs_struct_returns_not_implemented(self): - class s_t(struct): - x: c_int - - memory = bytearray(4) - lib = inflater(memory) - val = lib.inflate(c_int, 0) - s = lib.inflate(s_t, 0) - - result = val.__ne__(s) - self.assertIs(result, NotImplemented) - - def test_lt_between_compatible_primitives_works(self): - """Comparisons between compatible primitives should still work.""" - memory = bytearray(8) - lib = inflater(memory) - a = lib.inflate(c_int, 0) - b = lib.inflate(c_int, 4) - a.value = 1 - b.value = 2 - - self.assertTrue(a < b) - self.assertFalse(b < a) - - def test_comparison_with_raw_int(self): - memory = bytearray(4) - lib = inflater(memory) - a = lib.inflate(c_int, 0) - a.value = 5 - - self.assertTrue(a < 10) - self.assertTrue(a > 2) - self.assertTrue(a <= 5) - self.assertTrue(a >= 5) - - -class NegativeArrayCountTest(unittest.TestCase): - """array[T, N] must reject non-positive counts at handler time.""" - - def test_negative_count_raises(self): - """array[c_int, -5] must raise ValueError.""" - with self.assertRaises(ValueError): - class s_t(struct): - data: array[c_int, -5] - # Force size computation - size_of(s_t) - - def test_zero_count_raises(self): - """array[c_int, 0] must raise ValueError.""" - with self.assertRaises(ValueError): - class s_t(struct): - data: array[c_int, 0] - size_of(s_t) - - def test_positive_count_works(self): - """array[c_int, 3] must work fine.""" - class s_t(struct): - data: array[c_int, 3] - self.assertEqual(size_of(s_t), 12) - - -class BitfieldFreezeSafetyTest(unittest.TestCase): - """Frozen bitfields must reject writes even for non-owners.""" - - def test_non_owner_bitfield_rejects_write_after_freeze(self): - """The second bitfield in a group (non-owner) must reject writes when frozen.""" - class s_t(struct): - a: c_uint = bitfield_of(c_uint, 1) - b: c_uint = bitfield_of(c_uint, 1) - - memory = bytearray(4) - lib = inflater(memory) - s = lib.inflate(s_t, 0) - - s.a.value = 1 - s.b.value = 1 - - # Freeze the entire struct (which freezes all members) - s.freeze() - - # Both bitfields should reject writes - with self.assertRaises(ValueError): - s.a.value = 0 - - with self.assertRaises(ValueError): - s.b.value = 0 - - def test_individually_frozen_non_owner_rejects_write(self): - """Freezing a non-owner bitfield individually must also reject writes.""" - class s_t(struct): - a: c_uint = bitfield_of(c_uint, 1) - b: c_uint = bitfield_of(c_uint, 1) - - memory = bytearray(4) - lib = inflater(memory) - s = lib.inflate(s_t, 0) - - s.b.value = 1 - - # Freeze only the non-owner bitfield b - s.b.freeze() - - with self.assertRaises(ValueError): - s.b.value = 0 - - -class TypeRegistryDeduplicationTest(unittest.TestCase): - """Repeated handler registration must not accumulate duplicates.""" - - def test_generic_handler_not_duplicated(self): - """Registering the same handler twice must not produce duplicate entries.""" - from libdestruct.common.type_registry import TypeRegistry - - registry = TypeRegistry() - - class DummyType: - pass - - def dummy_handler(item, args, owner): - return None - - initial_count = len(registry.generic_handlers.get(DummyType, [])) - - registry.register_generic_handler(DummyType, dummy_handler) - registry.register_generic_handler(DummyType, dummy_handler) - - count = len(registry.generic_handlers[DummyType]) - self.assertEqual(count, initial_count + 1) - - def test_instance_handler_not_duplicated(self): - """Registering the same instance handler twice must not produce duplicate entries.""" - from libdestruct.common.type_registry import TypeRegistry - - registry = TypeRegistry() - - class DummyField: - pass - - def dummy_handler(item, annotation, owner): - return None - - initial_count = len(registry.instance_handlers.get(DummyField, [])) - - registry.register_instance_handler(DummyField, dummy_handler) - registry.register_instance_handler(DummyField, dummy_handler) - - count = len(registry.instance_handlers[DummyField]) - self.assertEqual(count, initial_count + 1) - - def test_type_handler_not_duplicated(self): - """Registering the same type handler twice must not produce duplicate entries.""" - from libdestruct.common.type_registry import TypeRegistry - - registry = TypeRegistry() - - class DummyParent: - pass - - def dummy_handler(item): - return None - - initial_count = len(registry.type_handlers.get(DummyParent, [])) - - registry.register_type_handler(DummyParent, dummy_handler) - registry.register_type_handler(DummyParent, dummy_handler) - - count = len(registry.type_handlers[DummyParent]) - self.assertEqual(count, initial_count + 1) - - -class ForwardTypedefTest(unittest.TestCase): - """Forward typedef references are a known parser limitation.""" - - def setUp(self): - clear_parser_cache() - - def tearDown(self): - clear_parser_cache() - - def test_chained_typedefs_in_order(self): - """Chained typedefs in declaration order must work.""" - t = definition_to_type(""" - typedef unsigned int u32; - typedef u32 mytype; - struct S { mytype x; }; - """) - data = (42).to_bytes(4, "little") - s = t.from_bytes(data) - self.assertEqual(s.x.value, 42) - - def test_forward_typedef_reference_raises(self): - """Forward typedef reference (use before define) must raise a clear error, not crash.""" - with self.assertRaises((ValueError, TypeError)): - definition_to_type(""" - typedef mytype1 mytype2; - typedef unsigned int mytype1; - struct S { mytype2 x; }; - """) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/scripts/string_test.py b/test/scripts/string_integration_test.py similarity index 100% rename from test/scripts/string_test.py rename to test/scripts/string_integration_test.py diff --git a/test/scripts/struct_parser_unit_test.py b/test/scripts/struct_parser_test.py similarity index 57% rename from test/scripts/struct_parser_unit_test.py rename to test/scripts/struct_parser_test.py index 3593306..376fe2b 100644 --- a/test/scripts/struct_parser_unit_test.py +++ b/test/scripts/struct_parser_test.py @@ -7,7 +7,7 @@ import struct as pystruct import unittest -from libdestruct.c.struct_parser import definition_to_type +from libdestruct.c.struct_parser import clear_parser_cache, definition_to_type, PARSED_STRUCTS from libdestruct import inflater @@ -88,5 +88,61 @@ def test_typedef_inflate_and_read(self): self.assertEqual(s.y.value, -42) +class AnonymousStructCacheTest(unittest.TestCase): + """Anonymous structs must not pollute the parser cache with a None key.""" + + def test_none_key_not_in_cache(self): + """Anonymous struct should not pollute PARSED_STRUCTS with a None key.""" + PARSED_STRUCTS.clear() + + definition_to_type("struct { int x; };") + + self.assertNotIn(None, PARSED_STRUCTS) + + +class UnsizedArrayMemberTest(unittest.TestCase): + """Parser must handle flexible array members (e.g. int data[]) gracefully.""" + + def test_unsized_array_member(self): + """Parsing a struct with a flexible array member should not crash.""" + try: + result = definition_to_type("struct test { int count; int data[]; };") + self.assertTrue(hasattr(result, '__annotations__')) + except (ValueError, TypeError): + pass + except AttributeError: + self.fail("arr_to_type crashed with AttributeError on unsized array - should handle gracefully") + + +class ForwardTypedefTest(unittest.TestCase): + """Forward typedef references are a known parser limitation.""" + + def setUp(self): + clear_parser_cache() + + def tearDown(self): + clear_parser_cache() + + def test_chained_typedefs_in_order(self): + """Chained typedefs in declaration order must work.""" + t = definition_to_type(""" + typedef unsigned int u32; + typedef u32 mytype; + struct S { mytype x; }; + """) + data = (42).to_bytes(4, "little") + s = t.from_bytes(data) + self.assertEqual(s.x.value, 42) + + def test_forward_typedef_reference_raises(self): + """Forward typedef reference (use before define) must raise a clear error, not crash.""" + with self.assertRaises((ValueError, TypeError)): + definition_to_type(""" + typedef mytype1 mytype2; + typedef unsigned int mytype1; + struct S { mytype2 x; }; + """) + + if __name__ == "__main__": unittest.main() diff --git a/test/scripts/struct_unit_test.py b/test/scripts/struct_test.py similarity index 80% rename from test/scripts/struct_unit_test.py rename to test/scripts/struct_test.py index 3d8f704..a99a784 100644 --- a/test/scripts/struct_unit_test.py +++ b/test/scripts/struct_test.py @@ -651,5 +651,144 @@ class flags_t(struct): self.assertIn("execute", dump) +class StructAttributeCollisionTest(unittest.TestCase): + """Struct members named after internal attributes must not break core methods.""" + + def test_struct_with_frozen_field_to_bytes(self): + """A struct with a field named '_frozen' must still serialize correctly after freeze.""" + class s_t(struct): + _frozen: c_int + b: c_int + + memory = b"" + memory += (10).to_bytes(4, "little") + memory += (20).to_bytes(4, "little") + + s = s_t.from_bytes(memory) + self.assertEqual(s.to_bytes(), memory) + + def test_struct_with_frozen_field_hexdump(self): + """A struct with a '_frozen' field must still produce a hexdump.""" + class s_t(struct): + _frozen: c_int + + s = s_t.from_bytes((42).to_bytes(4, "little")) + dump = s.hexdump() + self.assertIn("2a", dump) + + def test_struct_with_members_field_eq(self): + """A struct with a field named '_members' must still support equality.""" + class s_t(struct): + _members: c_int + + a = s_t.from_bytes((1).to_bytes(4, "little")) + b = s_t.from_bytes((1).to_bytes(4, "little")) + self.assertEqual(a, b) + + def test_struct_with_members_field_to_dict(self): + """A struct with a field named '_members' must still support to_dict.""" + class s_t(struct): + _members: c_int + + s = s_t.from_bytes((5).to_bytes(4, "little")) + d = s.to_dict() + self.assertEqual(d["_members"], 5) + + def test_struct_with_frozen_struct_bytes_field(self): + """A field named '_frozen_struct_bytes' must not break freeze/to_bytes.""" + class s_t(struct): + _frozen_struct_bytes: c_int + + memory = (99).to_bytes(4, "little") + s = s_t.from_bytes(memory) + self.assertEqual(s.to_bytes(), memory) + + +class StructResetWithCompositesTest(unittest.TestCase): + """struct.reset() must work for any composite member shape, including arrays.""" + + def test_reset_struct_with_array(self): + from libdestruct.backing.memory_resolver import MemoryResolver + + class S(struct): + arr: array[c_int, 3] + + memory = bytearray((1).to_bytes(4, "little") + (2).to_bytes(4, "little") + (3).to_bytes(4, "little")) + s = S(MemoryResolver(memory, 0)) + s.freeze() + memory[0:4] = (99).to_bytes(4, "little") + memory[4:8] = (98).to_bytes(4, "little") + s.reset() + self.assertEqual(int.from_bytes(memory[0:4], "little"), 1) + self.assertEqual(int.from_bytes(memory[4:8], "little"), 2) + + def test_reset_struct_with_nested_array_of_struct(self): + from libdestruct.backing.memory_resolver import MemoryResolver + + class Inner(struct): + x: c_int + + class Outer(struct): + items: array[Inner, 2] + + memory = bytearray((10).to_bytes(4, "little") + (20).to_bytes(4, "little")) + s = Outer(MemoryResolver(memory, 0)) + s.freeze() + memory[0:4] = (999).to_bytes(4, "little") + s.reset() + self.assertEqual(int.from_bytes(memory[0:4], "little"), 10) + + +class AnnotatedMultipleOffsetTest(unittest.TestCase): + """Multiple OffsetAttribute on a single field must raise instead of silently using one.""" + + def test_two_offsets_in_annotated_raises(self): + from libdestruct.common.attributes.offset_attribute import OffsetAttribute + from libdestruct.backing.memory_resolver import MemoryResolver + + class S(struct): + pad: c_int + field: Annotated[c_int, OffsetAttribute(4), OffsetAttribute(8)] + + with self.assertRaises(ValueError): + S(MemoryResolver(bytearray(16), 0)) + + def test_offset_in_annotated_and_attribute_tuple_raises(self): + from libdestruct.common.attributes.offset_attribute import OffsetAttribute + from libdestruct.backing.memory_resolver import MemoryResolver + + class S(struct): + pad: c_int + field: Annotated[c_int, OffsetAttribute(4)] = OffsetAttribute(8) + + with self.assertRaises(ValueError): + S(MemoryResolver(bytearray(16), 0)) + + +class VLAUndefinedCountFieldTest(unittest.TestCase): + """VLA referencing a count field that doesn't exist must fail at struct definition (inflation).""" + + def test_undefined_count_field_subscript_form(self): + from libdestruct.backing.memory_resolver import MemoryResolver + + class BadVLA(struct): + n: c_int + data: array[c_int, "nonexistent_field"] + + with self.assertRaises(ValueError): + BadVLA(MemoryResolver(bytearray(16), 0)) + + def test_undefined_count_field_descriptor_form(self): + from libdestruct.backing.memory_resolver import MemoryResolver + from libdestruct.common.array.vla_of import vla_of + + class BadVLA(struct): + n: c_int + data: array = vla_of(c_int, "nonexistent_field") + + with self.assertRaises(ValueError): + BadVLA(MemoryResolver(bytearray(16), 0)) + + if __name__ == "__main__": unittest.main() diff --git a/test/scripts/tagged_union_test.py b/test/scripts/tagged_union_test.py index 86ae62c..fa7590b 100644 --- a/test/scripts/tagged_union_test.py +++ b/test/scripts/tagged_union_test.py @@ -7,7 +7,7 @@ import struct as pystruct import unittest -from libdestruct import c_float, c_int, c_long, inflater, size_of, struct +from libdestruct import c_char, c_float, c_int, c_long, inflater, size_of, struct from libdestruct.common.union import tagged_union, union, union_of @@ -223,3 +223,111 @@ class s_t(struct): pystruct.pack_into("f", 3.14) + f = c_float.from_bytes(original, endianness="big") + self.assertAlmostEqual(f.value, 3.14, places=5) + self.assertEqual(f.to_bytes(), original) + + def test_c_double_big_endian(self): + original = pystruct.pack(">d", 2.718) + d = c_double.from_bytes(original, endianness="big") + self.assertAlmostEqual(d.value, 2.718, places=3) + self.assertEqual(d.to_bytes(), original) + + +class ComparisonOperatorSafetyTest(unittest.TestCase): + """Comparison operators must not raise TypeError for incompatible obj types.""" + + def test_lt_primitive_vs_struct_returns_not_implemented(self): + """c_int < struct should return NotImplemented, not raise TypeError.""" + class s_t(struct): + x: c_int + + memory = bytearray(4) + lib = inflater(memory) + val = lib.inflate(c_int, 0) + s = lib.inflate(s_t, 0) + + result = val.__lt__(s) + self.assertIs(result, NotImplemented) + + def test_gt_primitive_vs_struct_returns_not_implemented(self): + class s_t(struct): + x: c_int + + memory = bytearray(4) + lib = inflater(memory) + val = lib.inflate(c_int, 0) + s = lib.inflate(s_t, 0) + + result = val.__gt__(s) + self.assertIs(result, NotImplemented) + + def test_le_primitive_vs_struct_returns_not_implemented(self): + class s_t(struct): + x: c_int + + memory = bytearray(4) + lib = inflater(memory) + val = lib.inflate(c_int, 0) + s = lib.inflate(s_t, 0) + + result = val.__le__(s) + self.assertIs(result, NotImplemented) + + def test_ge_primitive_vs_struct_returns_not_implemented(self): + class s_t(struct): + x: c_int + + memory = bytearray(4) + lib = inflater(memory) + val = lib.inflate(c_int, 0) + s = lib.inflate(s_t, 0) + + result = val.__ge__(s) + self.assertIs(result, NotImplemented) + + def test_eq_primitive_vs_struct_returns_not_implemented(self): + class s_t(struct): + x: c_int + + memory = bytearray(4) + lib = inflater(memory) + val = lib.inflate(c_int, 0) + s = lib.inflate(s_t, 0) + + result = val.__eq__(s) + self.assertIs(result, NotImplemented) + + def test_ne_primitive_vs_struct_returns_not_implemented(self): + class s_t(struct): + x: c_int + + memory = bytearray(4) + lib = inflater(memory) + val = lib.inflate(c_int, 0) + s = lib.inflate(s_t, 0) + + result = val.__ne__(s) + self.assertIs(result, NotImplemented) + + def test_lt_between_compatible_primitives_works(self): + """Comparisons between compatible primitives should still work.""" + memory = bytearray(8) + lib = inflater(memory) + a = lib.inflate(c_int, 0) + b = lib.inflate(c_int, 4) + a.value = 1 + b.value = 2 + + self.assertTrue(a < b) + self.assertFalse(b < a) + + def test_comparison_with_raw_int(self): + memory = bytearray(4) + lib = inflater(memory) + a = lib.inflate(c_int, 0) + a.value = 5 + + self.assertTrue(a < 10) + self.assertTrue(a > 2) + self.assertTrue(a <= 5) + self.assertTrue(a >= 5) + + +class PtrUnwrapLengthZeroTest(unittest.TestCase): + """ptr.unwrap(0) must read 0 bytes, not 1.""" + + def test_unwrap_length_zero_returns_empty(self): + """unwrap(0) should return 0 bytes, not 1 byte.""" + memory = bytearray(16) + memory[0:8] = (8).to_bytes(8, "little") + memory[8] = 0xAB + + p = ptr(MemoryResolver(memory, 0)) + + result = p.unwrap(0) + self.assertEqual(len(result), 0) + self.assertEqual(result, b"") + + def test_unwrap_length_none_returns_one_byte(self): + """unwrap() (default None) should still return 1 byte.""" + memory = bytearray(16) + memory[0:8] = (8).to_bytes(8, "little") + memory[8] = 0xAB + + p = ptr(MemoryResolver(memory, 0)) + + result = p.unwrap() + self.assertEqual(len(result), 1) + self.assertEqual(result, bytes([0xAB])) + + def test_try_unwrap_length_zero(self): + """try_unwrap(0) should also return empty bytes, not 1 byte.""" + memory = bytearray(16) + memory[0:8] = (8).to_bytes(8, "little") + memory[8] = 0xAB + + p = ptr(MemoryResolver(memory, 0)) + + result = p.try_unwrap(0) + self.assertIsNotNone(result) + self.assertEqual(len(result), 0) + + +class TypeRegistryDeduplicationTest(unittest.TestCase): + """Repeated handler registration must not accumulate duplicates.""" + + def test_generic_handler_not_duplicated(self): + """Registering the same handler twice must not produce duplicate entries.""" + registry = TypeRegistry() + + class DummyType: + pass + + def dummy_handler(item, args, owner): + return None + + initial_count = len(registry.generic_handlers.get(DummyType, [])) + + registry.register_generic_handler(DummyType, dummy_handler) + registry.register_generic_handler(DummyType, dummy_handler) + + count = len(registry.generic_handlers[DummyType]) + self.assertEqual(count, initial_count + 1) + + def test_instance_handler_not_duplicated(self): + """Registering the same instance handler twice must not produce duplicate entries.""" + registry = TypeRegistry() + + class DummyField: + pass + + def dummy_handler(item, annotation, owner): + return None + + initial_count = len(registry.instance_handlers.get(DummyField, [])) + + registry.register_instance_handler(DummyField, dummy_handler) + registry.register_instance_handler(DummyField, dummy_handler) + + count = len(registry.instance_handlers[DummyField]) + self.assertEqual(count, initial_count + 1) + + def test_type_handler_not_duplicated(self): + """Registering the same type handler twice must not produce duplicate entries.""" + registry = TypeRegistry() + + class DummyParent: + pass + + def dummy_handler(item): + return None + + initial_count = len(registry.type_handlers.get(DummyParent, [])) + + registry.register_type_handler(DummyParent, dummy_handler) + registry.register_type_handler(DummyParent, dummy_handler) + + count = len(registry.type_handlers[DummyParent]) + self.assertEqual(count, initial_count + 1) + + +class PtrCacheStalenessTest(unittest.TestCase): + """ptr.unwrap() must observe address changes that happen via memory mutation, not just _set().""" + + def test_unwrap_returns_fresh_view_when_address_bytes_change(self): + class Inner(struct): + val: c_int + + class Outer(struct): + p: c_long # raw 8-byte address; we'll wrap with a ptr + + # Layout: 8 bytes ptr value, then two c_int payloads at 8 and 16 + memory = bytearray(20) + pystruct.pack_into("