diff --git a/frontend/src/components/Chat/ConverterPanel.test.tsx b/frontend/src/components/Chat/ConverterPanel.test.tsx index 7c657eebdc..99922b4408 100644 --- a/frontend/src/components/Chat/ConverterPanel.test.tsx +++ b/frontend/src/components/Chat/ConverterPanel.test.tsx @@ -149,6 +149,29 @@ describe('ConverterPanel loading', () => { renderPanel() await waitFor(() => expect(screen.getByTestId('converter-panel-empty')).toBeInTheDocument()) }) + + it('hides base/helper converters that should not be offered', async () => { + const catalogWithHidden = { + items: [ + ...MOCK_CATALOG.items, + { + converter_type: 'SelectiveTextConverter', + supported_input_types: ['text'], + supported_output_types: ['text'], + parameters: [], + is_llm_based: false, + description: 'Base/helper converter.', + }, + ], + } + mockedConvertersApi.listConverterCatalog.mockResolvedValueOnce(catalogWithHidden as ConverterCatalogResponse) + renderPanel() + await waitForList() + + fireEvent.click(getComboboxInput()) + await waitFor(() => expect(screen.getByTestId('converter-option-Base64Converter')).toBeInTheDocument()) + expect(screen.queryByTestId('converter-option-SelectiveTextConverter')).not.toBeInTheDocument() + }) }) // ─── Close button ──────────────────────────────────────────────── diff --git a/frontend/src/components/Chat/ConverterPanel/ConverterPanel.tsx b/frontend/src/components/Chat/ConverterPanel/ConverterPanel.tsx index d212e97e94..a079202a84 100644 --- a/frontend/src/components/Chat/ConverterPanel/ConverterPanel.tsx +++ b/frontend/src/components/Chat/ConverterPanel/ConverterPanel.tsx @@ -18,6 +18,10 @@ const PIECE_TYPE_LABELS: Record = { video: 'Video', } +// Converter classes the backend can build but that aren't useful to offer in the +// picker (base/helper classes). +const HIDDEN_CONVERTER_TYPES = new Set(['SelectiveTextConverter']) + interface ConverterPanelProps { onClose: () => void previewText?: string @@ -51,7 +55,7 @@ export default function ConverterPanel({ onClose, previewText = '', attachmentDa try { const response = await convertersApi.listConverterCatalog() - setConverters(response.items) + setConverters(response.items.filter((c) => !HIDDEN_CONVERTER_TYPES.has(c.converter_type))) } catch (err) { setConverters([]) setSelectedConverterType('') diff --git a/pyrit/backend/services/converter_service.py b/pyrit/backend/services/converter_service.py index 8bd6199592..4300ad7f16 100644 --- a/pyrit/backend/services/converter_service.py +++ b/pyrit/backend/services/converter_service.py @@ -15,14 +15,12 @@ import base64 import inspect import mimetypes -import re import uuid from functools import lru_cache from pathlib import Path from typing import Any, Literal, Union, get_args, get_origin from urllib.parse import parse_qs, urlparse -from pyrit import prompt_converter from pyrit.backend.mappers.converter_mappers import converter_object_to_instance from pyrit.backend.models.converters import ( ConverterCatalogEntry, @@ -38,9 +36,12 @@ ) from pyrit.memory import data_serializer_factory from pyrit.models import PromptDataType -from pyrit.prompt_converter import PromptConverter -from pyrit.prompt_target import PromptTarget -from pyrit.registry.object_registries import ConverterRegistry + +# ``get_union_non_none_args`` is a general type-introspection utility used here to +# render parameter types for the catalog (a presentation concern owned by this +# service). +from pyrit.registry.components import ConverterParameterMetadata, ConverterRegistry +from pyrit.registry.resolution import get_union_non_none_args _DATA_TYPE_EXTENSION: dict[str, str] = { "image_path": ".png", @@ -50,169 +51,31 @@ } -def _build_converter_class_registry() -> dict[str, type]: - """ - Build a registry mapping converter class names to their classes. - - Uses the prompt_converter module's __all__ to discover all available converters. - - Returns: - Dict mapping class name (str) to class (type). - """ - registry: dict[str, type] = {} - for name in prompt_converter.__all__: - cls = getattr(prompt_converter, name, None) - if cls is not None and isinstance(cls, type) and issubclass(cls, PromptConverter): - registry[name] = cls - return registry - - -# Module-level class registry (built once on import) -_CONVERTER_CLASS_REGISTRY: dict[str, type] = _build_converter_class_registry() - -# Types that can be rendered as simple form fields -_SIMPLE_TYPES: set[type] = {str, int, float, bool} - - -def _is_simple_type(annotation: Any) -> bool: - """Return True if the annotation represents a type renderable in a form field.""" - if annotation in _SIMPLE_TYPES: - return True - origin = get_origin(annotation) - if origin is Literal: - return True - if origin is Union: - args = get_args(annotation) - non_none = [a for a in args if a is not type(None)] - return len(non_none) == 1 and _is_simple_type(non_none[0]) - return False - - def _serialize_type(annotation: Any) -> str: """ - Convert a type annotation to a concise human-readable string. + Render a parameter's type annotation as a concise human-readable string. + + Used to populate the catalog DTO consumed by the frontend (e.g. ``"str"``, + ``"Optional[int]"``, ``"Literal['a', 'b']"``). Returns: str: A human-readable representation of the type annotation. """ if annotation is inspect.Parameter.empty: return "Any" - origin = get_origin(annotation) - if origin is Literal: + if get_origin(annotation) is Literal: args = get_args(annotation) return f"Literal[{', '.join(repr(a) for a in args)}]" - if origin is Union: - args = get_args(annotation) - non_none = [a for a in args if a is not type(None)] - if len(non_none) == 1: - inner = _serialize_type(non_none[0]) - return f"Optional[{inner}]" if len(args) > len(non_none) else inner + non_none = get_union_non_none_args(annotation) + if non_none is not None and len(non_none) == 1: + inner = _serialize_type(non_none[0]) + has_none = type(None) in get_args(annotation) + return f"Optional[{inner}]" if has_none else inner if hasattr(annotation, "__name__"): return str(annotation.__name__) return str(annotation) -def _parse_arg_descriptions(converter_class: type) -> dict[str, str]: - """ - Parse parameter descriptions from Google-style docstring Args section. - - Returns: - dict[str, str]: Mapping of parameter names to their descriptions. - """ - doc = (converter_class.__init__.__doc__ or converter_class.__doc__ or "").strip() - match = re.search(r"Args:\s*\n(.*?)(?:\n\s*\n|\n\s*Returns:|\n\s*Raises:|\Z)", doc, re.DOTALL) - if not match: - return {} - args_block = match.group(1) - # Detect indentation of first parameter line - indent_match = re.match(r"^(\s+)", args_block) - indent = indent_match.group(1) if indent_match else r"\s+" - pattern = rf"^{indent}(\w+)\s*(?:\([^)]*\))?\s*:\s*(.+?)(?=\n{indent}\w|\Z)" - descriptions: dict[str, str] = {} - for m in re.finditer(pattern, args_block, re.DOTALL | re.MULTILINE): - descriptions[m.group(1)] = " ".join(m.group(2).split()) - return descriptions - - -def _extract_parameters(converter_class: type) -> list[ConverterParameterSchema]: - """ - Extract simple constructor parameters from a converter class. - - Returns: - list[ConverterParameterSchema]: List of parameter schemas. - """ - try: - sig = inspect.signature(converter_class.__init__) - except (ValueError, TypeError): - return [] - - arg_descriptions = _parse_arg_descriptions(converter_class) - - params: list[ConverterParameterSchema] = [] - for name, p in sig.parameters.items(): - if name in ("self", "args", "kwargs"): - continue - if p.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): - continue - if not _is_simple_type(p.annotation): - continue - - no_default = p.default is inspect.Parameter.empty - is_sentinel = hasattr(p.default, "__class__") and "Sentinel" in type(p.default).__name__ - required = no_default or is_sentinel - - default_value: str | None = None - if not required and p.default is not None: - default_value = str(p.default) - - choices: list[str] | None = None - if get_origin(p.annotation) is Literal: - choices = [str(a) for a in get_args(p.annotation)] - - params.append( - ConverterParameterSchema( - name=name, - type_name=_serialize_type(p.annotation), - required=required, - default_value=default_value, - choices=choices, - description=arg_descriptions.get(name), - ) - ) - - return params - - -def _is_llm_based(converter_class: type) -> bool: - """ - Check if the converter requires a target parameter. - - Matches any converter whose ``__init__`` accepts - a ``PromptTarget`` (or subclass) parameter. - These converters perform LLM-based transformations and should not automatically be applied - - Returns: - bool: True if the converter is LLM-based, False otherwise. - """ - try: - sig = inspect.signature(converter_class.__init__) - except (ValueError, TypeError): - return False - - for name, p in sig.parameters.items(): - if name == "self": - continue - ann = p.annotation - if ann is inspect.Parameter.empty: - continue - try: - if isinstance(ann, type) and issubclass(ann, PromptTarget): - return True - except TypeError: - continue - return False - - class ConverterService: """ Service for managing converter instances. @@ -249,49 +112,55 @@ async def list_converters_async(self) -> ConverterInstanceListResponse: """ items = [ self._build_instance_from_object(converter_id=entry.name, converter_obj=entry.instance) - for entry in self._registry.get_all_instances() + for entry in self._registry.instances.get_all_instances() ] return ConverterInstanceListResponse(items=items) async def list_converter_catalog_async(self) -> ConverterCatalogResponse: """ - List all available converter types from the backend converter registry. + List all available converter types from the converter class registry. + + Returns every constructible converter. Deciding which entries to surface + to a user is a presentation concern owned by the caller (e.g. the + frontend), not this service. Returns: ConverterCatalogResponse containing all available converter classes. """ - items: list[ConverterCatalogEntry] = [] - for converter_type, converter_class in sorted(_CONVERTER_CLASS_REGISTRY.items()): - if ( - converter_type in ("PromptConverter", "ConverterResult", "SelectiveTextConverter") - or "Strategy" in converter_type - ): - continue - - supported_input_types = [ - str(data_type) for data_type in getattr(converter_class, "SUPPORTED_INPUT_TYPES", ()) - ] - supported_output_types = [ - str(data_type) for data_type in getattr(converter_class, "SUPPORTED_OUTPUT_TYPES", ()) - ] - - # Extract first paragraph of docstring as description - raw_doc = (converter_class.__doc__ or "").strip() - description = raw_doc.split("\n\n")[0].replace("\n", " ").strip() or None - - items.append( - ConverterCatalogEntry( - converter_type=converter_type, - supported_input_types=supported_input_types, - supported_output_types=supported_output_types, - parameters=_extract_parameters(converter_class), - is_llm_based=_is_llm_based(converter_class), - description=description, - ) + items: list[ConverterCatalogEntry] = [ + ConverterCatalogEntry( + converter_type=metadata.class_name, + supported_input_types=list(metadata.supported_input_types), + supported_output_types=list(metadata.supported_output_types), + parameters=[self._build_parameter_schema(p) for p in metadata.parameters if p.coercible_from_string], + is_llm_based=metadata.is_llm_based, + description=metadata.class_description or None, ) + for metadata in self._registry.list_class_metadata() + ] return ConverterCatalogResponse(items=items) + @staticmethod + def _build_parameter_schema(parameter: ConverterParameterMetadata) -> ConverterParameterSchema: + """ + Map registry parameter metadata to the catalog DTO. + + Renders the raw annotation to a human-readable ``type_name`` for the + frontend (presentation concern owned by this service). + + Returns: + ConverterParameterSchema: The parameter schema for the catalog entry. + """ + return ConverterParameterSchema( + name=parameter.name, + type_name=_serialize_type(parameter.annotation), + required=parameter.required, + default_value=parameter.default_value, + choices=list(parameter.choices) if parameter.choices is not None else None, + description=parameter.description, + ) + async def get_converter_async(self, *, converter_id: str) -> ConverterInstance | None: """ Get a converter instance by ID. @@ -299,7 +168,7 @@ async def get_converter_async(self, *, converter_id: str) -> ConverterInstance | Returns: ConverterInstance if found, None otherwise. """ - obj = self._registry.get_instance_by_name(converter_id) + obj = self._registry.instances.get(converter_id) if obj is None: return None return self._build_instance_from_object(converter_id=converter_id, converter_obj=obj) @@ -311,7 +180,7 @@ def get_converter_object(self, *, converter_id: str) -> Any | None: Returns: The PromptConverter object if found, None otherwise. """ - return self._registry.get_instance_by_name(converter_id) + return self._registry.instances.get(converter_id) async def create_converter_async(self, *, request: CreateConverterRequest) -> CreateConverterResponse: """ @@ -331,13 +200,17 @@ async def create_converter_async(self, *, request: CreateConverterRequest) -> Cr """ converter_id = str(uuid.uuid4()) - # Resolve any converter references in params and instantiate + # Resolve any converter references in params, persist data-URI params to + # disk (frontend concern), then delegate construction (incl. param + # coercion) to the converter registry. params = self._resolve_converter_params(params=request.params) - converter_class = self._get_converter_class(converter_type=request.type) - params = self._coerce_params(converter_class=converter_class, params=params) + try: + converter_class = self._registry.get_class(request.type) + except KeyError as e: + raise ValueError(f"Converter type '{request.type}' not found") from e params = await self._persist_data_uri_params_async(converter_class=converter_class, params=params) - converter_obj = converter_class(**params) - self._registry.register_instance(converter_obj, name=converter_id) + converter_obj = self._registry.create_instance(request.type, **params) + self._registry.instances.register(converter_obj, name=converter_id) return CreateConverterResponse( converter_id=converter_id, @@ -431,29 +304,6 @@ def get_converter_objects_for_ids(self, *, converter_ids: list[str]) -> list[Any # Private Helper Methods # ======================================================================== - def _get_converter_class(self, *, converter_type: str) -> type: - """ - Get the converter class for a given type name. - - Looks up the class in the module-level converter class registry. - - Args: - converter_type: The exact class name of the converter (e.g., 'Base64Converter'). - - Returns: - The converter class. - - Raises: - ValueError: If the converter type is not found. - """ - cls = _CONVERTER_CLASS_REGISTRY.get(converter_type) - if cls is None: - raise ValueError( - f"Converter type '{converter_type}' not found. " - f"Available types: {sorted(_CONVERTER_CLASS_REGISTRY.keys())}" - ) - return cls - def _resolve_converter_params(self, *, params: dict[str, Any]) -> dict[str, Any]: """ Resolve converter references in params. @@ -474,53 +324,6 @@ def _resolve_converter_params(self, *, params: dict[str, Any]) -> dict[str, Any] resolved["converter"] = conv_obj return resolved - @staticmethod - def _coerce_params(*, converter_class: type, params: dict[str, Any]) -> dict[str, Any]: - """ - Coerce parameter values to match the converter's __init__ type annotations. - - The frontend sends all values as strings; this converts them to int, float, - or bool as needed based on the constructor signature. - - Returns: - Params dict with values coerced to the expected types. - """ - try: - sig = inspect.signature(converter_class.__init__) - except (ValueError, TypeError) as e: - raise ValueError( - f"Failed to inspect __init__ signature for converter '{converter_class.__name__}': {e}" - ) from e - - coerced = dict(params) - for name, value in coerced.items(): - if name not in sig.parameters or not isinstance(value, str): - continue - annotation = sig.parameters[name].annotation - if annotation is inspect.Parameter.empty: - continue - - origin = get_origin(annotation) - # Unwrap X | None to X - if origin is Union: - args = get_args(annotation) - non_none = [a for a in args if a is not type(None)] - if len(non_none) == 1: - annotation = non_none[0] - origin = get_origin(annotation) - - try: - if annotation is int: - coerced[name] = int(value) - elif annotation is float: - coerced[name] = float(value) - elif annotation is bool: - coerced[name] = value.lower() in ("true", "1", "yes") - except (ValueError, TypeError) as e: - raise ValueError(f"Parameter '{name}' expects {annotation.__name__}, got {value!r}") from e - - return coerced - @staticmethod async def _persist_data_uri_params_async( *, diff --git a/pyrit/backend/services/scenario_run_service.py b/pyrit/backend/services/scenario_run_service.py index c8236aceda..b4997828a0 100644 --- a/pyrit/backend/services/scenario_run_service.py +++ b/pyrit/backend/services/scenario_run_service.py @@ -94,9 +94,7 @@ async def start_run_async(self, *, request: RunScenarioRequest) -> ScenarioRunSu init_kwargs = self._build_init_kwargs( request=request, scenario_class=scenario_class, objective_target=objective_target ) - scenario = await self._initialize_scenario_async( - request=request, scenario_class=scenario_class, init_kwargs=init_kwargs - ) + scenario = await self._initialize_scenario_async(request=request, init_kwargs=init_kwargs) except Exception: self._run_semaphore.release() raise @@ -371,15 +369,13 @@ def _build_init_kwargs( return init_kwargs - async def _initialize_scenario_async( - self, *, request: RunScenarioRequest, scenario_class: type[Scenario], init_kwargs: dict[str, Any] - ) -> Scenario: + async def _initialize_scenario_async(self, *, request: RunScenarioRequest, init_kwargs: dict[str, Any]) -> Scenario: """ Instantiate the scenario and call initialize_async. Args: - request: The run request (for scenario_params and scenario_result_id). - scenario_class: The resolved scenario class. + request: The run request (for scenario_name, scenario_params, and + scenario_result_id). init_kwargs: The kwargs to pass to scenario.initialize_async. Returns: @@ -388,7 +384,8 @@ async def _initialize_scenario_async( constructor_kwargs: dict[str, Any] = {} if request.scenario_result_id: constructor_kwargs["scenario_result_id"] = request.scenario_result_id - scenario = scenario_class(**constructor_kwargs) # type: ignore[call-arg] + scenario_registry = ScenarioRegistry.get_registry_singleton() + scenario = scenario_registry.create_instance(request.scenario_name, **constructor_kwargs) scenario.set_params_from_args(args=request.scenario_params or {}) await scenario.initialize_async(**init_kwargs) return scenario @@ -488,7 +485,7 @@ def _build_response_from_db(self, *, scenario_result: ScenarioResult) -> Scenari scenario_version=scenario_result.scenario_identifier.version, status=status, created_at=scenario_result.creation_time, - updated_at=scenario_result.completion_time, + updated_at=scenario_result.completion_time or scenario_result.creation_time, error=error, error_type=error_type, strategies_used=strategies_used, diff --git a/pyrit/registry/__init__.py b/pyrit/registry/__init__.py index cd94382b93..c4dc08da67 100644 --- a/pyrit/registry/__init__.py +++ b/pyrit/registry/__init__.py @@ -4,6 +4,7 @@ """Registry module for PyRIT class and object registries.""" from pyrit.registry.base import RegistryProtocol +from pyrit.registry.buildable_registry import BuildableRegistry from pyrit.registry.class_registries import ( BaseClassRegistry, ClassEntry, @@ -13,15 +14,24 @@ ScenarioParameterMetadata, ScenarioRegistry, ) +from pyrit.registry.components import ( + ConverterMetadata, + ConverterParameterMetadata, + ConverterRegistry, +) from pyrit.registry.discovery import ( discover_in_directory, discover_in_package, discover_subclasses_in_loaded_modules, ) +from pyrit.registry.instance_registry import ( + DefaultInstanceRegistry, + InstanceRegistry, + SupportsInstances, +) from pyrit.registry.object_registries import ( AttackTechniqueRegistry, BaseInstanceRegistry, - ConverterRegistry, RegistryEntry, RetrievableInstanceRegistry, ScorerRegistry, @@ -33,8 +43,14 @@ "AttackTechniqueRegistry", "BaseClassRegistry", "BaseInstanceRegistry", + "BuildableRegistry", + "ConverterMetadata", + "ConverterParameterMetadata", "ConverterRegistry", + "DefaultInstanceRegistry", + "InstanceRegistry", "RetrievableInstanceRegistry", + "SupportsInstances", "ClassEntry", "discover_in_directory", "discover_in_package", diff --git a/pyrit/registry/buildable_registry.py b/pyrit/registry/buildable_registry.py new file mode 100644 index 0000000000..d93dc6bc45 --- /dev/null +++ b/pyrit/registry/buildable_registry.py @@ -0,0 +1,140 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Buildable registry base for PyRIT. + +``BuildableRegistry`` is the universal registry capability: discover classes, +introspect them into metadata, and **build** configured instances from a type +name plus a flat argument dict. Construction routes through the shared +``resolve_constructor_args`` primitive, so simple values are coerced and +registry-reference parameters (e.g. a ``PromptTarget``) are resolved by name — +the same mechanism for every domain. + +Every PyRIT registry is buildable. Registries that additionally hold named +instances expose an ``instances`` property (an ``InstanceRegistry``); the +buildable layer itself only concerns the class catalog. +""" + +from __future__ import annotations + +from typing import TypeVar + +from pyrit.registry.class_registries.base_class_registry import BaseClassRegistry +from pyrit.registry.resolution import resolve_constructor_args + +T = TypeVar("T") +MetadataT = TypeVar("MetadataT") + + +class BuildableRegistry(BaseClassRegistry[T, MetadataT]): + """ + Registry base that can build instances from a type name and arguments. + + Extends the class-table infrastructure of ``BaseClassRegistry`` with a + construction path that routes through ``resolve_constructor_args``: string + values are coerced to their annotated scalar types and registry-reference + parameters are resolved by name from the owning domain's registry. A + registered factory, when present, is used as-is (its arguments are not + resolved, since a factory owns its own construction semantics). + + Type Parameters: + T: The type of classes being registered (e.g. ``PromptConverter``). + MetadataT: The metadata dataclass type (e.g. ``ConverterMetadata``). + """ + + def get_class_names(self) -> list[str]: + """ + Get a sorted list of all registered class names. + + Always reflects the class catalog, even on registries that also hold + instances (where the protocol surface ``get_names`` refers to instances on + the ``instances`` property, not here). + + Returns: + list[str]: The sorted class-catalog names. + """ + self._ensure_discovered() + return sorted(self._class_entries.keys()) + + def get_class(self, name: str) -> type[T]: + """ + Get a registered class by its catalog name. + + Overrides the base lookup so the "not found" error lists the class catalog + (``get_class_names``) rather than the instances held under a registry's + ``instances`` property. + + Args: + name (str): The class-catalog name to resolve. + + Returns: + type[T]: The registered class. + + Raises: + KeyError: If the name is not registered in the class catalog. + """ + self._ensure_discovered() + entry = self._class_entries.get(name) + if entry is None: + available = ", ".join(self.get_class_names()) + raise KeyError(f"'{name}' not found in registry. Available: {available}") + return entry.registered_class + + def list_class_metadata( + self, + *, + include_filters: dict[str, object] | None = None, + exclude_filters: dict[str, object] | None = None, + ) -> list[MetadataT]: + """ + List metadata for all registered classes, optionally filtered. + + This is the class-catalog metadata (one entry per registered class), + distinct from any instance-level metadata a container registry exposes. + It always reflects the class catalog, even on container registries where + ``list_metadata`` refers to instances. + + Args: + include_filters (dict[str, object] | None): Filters items must match. + exclude_filters (dict[str, object] | None): Filters items must not match. + + Returns: + list[MetadataT]: Metadata describing each registered class. + """ + return BaseClassRegistry.list_metadata(self, include_filters=include_filters, exclude_filters=exclude_filters) + + def create_instance(self, name: str, **kwargs: object) -> T: + """ + Build a configured instance by class name. + + Arguments are resolved via ``resolve_constructor_args`` (coerce simple + strings, resolve registry references by name, raise on unknown params). + When the class is registered with a factory, the factory is invoked + directly with the given arguments instead. + + Args: + name (str): The class-catalog name to build. + **kwargs (object): Constructor arguments (simple values or registry + names for reference parameters). + + Returns: + T: The constructed instance. + + Raises: + KeyError: If the name is not registered. + ValueError: If an argument is not a valid constructor parameter, a + registry reference cannot be resolved, or a value cannot be coerced. + """ + self._ensure_discovered() + entry = self._class_entries.get(name) + if entry is None: + available = ", ".join(self.get_class_names()) + raise KeyError(f"'{name}' not found in registry. Available: {available}") + + if entry.factory is not None: + return entry.create_instance(**kwargs) + + raw_args = {**entry.default_kwargs, **kwargs} + resolved = resolve_constructor_args(cls=entry.registered_class, raw_args=raw_args) + return entry.registered_class(**resolved) diff --git a/pyrit/registry/components/__init__.py b/pyrit/registry/components/__init__.py new file mode 100644 index 0000000000..be2922cace --- /dev/null +++ b/pyrit/registry/components/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Component registries package. + +This package contains registries for PyRIT components (objects identified by a +``ComponentIdentifier``, such as converters, scorers, and targets). A component +registry is a ``BuildableRegistry`` class catalog that can build instances from +classes and, when it retains pre-configured instances, also exposes them via an +``.instances`` property. + +Shared capabilities and base classes (``BuildableRegistry``, ``InstanceRegistry``, +``DefaultInstanceRegistry``) live at the top level of ``pyrit.registry``. +""" + +from pyrit.registry.components.converter_registry import ( + ConverterMetadata, + ConverterParameterMetadata, + ConverterRegistry, +) + +__all__ = [ + "ConverterRegistry", + "ConverterMetadata", + "ConverterParameterMetadata", +] diff --git a/pyrit/registry/components/converter_registry.py b/pyrit/registry/components/converter_registry.py new file mode 100644 index 0000000000..d5fd3fe3b3 --- /dev/null +++ b/pyrit/registry/components/converter_registry.py @@ -0,0 +1,299 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Converter registry for PyRIT. + +A single registry for ``PromptConverter`` that both: + +- **builds** converters from a type name plus arguments — discovering converter + classes, introspecting their constructor parameters, and constructing instances + via the shared resolver (so LLM converters can be built by passing a + ``converter_target`` registry name), and +- **holds** pre-configured converter instances registered via initializers or the + backend. + +It is a ``BuildableRegistry``: the registry's own surface (``get_class``, +``get_class_names``, ``list_class_metadata``, ``create_instance``) is the buildable +class catalog. Pre-configured instances live under the ``instances`` property +(``register``, ``get``, ``get_all_instances``, ``get_names``), a +``DefaultInstanceRegistry``. +""" + +from __future__ import annotations + +import inspect +import logging +import re +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Literal, NamedTuple, get_args, get_origin + +from pyrit.registry.base import ClassRegistryEntry +from pyrit.registry.buildable_registry import BuildableRegistry +from pyrit.registry.class_registries.base_class_registry import ClassEntry +from pyrit.registry.instance_registry import DefaultInstanceRegistry, InstanceRegistry +from pyrit.registry.resolution import get_union_non_none_args, is_coercible_from_string + +if TYPE_CHECKING: + from pyrit.prompt_converter import PromptConverter + +logger = logging.getLogger(__name__) + + +def _prompt_converter_type() -> type[PromptConverter]: + """ + Return the ``PromptConverter`` base class, importing it lazily. + + Used as the ``instance_type`` for the registry's ``instances`` container so + a non-converter cannot be registered, without importing the converter + package at module load (which would defeat lazy discovery). + + Returns: + type[PromptConverter]: The ``PromptConverter`` base class. + """ + from pyrit.prompt_converter import PromptConverter + + return PromptConverter + + +class ConverterParameterMetadata(NamedTuple): + """ + A converter constructor parameter described for dynamic construction. + + Carries raw introspection data so callers can build converters on the fly. + ``annotation`` is the parameter's raw type annotation; rendering it to a + human-readable string is a presentation concern left to the caller. + ``coercible_from_string`` is True when a string value can be coerced to the + annotated type. ``requires_llm`` is True when the parameter expects a + ``PromptTarget`` (i.e. the converter performs an LLM-based transformation). + + NamedTuple so consumers can read fields by name while the value stays + immutable (safe to cache inside a frozen ``ConverterMetadata``). + """ + + name: str + annotation: Any + required: bool + default_value: str | None + choices: tuple[str, ...] | None + description: str | None + coercible_from_string: bool + requires_llm: bool + + +@dataclass(frozen=True) +class ConverterMetadata(ClassRegistryEntry): + """ + Metadata describing a registered ``PromptConverter`` class. + + Use ``ConverterRegistry.get_class()`` to get the actual class or + ``create_instance()`` to build a configured instance. + """ + + # Input data types the converter accepts (stringified PromptDataType values). + supported_input_types: tuple[str, ...] = field(kw_only=True, default=()) + + # Output data types the converter produces (stringified PromptDataType values). + supported_output_types: tuple[str, ...] = field(kw_only=True, default=()) + + # Simple constructor parameters suitable for dynamic form generation. + parameters: tuple[ConverterParameterMetadata, ...] = field(kw_only=True, default=()) + + # Whether the converter requires an LLM target. + is_llm_based: bool = field(kw_only=True, default=False) + + +def _requires_llm_target(annotation: Any) -> bool: + """ + Return True if the annotation expects a ``PromptTarget`` (or subclass). + + Handles unioned forms such as ``PromptTarget | None``. A converter parameter + with such an annotation indicates the converter performs an LLM-based + transformation. + + Returns: + bool: True if the annotation expects a ``PromptTarget``, False otherwise. + """ + if annotation is inspect.Parameter.empty: + return False + + from pyrit.prompt_target import PromptTarget + + candidates = get_union_non_none_args(annotation) + if candidates is None: + candidates = [annotation] + for candidate in candidates: + try: + if isinstance(candidate, type) and issubclass(candidate, PromptTarget): + return True + except TypeError: + continue + return False + + +def _parse_arg_descriptions(converter_class: type) -> dict[str, str]: + """ + Parse parameter descriptions from a Google-style docstring Args section. + + Returns: + dict[str, str]: Mapping of parameter names to their descriptions. + """ + doc = (converter_class.__init__.__doc__ or converter_class.__doc__ or "").strip() + match = re.search(r"Args:\s*\n(.*?)(?:\n\s*\n|\n\s*Returns:|\n\s*Raises:|\Z)", doc, re.DOTALL) + if not match: + return {} + args_block = match.group(1) + # Detect indentation of first parameter line + indent_match = re.match(r"^(\s+)", args_block) + indent = indent_match.group(1) if indent_match else r"\s+" + pattern = rf"^{indent}(\w+)\s*(?:\([^)]*\))?\s*:\s*(.+?)(?=\n{indent}\w|\Z)" + descriptions: dict[str, str] = {} + for m in re.finditer(pattern, args_block, re.DOTALL | re.MULTILINE): + descriptions[m.group(1)] = " ".join(m.group(2).split()) + return descriptions + + +def _extract_parameters(converter_class: type) -> tuple[ConverterParameterMetadata, ...]: + """ + Extract constructor parameters from a converter class. + + Surfaces every settable constructor parameter (excluding ``self`` and + var-args) so a caller has the full picture for dynamic construction. Each + parameter records its raw ``annotation`` and a ``coercible_from_string`` flag + indicating whether a string value can be coerced to its type. + + Returns: + tuple[ConverterParameterMetadata, ...]: The constructor parameters. + """ + try: + sig = inspect.signature(converter_class.__init__) + except (ValueError, TypeError): + return () + + arg_descriptions = _parse_arg_descriptions(converter_class) + + params: list[ConverterParameterMetadata] = [] + for name, p in sig.parameters.items(): + if name in ("self", "args", "kwargs"): + continue + if p.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): + continue + + no_default = p.default is inspect.Parameter.empty + is_sentinel = hasattr(p.default, "__class__") and "Sentinel" in type(p.default).__name__ + required = no_default or is_sentinel + + default_value: str | None = None + if not required and p.default is not None: + default_value = str(p.default) + + choices: tuple[str, ...] | None = None + choice_annotation = p.annotation + non_none_choice = get_union_non_none_args(choice_annotation) + if non_none_choice is not None and len(non_none_choice) == 1: + choice_annotation = non_none_choice[0] + if get_origin(choice_annotation) is Literal: + choices = tuple(str(a) for a in get_args(choice_annotation)) + + params.append( + ConverterParameterMetadata( + name=name, + annotation=p.annotation, + required=required, + default_value=default_value, + choices=choices, + description=arg_descriptions.get(name), + coercible_from_string=is_coercible_from_string(p.annotation), + requires_llm=_requires_llm_target(p.annotation), + ) + ) + + return tuple(params) + + +class ConverterRegistry(BuildableRegistry["PromptConverter", ConverterMetadata]): + """ + Registry that discovers, builds, and holds ``PromptConverter`` instances. + + Discovers all concrete ``PromptConverter`` subclasses exported from + ``pyrit.prompt_converter`` (keyed by their exact class name, e.g. + ``"Base64Converter"``) for the buildable catalog. Pre-configured instances + registered via initializers or the backend are held under the ``instances`` + property. + + Building a converter resolves its arguments through the shared resolver, so + LLM converters can be constructed by passing a ``converter_target`` that names + a target in the ``TargetRegistry``. + """ + + def __init__(self, *, lazy_discovery: bool = True) -> None: + """ + Initialize the registry. + + Args: + lazy_discovery (bool): If True, class discovery is deferred until first + access. If False, discovery runs immediately. + """ + super().__init__(lazy_discovery=lazy_discovery) + self.instances: InstanceRegistry[PromptConverter] = DefaultInstanceRegistry( + instance_type=_prompt_converter_type + ) + + def _get_registry_name(self, cls: type) -> str: + """ + Use the exact class name as the catalog key. + + Converters are referenced by their class name (e.g. ``"Base64Converter"``) + rather than the snake_case default used by other class registries. + + Returns: + str: The class name. + """ + return cls.__name__ + + def _discover(self) -> None: + """Discover all concrete ``PromptConverter`` subclasses from ``pyrit.prompt_converter``.""" + from pyrit import prompt_converter + from pyrit.prompt_converter import PromptConverter + + for name in prompt_converter.__all__: + cls = getattr(prompt_converter, name, None) + if cls is None or not isinstance(cls, type): + continue + if not issubclass(cls, PromptConverter) or cls is PromptConverter: + continue + self._class_entries[name] = ClassEntry(registered_class=cls) + logger.debug(f"Registered converter class: {name}") + + def _build_metadata(self, name: str, entry: ClassEntry[PromptConverter]) -> ConverterMetadata: + """ + Build catalog metadata for a ``PromptConverter`` class. + + Args: + name (str): The catalog name (exact class name) of the converter. + entry (ClassEntry[PromptConverter]): The class entry being described. + + Returns: + ConverterMetadata: Metadata describing the converter class. + """ + converter_class = entry.registered_class + + # First paragraph of the docstring as a short description. + raw_doc = (converter_class.__doc__ or "").strip() + description = raw_doc.split("\n\n")[0].replace("\n", " ").strip() + + supported_input_types = tuple(str(dt) for dt in getattr(converter_class, "SUPPORTED_INPUT_TYPES", ())) + supported_output_types = tuple(str(dt) for dt in getattr(converter_class, "SUPPORTED_OUTPUT_TYPES", ())) + + parameters = _extract_parameters(converter_class) + + return ConverterMetadata( + class_name=converter_class.__name__, + class_module=converter_class.__module__, + class_description=description, + registry_name=name, + supported_input_types=supported_input_types, + supported_output_types=supported_output_types, + parameters=parameters, + is_llm_based=any(p.requires_llm for p in parameters), + ) diff --git a/pyrit/registry/instance_registry.py b/pyrit/registry/instance_registry.py new file mode 100644 index 0000000000..cd23d49a25 --- /dev/null +++ b/pyrit/registry/instance_registry.py @@ -0,0 +1,441 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Instance-registry capability for PyRIT registries. + +A registry that retains pre-configured, named instances exposes that capability +as an ``.instances`` property whose type is the ``InstanceRegistry`` protocol. +The concrete default implementation is ``DefaultInstanceRegistry``. + +Modelling instance-holding as a typed property (rather than a base class) makes +the capability visible in the type: a function can accept "a registry that holds +instances" (``SupportsInstances``) or just "an instance registry" +(``InstanceRegistry[T]``) without depending on a concrete class, and a registry +that does not hold instances simply has no ``.instances`` attribute. + +Stored items must implement ``Identifiable`` so per-instance metadata can be +derived from ``get_identifier()``. This module imports no ``pyrit.backend`` code, +so it can be reused anywhere (forms, agents, attack strategies). +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, runtime_checkable + +from pyrit.models import ComponentIdentifier, Identifiable + +if TYPE_CHECKING: + from collections.abc import Callable, Iterator + +T = TypeVar("T", bound=Identifiable) # The type of items stored + + +@dataclass +class RegistryEntry(Generic[T]): + """ + A wrapper around a registered item, holding its name, tags, and the item itself. + + Tags are always stored as ``dict[str, str]``. When callers pass a plain + ``list[str]``, each string is normalized to a key with an empty-string value. + + Attributes: + name (str): The registry name for this entry. + instance (T): The registered object. + tags (dict[str, str]): Key-value tags for categorization and filtering. + metadata (dict[str, Any]): Arbitrary key-value metadata for capability flags + and other per-entry data that should not pollute the tag namespace. + """ + + name: str + instance: T + tags: dict[str, str] = field(default_factory=dict) + metadata: dict[str, Any] = field(default_factory=dict) + + +@runtime_checkable +class InstanceRegistry(Protocol[T]): + """ + Typed instance-container capability a registry exposes as ``.instances``. + + Holds named, pre-configured instances that callers register and retrieve by + name, list, tag, and filter. Stored items must implement ``Identifiable``. + ``DefaultInstanceRegistry`` is the concrete default implementation; expressing + the surface as a protocol lets callers depend on the capability rather than a + concrete class. + + Type Parameters: + T: The type of instances held (must be ``Identifiable``). + """ + + def register( + self, + instance: T, + *, + name: str | None = None, + tags: dict[str, str] | list[str] | None = None, + metadata: dict[str, Any] | None = None, + ) -> None: + """Register a pre-configured instance, defaulting its name to the identifier's ``unique_name``.""" + ... + + def get(self, name: str) -> T | None: + """Return the instance registered under ``name``, or None.""" + ... + + def get_entry(self, name: str) -> RegistryEntry[T] | None: + """Return the full entry (including tags) for ``name``, or None.""" + ... + + def get_all_instances(self) -> list[RegistryEntry[T]]: + """Return all entries sorted by name.""" + ... + + def get_by_tag(self, *, tag: str, value: str | None = None) -> list[RegistryEntry[T]]: + """Return entries carrying ``tag`` (optionally matching ``value``), sorted by name.""" + ... + + def add_tags(self, *, name: str, tags: dict[str, str] | list[str]) -> None: + """Add tags to an existing entry.""" + ... + + def find_dependents_of_tag(self, *, tag: str) -> list[RegistryEntry[T]]: + """Return entries whose identifier tree references a tagged entry's ``eval_hash``.""" + ... + + def list_metadata( + self, + *, + include_filters: dict[str, object] | None = None, + exclude_filters: dict[str, object] | None = None, + ) -> list[ComponentIdentifier]: + """List per-instance identifier metadata, optionally filtered.""" + ... + + def get_names(self) -> list[str]: + """Return the sorted names of registered instances.""" + ... + + def __contains__(self, name: str) -> bool: + """Check whether an instance name is registered.""" + ... + + def __len__(self) -> int: + """Return the number of registered instances.""" + ... + + def __iter__(self) -> Iterator[str]: + """Iterate over registered instance names.""" + ... + + +class SupportsInstances(Protocol[T]): + """ + Structural marker for a registry that holds instances. + + Lets callers and type-checkers express "a registry that holds instances" + without naming a concrete class, so a registry's capabilities are legible + from its type. + + Type Parameters: + T: The type of instances held (must be ``Identifiable``). + """ + + instances: InstanceRegistry[T] + + +class DefaultInstanceRegistry(Generic[T]): + """ + Concrete ``InstanceRegistry`` implementation assigned to ``.instances``. + + Holds named, pre-configured instances with tags and derived metadata. It owns + no singleton lifecycle — the registry that exposes it via ``.instances`` owns + that. + + Type Parameters: + T: The type of instances held (must be ``Identifiable``). + """ + + def __init__(self, *, instance_type: type[T] | Callable[[], type[T]] | None = None) -> None: + """ + Initialize an empty instance container. + + Args: + instance_type (type[T] | Callable[[], type[T]] | None): Optional expected + element type. When set, ``register`` raises ``TypeError`` for any + instance that is not of this type, so a registry scoped to one + component kind (e.g. a target registry) cannot silently hold a + different kind (e.g. a scorer). May be the class itself or a + zero-argument callable returning it; the callable form lets owners + defer importing the type so a registry's lazy discovery is preserved. + It is resolved once, on the first ``register`` call, and cached. + """ + self._registry_items: dict[str, RegistryEntry[T]] = {} + self._metadata_cache: list[ComponentIdentifier] | None = None + self._instance_type: type[T] | Callable[[], type[T]] | None = instance_type + + def _resolve_instance_type(self) -> type | None: + """ + Resolve and cache the configured expected element type, if any. + + Returns: + type | None: The expected type, or None when no constraint is set. + """ + if self._instance_type is None or isinstance(self._instance_type, type): + return self._instance_type + resolved = self._instance_type() + self._instance_type = resolved + return resolved + + @staticmethod + def _normalize_tags(tags: dict[str, str] | list[str] | None = None) -> dict[str, str]: + """ + Normalize tags into a ``dict[str, str]``. + + Args: + tags (dict[str, str] | list[str] | None): Tags as a dict, a list of + string keys (values default to ``""``), or None (empty dict). + + Returns: + dict[str, str]: The normalized tags. + """ + if tags is None: + return {} + if isinstance(tags, list): + return dict.fromkeys(tags, "") + return dict(tags) + + def register( + self, + instance: T, + *, + name: str | None = None, + tags: dict[str, str] | list[str] | None = None, + metadata: dict[str, Any] | None = None, + ) -> None: + """ + Register a pre-configured instance. + + Args: + instance (T): The instance to register. + name (str | None): The registry name. Defaults to the instance's + identifier ``unique_name``. + tags (dict[str, str] | list[str] | None): Optional tags for + categorization. + metadata (dict[str, Any] | None): Optional per-entry metadata. + + Raises: + TypeError: If this registry was created with an ``instance_type`` and + ``instance`` is not of that type. + """ + expected_type = self._resolve_instance_type() + if expected_type is not None and not isinstance(instance, expected_type): + raise TypeError( + f"Cannot register a {type(instance).__name__!r} instance in a registry " + f"of {expected_type.__name__!r} instances." + ) + + if name is None: + name = instance.get_identifier().unique_name + + self._registry_items[name] = RegistryEntry( + name=name, + instance=instance, + tags=self._normalize_tags(tags), + metadata=metadata or {}, + ) + self._metadata_cache = None + + def get(self, name: str) -> T | None: + """ + Get a registered instance by name. + + Args: + name (str): The registry name of the instance. + + Returns: + T | None: The instance, or None if not found. + """ + entry = self._registry_items.get(name) + return entry.instance if entry is not None else None + + def get_entry(self, name: str) -> RegistryEntry[T] | None: + """ + Get the full entry (including tags) by name. + + Args: + name (str): The registry name of the entry. + + Returns: + RegistryEntry[T] | None: The entry, or None if not found. + """ + return self._registry_items.get(name) + + def get_all_instances(self) -> list[RegistryEntry[T]]: + """ + Get all registered entries sorted by name. + + Returns: + list[RegistryEntry[T]]: The entries sorted by name. + """ + return [self._registry_items[name] for name in sorted(self._registry_items.keys())] + + def get_names(self) -> list[str]: + """ + Get a sorted list of all registered instance names. + + Returns: + list[str]: The instance names sorted alphabetically. + """ + return sorted(self._registry_items.keys()) + + def get_by_tag(self, *, tag: str, value: str | None = None) -> list[RegistryEntry[T]]: + """ + Get entries that carry a given tag, optionally matching a value. + + Args: + tag (str): The tag key to match. + value (str | None): If provided, only entries whose tag value equals + this are returned. If None, any entry with the tag key matches. + + Returns: + list[RegistryEntry[T]]: Matching entries sorted by name. + """ + results: list[RegistryEntry[T]] = [] + for name in sorted(self._registry_items.keys()): + entry = self._registry_items[name] + if tag in entry.tags and (value is None or entry.tags[tag] == value): + results.append(entry) + return results + + def add_tags(self, *, name: str, tags: dict[str, str] | list[str]) -> None: + """ + Add tags to an existing entry. + + Args: + name (str): The registry name of the entry to tag. + tags (dict[str, str] | list[str]): Tags to add. + + Raises: + KeyError: If no entry with the given name exists. + """ + entry = self._registry_items.get(name) + if entry is None: + raise KeyError(f"No instance named '{name}' in registry.") + entry.tags.update(self._normalize_tags(tags)) + self._metadata_cache = None + + def find_dependents_of_tag(self, *, tag: str) -> list[RegistryEntry[T]]: + """ + Find entries whose children depend on entries with the given tag. + + Scans each entry's ``ComponentIdentifier`` tree and checks whether any + child's ``eval_hash`` matches the ``eval_hash`` of an entry that carries + ``tag``. Entries that themselves carry ``tag`` are excluded. + + This enables automatic dependency detection: for example, tagging base + refusal scorers with ``"refusal"`` lets you discover all wrapper scorers + (inverters, composites) that embed a refusal scorer without any explicit + ``depends_on`` declaration. + + Args: + tag (str): The tag key that identifies the "base" entries. + + Returns: + list[RegistryEntry[T]]: Entries that depend on tagged entries, sorted + by name. + """ + tagged_hashes: set[str] = set() + tagged_names: set[str] = set() + for entry in self.get_by_tag(tag=tag): + tagged_names.add(entry.name) + identifier = self._build_metadata(entry.instance) + if identifier.eval_hash: + tagged_hashes.add(identifier.eval_hash) + + if not tagged_hashes: + return [] + + dependents: list[RegistryEntry[T]] = [] + for name in sorted(self._registry_items.keys()): + if name in tagged_names: + continue + entry = self._registry_items[name] + identifier = self._build_metadata(entry.instance) + child_hashes = identifier._collect_child_eval_hashes() + if child_hashes & tagged_hashes: + dependents.append(entry) + return dependents + + def list_metadata( + self, + *, + include_filters: dict[str, object] | None = None, + exclude_filters: dict[str, object] | None = None, + ) -> list[ComponentIdentifier]: + """ + List metadata for all registered instances, optionally filtered. + + Args: + include_filters (dict[str, object] | None): Filters items must match. + exclude_filters (dict[str, object] | None): Filters items must not match. + + Returns: + list[ComponentIdentifier]: The identifier metadata for each instance. + """ + from pyrit.registry.base import _matches_filters + + if self._metadata_cache is None: + self._metadata_cache = [ + self._build_metadata(self._registry_items[name].instance) + for name in sorted(self._registry_items.keys()) + ] + + if not include_filters and not exclude_filters: + return self._metadata_cache + + return [ + m + for m in self._metadata_cache + if _matches_filters(m, include_filters=include_filters, exclude_filters=exclude_filters) + ] + + def _build_metadata(self, instance: T) -> ComponentIdentifier: + """ + Build metadata for an item via its ``Identifiable`` interface. + + Args: + instance (T): The item. + + Returns: + ComponentIdentifier: The item's identifier. + """ + return instance.get_identifier() + + def __contains__(self, name: str) -> bool: + """ + Check if an instance name is registered. + + Returns: + bool: True if the instance name is registered, False otherwise. + """ + return name in self._registry_items + + def __len__(self) -> int: + """ + Get the count of registered instances. + + Returns: + int: The number of registered instances. + """ + return len(self._registry_items) + + def __iter__(self) -> Iterator[str]: + """ + Iterate over registered instance names. + + Returns: + Iterator[str]: An iterator over sorted instance names. + """ + return iter(sorted(self._registry_items.keys())) diff --git a/pyrit/registry/object_registries/__init__.py b/pyrit/registry/object_registries/__init__.py index b6edf16088..9694f12bef 100644 --- a/pyrit/registry/object_registries/__init__.py +++ b/pyrit/registry/object_registries/__init__.py @@ -18,9 +18,6 @@ BaseInstanceRegistry, RegistryEntry, ) -from pyrit.registry.object_registries.converter_registry import ( - ConverterRegistry, -) from pyrit.registry.object_registries.retrievable_instance_registry import ( RetrievableInstanceRegistry, ) @@ -38,7 +35,6 @@ "RegistryEntry", # Concrete registries "AttackTechniqueRegistry", - "ConverterRegistry", "ScorerRegistry", "TargetRegistry", ] diff --git a/pyrit/registry/object_registries/base_instance_registry.py b/pyrit/registry/object_registries/base_instance_registry.py index f076b8f833..4982aac08c 100644 --- a/pyrit/registry/object_registries/base_instance_registry.py +++ b/pyrit/registry/object_registries/base_instance_registry.py @@ -4,6 +4,17 @@ """ Base instance registry for PyRIT. +.. note:: + + **Legacy stack — do not build new registries on this.** New component + registries should subclass ``BuildableRegistry`` (a class catalog that can + build instances by name) and hold pre-configured instances via the + ``.instances`` property (a ``DefaultInstanceRegistry``). See + ``ConverterRegistry`` for the target shape. This class and + ``RetrievableInstanceRegistry`` remain only because ``TargetRegistry``, + ``ScorerRegistry``, and ``AttackTechniqueRegistry`` still subclass them; + the whole stack is removed once those migrate. + This module provides ``BaseInstanceRegistry``, the shared infrastructure for registries that store ``Identifiable`` objects (not classes): singleton lifecycle, registration, tags, metadata, container protocol. @@ -19,46 +30,39 @@ from __future__ import annotations from abc import ABC -from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Generic, TypeVar from pyrit.models import ComponentIdentifier, Identifiable from pyrit.registry.base import RegistryProtocol +from pyrit.registry.instance_registry import RegistryEntry if TYPE_CHECKING: from collections.abc import Iterator from typing_extensions import Self -T = TypeVar("T", bound=Identifiable) # The type of items stored - - -@dataclass -class RegistryEntry(Generic[T]): - """ - A wrapper around a registered item, holding its name, tags, and the item itself. - - Tags are always stored as ``dict[str, str]``. When callers pass a plain - ``list[str]``, each string is normalized to a key with an empty-string value. - - Attributes: - name: The registry name for this entry. - instance: The registered object. - tags: Key-value tags for categorization and filtering. - metadata: Arbitrary key-value metadata for capability flags and - other per-entry data that should not pollute the tag namespace. - """ +# Re-exported for back-compat; the canonical definition now lives in +# ``pyrit.registry.instance_registry`` alongside the new instance-registry capability. +__all__ = ["BaseInstanceRegistry", "RegistryEntry"] - name: str - instance: T - tags: dict[str, str] = field(default_factory=dict) - metadata: dict[str, Any] = field(default_factory=dict) +T = TypeVar("T", bound=Identifiable) # The type of items stored class BaseInstanceRegistry(ABC, RegistryProtocol[ComponentIdentifier], Generic[T]): """ Abstract base class providing shared registry infrastructure. + .. note:: + + **Legacy — do not subclass for new registries.** New component + registries subclass ``BuildableRegistry`` and expose retained instances + via the ``.instances`` property (``DefaultInstanceRegistry``), which + carries this same surface (``register``/``get``/``get_by_tag``/ + ``add_tags``/``find_dependents_of_tag``/``list_metadata``). This class + survives only for the not-yet-migrated ``TargetRegistry``, + ``ScorerRegistry``, and ``AttackTechniqueRegistry`` and is removed once + they move to ``.instances``. + Provides singleton lifecycle, registration, tag-based lookup, metadata filtering, and the standard container protocol (``__contains__``, ``__len__``, ``__iter__``). diff --git a/pyrit/registry/object_registries/converter_registry.py b/pyrit/registry/object_registries/converter_registry.py deleted file mode 100644 index 568d1e6332..0000000000 --- a/pyrit/registry/object_registries/converter_registry.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -Converter registry for managing PyRIT converter instances. - -Converters are registered explicitly via initializers as pre-configured instances. - -NOTE: This is a placeholder implementation. A full implementation will be added soon. -""" - -from __future__ import annotations - -import logging -from typing import TYPE_CHECKING - -from pyrit.registry.object_registries.retrievable_instance_registry import ( - RetrievableInstanceRegistry, -) - -if TYPE_CHECKING: - from pyrit.prompt_converter import PromptConverter - -logger = logging.getLogger(__name__) - - -class ConverterRegistry(RetrievableInstanceRegistry["PromptConverter"]): - """ - Registry for managing available converter instances. - - This registry stores pre-configured PromptConverter instances (not classes). - Converters are registered explicitly via initializers after being instantiated - with their required parameters. - """ - - def register_instance( - self, - converter: PromptConverter, - *, - name: str | None = None, - tags: dict[str, str] | list[str] | None = None, - ) -> None: - """ - Register a converter instance. - - Args: - converter: The pre-configured converter instance (not a class). - name: Optional custom registry name. If not provided, - derived from the converter's unique identifier. - tags: Optional tags for categorisation. Accepts a ``dict[str, str]`` - or a ``list[str]`` (each string becomes a key with value ``""``). - """ - if name is None: - name = converter.get_identifier().unique_name - - self.register(converter, name=name, tags=tags) - logger.debug(f"Registered converter instance: {name} ({converter.__class__.__name__})") - - def get_instance_by_name(self, name: str) -> PromptConverter | None: - """ - Get a registered converter instance by name. - - Args: - name: The registry name of the converter. - - Returns: - The converter instance, or None if not found. - """ - return self.get(name) diff --git a/pyrit/registry/object_registries/retrievable_instance_registry.py b/pyrit/registry/object_registries/retrievable_instance_registry.py index 7498b3bf21..69e07fb0f3 100644 --- a/pyrit/registry/object_registries/retrievable_instance_registry.py +++ b/pyrit/registry/object_registries/retrievable_instance_registry.py @@ -4,11 +4,20 @@ """ Retrievable instance registry for PyRIT. +.. note:: + + **Legacy stack — do not build new registries on this.** New component + registries subclass ``BuildableRegistry`` and retain instances via the + ``.instances`` property (``DefaultInstanceRegistry``), which already + provides ``get``/``get_entry``/``get_all_instances``. See + ``ConverterRegistry`` for the target shape. This class remains only for the + not-yet-migrated ``ScorerRegistry`` and ``TargetRegistry`` and is removed + once they migrate. + This module provides ``RetrievableInstanceRegistry``, which extends ``BaseInstanceRegistry`` with ``get()``, ``get_entry()``, and ``get_all_instances()`` for registries where callers retrieve stored -objects directly (e.g., ``ScorerRegistry``, ``ConverterRegistry``, -``TargetRegistry``). +objects directly (e.g., ``ScorerRegistry``, ``TargetRegistry``). For the shared base class, see ``base_instance_registry``. For registries that store classes (type[T]), see ``class_registries/``. @@ -30,6 +39,14 @@ class RetrievableInstanceRegistry(BaseInstanceRegistry[T]): """ Base class for registries that store directly-retrievable instances. + .. note:: + + **Legacy — do not subclass for new registries.** Use + ``BuildableRegistry`` + the ``.instances`` property + (``DefaultInstanceRegistry``), which already exposes + ``get``/``get_entry``/``get_all_instances``. Retained only for the + not-yet-migrated ``ScorerRegistry`` and ``TargetRegistry``. + Extends ``BaseInstanceRegistry`` with ``get()``, ``get_entry()``, and ``get_all_instances()`` for registries where callers retrieve the stored objects directly (e.g., scorers, converters, targets). diff --git a/pyrit/registry/resolution.py b/pyrit/registry/resolution.py new file mode 100644 index 0000000000..1b6a75ad75 --- /dev/null +++ b/pyrit/registry/resolution.py @@ -0,0 +1,308 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Constructor-argument resolution for PyRIT registries. + +This is the shared mechanism that lets any registry build an instance from a +type name plus a flat dict of arguments. Build inputs are exactly two kinds: + +- **Simple values** — strings/ints/floats/bools (and ``Literal`` choices) that + can be coerced to the constructor's annotated type. +- **Registry references** — a parameter whose annotation is a domain base type + (``PromptTarget``, ``PromptConverter``, ``Scorer``) is supplied *by name* and + resolved from that domain's registry. An already-constructed instance passes + through unchanged. + +Unknown parameters raise, so a caller (form, agent, attack strategy) gets a +clear error instead of having values silently dropped. + +This module performs no eager heavy imports and never imports ``pyrit.backend``: +the resolvable-registry lookups are done lazily so it can be reused anywhere. +""" + +from __future__ import annotations + +import inspect +import types +from typing import TYPE_CHECKING, Any, Literal, Protocol, Union, get_args, get_origin + +if TYPE_CHECKING: + from collections.abc import Callable + +# Scalar Python types whose string values can be coerced to the real type. +_SIMPLE_TYPES: set[type] = {str, int, float, bool} + + +class _NamedInstanceRegistry(Protocol): + """Structural type for a registry that resolves stored instances by name.""" + + def get(self, name: str) -> Any | None: + """Return the instance registered under ``name``, or None.""" + ... + + def get_names(self) -> list[str]: + """Return the sorted names of registered instances.""" + ... + + +def get_union_non_none_args(annotation: Any) -> list[Any] | None: + """ + Return the non-``None`` members of a union annotation, or None if not a union. + + Handles both ``typing.Union[X, None]`` and PEP 604 ``X | None``. This is a + general type-introspection utility (not presentation), reused by coercion, + registry-reference detection, and callers that need to render a type. + + Args: + annotation (Any): The type annotation to inspect. + + Returns: + list[Any] | None: The non-None union members, or None when the annotation + is not a union. + """ + origin = get_origin(annotation) + if origin is Union or origin is types.UnionType: + return [a for a in get_args(annotation) if a is not type(None)] + return None + + +def is_coercible_from_string(annotation: Any) -> bool: + """ + Return True if a string value can be coerced to the annotated type. + + Covers the scalar types in ``_SIMPLE_TYPES`` (str/int/float/bool), + ``Literal`` annotations, and an ``Optional`` wrapping one of those. + + Returns: + bool: True if the annotation is coercible from a string, False otherwise. + """ + if annotation in _SIMPLE_TYPES: + return True + if get_origin(annotation) is Literal: + return True + non_none = get_union_non_none_args(annotation) + if non_none is not None: + return len(non_none) == 1 and is_coercible_from_string(non_none[0]) + return False + + +def _resolvable_registries() -> list[tuple[type, Callable[[], _NamedInstanceRegistry]]]: + """ + Return the (base type -> registry singleton getter) pairs that can be resolved by name. + + A constructor parameter whose annotation is (a subclass of) one of these base + types is supplied by name and looked up in the paired registry. Imports are + deferred so this core module stays import-light and free of cycles. + + Returns: + list[tuple[type, Callable[[], _NamedInstanceRegistry]]]: The resolvable + domain base types paired with a callable returning their registry singleton. + """ + from pyrit.prompt_converter import PromptConverter + from pyrit.prompt_target import PromptTarget + from pyrit.registry.components import ConverterRegistry + from pyrit.registry.object_registries import ( + ScorerRegistry, + TargetRegistry, + ) + from pyrit.score.scorer import Scorer + + return [ + (PromptTarget, TargetRegistry.get_registry_singleton), + (PromptConverter, lambda: ConverterRegistry.get_registry_singleton().instances), + (Scorer, ScorerRegistry.get_registry_singleton), + ] + + +def get_resolvable_registry_getter(annotation: Any) -> Callable[[], _NamedInstanceRegistry] | None: + """ + Return the registry-singleton getter for a registry-reference annotation. + + The annotation matches when it is (or unions, e.g. ``X | None``, to) a subclass + of a resolvable domain base type. A parameter with such an annotation is + supplied by name and resolved from the returned registry. + + Args: + annotation (Any): The parameter's type annotation. + + Returns: + Callable[[], _NamedInstanceRegistry] | None: A callable returning the + registry singleton, or None when the annotation is not a registry reference. + """ + if annotation is inspect.Parameter.empty: + return None + + candidates = get_union_non_none_args(annotation) + if candidates is None: + candidates = [annotation] + + for base_type, getter in _resolvable_registries(): + for candidate in candidates: + try: + if isinstance(candidate, type) and issubclass(candidate, base_type): + return getter + except TypeError: + continue + return None + + +def is_registry_reference(annotation: Any) -> bool: + """ + Return True if the annotation is a registry reference (resolved by name). + + Returns: + bool: True if a value for this parameter is supplied by name and resolved + from a registry, False otherwise. + """ + return get_resolvable_registry_getter(annotation) is not None + + +def coerce_string_to_annotation(*, value: str, annotation: Any) -> Any: + """ + Coerce a string value to the annotated scalar type (int/float/bool/Literal). + + ``Optional[X]`` / ``X | None`` is unwrapped to ``X`` first. A ``Literal`` value + is validated against the allowed members and returned as the matching member + (so an int literal comes back as an ``int``); other ``str`` values pass through + unchanged. + + Args: + value (str): The raw string value. + annotation (Any): The parameter's type annotation. + + Returns: + Any: The value coerced to the annotated type, or the original string when + no numeric/boolean/Literal coercion applies. + + Raises: + ValueError: If the value cannot be interpreted as the annotated type, or is + not one of the allowed members of an annotated ``Literal``. + """ + if annotation is inspect.Parameter.empty: + return value + + non_none = get_union_non_none_args(annotation) + if non_none is not None and len(non_none) == 1: + annotation = non_none[0] + + if get_origin(annotation) is Literal: + allowed = get_args(annotation) + for member in allowed: + if value == str(member): + return member + raise ValueError(f"expected one of {[str(a) for a in allowed]}, got {value!r}") + + if annotation is int: + return int(value) + if annotation is float: + return float(value) + if annotation is bool: + lowered = value.strip().lower() + if lowered in ("true", "1", "yes"): + return True + if lowered in ("false", "0", "no"): + return False + raise ValueError(f"cannot interpret {value!r} as a boolean") + return value + + +def _resolve_registry_reference( + *, value: Any, getter: Callable[[], _NamedInstanceRegistry], owner: str, name: str +) -> Any: + """ + Resolve a registry-reference parameter value to a stored instance. + + A string value is looked up by name in the paired registry. An already-built + instance passes through unchanged. + + Args: + value (Any): The raw value (a registry name, or an instance to pass through). + getter (Callable[[], _NamedInstanceRegistry]): Returns the registry singleton. + owner (str): The owning class name, for error messages. + name (str): The parameter name, for error messages. + + Returns: + Any: The resolved instance. + + Raises: + ValueError: If the name is not registered. + """ + if not isinstance(value, str): + return value + + registry = getter() + instance = registry.get(value) + if instance is not None: + return instance + + registry_label = type(registry).__name__ + available_names = registry.get_names() + if not available_names: + raise ValueError( + f"{owner}.{name}: '{value}' not found. The {registry_label} is empty. " + "Make sure to register instances (e.g. via an initializer) before building " + "components that reference them by name." + ) + raise ValueError( + f"{owner}.{name}: '{value}' not found in {registry_label}. Available: {', '.join(available_names)}" + ) + + +def resolve_constructor_args(*, cls: type, raw_args: dict[str, Any]) -> dict[str, Any]: + """ + Resolve a flat argument dict into constructor-ready keyword arguments. + + For each argument: validate it is a real constructor parameter (unless the + constructor accepts ``**kwargs``); resolve registry-reference parameters by + name; coerce simple string values to their annotated scalar type; pass + everything else through unchanged. + + Args: + cls (type): The class whose ``__init__`` signature drives resolution. + raw_args (dict[str, Any]): The raw argument values (e.g. from a form or agent). + + Returns: + dict[str, Any]: Arguments ready to pass to ``cls(**resolved)``. + + Raises: + ValueError: If the signature cannot be inspected, an argument is not a + valid constructor parameter, a registry reference cannot be resolved, + or a simple value cannot be coerced. + """ + try: + sig = inspect.signature(cls.__init__) + except (ValueError, TypeError) as e: + raise ValueError(f"Failed to inspect __init__ signature for '{cls.__name__}': {e}") from e + + accepts_var_kwargs = any(p.kind is inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()) + valid_params = { + param_name: p + for param_name, p in sig.parameters.items() + if param_name != "self" and p.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD) + } + + resolved: dict[str, Any] = {} + for name, value in raw_args.items(): + param = valid_params.get(name) + if param is None and not accepts_var_kwargs: + raise ValueError( + f"Unknown parameter '{name}' for '{cls.__name__}'. Valid parameters: {sorted(valid_params.keys())}" + ) + + annotation = param.annotation if param is not None else inspect.Parameter.empty + + registry_getter = get_resolvable_registry_getter(annotation) + if registry_getter is not None: + resolved[name] = _resolve_registry_reference( + value=value, getter=registry_getter, owner=cls.__name__, name=name + ) + elif isinstance(value, str) and is_coercible_from_string(annotation): + try: + resolved[name] = coerce_string_to_annotation(value=value, annotation=annotation) + except (ValueError, TypeError) as e: + raise ValueError(f"Parameter '{name}' of '{cls.__name__}': {e}") from e + else: + resolved[name] = value + + return resolved diff --git a/tests/unit/backend/test_converter_service.py b/tests/unit/backend/test_converter_service.py index 6b41ef93af..c9a5e793ce 100644 --- a/tests/unit/backend/test_converter_service.py +++ b/tests/unit/backend/test_converter_service.py @@ -14,28 +14,25 @@ ConverterPreviewRequest, CreateConverterRequest, ) -from pyrit.backend.services.converter_service import ConverterService, _is_llm_based, get_converter_service +from pyrit.backend.services.converter_service import ( + ConverterService, + _serialize_type, + get_converter_service, +) from pyrit.models import ComponentIdentifier from pyrit.prompt_converter import ( Base64Converter, CaesarConverter, - LLMGenericTextConverter, - NoiseConverter, - PersuasionConverter, RepeatTokenConverter, SuffixAppendConverter, - TenseConverter, - ToneConverter, - TranslationConverter, - VariationConverter, ) from pyrit.prompt_converter.prompt_converter import get_converter_modalities -from pyrit.registry.object_registries import ConverterRegistry +from pyrit.registry.components import ConverterRegistry @pytest.fixture(autouse=True) def reset_registry(): - """Reset the ConverterRegistry singleton before each test.""" + """Reset the converter registry before each test.""" ConverterRegistry.reset_instance() yield ConverterRegistry.reset_instance() @@ -56,8 +53,7 @@ async def test_list_converters_returns_converters_from_registry(self) -> None: """Test that list_converters returns converters from registry with full params.""" service = ConverterService() - mock_converter = MagicMock() - mock_converter.__class__.__name__ = "MockConverter" + mock_converter = MagicMock(spec=prompt_converter.PromptConverter) mock_identifier = ComponentIdentifier( class_name="MockConverter", class_module="tests.unit.backend.test_converter_service", @@ -69,7 +65,7 @@ async def test_list_converters_returns_converters_from_registry(self) -> None: }, ) mock_converter.get_identifier.return_value = mock_identifier - service._registry.register_instance(mock_converter, name="conv-1") + service._registry.instances.register(mock_converter, name="conv-1") result = await service.list_converters_async() @@ -104,6 +100,60 @@ async def test_list_converter_catalog_includes_supported_types(self) -> None: assert "text" in base64_entry.supported_input_types assert "text" in base64_entry.supported_output_types + async def test_catalog_includes_all_constructible_converters(self) -> None: + """The catalog surfaces every constructible converter, including base/helper classes. + + Whether to display a given converter is left to the caller (e.g. the frontend), + so the service no longer hides anything. + """ + service = ConverterService() + + result = await service.list_converter_catalog_async() + + converter_types = [item.converter_type for item in result.items] + assert "Base64Converter" in converter_types + assert "SelectiveTextConverter" in converter_types + + async def test_catalog_serializes_parameter_type(self) -> None: + """Catalog renders the raw annotation into a human-readable type_name.""" + service = ConverterService() + + result = await service.list_converter_catalog_async() + + caesar_entry = next(item for item in result.items if item.converter_type == "CaesarConverter") + caesar_param = next(p for p in caesar_entry.parameters if p.name == "caesar_offset") + assert caesar_param.type_name == "int" + + async def test_catalog_excludes_non_coercible_params(self) -> None: + """Catalog only surfaces params that can be set from a string (e.g. not the LLM target).""" + service = ConverterService() + + result = await service.list_converter_catalog_async() + + persuasion_entry = next(item for item in result.items if item.converter_type == "PersuasionConverter") + assert persuasion_entry.is_llm_based is True + assert all("Target" not in p.type_name for p in persuasion_entry.parameters) + + +class TestSerializeType: + """Tests for the _serialize_type presentation helper.""" + + def test_empty_annotation(self) -> None: + import inspect + + assert _serialize_type(inspect.Parameter.empty) == "Any" + + def test_plain_type(self) -> None: + assert _serialize_type(int) == "int" + + def test_optional_pep604(self) -> None: + assert _serialize_type(str | None) == "Optional[str]" + + def test_literal(self) -> None: + from typing import Literal + + assert _serialize_type(Literal["a", "b"]) == "Literal['a', 'b']" + class TestGetConverter: """Tests for ConverterService.get_converter method.""" @@ -120,8 +170,7 @@ async def test_get_converter_returns_converter_from_registry(self) -> None: """Test that get_converter returns converter built from registry object.""" service = ConverterService() - mock_converter = MagicMock() - mock_converter.__class__.__name__ = "MockConverter" + mock_converter = MagicMock(spec=prompt_converter.PromptConverter) mock_identifier = ComponentIdentifier( class_name="MockConverter", class_module="tests.unit.backend.test_converter_service", @@ -132,7 +181,7 @@ async def test_get_converter_returns_converter_from_registry(self) -> None: }, ) mock_converter.get_identifier.return_value = mock_identifier - service._registry.register_instance(mock_converter, name="conv-1") + service._registry.instances.register(mock_converter, name="conv-1") result = await service.get_converter_async(converter_id="conv-1") @@ -155,8 +204,8 @@ def test_get_converter_object_returns_none_for_nonexistent(self) -> None: def test_get_converter_object_returns_object_from_registry(self) -> None: """Test that get_converter_object returns the actual converter object.""" service = ConverterService() - mock_converter = MagicMock() - service._registry.register_instance(mock_converter, name="conv-1") + mock_converter = MagicMock(spec=prompt_converter.PromptConverter) + service._registry.instances.register(mock_converter, name="conv-1") result = service.get_converter_object(converter_id="conv-1") @@ -227,8 +276,8 @@ def test_resolve_converter_params_resolves_converter_id_reference(self) -> None: service = ConverterService() # Register a mock converter - mock_converter = MagicMock() - service._registry.register_instance(mock_converter, name="inner-conv") + mock_converter = MagicMock(spec=prompt_converter.PromptConverter) + service._registry.instances.register(mock_converter, name="inner-conv") params = {"converter": {"converter_id": "inner-conv"}} @@ -275,13 +324,12 @@ async def test_preview_conversion_with_converter_ids(self) -> None: """Test preview with converter IDs.""" service = ConverterService() - mock_converter = MagicMock() - mock_converter.__class__.__name__ = "MockConverter" + mock_converter = MagicMock(spec=prompt_converter.PromptConverter) mock_result = MagicMock() mock_result.output_text = "encoded_value" mock_result.output_type = "text" mock_converter.convert_async = AsyncMock(return_value=mock_result) - service._registry.register_instance(mock_converter, name="conv-1") + service._registry.instances.register(mock_converter, name="conv-1") request = ConverterPreviewRequest( original_value="test", @@ -300,22 +348,20 @@ async def test_preview_conversion_chains_multiple_converters(self) -> None: """Test that preview chains multiple converters.""" service = ConverterService() - mock_converter1 = MagicMock() - mock_converter1.__class__.__name__ = "MockConverter1" + mock_converter1 = MagicMock(spec=prompt_converter.PromptConverter) mock_result1 = MagicMock() mock_result1.output_text = "step1_output" mock_result1.output_type = "text" mock_converter1.convert_async = AsyncMock(return_value=mock_result1) - mock_converter2 = MagicMock() - mock_converter2.__class__.__name__ = "MockConverter2" + mock_converter2 = MagicMock(spec=prompt_converter.PromptConverter) mock_result2 = MagicMock() mock_result2.output_text = "step2_output" mock_result2.output_type = "text" mock_converter2.convert_async = AsyncMock(return_value=mock_result2) - service._registry.register_instance(mock_converter1, name="conv-1") - service._registry.register_instance(mock_converter2, name="conv-2") + service._registry.instances.register(mock_converter1, name="conv-1") + service._registry.instances.register(mock_converter2, name="conv-2") request = ConverterPreviewRequest( original_value="input", @@ -344,10 +390,10 @@ def test_get_converter_objects_for_ids_returns_objects(self) -> None: """Test that method returns converter objects in order.""" service = ConverterService() - mock1 = MagicMock() - mock2 = MagicMock() - service._registry.register_instance(mock1, name="conv-1") - service._registry.register_instance(mock2, name="conv-2") + mock1 = MagicMock(spec=prompt_converter.PromptConverter) + mock2 = MagicMock(spec=prompt_converter.PromptConverter) + service._registry.instances.register(mock1, name="conv-1") + service._registry.instances.register(mock2, name="conv-2") result = service.get_converter_objects_for_ids(converter_ids=["conv-1", "conv-2"]) @@ -607,25 +653,3 @@ def test_base64_converter_default_params(self) -> None: # Verify type info is populated from identifier assert isinstance(result.supported_input_types, list) assert isinstance(result.supported_output_types, list) - - -class TestIsLlmBased: - """Tests for the _is_llm_based introspection helper""" - - def test_detects_llm_text_converter(self) -> None: - # Test that _is_llm_based correctly identifies converters that use LLMS as LLM-based. - for cls in ( - LLMGenericTextConverter, - NoiseConverter, - PersuasionConverter, - ToneConverter, - TenseConverter, - TranslationConverter, - VariationConverter, - ): - assert _is_llm_based(cls) is True, f"{cls.__name__} should be detected as LLM-based" - - def test_does_not_flag_non_target_converters(self) -> None: - # Test that _is_llm_based does not incorrectly flag non-LLM converters. - assert _is_llm_based(Base64Converter) is False - assert _is_llm_based(CaesarConverter) is False diff --git a/tests/unit/backend/test_scenario_run_service.py b/tests/unit/backend/test_scenario_run_service.py index 0a7463d1f0..15116a0ac9 100644 --- a/tests/unit/backend/test_scenario_run_service.py +++ b/tests/unit/backend/test_scenario_run_service.py @@ -110,6 +110,7 @@ def mock_all_registries(mock_memory): mock_sr = MagicMock() mock_sr.get_class.return_value = mock_scenario_class + mock_sr.create_instance.return_value = mock_scenario_instance mock_tr = MagicMock() mock_tr.get_instance_by_name.return_value = MagicMock() @@ -452,23 +453,25 @@ async def test_start_run_runs_initializers(self, mock_all_registries) -> None: assert mock_init_instance.initialize_async.await_count == 2 async def test_start_run_passes_scenario_result_id_for_resume(self, mock_all_registries) -> None: - """Test that scenario_result_id is passed to the scenario constructor for resumption.""" + """Test that scenario_result_id is passed to the registry constructor for resumption.""" service = ScenarioRunService() - mock_scenario_class = mock_all_registries["scenario_class"] + mock_sr = mock_all_registries["scenario_registry"] response = await service.start_run_async(request=_make_request(scenario_result_id="existing-result-uuid")) assert response.status == ScenarioRunStatus.IN_PROGRESS - mock_scenario_class.assert_called_once_with(scenario_result_id="existing-result-uuid") + mock_sr.create_instance.assert_called_once_with( + "foundry.red_team_agent", scenario_result_id="existing-result-uuid" + ) async def test_start_run_omits_scenario_result_id_when_none(self, mock_all_registries) -> None: - """Test that scenario_result_id is not passed to constructor when not provided.""" + """Test that scenario_result_id is not passed to the registry constructor when not provided.""" service = ScenarioRunService() - mock_scenario_class = mock_all_registries["scenario_class"] + mock_sr = mock_all_registries["scenario_registry"] await service.start_run_async(request=_make_request()) - mock_scenario_class.assert_called_once_with() + mock_sr.create_instance.assert_called_once_with("foundry.red_team_agent") class TestScenarioRunServiceGetRun: diff --git a/tests/unit/registry/test_converter_registry.py b/tests/unit/registry/test_converter_registry.py index 7fa2de4599..926de6835e 100644 --- a/tests/unit/registry/test_converter_registry.py +++ b/tests/unit/registry/test_converter_registry.py @@ -1,9 +1,63 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from pyrit.models import ComponentIdentifier, PromptDataType -from pyrit.prompt_converter import ConverterResult, PromptConverter -from pyrit.registry.object_registries.converter_registry import ConverterRegistry +""" +Tests for the merged ``ConverterRegistry`` (buildable catalog + instance container) +and its introspection helpers. +""" + +from typing import Literal + +import pytest + +from pyrit.models import ComponentIdentifier, Message, MessagePiece, PromptDataType +from pyrit.prompt_converter import ( + Base64Converter, + CaesarConverter, + ConverterResult, + LLMGenericTextConverter, + NoiseConverter, + PersuasionConverter, + PromptConverter, + TenseConverter, + ToneConverter, + TranslationConverter, + VariationConverter, +) +from pyrit.prompt_target import PromptTarget, TargetCapabilities, TargetConfiguration +from pyrit.registry.components import ( + ConverterMetadata, + ConverterRegistry, +) +from pyrit.registry.components.converter_registry import ( + _extract_parameters, + _requires_llm_target, +) +from pyrit.registry.object_registries import ( + TargetRegistry, +) + + +class MockPromptTarget(PromptTarget): + """Minimal PromptTarget (with LLM-converter capabilities) for resolution tests.""" + + _DEFAULT_CONFIGURATION: TargetConfiguration = TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_turn=True, + supports_multi_message_pieces=True, + supports_system_prompt=True, + supports_editable_history=True, + ) + ) + + def __init__(self, *, model_name: str = "mock_model") -> None: + super().__init__(model_name=model_name) + + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + return [MessagePiece(role="assistant", original_value="mock response").to_message()] + + def _validate_request(self, *, normalized_conversation: list[Message]) -> None: + pass class MockTextConverter(PromptConverter): @@ -15,10 +69,6 @@ class MockTextConverter(PromptConverter): async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text") -> ConverterResult: """Convert prompt (no-op for testing). - Args: - prompt (str): The prompt to convert. - input_type (PromptDataType): The input type. Defaults to "text". - Returns: ConverterResult: The unchanged prompt. """ @@ -34,10 +84,6 @@ class MockImageConverter(PromptConverter): async def convert_async(self, *, prompt: str, input_type: PromptDataType = "image_path") -> ConverterResult: """Convert prompt (no-op for testing). - Args: - prompt (str): The prompt to convert. - input_type (PromptDataType): The input type. Defaults to "image_path". - Returns: ConverterResult: The unchanged prompt. """ @@ -53,297 +99,368 @@ class MockMultiModalConverter(PromptConverter): async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text") -> ConverterResult: """Convert prompt (no-op for testing). - Args: - prompt (str): The prompt to convert. - input_type (PromptDataType): The input type. Defaults to "text". - Returns: ConverterResult: The unchanged prompt. """ return ConverterResult(output_text=prompt, output_type="text") +@pytest.fixture +def registry(): + """Provide a fresh ``ConverterRegistry`` singleton, reset around each test.""" + ConverterRegistry.reset_instance() + instance = ConverterRegistry.get_registry_singleton() + yield instance + ConverterRegistry.reset_instance() + + +# --------------------------------------------------------------------------- +# Instance container (reached via the ``instances`` property) +# --------------------------------------------------------------------------- + + class TestConverterRegistrySingleton: """Tests for the singleton pattern in ConverterRegistry.""" def setup_method(self): - """Reset the singleton before each test.""" ConverterRegistry.reset_instance() def teardown_method(self): - """Reset the singleton after each test.""" ConverterRegistry.reset_instance() def test_get_registry_singleton_returns_same_instance(self): - """Test that get_registry_singleton returns the same singleton each time.""" - instance1 = ConverterRegistry.get_registry_singleton() - instance2 = ConverterRegistry.get_registry_singleton() - - assert instance1 is instance2 + assert ConverterRegistry.get_registry_singleton() is ConverterRegistry.get_registry_singleton() def test_get_registry_singleton_returns_converter_registry_type(self): - """Test that get_registry_singleton returns a ConverterRegistry instance.""" - instance = ConverterRegistry.get_registry_singleton() - assert isinstance(instance, ConverterRegistry) + assert isinstance(ConverterRegistry.get_registry_singleton(), ConverterRegistry) def test_reset_instance_clears_singleton(self): - """Test that reset_instance clears the singleton.""" instance1 = ConverterRegistry.get_registry_singleton() ConverterRegistry.reset_instance() - instance2 = ConverterRegistry.get_registry_singleton() - - assert instance1 is not instance2 + assert ConverterRegistry.get_registry_singleton() is not instance1 class TestConverterRegistryRegisterInstance: - """Tests for register_instance functionality in ConverterRegistry.""" - - def setup_method(self): - """Reset and get a fresh registry for each test.""" - ConverterRegistry.reset_instance() - self.registry = ConverterRegistry.get_registry_singleton() - - def teardown_method(self): - """Reset the singleton after each test.""" - ConverterRegistry.reset_instance() + """Tests for instance registration via the ``instances`` property.""" - def test_register_instance_with_custom_name(self): - """Test registering a converter with a custom name.""" + def test_register_instance_with_custom_name(self, registry: ConverterRegistry): converter = MockTextConverter() - self.registry.register_instance(converter, name="custom_converter") + registry.instances.register(converter, name="custom_converter") - assert "custom_converter" in self.registry - assert self.registry.get("custom_converter") is converter + assert "custom_converter" in registry.instances + assert registry.instances.get("custom_converter") is converter - def test_register_instance_generates_name_from_class(self): - """Test that register_instance generates a name from class name when not provided.""" + def test_register_instance_generates_name_from_class(self, registry: ConverterRegistry): converter = MockTextConverter() - self.registry.register_instance(converter) + registry.instances.register(converter) - # Name should be derived from class name with hash suffix - names = self.registry.get_names() + names = registry.instances.get_names() assert len(names) == 1 assert names[0].startswith("MockTextConverter::") - def test_register_instance_multiple_converters_unique_names(self): - """Test registering multiple converters generates unique names.""" + def test_register_instance_multiple_converters_unique_names(self, registry: ConverterRegistry): + registry.instances.register(MockTextConverter()) + registry.instances.register(MockImageConverter()) + + assert len(registry.instances) == 2 + + def test_register_instance_duplicate_name_overwrites(self, registry: ConverterRegistry): converter1 = MockTextConverter() converter2 = MockImageConverter() - self.registry.register_instance(converter1) - self.registry.register_instance(converter2) + registry.instances.register(converter1, name="shared_name") + registry.instances.register(converter2, name="shared_name") - assert len(self.registry) == 2 + assert len(registry.instances) == 1 + assert registry.instances.get("shared_name") is converter2 - def test_register_instance_same_converter_type_different_names(self): - """Test that same converter class can be registered with different names.""" - converter1 = MockTextConverter() - converter2 = MockTextConverter() + def test_register_instance_rejects_non_converter(self, registry: ConverterRegistry): + class NotAConverter: + pass - self.registry.register_instance(converter1, name="converter_1") - self.registry.register_instance(converter2, name="converter_2") + with pytest.raises(TypeError, match="PromptConverter"): + registry.instances.register(NotAConverter()) # type: ignore[arg-type] - assert len(self.registry) == 2 + assert len(registry.instances) == 0 - def test_register_instance_duplicate_name_overwrites(self): - """Test that registering with a duplicate name silently overwrites the previous instance.""" - converter1 = MockTextConverter() - converter2 = MockImageConverter() - self.registry.register_instance(converter1, name="shared_name") - self.registry.register_instance(converter2, name="shared_name") +class TestConverterRegistryGetInstanceByName: + """Tests for instance lookup via ``instances.get``.""" - assert len(self.registry) == 1 - assert self.registry.get("shared_name") is converter2 + def test_get_instance_by_name_returns_converter(self, registry: ConverterRegistry): + converter = MockTextConverter() + registry.instances.register(converter, name="test_converter") + assert registry.instances.get("test_converter") is converter + def test_get_instance_by_name_nonexistent_returns_none(self, registry: ConverterRegistry): + assert registry.instances.get("nonexistent") is None -class TestConverterRegistryGetInstanceByName: - """Tests for get_instance_by_name functionality in ConverterRegistry.""" - def setup_method(self): - """Reset and get a fresh registry for each test.""" - ConverterRegistry.reset_instance() - self.registry = ConverterRegistry.get_registry_singleton() - self.converter = MockTextConverter() - self.registry.register_instance(self.converter, name="test_converter") +class TestConverterRegistryInstanceMetadata: + """Tests for instance-level metadata (``instances.list_metadata``).""" - def teardown_method(self): - """Reset the singleton after each test.""" - ConverterRegistry.reset_instance() + def test_instance_metadata_is_component_identifier(self, registry: ConverterRegistry): + converter = MockTextConverter() + registry.instances.register(converter, name="text_converter") - def test_get_instance_by_name_returns_converter(self): - """Test getting a registered converter by name.""" - result = self.registry.get_instance_by_name("test_converter") - assert result is self.converter + metadata = registry.instances.list_metadata() + assert len(metadata) == 1 + assert isinstance(metadata[0], ComponentIdentifier) + assert metadata[0] == converter.get_identifier() - def test_get_instance_by_name_nonexistent_returns_none(self): - """Test that getting a non-existent converter returns None.""" - result = self.registry.get_instance_by_name("nonexistent") - assert result is None + def test_instance_metadata_filter_by_class_name(self, registry: ConverterRegistry): + registry.instances.register(MockTextConverter(), name="t1") + registry.instances.register(MockTextConverter(), name="t2") + registry.instances.register(MockImageConverter(), name="i1") + metadata = registry.instances.list_metadata(include_filters={"class_name": "MockTextConverter"}) + assert len(metadata) == 2 + assert all(m.class_name == "MockTextConverter" for m in metadata) -class TestConverterRegistryBuildMetadata: - """Tests for _build_metadata functionality in ConverterRegistry.""" - def setup_method(self): - """Reset and get a fresh registry for each test.""" - ConverterRegistry.reset_instance() - self.registry = ConverterRegistry.get_registry_singleton() +class TestConverterRegistryContainerProtocol: + """Tests for the ``instances`` container protocol surface.""" - def teardown_method(self): - """Reset the singleton after each test.""" - ConverterRegistry.reset_instance() + def test_contains_and_len_and_iter(self, registry: ConverterRegistry): + registry.instances.register(MockTextConverter(), name="test_converter") + assert "test_converter" in registry.instances + assert "unknown_converter" not in registry.instances + assert len(registry.instances) == 1 + assert "test_converter" in list(registry.instances) - def test_build_metadata_includes_class_name(self): - """Test that metadata includes the converter class name.""" - converter = MockTextConverter() - self.registry.register_instance(converter, name="text_converter") + def test_get_names_returns_sorted_list(self, registry: ConverterRegistry): + registry.instances.register(MockImageConverter(), name="zeta_converter") + registry.instances.register(MockImageConverter(), name="alpha_converter") + assert registry.instances.get_names() == ["alpha_converter", "zeta_converter"] - metadata = self.registry.list_metadata() - assert len(metadata) == 1 - assert metadata[0].class_name == "MockTextConverter" + def test_get_all_instances_returns_all(self, registry: ConverterRegistry): + text = MockTextConverter() + image = MockImageConverter() + registry.instances.register(text, name="text_converter") + registry.instances.register(image, name="image_converter") - def test_build_metadata_includes_supported_input_types(self): - """Test that metadata includes supported_input_types in params.""" - converter = MockTextConverter() - self.registry.register_instance(converter, name="text_converter") + entry_map = {e.name: e for e in registry.instances.get_all_instances()} + assert entry_map["text_converter"].instance is text + assert entry_map["image_converter"].instance is image - metadata = self.registry.list_metadata() - assert metadata[0].params["supported_input_types"] == ("text",) - def test_build_metadata_includes_supported_output_types(self): - """Test that metadata includes supported_output_types in params.""" - converter = MockTextConverter() - self.registry.register_instance(converter, name="text_converter") +# --------------------------------------------------------------------------- +# Buildable class catalog (discovery + introspection + build) +# --------------------------------------------------------------------------- - metadata = self.registry.list_metadata() - assert metadata[0].params["supported_output_types"] == ("text",) - def test_build_metadata_is_component_identifier(self): - """Test that metadata is the converter's ComponentIdentifier.""" - converter = MockTextConverter() - self.registry.register_instance(converter, name="text_converter") +class TestDiscovery: + """Tests for converter class discovery.""" - metadata = self.registry.list_metadata() - assert isinstance(metadata[0], ComponentIdentifier) - assert metadata[0] == converter.get_identifier() + def test_discovers_known_converters(self, registry: ConverterRegistry): + names = registry.get_class_names() + assert "Base64Converter" in names + assert "CaesarConverter" in names - def test_build_metadata_different_modalities(self): - """Test that metadata reflects converter-specific modalities.""" - converter = MockImageConverter() - self.registry.register_instance(converter, name="image_converter") + def test_discovers_non_catalog_converters(self, registry: ConverterRegistry): + # SelectiveTextConverter is hidden from the user-facing catalog (a frontend + # concern) but must remain discoverable/buildable so agents can use it. + assert "SelectiveTextConverter" in registry.get_class_names() - metadata = self.registry.list_metadata() - assert metadata[0].params["supported_input_types"] == ("image_path",) - assert metadata[0].params["supported_output_types"] == ("text",) - assert metadata[0].class_name == "MockImageConverter" + def test_does_not_register_base_class(self, registry: ConverterRegistry): + assert "PromptConverter" not in registry.get_class_names() + def test_keyed_by_exact_class_name(self, registry: ConverterRegistry): + names = registry.get_class_names() + assert "Base64Converter" in names + assert "base64_converter" not in names -class TestConverterRegistryListMetadataFiltering: - """Tests for list_metadata filtering in ConverterRegistry.""" - def setup_method(self): - """Reset and get a fresh registry with multiple converters.""" - ConverterRegistry.reset_instance() - self.registry = ConverterRegistry.get_registry_singleton() +class TestGetClass: + """Tests for get_class (the inherited class-catalog accessor).""" - self.text_converter1 = MockTextConverter() - self.text_converter2 = MockTextConverter() - self.image_converter = MockImageConverter() - self.multi_modal_converter = MockMultiModalConverter() + def test_returns_class(self, registry: ConverterRegistry): + assert registry.get_class("Base64Converter") is Base64Converter - self.registry.register_instance(self.text_converter1, name="text_converter_1") - self.registry.register_instance(self.text_converter2, name="text_converter_2") - self.registry.register_instance(self.image_converter, name="image_converter") - self.registry.register_instance(self.multi_modal_converter, name="multi_modal_converter") + def test_unknown_type_raises(self, registry: ConverterRegistry): + with pytest.raises(KeyError, match="not found"): + registry.get_class("NotARealConverter") - def teardown_method(self): - """Reset the singleton after each test.""" - ConverterRegistry.reset_instance() + def test_is_subclass_relationship(self, registry: ConverterRegistry): + assert issubclass(registry.get_class("Base64Converter"), PromptConverter) - def test_list_metadata_no_filter_returns_all(self): - """Test that list_metadata without filters returns all items.""" - metadata = self.registry.list_metadata() - assert len(metadata) == 4 - def test_list_metadata_filter_by_class_name(self): - """Test filtering metadata by class_name.""" - metadata = self.registry.list_metadata(include_filters={"class_name": "MockTextConverter"}) - assert len(metadata) == 2 - assert all(m.class_name == "MockTextConverter" for m in metadata) +class TestCreateInstance: + """Tests for create_instance (build via the shared resolver).""" - def test_list_metadata_filter_by_supported_input_type(self): - """Test filtering metadata by supported_input_types (containment check).""" - # "text" is in supported_input_types for MockTextConverter and MockMultiModalConverter - metadata = self.registry.list_metadata(include_filters={"supported_input_types": "text"}) - assert len(metadata) == 3 # 2 text converters + 1 multi-modal - class_names = {m.class_name for m in metadata} - assert "MockTextConverter" in class_names - assert "MockMultiModalConverter" in class_names - - def test_list_metadata_exclude_by_class_name(self): - """Test excluding metadata by class_name.""" - metadata = self.registry.list_metadata(exclude_filters={"class_name": "MockTextConverter"}) - assert len(metadata) == 2 - assert all(m.class_name != "MockTextConverter" for m in metadata) - - def test_list_metadata_combined_include_and_exclude(self): - """Test combined include and exclude filters.""" - # Include converters that accept text, exclude MockMultiModalConverter - metadata = self.registry.list_metadata( - include_filters={"supported_input_types": "text"}, - exclude_filters={"class_name": "MockMultiModalConverter"}, + def test_creates_instance(self, registry: ConverterRegistry): + assert isinstance(registry.create_instance("Base64Converter"), Base64Converter) + + def test_coerces_string_params(self, registry: ConverterRegistry): + converter = registry.create_instance("CaesarConverter", caesar_offset="13") + assert isinstance(converter, CaesarConverter) + assert converter.get_identifier().params.get("caesar_offset") == 13 + + def test_unknown_type_raises(self, registry: ConverterRegistry): + with pytest.raises(KeyError, match="not found"): + registry.create_instance("NotARealConverter") + + def test_unknown_param_raises(self, registry: ConverterRegistry): + with pytest.raises(ValueError, match="Unknown parameter"): + registry.create_instance("Base64Converter", not_a_param="x") + + def test_build_does_not_register_instance(self, registry: ConverterRegistry): + registry.create_instance("Base64Converter") + assert len(registry.instances) == 0 + + def test_honors_registered_default_kwargs(self, registry: ConverterRegistry): + registry.register(CaesarConverter, name="CaesarDefault", default_kwargs={"caesar_offset": 5}) + converter = registry.create_instance("CaesarDefault") + assert converter.get_identifier().params.get("caesar_offset") == 5 + + def test_uses_registered_factory(self, registry: ConverterRegistry): + sentinel = Base64Converter() + registry.register(Base64Converter, name="B64Factory", factory=lambda **kwargs: sentinel) + assert registry.create_instance("B64Factory") is sentinel + + +@pytest.mark.usefixtures("patch_central_database") +class TestCreateLLMConverter: + """Tests that LLM converters are buildable by resolving a target by name.""" + + def test_build_llm_converter_resolves_target_by_name(self, registry: ConverterRegistry): + target = MockPromptTarget() + TargetRegistry.reset_instance() + TargetRegistry.get_registry_singleton().register_instance(target, name="my_target") + try: + converter = registry.create_instance("TenseConverter", converter_target="my_target", tense="past") + assert isinstance(converter, TenseConverter) + assert converter._converter_target is target + finally: + TargetRegistry.reset_instance() + + def test_build_llm_converter_unknown_target_raises(self, registry: ConverterRegistry): + TargetRegistry.reset_instance() + try: + with pytest.raises(ValueError, match="not found"): + registry.create_instance("TenseConverter", converter_target="missing", tense="past") + finally: + TargetRegistry.reset_instance() + + +class TestClassMetadata: + """Tests for converter class-catalog metadata building.""" + + def _metadata_for(self, registry: ConverterRegistry, name: str) -> ConverterMetadata: + return next(m for m in registry.list_class_metadata() if m.class_name == name) + + def test_metadata_includes_supported_types(self, registry: ConverterRegistry): + meta = self._metadata_for(registry, "Base64Converter") + assert "text" in meta.supported_input_types + assert "text" in meta.supported_output_types + + def test_metadata_has_no_catalog_visible_field(self, registry: ConverterRegistry): + # catalog_visible is a presentation concern owned by the backend/frontend. + assert not hasattr(self._metadata_for(registry, "Base64Converter"), "catalog_visible") + + def test_is_llm_based_flag(self, registry: ConverterRegistry): + llm_based = ( + LLMGenericTextConverter, + NoiseConverter, + PersuasionConverter, + ToneConverter, + TenseConverter, + TranslationConverter, + VariationConverter, ) - assert len(metadata) == 2 - assert all(m.class_name == "MockTextConverter" for m in metadata) + for cls in llm_based: + meta = self._metadata_for(registry, cls.__name__) + assert meta.is_llm_based is True, f"{cls.__name__} should be LLM-based" + assert self._metadata_for(registry, "Base64Converter").is_llm_based is False + assert self._metadata_for(registry, "CaesarConverter").is_llm_based is False + def test_parameters_extracted(self, registry: ConverterRegistry): + meta = self._metadata_for(registry, "CaesarConverter") + caesar_param = next(p for p in meta.parameters if p.name == "caesar_offset") + assert caesar_param.required is True + assert caesar_param.annotation is int + assert caesar_param.coercible_from_string is True -class TestConverterRegistryInheritedMethods: - """Tests for inherited methods from RetrievableInstanceRegistry.""" + def test_surfaces_non_coercible_params(self, registry: ConverterRegistry): + # An LLM-based converter exposes its target parameter for dynamic + # construction even though it cannot be coerced from a string. + meta = self._metadata_for(registry, "PersuasionConverter") + non_coercible = [p for p in meta.parameters if not p.coercible_from_string] + assert non_coercible, "expected at least one non-coercible parameter (the LLM target)" - def setup_method(self): - """Reset and get a fresh registry.""" - ConverterRegistry.reset_instance() - self.registry = ConverterRegistry.get_registry_singleton() - self.converter = MockTextConverter() - self.registry.register_instance(self.converter, name="test_converter") - def teardown_method(self): - """Reset the singleton after each test.""" - ConverterRegistry.reset_instance() +# --------------------------------------------------------------------------- +# Introspection helpers +# --------------------------------------------------------------------------- + + +class _UnionTargetConverter: + """Helper with a PEP 604 unioned target parameter for introspection tests.""" + + def __init__(self, *, target: PromptTarget | None = None, offset: int | None = None) -> None: + self.target = target + self.offset = offset + + +class _OptionalLiteralConverter: + """Helper with an optional Literal parameter for choices extraction tests.""" + + def __init__(self, *, fmt: Literal["A", "B"] | None = None) -> None: + self.fmt = fmt + + +class TestExtractParameters: + """Tests for the converter-parameter introspection helper.""" + + def test_exposes_raw_annotation(self) -> None: + offset_param = next(p for p in _extract_parameters(_UnionTargetConverter) if p.name == "offset") + assert offset_param.annotation == (int | None) + assert offset_param.coercible_from_string is True + + def test_includes_non_coercible(self) -> None: + target_param = next(p for p in _extract_parameters(_UnionTargetConverter) if p.name == "target") + assert target_param.coercible_from_string is False + + def test_optional_literal_choices(self) -> None: + fmt_param = next(p for p in _extract_parameters(_OptionalLiteralConverter) if p.name == "fmt") + assert fmt_param.choices == ("A", "B") + + def test_sets_requires_llm(self) -> None: + params = _extract_parameters(_UnionTargetConverter) + target_param = next(p for p in params if p.name == "target") + offset_param = next(p for p in params if p.name == "offset") + assert target_param.requires_llm is True + assert offset_param.requires_llm is False + + +class TestRequiresLlmTarget: + """Tests for the _requires_llm_target helper.""" + + def test_plain_target(self) -> None: + assert _requires_llm_target(PromptTarget) is True + + def test_optional_target(self) -> None: + assert _requires_llm_target(PromptTarget | None) is True + + def test_non_target(self) -> None: + assert _requires_llm_target(int) is False + assert _requires_llm_target(str | None) is False + + +class TestNoBackendDependency: + """The registry must be reusable without depending on pyrit.backend.""" + + def test_module_has_no_backend_dependency(self) -> None: + import ast + import inspect + + import pyrit.registry.components.converter_registry as module - def test_contains_registered_name(self): - """Test __contains__ for registered name.""" - assert "test_converter" in self.registry - - def test_contains_unregistered_name(self): - """Test __contains__ for unregistered name.""" - assert "unknown_converter" not in self.registry - - def test_len_returns_count(self): - """Test __len__ returns correct count.""" - assert len(self.registry) == 1 - - def test_iter_yields_names(self): - """Test __iter__ yields registered names.""" - names = list(self.registry) - assert "test_converter" in names - - def test_get_names_returns_sorted_list(self): - """Test get_names returns sorted list of names.""" - self.registry.register_instance(MockImageConverter(), name="alpha_converter") - self.registry.register_instance(MockImageConverter(), name="zeta_converter") - - names = self.registry.get_names() - assert names == ["alpha_converter", "test_converter", "zeta_converter"] - - def test_get_all_instances_returns_all(self): - """Test get_all_instances returns list of all registered entries.""" - image_converter = MockImageConverter() - self.registry.register_instance(image_converter, name="image_converter") - - all_entries = self.registry.get_all_instances() - assert len(all_entries) == 2 - entry_map = {e.name: e for e in all_entries} - assert entry_map["test_converter"].instance is self.converter - assert entry_map["image_converter"].instance is image_converter + tree = ast.parse(inspect.getsource(module)) + imported_modules: list[str] = [] + for node in ast.walk(tree): + if isinstance(node, ast.Import): + imported_modules.extend(alias.name for alias in node.names) + elif isinstance(node, ast.ImportFrom) and node.module: + imported_modules.append(node.module) + assert not any(name.startswith("pyrit.backend") for name in imported_modules) diff --git a/tests/unit/registry/test_instance_registry.py b/tests/unit/registry/test_instance_registry.py new file mode 100644 index 0000000000..df66cc4506 --- /dev/null +++ b/tests/unit/registry/test_instance_registry.py @@ -0,0 +1,460 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tests for the instance-registry capability (``DefaultInstanceRegistry``) that a +registry exposes as its ``.instances`` property, plus the ``InstanceRegistry`` and +``SupportsInstances`` protocols. +""" + +import pytest + +from pyrit.models import ComponentIdentifier, Identifiable +from pyrit.registry.instance_registry import ( + DefaultInstanceRegistry, + InstanceRegistry, + RegistryEntry, + SupportsInstances, +) + + +class _TestItem(Identifiable): + """Minimal Identifiable stub wrapping a string value for testing.""" + + def __init__(self, value: str) -> None: + self.value = value + + def _build_identifier(self) -> ComponentIdentifier: + return ComponentIdentifier( + class_name="_TestItem", + class_module="test", + params={"category": "test" if "test" in self.value.lower() else "other"}, + ) + + def __eq__(self, other: object) -> bool: + if isinstance(other, _TestItem): + return self.value == other.value + if isinstance(other, str): + return self.value == other + return NotImplemented + + def __hash__(self) -> int: + return hash(self.value) + + def __repr__(self) -> str: + return f"_TestItem({self.value!r})" + + +def _item(value: str) -> _TestItem: + """Shorthand factory for _TestItem.""" + return _TestItem(value) + + +class _OtherItem(Identifiable): + """A second Identifiable type, unrelated to _TestItem, for type-enforcement tests.""" + + def _build_identifier(self) -> ComponentIdentifier: + return ComponentIdentifier(class_name="_OtherItem", class_module="test") + + +@pytest.fixture +def registry() -> DefaultInstanceRegistry[_TestItem]: + """Provide a fresh, singleton-free instance registry for each test.""" + return DefaultInstanceRegistry() + + +class TestRegistration: + """Tests for registering instances.""" + + def test_register_adds_instance(self, registry: DefaultInstanceRegistry[_TestItem]): + registry.register(_item("test_value"), name="test_name") + + assert "test_name" in registry + assert registry.get("test_name") == "test_value" + + def test_register_multiple_instances(self, registry: DefaultInstanceRegistry[_TestItem]): + registry.register(_item("value1"), name="name1") + registry.register(_item("value2"), name="name2") + registry.register(_item("value3"), name="name3") + + assert len(registry) == 3 + assert registry.get("name2") == "value2" + + def test_register_overwrites_existing(self, registry: DefaultInstanceRegistry[_TestItem]): + registry.register(_item("original"), name="name") + registry.register(_item("updated"), name="name") + + assert len(registry) == 1 + assert registry.get("name") == "updated" + + def test_register_defaults_name_to_identifier_unique_name(self, registry: DefaultInstanceRegistry[_TestItem]): + registry.register(_item("value1")) + + names = registry.get_names() + assert len(names) == 1 + assert names[0].startswith("_TestItem::") + + def test_register_invalidates_metadata_cache(self, registry: DefaultInstanceRegistry[_TestItem]): + registry.register(_item("value1"), name="name1") + assert len(registry.list_metadata()) == 1 + + registry.register(_item("value2"), name="name2") + assert len(registry.list_metadata()) == 2 + + +class TestInstanceTypeEnforcement: + """Tests for the optional ``instance_type`` registration constraint.""" + + def test_register_accepts_matching_type(self): + registry: DefaultInstanceRegistry[_TestItem] = DefaultInstanceRegistry(instance_type=_TestItem) + registry.register(_item("value1"), name="name1") + assert registry.get("name1") == "value1" + + def test_register_rejects_mismatched_type(self): + registry: DefaultInstanceRegistry[_TestItem] = DefaultInstanceRegistry(instance_type=_TestItem) + with pytest.raises(TypeError, match="_OtherItem.*_TestItem"): + registry.register(_OtherItem(), name="wrong") # type: ignore[arg-type] + assert "wrong" not in registry + + def test_register_accepts_subclass_of_expected_type(self): + class _SubItem(_TestItem): + pass + + registry: DefaultInstanceRegistry[_TestItem] = DefaultInstanceRegistry(instance_type=_TestItem) + registry.register(_SubItem("value1"), name="name1") + assert registry.get("name1") == "value1" + + def test_instance_type_accepts_lazy_callable(self): + calls = 0 + + def provide_type() -> type[_TestItem]: + nonlocal calls + calls += 1 + return _TestItem + + registry: DefaultInstanceRegistry[_TestItem] = DefaultInstanceRegistry(instance_type=provide_type) + registry.register(_item("value1"), name="name1") + registry.register(_item("value2"), name="name2") + + assert calls == 1 # resolved once and cached + with pytest.raises(TypeError): + registry.register(_OtherItem(), name="wrong") # type: ignore[arg-type] + + def test_no_instance_type_allows_any_identifiable(self): + registry: DefaultInstanceRegistry[Identifiable] = DefaultInstanceRegistry() + registry.register(_item("value1"), name="a") + registry.register(_OtherItem(), name="b") + assert registry.get_names() == ["a", "b"] + + +class TestGet: + """Tests for get and get_entry.""" + + def test_get_existing_instance(self, registry: DefaultInstanceRegistry[_TestItem]): + registry.register(_item("value1"), name="name1") + assert registry.get("name1") == "value1" + + def test_get_nonexistent_returns_none(self, registry: DefaultInstanceRegistry[_TestItem]): + assert registry.get("missing") is None + + def test_get_entry_returns_registry_entry(self, registry: DefaultInstanceRegistry[_TestItem]): + registry.register(_item("value1"), name="name1", tags=["fast"]) + entry = registry.get_entry("name1") + assert isinstance(entry, RegistryEntry) + assert entry is not None + assert entry.name == "name1" + assert entry.instance == "value1" + assert entry.tags == {"fast": ""} + + def test_get_entry_nonexistent_returns_none(self, registry: DefaultInstanceRegistry[_TestItem]): + assert registry.get_entry("missing") is None + + +class TestGetNamesAndAllInstances: + """Tests for get_names and get_all_instances.""" + + def test_get_names_empty_registry(self, registry: DefaultInstanceRegistry[_TestItem]): + assert registry.get_names() == [] + + def test_get_names_returns_sorted_list(self, registry: DefaultInstanceRegistry[_TestItem]): + registry.register(_item("v1"), name="zeta") + registry.register(_item("v2"), name="alpha") + assert registry.get_names() == ["alpha", "zeta"] + + def test_get_all_instances_sorted_by_name(self, registry: DefaultInstanceRegistry[_TestItem]): + registry.register(_item("v1"), name="zeta") + registry.register(_item("v2"), name="alpha") + assert [e.name for e in registry.get_all_instances()] == ["alpha", "zeta"] + + def test_get_all_instances_preserves_tags(self, registry: DefaultInstanceRegistry[_TestItem]): + registry.register(_item("v1"), name="n1", tags={"speed": "fast"}) + entry = registry.get_all_instances()[0] + assert entry.tags == {"speed": "fast"} + + def test_get_all_instances_empty_registry(self, registry: DefaultInstanceRegistry[_TestItem]): + assert registry.get_all_instances() == [] + + +class TestListMetadata: + """Tests for list_metadata and its filtering/caching.""" + + def test_list_metadata_returns_all_items(self, registry: DefaultInstanceRegistry[_TestItem]): + registry.register(_item("test_v1"), name="n1") + registry.register(_item("other_v2"), name="n2") + metadata = registry.list_metadata() + assert len(metadata) == 2 + assert all(isinstance(m, ComponentIdentifier) for m in metadata) + + def test_list_metadata_with_include_filter(self, registry: DefaultInstanceRegistry[_TestItem]): + registry.register(_item("test_v1"), name="n1") + registry.register(_item("other_v2"), name="n2") + metadata = registry.list_metadata(include_filters={"category": "test"}) + assert len(metadata) == 1 + assert metadata[0].params["category"] == "test" + + def test_list_metadata_with_exclude_filter(self, registry: DefaultInstanceRegistry[_TestItem]): + registry.register(_item("test_v1"), name="n1") + registry.register(_item("other_v2"), name="n2") + metadata = registry.list_metadata(exclude_filters={"category": "test"}) + assert len(metadata) == 1 + assert metadata[0].params["category"] == "other" + + def test_list_metadata_caches_until_invalidated(self, registry: DefaultInstanceRegistry[_TestItem]): + registry.register(_item("test_v1"), name="n1") + first = registry.list_metadata() + second = registry.list_metadata() + assert first is second + + +class TestTags: + """Tests for tag storage, normalization, and queries.""" + + def test_register_with_dict_tags(self, registry: DefaultInstanceRegistry[_TestItem]): + registry.register(_item("v1"), name="n1", tags={"speed": "fast"}) + assert registry.get_entry("n1").tags == {"speed": "fast"} + + def test_register_with_list_tags_defaults_empty_values(self, registry: DefaultInstanceRegistry[_TestItem]): + registry.register(_item("v1"), name="n1", tags=["fast", "stable"]) + assert registry.get_entry("n1").tags == {"fast": "", "stable": ""} + + def test_get_by_tag_key_only(self, registry: DefaultInstanceRegistry[_TestItem]): + registry.register(_item("v1"), name="n1", tags=["fast"]) + registry.register(_item("v2"), name="n2", tags=["slow"]) + results = registry.get_by_tag(tag="fast") + assert [e.name for e in results] == ["n1"] + + def test_get_by_tag_key_and_value(self, registry: DefaultInstanceRegistry[_TestItem]): + registry.register(_item("v1"), name="n1", tags={"speed": "fast"}) + registry.register(_item("v2"), name="n2", tags={"speed": "slow"}) + results = registry.get_by_tag(tag="speed", value="fast") + assert [e.name for e in results] == ["n1"] + + def test_get_by_tag_returns_sorted_by_name(self, registry: DefaultInstanceRegistry[_TestItem]): + registry.register(_item("v1"), name="zeta", tags=["t"]) + registry.register(_item("v2"), name="alpha", tags=["t"]) + assert [e.name for e in registry.get_by_tag(tag="t")] == ["alpha", "zeta"] + + def test_get_by_tag_no_match(self, registry: DefaultInstanceRegistry[_TestItem]): + registry.register(_item("v1"), name="n1", tags=["fast"]) + assert registry.get_by_tag(tag="missing") == [] + + def test_normalize_tags_none(self, registry: DefaultInstanceRegistry[_TestItem]): + assert registry._normalize_tags(None) == {} + + def test_normalize_tags_list(self, registry: DefaultInstanceRegistry[_TestItem]): + assert registry._normalize_tags(["a", "b"]) == {"a": "", "b": ""} + + def test_normalize_tags_dict(self, registry: DefaultInstanceRegistry[_TestItem]): + assert registry._normalize_tags({"a": "1"}) == {"a": "1"} + + +class TestAddTags: + """Tests for add_tags.""" + + def test_add_tags_with_list(self, registry: DefaultInstanceRegistry[_TestItem]): + registry.register(_item("v1"), name="n1") + registry.add_tags(name="n1", tags=["fast"]) + assert registry.get_entry("n1").tags == {"fast": ""} + + def test_add_tags_merges_with_existing(self, registry: DefaultInstanceRegistry[_TestItem]): + registry.register(_item("v1"), name="n1", tags={"a": "1"}) + registry.add_tags(name="n1", tags={"b": "2"}) + assert registry.get_entry("n1").tags == {"a": "1", "b": "2"} + + def test_add_tags_raises_for_missing_entry(self, registry: DefaultInstanceRegistry[_TestItem]): + with pytest.raises(KeyError): + registry.add_tags(name="missing", tags=["fast"]) + + def test_add_tags_invalidates_metadata_cache(self, registry: DefaultInstanceRegistry[_TestItem]): + registry.register(_item("v1"), name="n1") + first = registry.list_metadata() + registry.add_tags(name="n1", tags=["fast"]) + second = registry.list_metadata() + assert first is not second + + def test_add_tags_entries_findable_by_get_by_tag(self, registry: DefaultInstanceRegistry[_TestItem]): + registry.register(_item("v1"), name="n1") + registry.add_tags(name="n1", tags={"speed": "fast"}) + assert [e.name for e in registry.get_by_tag(tag="speed", value="fast")] == ["n1"] + + +class TestDunderMethods: + """Tests for __contains__, __len__, and __iter__.""" + + def test_contains(self, registry: DefaultInstanceRegistry[_TestItem]): + registry.register(_item("v1"), name="n1") + assert "n1" in registry + assert "missing" not in registry + + def test_len(self, registry: DefaultInstanceRegistry[_TestItem]): + assert len(registry) == 0 + registry.register(_item("v1"), name="n1") + assert len(registry) == 1 + + def test_iter_returns_sorted_names(self, registry: DefaultInstanceRegistry[_TestItem]): + registry.register(_item("v1"), name="zeta") + registry.register(_item("v2"), name="alpha") + assert list(registry) == ["alpha", "zeta"] + + +class TestMetadataField: + """Tests for the metadata field on RegistryEntry.""" + + def test_register_with_metadata_stores_it(self, registry: DefaultInstanceRegistry[_TestItem]): + registry.register(_item("v1"), name="n1", metadata={"accepts_scorer_override": False, "priority": 5}) + entry = registry.get_entry("n1") + assert entry.metadata == {"accepts_scorer_override": False, "priority": 5} + + def test_register_without_metadata_defaults_to_empty_dict(self, registry: DefaultInstanceRegistry[_TestItem]): + registry.register(_item("v1"), name="n1") + assert registry.get_entry("n1").metadata == {} + + def test_metadata_does_not_affect_tags(self, registry: DefaultInstanceRegistry[_TestItem]): + registry.register(_item("v1"), name="n1", tags=["fast"], metadata={"key": "value"}) + entry = registry.get_entry("n1") + assert entry.tags == {"fast": ""} + assert entry.metadata == {"key": "value"} + assert registry.get_by_tag(tag="key") == [] + + +class _IdentifiableStub(Identifiable): + """A minimal stub that holds a ComponentIdentifier for dependency tests.""" + + def __init__(self, identifier: ComponentIdentifier) -> None: + self._stored_identifier = identifier + + def _build_identifier(self) -> ComponentIdentifier: + return self._stored_identifier + + +class TestFindDependentsOfTag: + """Tests for DefaultInstanceRegistry.find_dependents_of_tag.""" + + @pytest.fixture + def registry(self) -> DefaultInstanceRegistry[_IdentifiableStub]: + return DefaultInstanceRegistry() + + def test_no_tagged_entries_returns_empty(self, registry: DefaultInstanceRegistry[_IdentifiableStub]) -> None: + registry.register(_IdentifiableStub(ComponentIdentifier(class_name="A", class_module="mod")), name="a") + assert registry.find_dependents_of_tag(tag="refusal") == [] + + def test_tagged_entry_not_returned_as_dependent(self, registry: DefaultInstanceRegistry[_IdentifiableStub]) -> None: + stub = _IdentifiableStub(ComponentIdentifier(class_name="Refusal", class_module="mod", eval_hash="r1")) + registry.register(stub, name="refusal_scorer", tags=["refusal"]) + assert registry.find_dependents_of_tag(tag="refusal") == [] + + def test_dependent_found_by_child_eval_hash(self, registry: DefaultInstanceRegistry[_IdentifiableStub]) -> None: + base_id = ComponentIdentifier(class_name="Refusal", class_module="mod", eval_hash="r_hash") + registry.register(_IdentifiableStub(base_id), name="refusal_scorer", tags=["refusal"]) + + child_id = ComponentIdentifier(class_name="Refusal", class_module="mod", eval_hash="r_hash") + wrapper_id = ComponentIdentifier( + class_name="Inverter", + class_module="mod", + eval_hash="w_hash", + children={"sub_scorers": [child_id]}, + ) + registry.register(_IdentifiableStub(wrapper_id), name="inverter") + + dependents = registry.find_dependents_of_tag(tag="refusal") + assert [d.name for d in dependents] == ["inverter"] + + def test_non_dependent_not_returned(self, registry: DefaultInstanceRegistry[_IdentifiableStub]) -> None: + base_id = ComponentIdentifier(class_name="Refusal", class_module="mod", eval_hash="r_hash") + registry.register(_IdentifiableStub(base_id), name="refusal_scorer", tags=["refusal"]) + + unrelated_id = ComponentIdentifier(class_name="Likert", class_module="mod", eval_hash="l_hash") + registry.register(_IdentifiableStub(unrelated_id), name="likert") + + assert registry.find_dependents_of_tag(tag="refusal") == [] + + def test_deeply_nested_dependency_found(self, registry: DefaultInstanceRegistry[_IdentifiableStub]) -> None: + base_id = ComponentIdentifier(class_name="Refusal", class_module="mod", eval_hash="deep_r") + registry.register(_IdentifiableStub(base_id), name="refusal_scorer", tags=["refusal"]) + + inner_child = ComponentIdentifier(class_name="Refusal", class_module="mod", eval_hash="deep_r") + inverter = ComponentIdentifier( + class_name="Inverter", + class_module="mod", + children={"sub_scorers": [inner_child]}, + ) + composite_id = ComponentIdentifier( + class_name="Composite", + class_module="mod", + children={"sub_scorers": [inverter]}, + ) + registry.register(_IdentifiableStub(composite_id), name="composite") + + dependents = registry.find_dependents_of_tag(tag="refusal") + assert [d.name for d in dependents] == ["composite"] + + def test_multiple_dependents_returned_sorted(self, registry: DefaultInstanceRegistry[_IdentifiableStub]) -> None: + base_id = ComponentIdentifier(class_name="Refusal", class_module="mod", eval_hash="r1") + registry.register(_IdentifiableStub(base_id), name="refusal_scorer", tags=["refusal"]) + + child = ComponentIdentifier(class_name="Refusal", class_module="mod", eval_hash="r1") + for wrapper_name in ["z_wrapper", "a_wrapper", "m_wrapper"]: + wrapper_id = ComponentIdentifier( + class_name="Wrapper", + class_module="mod", + children={"sub_scorers": [child]}, + ) + registry.register(_IdentifiableStub(wrapper_id), name=wrapper_name) + + dependents = registry.find_dependents_of_tag(tag="refusal") + assert [d.name for d in dependents] == ["a_wrapper", "m_wrapper", "z_wrapper"] + + +class TestProtocolConformance: + """Tests that DefaultInstanceRegistry satisfies the registry protocols.""" + + def test_default_impl_is_instance_registry(self, registry: DefaultInstanceRegistry[_TestItem]) -> None: + assert isinstance(registry, InstanceRegistry) + + def test_supports_instances_marker(self, registry: DefaultInstanceRegistry[_TestItem]) -> None: + class _Holder: + def __init__(self, instances: InstanceRegistry[_TestItem]) -> None: + self.instances = instances + + holder: SupportsInstances[_TestItem] = _Holder(registry) + holder.instances.register(_item("v1"), name="n1") + assert holder.instances.get("n1") == "v1" + + +class TestNoBackendDependency: + """The instance registry must be reusable without depending on pyrit.backend.""" + + def test_module_has_no_backend_dependency(self) -> None: + import ast + import inspect + + import pyrit.registry.instance_registry as module + + tree = ast.parse(inspect.getsource(module)) + imported_modules: list[str] = [] + for node in ast.walk(tree): + if isinstance(node, ast.Import): + imported_modules.extend(alias.name for alias in node.names) + elif isinstance(node, ast.ImportFrom) and node.module: + imported_modules.append(node.module) + assert not any(name.startswith("pyrit.backend") for name in imported_modules) diff --git a/tests/unit/registry/test_resolution.py b/tests/unit/registry/test_resolution.py new file mode 100644 index 0000000000..8a7889bd78 --- /dev/null +++ b/tests/unit/registry/test_resolution.py @@ -0,0 +1,218 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tests for the shared registry constructor-argument resolution primitive. +""" + +from typing import Literal + +import pytest + +from pyrit.models import Message, MessagePiece +from pyrit.prompt_target import PromptTarget +from pyrit.registry.object_registries import TargetRegistry +from pyrit.registry.resolution import ( + coerce_string_to_annotation, + get_resolvable_registry_getter, + get_union_non_none_args, + is_coercible_from_string, + is_registry_reference, + resolve_constructor_args, +) + + +class MockPromptTarget(PromptTarget): + """Minimal PromptTarget for registry-resolution tests.""" + + def __init__(self, *, model_name: str = "mock_model") -> None: + super().__init__(model_name=model_name) + + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + return [MessagePiece(role="assistant", original_value="mock response").to_message()] + + def _validate_request(self, *, normalized_conversation: list[Message]) -> None: + pass + + +class _NeedsTarget: + """Helper whose constructor takes a registry-reference target plus simple params.""" + + def __init__(self, *, converter_target: PromptTarget, offset: int = 0, label: str = "x") -> None: + self.converter_target = converter_target + self.offset = offset + self.label = label + + +class _SimpleOnly: + """Helper whose constructor takes only simple/coercible params.""" + + def __init__( + self, *, count: int = 1, ratio: float = 0.5, flag: bool = False, mode: Literal["a", "b"] = "a" + ) -> None: + self.count = count + self.ratio = ratio + self.flag = flag + self.mode = mode + + +class _AcceptsKwargs: + """Helper whose constructor accepts arbitrary keyword arguments.""" + + def __init__(self, *, name: str = "n", **kwargs: object) -> None: + self.name = name + self.kwargs = kwargs + + +@pytest.fixture +def target_registry(): + """Provide a fresh TargetRegistry singleton with one registered target.""" + TargetRegistry.reset_instance() + registry = TargetRegistry.get_registry_singleton() + registry.register_instance(MockPromptTarget(), name="my_target") + yield registry + TargetRegistry.reset_instance() + + +@pytest.fixture +def empty_target_registry(): + """Provide a fresh, empty TargetRegistry singleton.""" + TargetRegistry.reset_instance() + registry = TargetRegistry.get_registry_singleton() + yield registry + TargetRegistry.reset_instance() + + +class TestTypeHelpers: + """Tests for the type-introspection helpers.""" + + def test_get_union_non_none_args_pep604(self) -> None: + assert get_union_non_none_args(int | None) == [int] + + def test_get_union_non_none_args_not_a_union(self) -> None: + assert get_union_non_none_args(int) is None + + def test_is_coercible_from_string(self) -> None: + assert is_coercible_from_string(str) is True + assert is_coercible_from_string(int | None) is True + assert is_coercible_from_string(Literal["a", "b"]) is True + assert is_coercible_from_string(PromptTarget) is False + + def test_is_registry_reference(self) -> None: + assert is_registry_reference(PromptTarget) is True + assert is_registry_reference(PromptTarget | None) is True + assert is_registry_reference(int) is False + + def test_get_resolvable_registry_getter_returns_target_registry(self) -> None: + getter = get_resolvable_registry_getter(PromptTarget) + assert getter is not None + assert isinstance(getter(), TargetRegistry) + + def test_get_resolvable_registry_getter_none_for_simple(self) -> None: + assert get_resolvable_registry_getter(int) is None + + +class TestCoerceStringToAnnotation: + """Tests for scalar string coercion.""" + + def test_int(self) -> None: + assert coerce_string_to_annotation(value="42", annotation=int) == 42 + + def test_float(self) -> None: + assert coerce_string_to_annotation(value="0.25", annotation=float) == 0.25 + + def test_bool_true(self) -> None: + assert coerce_string_to_annotation(value="yes", annotation=bool) is True + + def test_bool_false(self) -> None: + assert coerce_string_to_annotation(value="0", annotation=bool) is False + + def test_bool_invalid_raises(self) -> None: + with pytest.raises(ValueError, match="boolean"): + coerce_string_to_annotation(value="maybe", annotation=bool) + + def test_optional_unwrapped(self) -> None: + assert coerce_string_to_annotation(value="7", annotation=int | None) == 7 + + def test_str_passthrough(self) -> None: + assert coerce_string_to_annotation(value="hello", annotation=str) == "hello" + + def test_literal_valid(self) -> None: + assert coerce_string_to_annotation(value="b", annotation=Literal["a", "b"]) == "b" + + def test_literal_invalid_raises(self) -> None: + with pytest.raises(ValueError, match="one of"): + coerce_string_to_annotation(value="c", annotation=Literal["a", "b"]) + + def test_literal_coerces_to_member_type(self) -> None: + result = coerce_string_to_annotation(value="2", annotation=Literal[1, 2]) + assert result == 2 + assert isinstance(result, int) + + +@pytest.mark.usefixtures("patch_central_database") +class TestResolveConstructorArgs: + """Tests for the end-to-end resolve_constructor_args.""" + + def test_coerces_simple_params(self) -> None: + resolved = resolve_constructor_args(cls=_SimpleOnly, raw_args={"count": "3", "ratio": "0.75", "flag": "true"}) + assert resolved == {"count": 3, "ratio": 0.75, "flag": True} + + def test_literal_passthrough(self) -> None: + resolved = resolve_constructor_args(cls=_SimpleOnly, raw_args={"mode": "b"}) + assert resolved == {"mode": "b"} + + def test_literal_invalid_raises(self) -> None: + with pytest.raises(ValueError, match="mode"): + resolve_constructor_args(cls=_SimpleOnly, raw_args={"mode": "z"}) + + def test_unknown_param_raises(self) -> None: + with pytest.raises(ValueError, match="Unknown parameter 'nope'"): + resolve_constructor_args(cls=_SimpleOnly, raw_args={"nope": "1"}) + + def test_unknown_param_lists_valid_params(self) -> None: + with pytest.raises(ValueError, match="count"): + resolve_constructor_args(cls=_SimpleOnly, raw_args={"nope": "1"}) + + def test_var_kwargs_accepts_unknown(self) -> None: + resolved = resolve_constructor_args(cls=_AcceptsKwargs, raw_args={"anything": "value"}) + assert resolved == {"anything": "value"} + + def test_invalid_coercion_raises(self) -> None: + with pytest.raises(ValueError, match="count"): + resolve_constructor_args(cls=_SimpleOnly, raw_args={"count": "not-an-int"}) + + def test_resolves_registry_reference_by_name(self, target_registry: TargetRegistry) -> None: + resolved = resolve_constructor_args(cls=_NeedsTarget, raw_args={"converter_target": "my_target", "offset": "5"}) + assert resolved["converter_target"] is target_registry.get_instance_by_name("my_target") + assert resolved["offset"] == 5 + + def test_registry_reference_instance_passthrough(self, target_registry: TargetRegistry) -> None: + instance = MockPromptTarget() + resolved = resolve_constructor_args(cls=_NeedsTarget, raw_args={"converter_target": instance}) + assert resolved["converter_target"] is instance + + def test_unknown_registry_reference_raises_with_names(self, target_registry: TargetRegistry) -> None: + with pytest.raises(ValueError, match="my_target"): + resolve_constructor_args(cls=_NeedsTarget, raw_args={"converter_target": "missing"}) + + def test_unknown_registry_reference_empty_registry_hint(self, empty_target_registry: TargetRegistry) -> None: + with pytest.raises(ValueError, match="is empty"): + resolve_constructor_args(cls=_NeedsTarget, raw_args={"converter_target": "missing"}) + + +def test_module_has_no_backend_dependency() -> None: + # The resolution primitive must be reusable without depending on pyrit.backend. + import ast + import inspect + + import pyrit.registry.resolution as module + + tree = ast.parse(inspect.getsource(module)) + imported_modules: list[str] = [] + for node in ast.walk(tree): + if isinstance(node, ast.Import): + imported_modules.extend(alias.name for alias in node.names) + elif isinstance(node, ast.ImportFrom) and node.module: + imported_modules.append(node.module) + assert not any(name.startswith("pyrit.backend") for name in imported_modules)