diff --git a/src/easyscience/base_classes/__init__.py b/src/easyscience/base_classes/__init__.py index b5dc0418..8fd4aab2 100644 --- a/src/easyscience/base_classes/__init__.py +++ b/src/easyscience/base_classes/__init__.py @@ -3,9 +3,10 @@ from .based_base import BasedBase from .collection_base import CollectionBase +from .easy_collection import EasyCollection from .easy_list import EasyList from .model_base import ModelBase from .new_base import NewBase from .obj_base import ObjBase -__all__ = [BasedBase, CollectionBase, ObjBase, ModelBase, NewBase, EasyList] +__all__ = [BasedBase, CollectionBase, ObjBase, ModelBase, NewBase, EasyList, EasyCollection] diff --git a/src/easyscience/base_classes/easy_collection.py b/src/easyscience/base_classes/easy_collection.py new file mode 100644 index 00000000..af3a9bf0 --- /dev/null +++ b/src/easyscience/base_classes/easy_collection.py @@ -0,0 +1,402 @@ +# SPDX-FileCopyrightText: 2026 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import copy +import warnings +from importlib import import_module +from typing import Any +from typing import Dict +from typing import Iterable +from typing import List +from typing import Optional +from typing import Type +from typing import cast +from typing import overload + +from easyscience.io.serializer_base import SerializerBase + +from .easy_list import EasyList +from .new_base import NewBase + +CollectionItem = NewBase + + +class EasyCollection(EasyList[CollectionItem]): + """Collection built on :class:`EasyList` for ``NewBase`` objects. + + This class keeps the list storage and ``NewBase`` inheritance from + ``EasyList`` while adding parent-child edges in the global object map. + """ + + _DEFAULT_PROTECTED_TYPES = [NewBase] + + def __init__( + self, + *args: CollectionItem | list[CollectionItem], + protected_types: list[Type[NewBase]] | Type[NewBase] | None = None, + unique_name: Optional[str] = None, + display_name: Optional[str] = None, + ): + """Initialize the collection. + + :param args: Initial collection items. + :param protected_types: Types allowed in the collection. + :param unique_name: Optional unique name for the collection. + :param display_name: Optional display name for the collection. + """ + self.user_data: dict = {} + self._protected_types_explicit = protected_types is not None + + super().__init__(unique_name=unique_name, display_name=display_name) + self._protected_types = self._normalize_protected_types(protected_types) + + for item in self._flatten_items(args): + self.append(item) + + @overload + def __getitem__(self, idx: int) -> CollectionItem: ... + @overload + def __getitem__(self, idx: slice) -> EasyCollection: ... + @overload + def __getitem__(self, idx: str) -> CollectionItem: ... + def __getitem__(self, idx: int | slice | str) -> CollectionItem | EasyCollection: + """Get an item by index, slice, or unique name. + + String lookup returns the single item with the matching ``unique_name``. + Duplicate ``unique_name`` values are rejected at insertion time, so a + successful lookup always resolves to exactly one item. Slice lookup + returns a new collection that references the same item objects and + registers them as children of the sliced collection in the global map. + """ + if isinstance(idx, int): + return self._data[idx] + if isinstance(idx, slice): + return self._new_like(self._data[idx]) + if isinstance(idx, str): + for item in self._data: + if self._get_key(item) == idx: + return item + raise KeyError(f'No item with unique name "{idx}" found') + raise TypeError('Index must be an int, slice, or str') + + @overload + def __setitem__(self, idx: int, value: CollectionItem) -> None: ... + @overload + def __setitem__(self, idx: slice, value: Iterable[CollectionItem]) -> None: ... + def __setitem__( + self, idx: int | slice, value: CollectionItem | Iterable[CollectionItem] + ) -> None: + """Set collection items and keep graph state synchronized.""" + if isinstance(idx, int): + # cast(CollectionItem, value) tells the type checker: + # “for this branch, treat value as one item.” + self._set_single_item(idx, cast(CollectionItem, value)) + return + if isinstance(idx, slice): + if not isinstance(value, Iterable): + raise TypeError('Value must be an iterable for slice assignment') + self._set_slice(idx, value) + return + raise TypeError('Index must be an int or slice') + + def __delitem__(self, idx: int | slice | str) -> None: + """Delete collection items and prune graph edges.""" + if isinstance(idx, int): + item = self._data[idx] + self._prune_child_relation(item) + del self._data[idx] + return + if isinstance(idx, slice): + for item in self._data[idx]: + self._prune_child_relation(item) + del self._data[idx] + return + if isinstance(idx, str): + # Find the matching unique-name entry so missing string keys raise KeyError. + for i, item in enumerate(self._data): + if self._get_key(item) == idx: + self._prune_child_relation(item) + del self._data[i] + return + raise KeyError(f'No item with unique name "{idx}" found') + raise TypeError('Index must be an int, slice, or str') + + def insert(self, index: int, value: CollectionItem) -> None: + """Insert an item and register it as a child of the + collection. + """ + if not isinstance(index, int): + raise TypeError('Index must be an integer') + self._validate_item(value) + if self._contains_key(self._get_key(value)): + warnings.warn( + f'Item with unique name "{self._get_key(value)}" already in EasyCollection, it will be ignored' + ) + return + self._data.insert(index, value) + self._add_child_relation(value) + + def pop(self, index: int | str = -1) -> CollectionItem: + """Remove and return an item by index or unique name. + + Extends :class:`collections.abc.MutableSequence`'s int-only ``pop`` + signature to also accept a ``unique_name`` string for symmetry with + :meth:`__getitem__` and :meth:`__delitem__`. + """ + if isinstance(index, int): + item = self._data[index] + del self[index] + return item + if isinstance(index, str): + # Find the matching unique-name entry so missing string keys raise KeyError. + for i, item in enumerate(self._data): + if self._get_key(item) == index: + del self[i] + return item + raise KeyError(f'No item with unique name "{index}" found') + raise TypeError('Index must be an int or str') + + def __contains__(self, item: CollectionItem | str) -> bool: + """Return whether an item object or unique name exists in the + collection. + """ + if isinstance(item, str): + return self._contains_key(item) + return item in self._data + + def index(self, value: CollectionItem | str, start: int = 0, stop: int | None = None) -> int: + """Return the index of an item object or the first item with the + given unique name. + """ + if isinstance(value, str): + start_index, stop_index, _ = slice(start, stop).indices(len(self._data)) + for i in range(start_index, stop_index): + item = self._data[i] + if self._get_key(item) == value: + return i + raise ValueError(f'{value} is not in EasyCollection') + return self._data.index(value, start, stop) + + def to_dict(self, skip: Optional[List[str]] = None) -> Dict[str, Any]: + """Convert the collection to a serialized dictionary.""" + if skip is None: + skip = [] + dict_repr = self._metadata_dict() + if not self._default_unique_name and 'unique_name' not in skip: + dict_repr['unique_name'] = self.unique_name + if self._display_name is not None and 'display_name' not in skip: + dict_repr['display_name'] = self._display_name + if self._protected_types_explicit and 'protected_types' not in skip: + dict_repr['protected_types'] = [ + {'@module': cls_.__module__, '@class': cls_.__name__} + for cls_ in self._protected_types + ] + dict_repr['data'] = [self._item_to_dict(item, skip=skip) for item in self._data] + return dict_repr + + def as_dict(self, skip: Optional[List[str]] = None) -> Dict[str, Any]: + """Compatibility alias for legacy callers.""" + return self.to_dict(skip=skip) + + @classmethod + def from_dict(cls, obj_dict: Dict[str, Any]) -> EasyCollection: + """Create an ``EasyCollection`` from a serialized dictionary. + + The payload's ``@class`` field must match ``cls.__name__`` exactly; + subclass payloads are not accepted through a parent class. Dispatch to + the correct concrete subclass via :class:`SerializerBase` (or call + ``from_dict`` on the matching subclass directly) before invoking this. + """ + if not SerializerBase._is_serialized_easyscience_object(obj_dict): + raise ValueError( + 'Input must be a dictionary representing an EasyScience EasyCollection object.' + ) + temp_dict = copy.deepcopy(obj_dict) + if temp_dict['@class'] != cls.__name__: + raise ValueError( + f'Class name in dictionary does not match the expected class: {cls.__name__}.' + ) + + protected_types = cls._deserialize_protected_types(temp_dict.pop('protected_types', None)) + kwargs = SerializerBase.deserialize_dict(temp_dict) + data = kwargs.pop('data', []) + return cls(*data, protected_types=protected_types, **kwargs) + + def _new_like(self, data: Iterable[CollectionItem]) -> EasyCollection: + """Create a same-class collection for slice and duplicate-name + results. + """ + return self.__class__( + list(data), + protected_types=self._protected_types, + display_name=self._display_name, + ) + + @classmethod + def _normalize_protected_types( + cls, protected_types: list[Type[NewBase]] | Type[NewBase] | None + ) -> list[Type[NewBase]]: + """Return protected types as a validated list of ``NewBase`` + subclasses. + """ + if protected_types is None: + return list(cls._DEFAULT_PROTECTED_TYPES) + if isinstance(protected_types, type) and issubclass(protected_types, NewBase): + protected_types = [protected_types] + elif isinstance(protected_types, Iterable) and all( + issubclass(t, NewBase) for t in protected_types + ): + protected_types = list(protected_types) + else: + raise TypeError( + 'protected_types must be a NewBase subclass or an iterable of NewBase subclasses' + ) + return protected_types + + @staticmethod + def _flatten_items(args: tuple[Any, ...]) -> list[CollectionItem]: + """Flatten positional item lists into the sequence inserted into + the collection. + """ + items = [] + for item in args: + if isinstance(item, list): + items.extend(item) + else: + items.append(item) + return items + + def _validate_item(self, value: CollectionItem) -> None: + """Raise if the value is not one of the configured protected + types. + """ + if not isinstance(value, tuple(self._protected_types)): + raise TypeError(f'Items must be one of {self._protected_types}, got {type(value)}') + + def _contains_key(self, key: str) -> bool: + """Return whether the collection contains an item with the given + unique name. + """ + return any(self._get_key(item) == key for item in self._data) + + def _set_single_item(self, idx: int, value: CollectionItem) -> None: + """Replace one item while preserving unique-name and graph + invariants. + """ + self._validate_item(value) + old_item = self._data[idx] + value_key = self._get_key(value) + if value is not old_item and any( + self._get_key(item) == value_key for item in self._data if item is not old_item + ): + # Warn if the new item has the same unique name as another existing item (other than the one being replaced) + # and skip the update to avoid duplicate keys. + # Or should we raise here instead? + warnings.warn( + f'Item with unique name "{value_key}" already in EasyCollection, it will be ignored' + ) + return + self._prune_child_relation(old_item) + self._data[idx] = value + self._add_child_relation(value) + + def _set_slice(self, idx: slice, value: Iterable[CollectionItem]) -> None: + """Replace a slice while preserving length, unique-name, and + graph invariants. + """ + replaced = self._data[idx] + new_values = list(value) + if len(new_values) != len(replaced): + raise ValueError( + 'Length of new values must match the length of the slice being replaced' + ) + for new_value in new_values: + self._validate_item(new_value) + + existing_keys = { + self._get_key(item) + for item in self._data + if all(item is not replaced_item for replaced_item in replaced) + } + seen_batch_keys: set[str] = set() + # Track unique names already accepted from ``new_values`` so that + # passing the same item (or two items sharing a unique name) inside one + # slice assignment is rejected the same way as collisions with items + # outside the slice. + for position, new_value in enumerate(new_values): + key = self._get_key(new_value) + if key in existing_keys or key in seen_batch_keys: + warnings.warn( + f'Item with unique name "{key}" already in EasyCollection, it will be ignored' + ) + new_values[position] = replaced[position] + continue + seen_batch_keys.add(key) + + for old_item in replaced: + self._prune_child_relation(old_item) + self._data[idx] = new_values + for new_value in new_values: + self._add_child_relation(new_value) + + def _add_child_relation(self, value: CollectionItem) -> None: + """Register a collection-child edge in the global object map.""" + # ``get_edges`` returns the list of child ``unique_name`` strings, which + # is the same key space as ``_get_key`` so the membership check is direct. + edges = self._global_object.map.get_edges(self) + if self._get_key(value) not in edges: + self._global_object.map.add_edge(self, value) + self._global_object.map.reset_type(value, 'created_internal') + + def _prune_child_relation(self, value: CollectionItem) -> None: + """Remove the collection-child edge for a removed item.""" + self._global_object.map.prune_vertex_from_edge(self, value) + + def _metadata_dict(self) -> Dict[str, Any]: + """Return serialization metadata for the concrete collection + class. + """ + dict_repr: Dict[str, Any] = {'@module': self.__module__, '@class': self.__class__.__name__} + try: + module_version = import_module('easyscience').__version__ + dict_repr['@version'] = f'{module_version}' + except (AttributeError, ImportError): + dict_repr['@version'] = None + return dict_repr + + @staticmethod + def _item_to_dict(item: CollectionItem, skip: List[str]) -> Any: + """Serialize one collection item using the best available + EasyScience encoder. + """ + if hasattr(item, 'to_dict'): + return item.to_dict(skip=skip) + as_dict = getattr(item, 'as_dict', None) + if callable(as_dict): + return as_dict(skip=skip) + return SerializerBase()._recursive_encoder(item, skip=skip) + + @staticmethod + def _deserialize_protected_types( + protected_types: list[dict[str, str]] | None, + ) -> list[Type] | None: + """Convert serialized protected-type metadata back into Python + classes. + """ + if protected_types is None: + return None + deserialized_types = [] + for type_dict in protected_types: + if '@module' not in type_dict or '@class' not in type_dict: + raise ValueError( + 'Each protected type must be a serialized EasyScience class with @module and @class keys' + ) + mod = __import__(type_dict['@module'], globals(), locals(), [type_dict['@class']], 0) + if not hasattr(mod, type_dict['@class']): + raise ImportError( + f'Could not import class {type_dict["@class"]} from module {type_dict["@module"]}' + ) + deserialized_types.append(getattr(mod, type_dict['@class'])) + return deserialized_types diff --git a/tests/unit/base_classes/test_easy_collection.py b/tests/unit/base_classes/test_easy_collection.py new file mode 100644 index 00000000..b3f557b3 --- /dev/null +++ b/tests/unit/base_classes/test_easy_collection.py @@ -0,0 +1,245 @@ +# SPDX-FileCopyrightText: 2026 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +from easyscience import global_object +from easyscience.base_classes import EasyCollection +from easyscience.base_classes import NewBase + + +class Alpha(NewBase): + def __init__(self, unique_name=None, display_name=None): + super().__init__(unique_name=unique_name, display_name=display_name) + + +class TestEasyCollection: + @pytest.fixture(autouse=True) + def clear(self): + global_object.map._clear() + + def test_init_accepts_newbase_items_by_default(self): + alpha = Alpha(unique_name='alpha') + + collection = EasyCollection(alpha, display_name='collection') + + assert collection.display_name == 'collection' + assert collection[0] is alpha + + def test_init_flattens_positional_list_args(self): + first = Alpha(unique_name='first') + second = Alpha(unique_name='second') + third = Alpha(unique_name='third') + + collection = EasyCollection([first, second], third) + + assert list(collection) == [first, second, third] + + def test_protected_types_are_enforced(self): + collection = EasyCollection(protected_types=Alpha) + + collection.append(Alpha(unique_name='alpha')) + + with pytest.raises(TypeError, match='Items must be one of'): + collection.append(NewBase(unique_name='new-base')) + + def test_graph_edges_are_applied_on_append_insert_and_replace(self): + first = Alpha(unique_name='first') + second = Alpha(unique_name='second') + replacement = Alpha(unique_name='replacement') + collection = EasyCollection(first) + + collection.insert(0, second) + collection[1] = replacement + + assert first.unique_name not in global_object.map.get_edges(collection) + assert second.unique_name in global_object.map.get_edges(collection) + assert replacement.unique_name in global_object.map.get_edges(collection) + + def test_delete_and_pop_prune_graph_edges(self): + first = Alpha(unique_name='first') + second = Alpha(unique_name='second') + collection = EasyCollection(first, second) + + del collection[0] + popped = collection.pop('second') + + assert popped is second + assert first.unique_name not in global_object.map.get_edges(collection) + assert second.unique_name not in global_object.map.get_edges(collection) + assert len(collection) == 0 + + def test_string_lookup_uses_unique_name(self): + first = Alpha(unique_name='first') + second = Alpha(unique_name='second') + collection = EasyCollection(first, second) + + assert collection['second'] is second + with pytest.raises(KeyError, match='No item with unique name'): + collection['missing'] + + def test_missing_unique_name_removal_raises_key_error(self): + collection = EasyCollection(Alpha(unique_name='first')) + + with pytest.raises(KeyError, match='No item with unique name'): + del collection['missing'] + + with pytest.raises(KeyError, match='No item with unique name'): + collection.pop('missing') + + def test_slice_preserves_collection_metadata(self): + first = Alpha(unique_name='first') + second = Alpha(unique_name='second') + collection = EasyCollection(first, second, display_name='collection') + + sliced = collection[:1] + + assert isinstance(sliced, EasyCollection) + assert sliced.display_name == collection.display_name + assert sliced[0] is first + assert list(sliced) == [first] + assert first.unique_name in global_object.map.get_edges(sliced) + + def test_index_by_unique_name_uses_python_slice_bounds(self): + first = Alpha(unique_name='first') + second = Alpha(unique_name='second') + third = Alpha(unique_name='third') + collection = EasyCollection(first, second, third) + + assert collection.index('second', 0, -1) == 1 + assert collection.index('second', -2) == 1 + with pytest.raises(ValueError, match='third is not in EasyCollection'): + collection.index('third', 0, -1) + + def test_index_by_unique_name_rejects_non_int_bounds(self): + first = Alpha(unique_name='first') + second = Alpha(unique_name='second') + collection = EasyCollection(first, second) + + with pytest.raises(TypeError, match='slice indices'): + collection.index('second', 0.0) + with pytest.raises(TypeError, match='slice indices'): + collection.index('second', 0, 2.0) + + def test_to_dict_and_from_dict_round_trip(self): + first = Alpha(unique_name='first') + second = Alpha(unique_name='second') + collection = EasyCollection( + first, second, unique_name='collection_key', protected_types=Alpha + ) + + collection_dict = collection.to_dict() + global_object.map._clear() + deserialized = EasyCollection.from_dict(collection_dict) + + assert deserialized.unique_name == 'collection_key' + assert [item.unique_name for item in deserialized] == ['first', 'second'] + assert deserialized.to_dict() == collection_dict + + def test_as_dict_is_compatibility_alias(self): + collection = EasyCollection(Alpha(unique_name='alpha')) + + assert collection.as_dict() == collection.to_dict() + + def test_empty_collection_round_trip_does_not_repopulate(self): + collection = EasyCollection() + + collection_dict = collection.to_dict() + global_object.map._clear() + deserialized = EasyCollection.from_dict(collection_dict) + + assert len(deserialized) == 0 + assert list(deserialized) == [] + assert deserialized.to_dict() == collection_dict + + def test_slice_assignment_rejects_within_batch_duplicates(self): + first = Alpha(unique_name='first') + second = Alpha(unique_name='second') + duplicate = Alpha(unique_name='dup') + collection = EasyCollection(first, second) + + with pytest.warns(UserWarning, match='already in EasyCollection'): + collection[0:2] = [duplicate, duplicate] + + # Second slot should fall back to the original ``second`` because the + # batch already contained ``duplicate``. + assert collection[0] is duplicate + assert collection[1] is second + + def test_getitem_rejects_unsupported_type(self): + collection = EasyCollection(Alpha(unique_name='first')) + + with pytest.raises(TypeError, match='Index must be an int, slice, or str'): + _ = collection[1.0] + + def test_setitem_rejects_non_iterable_slice_assignment(self): + collection = EasyCollection(Alpha(unique_name='first')) + + with pytest.raises(TypeError, match='Value must be an iterable for slice assignment'): + collection[0:1] = Alpha(unique_name='other') # type: ignore[arg-type] + + def test_delitem_rejects_invalid_type(self): + collection = EasyCollection(Alpha(unique_name='first')) + + with pytest.raises(TypeError, match='Index must be an int, slice, or str'): + del collection[1.0] + + def test_insert_rejects_non_integer_index(self): + collection = EasyCollection(Alpha(unique_name='first')) + + with pytest.raises(TypeError, match='Index must be an integer'): + collection.insert('0', Alpha(unique_name='second')) # type: ignore[arg-type] + + def test_pop_rejects_invalid_type(self): + collection = EasyCollection(Alpha(unique_name='first')) + + with pytest.raises(TypeError, match='Index must be an int or str'): + collection.pop(1.0) # type: ignore[arg-type] + + def test_contains_checks_unique_name_and_item_object(self): + first = Alpha(unique_name='first') + second = Alpha(unique_name='second') + collection = EasyCollection(first, second) + + assert 'first' in collection + assert first in collection + assert 'missing' not in collection + + def test_to_dict_respects_skip_keys(self): + collection = EasyCollection( + Alpha(unique_name='alpha'), + unique_name='collection_key', + protected_types=Alpha, + ) + + collection_dict = collection.to_dict(skip=['unique_name', 'protected_types']) + + assert 'unique_name' not in collection_dict + assert 'protected_types' not in collection_dict + assert 'data' in collection_dict + + def test_from_dict_raises_for_invalid_serialized_dict(self): + with pytest.raises( + ValueError, + match='Input must be a dictionary representing an EasyScience EasyCollection object.', + ): + EasyCollection.from_dict({'foo': 'bar'}) + + def test_from_dict_raises_for_mismatched_class(self): + collection = EasyCollection(Alpha(unique_name='alpha')) + collection_dict = collection.to_dict() + collection_dict['@class'] = 'WrongClass' + + global_object.map._clear() + with pytest.raises( + ValueError, + match='Class name in dictionary does not match the expected class: EasyCollection.', + ): + EasyCollection.from_dict(collection_dict) + + def test_protected_types_accepts_iterable_types(self): + collection = EasyCollection(protected_types=[Alpha]) + + collection.append(Alpha(unique_name='alpha')) + with pytest.raises(TypeError, match='Items must be one of'): + collection.append(NewBase(unique_name='new-base'))