diff --git a/src/anonymizer/engine/constants.py b/src/anonymizer/engine/constants.py index fdfecadf..9d3929a7 100644 --- a/src/anonymizer/engine/constants.py +++ b/src/anonymizer/engine/constants.py @@ -91,6 +91,7 @@ COL_DOMAIN_SUPPLEMENT = "_domain_supplement" COL_DOMAIN_SUPPLEMENT_PRIVACY = "_domain_supplement_privacy" COL_SENSITIVITY_DISPOSITION = "_sensitivity_disposition" +COL_SIMPLE_DISPOSITION = "_simple_disposition" # internal hand-off: loose LLM wire output COL_SENSITIVITY_DISPOSITION_BLOCK = "_sensitivity_disposition_block" COL_REWRITE_DISPOSITION_BLOCK = "_rewrite_disposition_block" COL_REPLACEMENT_MAP_FOR_PROMPT = "_replacement_map_for_prompt" diff --git a/src/anonymizer/engine/rewrite/disposition_derivation.py b/src/anonymizer/engine/rewrite/disposition_derivation.py new file mode 100644 index 00000000..001327ad --- /dev/null +++ b/src/anonymizer/engine/rewrite/disposition_derivation.py @@ -0,0 +1,409 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Server-side reconstruction of the strict EntityDispositionSchema from the +loose wire-contract SimpleDispositionResult + the per-entity context columns. + +The disposition_analyzer LLM emits a minimal ``SimpleDispositionResult`` +(8 optional/loose fields per item). This module rebuilds the strict form +deterministically: pair each simple item with its entity context (by id, +with entity_label/value echoes as belt-and-braces), normalize category and +method drift, derive ``combined_risk_level`` consistent with the chosen +method, and template ``protection_reason`` when the model did not provide +one. + +No LLM calls; no I/O. Pure python for the reconstruction column. + +Why server-side reconstruction (vs. asking the LLM to emit the strict +schema directly): DataDesigner runs ``jsonschema.validate()`` on the raw +LLM response BEFORE pydantic's before-validators get a chance to coerce. +Strict ``enum`` / ``required`` / ``minLength`` constraints on the wire +schema therefore become un-coercible gates for small-model drift, dropping +the entire record. The loose ``SimpleDispositionResult`` wire contract +(see ``schemas/rewrite.py``) lets drifted output survive that gate; this +module then rebuilds the strict form server-side. +""" + +from __future__ import annotations + +import logging + +from anonymizer.engine.schemas.rewrite import ( + _ENTITY_LABEL_TO_CATEGORY, + EntityDispositionSchema, + SensitivityDispositionSchema, + SimpleDispositionResult, +) + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Derivation helpers +# --------------------------------------------------------------------------- + + +_VALID_CATEGORIES: frozenset[str] = frozenset({"direct_identifier", "quasi_identifier", "latent_identifier"}) + + +def _normalize_category(raw: object, *, entity_label: str = "") -> str: + """Resolve a free-form category string emitted by the disposition LLM + into a valid EntityCategory value. + + Handles four small-model drift modes observed on gemma4-e2b / Nemotron: + + - **Display variants** — ``"Direct-Identifier"``, ``"DIRECT IDENTIFIERS"`` + → lowercased, separator-normalized, plural-stripped to the enum value. + - **Merged enums** — ``"latent_sensitive_attribute"`` (Nemotron splices + two enums) → matched by substring with strongest-protection priority, + so harm dimension wins over inference dimension. + - **Entity-label confusion** — ``"last_name"``, ``"date_of_birth"`` + written in the category slot → looked up in + ``_ENTITY_LABEL_TO_CATEGORY`` and mapped to the most-likely category. + Source label provenance is preserved by the ``source`` field on the + strict schema. + - **Empty / unknown** — falls back to ``"quasi_identifier"`` (the + conservative default; pessimistic protection rather than dropping + the row). + + Note: ``"sensitive_attribute"`` is no longer a valid EntityCategory + (collapsed into ``quasi_identifier`` by #150), so the substring branch + that used to map it to a dedicated category now falls through to the + quasi_identifier fallback via the entity-label lookup. + """ + if not isinstance(raw, str) or not raw.strip(): + return "quasi_identifier" + normalized = raw.strip().lower().replace("-", "_").replace(" ", "_") + if normalized in _VALID_CATEGORIES: + return normalized + if normalized.endswith("s") and normalized[:-1] in _VALID_CATEGORIES: + return normalized[:-1] + # Merged-enum hallucination: order = strongest protection wins so that + # "latent_direct_identifier" maps to direct_identifier rather than + # latent_identifier. ``"sensitive"`` and ``"quasi"`` both fold into + # quasi_identifier (the conservative protect-cautiously bucket). + for sub, target in ( + ("direct", "direct_identifier"), + ("quasi", "quasi_identifier"), + ("sensitive", "quasi_identifier"), + ("latent", "latent_identifier"), + ): + if sub in normalized: + return target + # Entity-label confusion: model wrote an entity_label value in the slot. + mapped = _ENTITY_LABEL_TO_CATEGORY.get(normalized) + if mapped is not None: + return mapped + if entity_label and normalized == entity_label.strip().lower(): + return "quasi_identifier" + return "quasi_identifier" + + +_VALID_METHODS: frozenset[str] = frozenset({"replace", "generalize", "remove", "suppress_inference", "leave_as_is"}) + + +def _normalize_method(raw: str) -> str: + """Resolve a (potentially case-drifted) protection_method_suggestion + string into a valid ProtectionMethod enum value. + + Returns ``""`` when no recognizable choice can be extracted, signaling + the caller should apply a pessimistic default. + + Strategy mirrors ``_normalize_category``: + * Empty / non-string -> ``""``. + * Exact match -> as-is. + * Substring match in priority order ``suppress_inference -> + leave_as_is -> generalize -> replace -> remove`` so e.g. + ``"replace_with_surrogate"`` resolves to ``"replace"``, + ``"leave_as_is_for_now"`` resolves to ``"leave_as_is"``. + """ + if not raw or not isinstance(raw, str): + return "" + cleaned = raw.strip().lower() + if cleaned in _VALID_METHODS: + return cleaned + for choice in ("suppress_inference", "leave_as_is", "generalize", "replace", "remove"): + if choice in cleaned: + return choice + return "" + + +def derive_combined_risk_level(category: str, method: str, sensitivity: str) -> str: + """Pick a ``CombinedRiskLevel`` consistent with ``EntityDispositionSchema._validate_protection_consistency``. + + The post-#163 invariant is: + * ``low`` requires ``method == "leave_as_is"`` + * ``high`` requires ``method != "leave_as_is"`` + * ``medium`` is permissive (either method) + + We always have ``method`` already (either echoed by the LLM or derived + pessimistically by the reconstructor), so we pick the risk level that + *passes the invariant* AND best reflects the inputs: + + * ``method == "leave_as_is"`` -> ``"low"``. This matches the spirit + of leaving a value alone: the model judged the row low risk, or + the entity is too utility-critical to mask. + * ``method != "leave_as_is"`` -> ``"high"`` if the entity is a direct + identifier of any sensitivity, or has ``sensitivity == "high"``, + else ``"medium"``. Direct identifiers are always high re-id risk + on their own; explicit ``high`` sensitivity from the LLM is + respected; otherwise we don't claim high without evidence. + + Returning ``"medium"`` rather than ``"high"`` for borderline cases + keeps the strict schema's leave_as_is_when_low rule from accidentally + constraining downstream re-validation if the disposition is + re-evaluated upstream. + """ + method = (method or "").strip() + if method == "leave_as_is": + return "low" + sens = (sensitivity or "").strip().lower() + cat = (category or "").strip().lower() + if cat == "direct_identifier": + return "high" + if sens == "high": + return "high" + return "medium" + + +# (category, method) -> template text (without leading sensitivity prefix). +# Sensitivity fills a prefix ("high-risk ...", "moderate-risk ...", ""). +_REASON_TEMPLATES: dict[tuple[str, str], str] = { + ("direct_identifier", "replace"): "direct identifier — replaced with a contextual surrogate", + ("direct_identifier", "remove"): "direct identifier — removed to prevent re-identification", + ("direct_identifier", "generalize"): "direct identifier — generalized to reduce re-identification", + ("direct_identifier", "suppress_inference"): "direct identifier — suppressed to prevent inference", + ("quasi_identifier", "generalize"): "quasi-identifier — generalized to reduce re-identification risk", + ("quasi_identifier", "replace"): "quasi-identifier — replaced with a plausible surrogate", + ("quasi_identifier", "remove"): "quasi-identifier — removed due to re-identification risk", + ("quasi_identifier", "suppress_inference"): "quasi-identifier — suppressed to prevent inference", + ("latent_identifier", "suppress_inference"): "latent inference — suppressed to prevent deduction", + ("latent_identifier", "remove"): "latent identifier — removed to prevent inference", + ("latent_identifier", "generalize"): "latent identifier — generalized to reduce inference", + ("latent_identifier", "replace"): "latent identifier — replaced with a less specific surrogate", +} + +_SENSITIVITY_PREFIX = {"low": "", "medium": "moderate-risk ", "high": "high-risk "} + +# Upper bound for a passthrough (model-authored) protection_reason. The schema +# no longer enforces max_length, so the reconstructor caps length here to keep +# rewrite prompts and parquet bounded without ever failing validation. +_MAX_PROTECTION_REASON_CHARS = 500 + + +def template_protection_reason(category: str, method: str, sensitivity: str) -> str: + """Build a reason string guaranteed >=10 chars. + + Used when the LLM omits or emits a too-short ``protection_reason`` + (the schema no longer enforces a min_length; this keeps a sensible + human-readable floor for the rewrite-context line). + Strong models that provide their own document-specific reason have + theirs kept verbatim by the reconstructor. + """ + method = (method or "").strip() + category = (category or "").strip() + sensitivity = (sensitivity or "").strip().lower() + + if method == "leave_as_is": + cat_label = category.replace("_", " ") if category else "entity" + return f"Low-risk {cat_label}; retained as-is for utility." + + base = _REASON_TEMPLATES.get((category, method)) + if base is None: + cat_label = category.replace("_", " ") if category else "entity" + method_label = method or "an appropriate method" + base = f"{cat_label} — protected via {method_label}" + + prefix = _SENSITIVITY_PREFIX.get(sensitivity, "") + reason = (prefix + base).strip() + return reason[:1].upper() + reason[1:] if reason else "Protection applied per policy." + + +# --------------------------------------------------------------------------- +# Entity-context flattening +# --------------------------------------------------------------------------- + + +def _coerce_entity_list(raw: object) -> list[dict]: + """DataDesigner hands context columns to custom generators in several + shapes: a pydantic-dump dict with a keyed list, a raw list, a JSON- + encoded string, or None. Normalize to a plain list of dicts. + """ + import json + + if raw is None: + return [] + if isinstance(raw, str): + raw = raw.strip() + if not raw: + return [] + try: + raw = json.loads(raw) + except Exception: + return [] + if isinstance(raw, dict): + raw_dict: dict = raw + for key in ("entities_by_value", "latent_entities", "entities", "items"): + inner = raw_dict.get(key) + if isinstance(inner, list): + raw = inner + break + else: + return [] + if not isinstance(raw, list): + return [] + out: list[dict] = [] + for item in raw: + if isinstance(item, dict): + out.append(item) + elif isinstance(item, str): + try: + parsed = json.loads(item) + if isinstance(parsed, dict): + out.append(parsed) + except Exception: + continue + return out + + +def _flatten_context(entities_by_value: object, latent_entities: object) -> list[dict]: + """Produce a flat, ordered list of ``{source, entity_label, entity_value}``. + + Order matches how the disposition prompt enumerates entities: tagged + entries from ``entities_by_value`` (one slot per ``(value, label)`` + pair) followed by latent entries. The returned list index+1 is the + expected id. + """ + flat: list[dict] = [] + for ev in _coerce_entity_list(entities_by_value): + value = ev.get("value", "") + labels = ev.get("labels") or [] + if not labels: + flat.append({"source": "tagged", "entity_label": "", "entity_value": value}) + continue + for label in labels: + flat.append({"source": "tagged", "entity_label": label, "entity_value": value}) + for le in _coerce_entity_list(latent_entities): + flat.append( + { + "source": "latent", + "entity_label": le.get("label", ""), + "entity_value": le.get("value", ""), + } + ) + return flat + + +# --------------------------------------------------------------------------- +# Reconstruction +# --------------------------------------------------------------------------- + + +def reconstruct_full_disposition( + simple: SimpleDispositionResult, + entities_by_value: object = None, + latent_entities: object = None, +) -> SensitivityDispositionSchema: + """Build the strict disposition from the loose LLM output + context columns. + + For each ``SimpleDispositionItem``: + - prefer the model-echoed ``source/entity_label/entity_value`` only + when the id falls outside the context range (orphan path); use + the trusted context entry when in range, since small models + routinely echo garbage. + - normalize ``category`` and ``method`` drift; pessimistically + default ``method`` to ``replace`` for high-risk entities when the + LLM omits it. + - derive ``combined_risk_level`` from category/sensitivity/method + such that ``EntityDispositionSchema._validate_protection_consistency`` + passes (``leave_as_is`` -> low; otherwise medium/high based on + category and sensitivity). ``needs_protection`` is the strict + schema's ``@property`` and falls out of method automatically. + - keep the LLM ``protection_reason`` if it stripped to >=10 chars, + else template one from (category, method, sensitivity). + + Orphan simple items (id outside the context range AND no usable + echoes) are skipped with a warning — better to return a smaller valid + schema than to drop the whole record. Duplicate ids are de-duplicated + (first occurrence wins). + + Raises ``ValidationError`` only when ``full_items`` is empty (the + strict schema requires ``min_length=1``); the workflow column wraps + this case with a try/except and emits an empty disposition rather + than dropping the row. + """ + context = _flatten_context(entities_by_value, latent_entities) + seen_ids: set[int] = set() + full_items: list[EntityDispositionSchema] = [] + + for item in simple.sensitivity_disposition: + if item.id in seen_ids: + logger.warning( + "reconstruct_full_disposition: duplicate id=%s in simple output; keeping first occurrence", + item.id, + ) + continue + seen_ids.add(item.id) + + idx = item.id - 1 + if 0 <= idx < len(context): + ctx = context[idx] + src = ctx["source"] + lbl = ctx["entity_label"] + val = ctx["entity_value"] + else: + echoed_src = (item.source or "").strip().lower() + src = echoed_src if echoed_src in {"tagged", "latent"} else "" + lbl = item.entity_label or "" + val = item.entity_value or "" + + if not src or not lbl or not val: + logger.warning( + "reconstruct_full_disposition: orphan simple item id=%s " + "(missing or drifted source/label/value, out of context range); skipping", + item.id, + ) + continue + + category = _normalize_category(item.category, entity_label=lbl) + sensitivity = (item.sensitivity or "").strip().lower() or "medium" + + raw_method = (item.protection_method_suggestion or "").strip().lower() + method = _normalize_method(raw_method) + if not method: + # Pessimistic default for omitted method: high-risk entities + # default to "replace"; everything else to "leave_as_is". + if category == "direct_identifier" or sensitivity in ("medium", "high"): + method = "replace" + else: + method = "leave_as_is" + + combined_risk = derive_combined_risk_level(category, method, sensitivity) + + raw_reason = (item.protection_reason or "").strip() + if len(raw_reason) >= 10: + # Passthrough from the model: cap length here (the schema no longer + # enforces max_length, so a rambling small-model reason would + # otherwise flow unbounded into the rewrite prompt and parquet). + reason = ( + raw_reason[: _MAX_PROTECTION_REASON_CHARS - 3].rstrip() + "..." + if (len(raw_reason) > _MAX_PROTECTION_REASON_CHARS) + else raw_reason + ) + else: + reason = template_protection_reason(category, method, sensitivity) + + full_items.append( + EntityDispositionSchema( + id=item.id, + source=src, + category=category, + sensitivity=sensitivity, + entity_label=lbl, + entity_value=val, + protection_method_suggestion=method, + combined_risk_level=combined_risk, + protection_reason=reason, + ) + ) + + return SensitivityDispositionSchema(sensitivity_disposition=full_items) diff --git a/src/anonymizer/engine/rewrite/qa_generation.py b/src/anonymizer/engine/rewrite/qa_generation.py index 03978657..e043adb7 100644 --- a/src/anonymizer/engine/rewrite/qa_generation.py +++ b/src/anonymizer/engine/rewrite/qa_generation.py @@ -39,12 +39,18 @@ ) # Derived from the schema so the Jinja key stays in sync with the field name. +# Prefer an annotation-typed lookup (strict-mode contract); fall back to a +# name-based lookup so wire-loose typing of ``domain`` (str instead of Domain +# enum) still resolves to the same field. The Domain enum hint is preserved +# in the field description and the ``_normalize_domain`` before-validator. _DOMAIN_KEY = next( (name for name, info in DomainClassificationSchema.model_fields.items() if info.annotation is Domain), None, ) +if _DOMAIN_KEY is None and "domain" in DomainClassificationSchema.model_fields: + _DOMAIN_KEY = "domain" if _DOMAIN_KEY is None: - raise RuntimeError("DomainClassificationSchema must define a field annotated with Domain") + raise RuntimeError("DomainClassificationSchema must define a 'domain' field") # --------------------------------------------------------------------------- # Stage 1 pre-step: format disposition → disposition block diff --git a/src/anonymizer/engine/rewrite/sensitivity_disposition.py b/src/anonymizer/engine/rewrite/sensitivity_disposition.py index fca24ba8..4d06c744 100644 --- a/src/anonymizer/engine/rewrite/sensitivity_disposition.py +++ b/src/anonymizer/engine/rewrite/sensitivity_disposition.py @@ -3,8 +3,13 @@ from __future__ import annotations -from data_designer.config.column_configs import LLMStructuredColumnConfig +import logging +from typing import Any + +from data_designer.config import custom_column_generator +from data_designer.config.column_configs import CustomColumnConfig, LLMStructuredColumnConfig from data_designer.config.column_types import ColumnConfigT +from pydantic import ValidationError from anonymizer.config.models import RewriteModelSelection from anonymizer.config.rewrite import PrivacyGoal @@ -14,13 +19,27 @@ COL_ENTITIES_BY_VALUE, COL_LATENT_ENTITIES, COL_SENSITIVITY_DISPOSITION, + COL_SIMPLE_DISPOSITION, COL_TAG_NOTATION, COL_TAGGED_TEXT, _jinja, ) from anonymizer.engine.ndd.model_loader import resolve_model_alias from anonymizer.engine.prompt_utils import substitute_placeholders -from anonymizer.engine.schemas import SensitivityDispositionSchema, StrictSensitivityDispositionSchema +from anonymizer.engine.rewrite.disposition_derivation import ( + _flatten_context, + derive_combined_risk_level, + reconstruct_full_disposition, + template_protection_reason, +) +from anonymizer.engine.schemas import ( + EntityDispositionSchema, + SensitivityDispositionSchema, + SimpleDispositionResult, +) +from anonymizer.engine.schemas.rewrite import _ENTITY_LABEL_TO_CATEGORY + +logger = logging.getLogger(__name__) def _get_sensitivity_disposition_prompt( @@ -257,6 +276,165 @@ def _get_sensitivity_disposition_prompt( # --------------------------------------------------------------------------- +# --------------------------------------------------------------------------- +# Pessimistic fallback when reconstruction yields nothing +# --------------------------------------------------------------------------- + + +def _pessimistic_fallback_disposition( + entities_by_value: object, + latent_entities: object, +) -> SensitivityDispositionSchema: + """Build a worst-case disposition from the entity context alone. + + Used when ``reconstruct_full_disposition`` returns an empty list — e.g. + every ``SimpleDispositionItem`` was an orphan, or the LLM emitted no + items at all. Without this fallback, downstream + ``parse_sensitivity_disposition`` raises ``ValidationError`` on + ``min_length=1`` and the row drops, defeating the whole loose-wire + + server-reconstruction architecture this PR exists to add. + + Policy (per Lipika/Andre's review on PR #130, addressing the + record-drop concern): + * ``direct_identifier`` -> ``replace`` (high risk, must be masked). + * everything else -> ``generalize`` (medium risk, mask but keep + rough semantics for utility). + + Categories come from the per-entity ``entity_label`` via the + ``_ENTITY_LABEL_TO_CATEGORY`` map (the same source of truth the + reconstructor uses for entity-label-stuffed-into-category drift); + unmapped labels fall back to ``quasi_identifier``. + """ + flat = _flatten_context(entities_by_value, latent_entities) + items: list[EntityDispositionSchema] = [] + for idx, slot in enumerate(flat, start=1): + label = (slot.get("entity_label") or "").strip() + value = (slot.get("entity_value") or "").strip() + source = slot.get("source") or "tagged" + if not label or not value: + continue + if source == "latent": + category = "latent_identifier" + else: + category = _ENTITY_LABEL_TO_CATEGORY.get(label, "quasi_identifier") + method = "replace" if category == "direct_identifier" else "generalize" + sensitivity = "high" if category == "direct_identifier" else "medium" + combined_risk = derive_combined_risk_level(category, method, sensitivity) + reason = template_protection_reason(category, method, sensitivity) + items.append( + EntityDispositionSchema( + id=idx, + source=source, + category=category, + sensitivity=sensitivity, + entity_label=label, + entity_value=value, + protection_method_suggestion=method, + combined_risk_level=combined_risk, + protection_reason=reason, + ) + ) + if not items: + # Genuinely no entities at all in context. The orchestrator should have + # short-circuited rows with no detected entities before this step, so + # this is a pipeline-invariant violation — but this is the last-resort + # path whose contract is "never drop the row." Emitting an empty list + # would raise on SensitivityDispositionSchema's (and the downstream + # parser's) min_length=1 invariant and drop the record, so we log loudly + # and emit a single no-op (leave_as_is/low) disposition instead. It is + # excluded from protected_entities, so it never reaches the rewrite. + logger.error( + "pessimistic fallback: empty entity context at the disposition step " + "(orchestrator should have short-circuited entity-free rows); emitting " + "a single no-op disposition so the row is not dropped" + ) + items.append( + EntityDispositionSchema( + id=1, + source="tagged", + category="quasi_identifier", + sensitivity="low", + entity_label="", + entity_value="", + protection_method_suggestion="leave_as_is", + combined_risk_level=derive_combined_risk_level("quasi_identifier", "leave_as_is", "low"), + protection_reason=template_protection_reason("quasi_identifier", "leave_as_is", "low"), + ) + ) + return SensitivityDispositionSchema(sensitivity_disposition=items) + + +# --------------------------------------------------------------------------- +# Reconstruction column +# --------------------------------------------------------------------------- + + +@custom_column_generator(required_columns=[COL_SIMPLE_DISPOSITION, COL_ENTITIES_BY_VALUE, COL_LATENT_ENTITIES]) +def _reconstruct_full_disposition_column(row: dict[str, Any]) -> dict[str, Any]: + """Rebuild the strict EntityDispositionSchema list from the loose LLM + output in ``COL_SIMPLE_DISPOSITION`` plus the entity context columns. + + Writes ``COL_SENSITIVITY_DISPOSITION`` so every downstream consumer + reads the same column name / shape as before this refactor. + + Empty-result fallback: when the model returns nothing usable (every + item is an orphan, or the LLM omitted the field entirely), build a + pessimistic disposition from the entity context (direct identifiers + -> replace, everything else -> generalize). This addresses the + record-drop concern Lipika and Andre raised on PR #130 — emitting an + empty disposition would have failed downstream + ``parse_sensitivity_disposition``'s ``min_length=1`` check anyway. + """ + simple_raw = row.get(COL_SIMPLE_DISPOSITION, {}) or {} + if isinstance(simple_raw, SimpleDispositionResult): + simple = simple_raw + else: + if isinstance(simple_raw, str): + import json as _json + + try: + simple_raw = _json.loads(simple_raw) + except Exception: + simple_raw = {} + try: + simple = SimpleDispositionResult.model_validate(simple_raw) + except ValidationError as exc: + logger.warning( + "reconstruct: SimpleDispositionResult failed to validate (%s); " + "falling back to pessimistic disposition from entity context", + str(exc)[:200], + ) + simple = SimpleDispositionResult() + + entities_by_value = row.get(COL_ENTITIES_BY_VALUE) + latent_entities = row.get(COL_LATENT_ENTITIES) + + if not simple.sensitivity_disposition: + logger.warning( + "reconstruct: empty SimpleDispositionResult for row; " + "falling back to pessimistic disposition from entity context" + ) + full = _pessimistic_fallback_disposition(entities_by_value, latent_entities) + else: + try: + full = reconstruct_full_disposition(simple, entities_by_value, latent_entities) + except ValidationError as exc: + logger.warning( + "reconstruct: ValidationError after orphan-skipping (likely all items out of context range); " + "falling back to pessimistic disposition. detail=%s", + str(exc)[:200], + ) + full = _pessimistic_fallback_disposition(entities_by_value, latent_entities) + + row[COL_SENSITIVITY_DISPOSITION] = full.model_dump() + return row + + +# --------------------------------------------------------------------------- +# Workflow +# --------------------------------------------------------------------------- + + class SensitivityDispositionWorkflow: def columns( self, @@ -266,17 +444,46 @@ def columns( data_summary: str | None = None, strict_entity_protection: bool = False, ) -> list[ColumnConfigT]: + """Two-step pipeline for small-model robustness: + + 1. LLM column emits the loose ``SimpleDispositionResult`` to a + hidden ``COL_SIMPLE_DISPOSITION`` column. The wire schema has + no enum/required/minLength constraints, so DataDesigner's + jsonschema pre-validate gate accepts drifted small-model + output that strict ``SensitivityDispositionSchema`` would + reject. ``drop=True`` keeps this internal hand-off out of the + user-facing preview DataFrame. + 2. Pure-python reconstruction column rebuilds the strict + ``SensitivityDispositionSchema`` from the loose wire output + plus the entity-context columns. No LLM call; deterministic; + handles id pairing, category/method drift normalization, + ``combined_risk_level`` derivation, and pessimistic fallback + when the LLM produces nothing usable. + + ``strict_entity_protection`` continues to flow into the prompt's + ```` block — the contract is enforced + at prompt time. The output_format selection between + ``SensitivityDispositionSchema`` and + ``StrictSensitivityDispositionSchema`` is no longer needed + because we always emit ``SimpleDispositionResult`` on the wire + and reconstruct into the canonical (non-strict) schema, which + downstream consumers already accept. + """ disposition_alias = resolve_model_alias("disposition_analyzer", selected_models) - output_schema = StrictSensitivityDispositionSchema if strict_entity_protection else SensitivityDispositionSchema return [ LLMStructuredColumnConfig( - name=COL_SENSITIVITY_DISPOSITION, + name=COL_SIMPLE_DISPOSITION, prompt=_get_sensitivity_disposition_prompt( privacy_goal, data_summary, strict_entity_protection=strict_entity_protection, ), model_alias=disposition_alias, - output_format=output_schema, + output_format=SimpleDispositionResult, + drop=True, + ), + CustomColumnConfig( + name=COL_SENSITIVITY_DISPOSITION, + generator_function=_reconstruct_full_disposition_column, ), ] diff --git a/src/anonymizer/engine/schemas/__init__.py b/src/anonymizer/engine/schemas/__init__.py index 96607d66..f9bc9946 100644 --- a/src/anonymizer/engine/schemas/__init__.py +++ b/src/anonymizer/engine/schemas/__init__.py @@ -53,6 +53,8 @@ RewriteOutputSchema, SensitivityDispositionSchema, SensitivityLevel, + SimpleDispositionItem, + SimpleDispositionResult, StrictCombinedRiskLevel, StrictEntityDispositionSchema, StrictProtectionMethod, @@ -109,6 +111,8 @@ "RewriteOutputSchema", "SensitivityDispositionSchema", "SensitivityLevel", + "SimpleDispositionItem", + "SimpleDispositionResult", "StrictCombinedRiskLevel", "StrictEntityDispositionSchema", "StrictProtectionMethod", diff --git a/src/anonymizer/engine/schemas/rewrite.py b/src/anonymizer/engine/schemas/rewrite.py index 912281dd..7e719199 100644 --- a/src/anonymizer/engine/schemas/rewrite.py +++ b/src/anonymizer/engine/schemas/rewrite.py @@ -40,7 +40,9 @@ from enum import Enum -from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, model_validator +from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator, model_validator + +from anonymizer.engine.schemas.shared import accept_bare_list, loose_list_wrapper_json_schema # --------------------------------------------------------------------------- # Domain @@ -79,10 +81,58 @@ class Domain(str, Enum): class DomainClassificationSchema(BaseModel): - """LLM output schema for domain classification step.""" + """LLM output schema for domain classification step. + + Wire contract is loose so DD's jsonschema pre-check cannot reject + enum or type drift: ``domain`` is typed as ``str`` (not the Domain + enum); ``domain_confidence`` accepts string-typed input. A pair of + before-validators normalize drift before pydantic enforces ranges. + Unknown domains fall back to ``Domain.OTHER``; unparseable + confidences fall back to ``0.5``. + """ - domain: Domain - domain_confidence: float = Field(ge=0.0, le=1.0) + domain: str = Field( + default=Domain.OTHER.value, + # Enumerated inline (derived from the enum) because the field is + # wire-typed as ``str`` and the enum is absent from the JSON schema the + # model sees; the description is its only source of valid values. + description=("One of: " + ", ".join(d.value for d in Domain) + ". Unknown values coerce to OTHER."), + ) + domain_confidence: float = Field(default=0.5, ge=0.0, le=1.0) + + @field_validator("domain", mode="before") + @classmethod + def _normalize_domain(cls, v: object) -> str: + if v is None or not isinstance(v, str) or not v.strip(): + return Domain.OTHER.value + cleaned = v.strip().upper().replace(" ", "_").replace("-", "_") + allowed = {d.value for d in Domain} + if cleaned in allowed: + return cleaned + # Substring match — pick first Domain that appears as substring. + for d in Domain: + if d.value in cleaned or cleaned in d.value: + return d.value + return Domain.OTHER.value + + @field_validator("domain_confidence", mode="before") + @classmethod + def _coerce_confidence(cls, v: object) -> float: + """Accept "0.95", "85%", or numeric input; clamp to [0, 1].""" + if isinstance(v, bool): + return 0.5 + if isinstance(v, (int, float)): + return max(0.0, min(1.0, float(v))) + if isinstance(v, str): + try: + raw = v.strip().rstrip("%") + val = float(raw) + if "%" in v: + val /= 100.0 + return max(0.0, min(1.0, val)) + except (ValueError, TypeError): + return 0.5 + return 0.5 # --------------------------------------------------------------------------- @@ -121,6 +171,109 @@ class CombinedRiskLevel(str, Enum): high = "high" +# --------------------------------------------------------------------------- +# Entity-label -> EntityCategory mapping. +# +# Used by the disposition reconstructor (engine/rewrite/disposition_derivation.py) +# when the disposition LLM outputs an entity_label string in the ``category`` +# slot — observed consistently with small Gemma models. Derived from two +# frozensets so the source of truth lives in one place per category. +# ``test_entity_label_to_category_covers_default_labels`` is a CI guard that +# fires when a label is added to ``DEFAULT_ENTITY_LABELS`` without a category +# assignment here. Any label not in this table falls back to +# ``"quasi_identifier"`` which is the most conservative (protect-cautiously) +# choice. +# +# Note: the original PR carried a third ``_SENSITIVE_ATTR_LABELS`` set +# mapping to a ``sensitive_attribute`` EntityCategory, but that enum value +# was removed when sensitivity disposition was recalibrated (#150) — the +# six former-sensitive labels (gender, sexuality, race_ethnicity, +# religious_belief, political_view, blood_type) are now folded into +# ``_QUASI_ID_LABELS`` (the conservative protect-cautiously choice given +# the now-3-value EntityCategory enum). +# --------------------------------------------------------------------------- + +_DIRECT_ID_LABELS: frozenset[str] = frozenset( + { + "first_name", + "last_name", + "email", + "phone_number", + "fax_number", + "ssn", + "national_id", + "street_address", + "postcode", + "credit_debit_card", + "account_number", + "bank_routing_number", + "tax_id", + "medical_record_number", + "health_plan_beneficiary_number", + "api_key", + "password", + "ipv4", + "ipv6", + "mac_address", + "url", + "user_name", + "employee_id", + "customer_id", + "unique_id", + "biometric_identifier", + "device_identifier", + "license_plate", + "vehicle_identifier", + "swift_bic", + "pin", + "cvv", + "http_cookie", + } +) +_QUASI_ID_LABELS: frozenset[str] = frozenset( + { + "age", + "date", + "date_of_birth", + "date_time", + "time", + "city", + "state", + "country", + "county", + "place_name", + "landmark", + "coordinate", + "occupation", + "organization_name", + "company_name", + "university", + "court_name", + "prison_detention_facility", + "degree", + "field_of_study", + "education_level", + "language", + "nationality", + "employment_status", + "monetary_amount", + "certificate_license_number", + # Former-sensitive labels (#150 collapsed sensitive_attribute into + # quasi_identifier when EntityCategory was reduced to 3 values). + "gender", + "sexuality", + "race_ethnicity", + "religious_belief", + "political_view", + "blood_type", + } +) + +_ENTITY_LABEL_TO_CATEGORY: dict[str, str] = {lbl: "direct_identifier" for lbl in _DIRECT_ID_LABELS} | { + lbl: "quasi_identifier" for lbl in _QUASI_ID_LABELS +} + + class EntityDispositionSchema(BaseModel): """Protection decision for one tagged or latent entity in rewrite planning. @@ -134,9 +287,14 @@ class EntityDispositionSchema(BaseModel): source: EntitySource category: EntityCategory sensitivity: SensitivityLevel - entity_label: str = Field(min_length=1) - entity_value: str = Field(min_length=1) - protection_reason: str = Field(min_length=10, max_length=500) + # No length constraints: this schema is reconstructed server-side from + # trusted entity context, never directly from raw model output. The + # reconstructor guarantees a non-empty templated protection_reason and + # caps its length, so length bounds here would only be redundant tripwires + # that risk dropping a record when a passthrough reason runs long. + entity_label: str + entity_value: str + protection_reason: str protection_method_suggestion: ProtectionMethod combined_risk_level: CombinedRiskLevel @@ -244,6 +402,108 @@ class StrictSensitivityDispositionSchema(SensitivityDispositionSchema): sensitivity_disposition: list[StrictEntityDispositionSchema] = Field(min_length=1) +# --------------------------------------------------------------------------- +# Loose wire-contract schemas for disposition (small-model tolerance) +# +# Used as the ``output_format`` for the disposition_analyzer LLM column. A +# server-side reconstruction column (see +# ``engine/rewrite/disposition_derivation.py``) pairs these with the entity +# context columns to produce the strict ``EntityDispositionSchema`` that +# downstream consumers read. +# --------------------------------------------------------------------------- + + +class SimpleDispositionItem(BaseModel): + """Loose wire-contract shape for one disposition decision from the LLM. + + Why "loose": every field is typed as ``str`` (not the corresponding + enum) and has a permissive default. This keeps the emitted JSON + Schema free of ``enum``, ``required``, and ``minLength`` constraints + that DataDesigner's ``jsonschema.validate()`` runs BEFORE pydantic's + coercion. Small models drift at those constraints; the loose wire + gate lets drifted output survive to the server-side reconstruction. + """ + + id: int = Field(ge=1) + # Echoed from the entity context for belt-and-braces pairing. Optional + # so the server can fall back to id-based lookup if the model omits. + source: str = Field(default="") + entity_label: str = Field(default="") + entity_value: str = Field(default="") + # LLM judgments; typed str so enum drift ("latent_sensitive_attribute", + # "DIRECT IDENTIFIER", etc.) is accepted at the wire layer and + # normalized during reconstruction. Valid values are enumerated inline + # (derived from the enums) since the wire JSON schema carries no enum. + category: str = Field( + default="", + description="One of: " + ", ".join(c.value for c in EntityCategory) + ".", + ) + sensitivity: str = Field( + default="", + description="One of: " + ", ".join(s.value for s in SensitivityLevel) + ".", + ) + protection_method_suggestion: str = Field( + default="", + description="One of: " + ", ".join(m.value for m in ProtectionMethod) + ".", + ) + # Optional: when the model emits a document-specific rationale we keep + # it verbatim; otherwise the reconstructor templates one from + # (category, method, sensitivity). + protection_reason: str = Field(default="") + + @field_validator( + "source", + "entity_label", + "entity_value", + "category", + "sensitivity", + "protection_method_suggestion", + "protection_reason", + mode="before", + ) + @classmethod + def _coerce_scalar_to_str(cls, v: object) -> str: + if v is None: + return "" + if isinstance(v, str): + return v + if isinstance(v, (int, float, bool)): + return str(v) + # Unexpected container (list/dict) from a drifted response: coerce to "" + # rather than letting pydantic raise on the whole SimpleDispositionItem. + # The reconstructor recovers the true value from trusted entity context. + return "" + + +class SimpleDispositionResult(BaseModel): + """Wire-contract wrapper around a list of ``SimpleDispositionItem``. + + Tolerates two LLM output shapes at the wire layer: + + 1. Canonical wrapper: ``{"sensitivity_disposition": [item, ...]}`` + 2. Bare list at the top level: ``[item, ...]`` — observed + consistently on ``nemotron-3-nano:4b`` (rewrite mode, 5/5 records) + and intermittently on Gemma4-edge models for dense entity sets. + + The fix is two-layered (see ``schemas/shared.py``): + + - ``__get_pydantic_json_schema__`` widens the emitted JSON Schema to + a ``oneOf`` of {wrapper-object, bare-array} so DD's + ``jsonschema.validate()`` pre-check accepts both. + - ``_accept_bare_list`` (mode="before") normalizes the bare-list + shape to the wrapper dict so downstream consumers continue to read + the canonical ``sensitivity_disposition`` field. + """ + + sensitivity_disposition: list[SimpleDispositionItem] = Field(default_factory=list) + + _accept_bare_list = model_validator(mode="before")(accept_bare_list(list_field="sensitivity_disposition")) + + @classmethod + def __get_pydantic_json_schema__(cls, schema, handler): + return loose_list_wrapper_json_schema(handler, schema, list_field="sensitivity_disposition") + + # --------------------------------------------------------------------------- # Meaning Units # --------------------------------------------------------------------------- @@ -274,17 +534,111 @@ class MeaningUnitImportance(str, Enum): class MeaningUnitSchema(BaseModel): - id: int = Field(ge=1) - aspect: MeaningUnitAspect - unit: str = Field(min_length=1) - importance: MeaningUnitImportance + """Single meaning unit extracted by the meaning_extractor role. + + Loose wire: ``aspect`` and ``importance`` are typed as ``str`` (not + the corresponding enum) so small-model enum drift (``"ROLE"`` vs + ``"role"``, ``"the role"``, ``"crit"``, etc.) is accepted at the + wire layer and normalized by before-validators. ``id`` defaults to 1 + so missing-id rows survive; ``MeaningUnitsSchema._ensure_list`` + re-numbers when ids collide. + """ + + id: int = Field(ge=1, default=1) + aspect: str = Field( + default="", + # Enumerated inline (derived from the enum, not hand-maintained) because + # the field is wire-typed as ``str``: the enum is absent from the JSON + # schema the model sees, so the description is its only source of truth. + description=("One of: " + ", ".join(a.value for a in MeaningUnitAspect) + ". Use the closest match."), + ) + unit: str = Field(default="") + importance: str = Field( + default=MeaningUnitImportance.important.value, + description=( + "One of: " + ", ".join(i.value for i in MeaningUnitImportance) + ". Defaults to 'important' if unsure." + ), + ) + + @field_validator("aspect", mode="before") + @classmethod + def _normalize_aspect(cls, v: object) -> str: + if v is None or not isinstance(v, str) or not v.strip(): + return "" + cleaned = v.strip().lower().replace(" ", "_").replace("-", "_") + allowed = {a.value for a in MeaningUnitAspect} + if cleaned in allowed: + return cleaned + for a in MeaningUnitAspect: + if a.value in cleaned or cleaned in a.value: + return a.value + return "" + + @field_validator("importance", mode="before") + @classmethod + def _normalize_importance(cls, v: object) -> str: + if v is None or not isinstance(v, str) or not v.strip(): + return MeaningUnitImportance.important.value + cleaned = v.strip().lower() + allowed = {a.value for a in MeaningUnitImportance} + if cleaned in allowed: + return cleaned + for a in MeaningUnitImportance: + if a.value in cleaned or cleaned in a.value: + return a.value + return MeaningUnitImportance.important.value + + @field_validator("unit", mode="before") + @classmethod + def _coerce_unit(cls, v: object) -> str: + if v is None: + return "" + if isinstance(v, (int, float, bool)): + return str(v) + return v class MeaningUnitsSchema(BaseModel): - """LLM output schema for meaning unit extraction step.""" + """LLM output schema for meaning unit extraction step. - # Non-empty by design: meaning extraction only runs when entities were detected. - units: list[MeaningUnitSchema] = Field(min_length=1) + Outer list has ``default_factory=list`` (was ``min_length=1``); if + the model emits an empty list the record still survives so the + pipeline can decide what to do downstream. + + Tolerates two LLM output shapes at the wire layer (same pattern as + ``SimpleDispositionResult``): + + 1. Canonical wrapper: ``{"units": [item, ...]}`` + 2. Bare list at the top level: ``[item, ...]`` — observed on + qwen3.5:4b for legal-court documents. + """ + + units: list[MeaningUnitSchema] = Field(default_factory=list) + + _accept_bare_list = model_validator(mode="before")(accept_bare_list(list_field="units")) + + @classmethod + def __get_pydantic_json_schema__(cls, schema, handler): + return loose_list_wrapper_json_schema(handler, schema, list_field="units") + + @field_validator("units", mode="before") + @classmethod + def _ensure_list(cls, v: object) -> list: + if not isinstance(v, list): + v = [v] if isinstance(v, dict) else [] + # ``MeaningUnitSchema.id`` defaults to 1; if the LLM omits ids the + # wire collapses every unit to id=1. Reassign sequentially when + # any id is missing or duplicated. Explicit unique ids are kept. + if isinstance(v, list) and v: + raw_ids = [item.get("id") if isinstance(item, dict) else getattr(item, "id", None) for item in v] + valid = [i for i in raw_ids if isinstance(i, int) and i >= 1] + if len(valid) != len(raw_ids) or len(set(valid)) != len(valid): + for idx, item in enumerate(v, start=1): + if isinstance(item, dict): + item["id"] = idx + elif hasattr(item, "id"): + item.id = idx # type: ignore[misc] + return v # --------------------------------------------------------------------------- @@ -364,6 +718,53 @@ def _validate_id_coverage(expected_ids: list[int], returned_ids: list[int], labe raise ValueError(f"Extra {label} IDs not in expected set: {extra}") +def _normalize_id_covered_list( + raw: list, + *, + expected_ids: list[int] | None, + default_item_for_id, +) -> list: + """Normalize a list of id-bearing items to exact-coverage shape. + + Used as the wire-layer normalizer on all context-validated answer + schemas (``QualityAnswersSchema`` / ``PrivacyAnswersSchema`` / + ``QACompareResultsSchema``). Dedupes by id (first occurrence wins), + pads missing ids via ``default_item_for_id(id) -> dict``, drops + extras. When ``expected_ids`` is None, only dedupes. + + Each item can be a dict or a BaseModel instance. Items without a + parseable ``id`` are skipped. + """ + seen: set[int] = set() + deduped: list = [] + for item in raw: + if hasattr(item, "model_dump"): + item = item.model_dump() + if not isinstance(item, dict): + continue + try: + iid = int(item.get("id")) + except (TypeError, ValueError): + continue + if iid in seen: + continue + seen.add(iid) + deduped.append(item) + + if expected_ids is None: + return deduped + + expected_set = set(expected_ids) + deduped = [it for it in deduped if int(it["id"]) in expected_set] + present = {int(it["id"]) for it in deduped} + for eid in expected_ids: + if eid not in present: + deduped.append(default_item_for_id(eid)) + order = {eid: i for i, eid in enumerate(expected_ids)} + deduped.sort(key=lambda it: order.get(int(it["id"]), len(order))) + return deduped + + class QualityAnswerSchema(BaseModel): id: int answer: str @@ -373,16 +774,44 @@ class QualityAnswersSchema(BaseModel): """LLM output schema for quality QA re-answer step (on rewritten text). When validated with ``context={"expected_ids": [1, 2, ...]}``, - enforces exact coverage: no missing, duplicate, or extra IDs. + normalizes the returned list to exactly that ID set: dedupes by id + (first wins), pads missing ids with a placeholder ``"missing"`` + answer, and drops extras. This prevents an LLM that emits a + duplicate or misses an id from dropping the whole record (observed + on gpt-oss-20b × 05_legal_court rewrite, qwen3.5:9b on 4-entity + notes). """ - answers: list[QualityAnswerSchema] + answers: list[QualityAnswerSchema] = Field(default_factory=list) + + @model_validator(mode="before") + @classmethod + def _normalize_answers(cls, data: object, info: ValidationInfo) -> object: + if not isinstance(data, dict): + return data + raw = data.get("answers") or [] + if not isinstance(raw, list): + raw = [] + expected = (info.context or {}).get("expected_ids") if info.context else None + data["answers"] = _normalize_id_covered_list( + raw, + expected_ids=expected, + default_item_for_id=lambda i: {"id": i, "answer": "missing"}, + ) + return data @model_validator(mode="after") def _check_coverage(self, info: ValidationInfo) -> QualityAnswersSchema: + # Soft check: the before-validator already normalized. If coverage + # still fails here it indicates a schema-level bug, not LLM drift. expected_ids = (info.context or {}).get("expected_ids") if expected_ids is not None: - _validate_id_coverage(expected_ids, [a.id for a in self.answers], "answer") + try: + _validate_id_coverage(expected_ids, [a.id for a in self.answers], "answer") + except ValueError as e: + import logging + + logging.getLogger(__name__).warning("QualityAnswersSchema post-normalization coverage warning: %s", e) return self @@ -390,24 +819,80 @@ class PrivacyAnswerItemSchema(BaseModel): id: int answer: PrivacyAnswer confidence: float = Field(ge=0.0, le=1.0) - reason: str = Field(min_length=1, max_length=200) + # No length constraints: _truncate_reason (below) already forces every + # value into a non-empty, <=200-char envelope before validation, so a + # min_length/max_length here could never trip and only adds drift surface. + reason: str evidence: list[str] = Field(default_factory=list) + @field_validator("reason", mode="before") + @classmethod + def _truncate_reason(cls, v: object) -> object: + """Coerce small-model reason drift into a non-empty, <=200-char + envelope rather than dropping the record. (The field no longer + carries min_length/max_length constraints; this validator is the + sole guard, so it always normalizes into range.) + + Three observed drift modes: + + - 250+ char prose (nemotron-3-nano on vLLM) → truncated to 197 + chars + "..." + - ``None`` (some models omit the field on ``answer="no"``) → + placeholder "no reason provided" + - Empty / whitespace-only string → placeholder (would otherwise + fail ``min_length=1``) + """ + if isinstance(v, str) and len(v) > 200: + return v[:197].rstrip() + "..." + if v is None or (isinstance(v, str) and not v.strip()): + return "no reason provided" + return v + class PrivacyAnswersSchema(BaseModel): """LLM output schema for privacy QA re-answer step (on rewritten text). When validated with ``context={"expected_ids": [1, 2, ...]}``, - enforces exact coverage: no missing, duplicate, or extra IDs. + normalizes the returned list to exactly that ID set: dedupes by id + (first wins), pads missing ids with a pessimistic answer + (``"yes"`` = assume leak), and drops extras. Pessimistic default + because a missing answer should bias toward triggering human review + rather than silently passing. """ - answers: list[PrivacyAnswerItemSchema] + answers: list[PrivacyAnswerItemSchema] = Field(default_factory=list) + + @model_validator(mode="before") + @classmethod + def _normalize_answers(cls, data: object, info: ValidationInfo) -> object: + if not isinstance(data, dict): + return data + raw = data.get("answers") or [] + if not isinstance(raw, list): + raw = [] + expected = (info.context or {}).get("expected_ids") if info.context else None + data["answers"] = _normalize_id_covered_list( + raw, + expected_ids=expected, + default_item_for_id=lambda i: { + "id": i, + "answer": "yes", # pessimistic default — flag for human review + "confidence": 0.5, + "reason": "missing answer - defaulted to pessimistic", + }, + ) + return data @model_validator(mode="after") def _check_coverage(self, info: ValidationInfo) -> PrivacyAnswersSchema: expected_ids = (info.context or {}).get("expected_ids") if expected_ids is not None: - _validate_id_coverage(expected_ids, [a.id for a in self.answers], "answer") + try: + _validate_id_coverage(expected_ids, [a.id for a in self.answers], "answer") + except ValueError as e: + import logging + + logging.getLogger(__name__).warning("PrivacyAnswersSchema post-normalization coverage warning: %s", e) return self @@ -421,14 +906,36 @@ class QACompareResultsSchema(BaseModel): """LLM output schema for quality QA comparison step. When validated with ``context={"expected_ids": [1, 2, ...]}``, - enforces exact coverage: no missing, duplicate, or extra IDs. + normalizes the returned list to exactly that ID set: dedupes, pads + missing ids with a neutral 0.5 score, and drops extras. """ - per_item: list[QACompareItemSchema] + per_item: list[QACompareItemSchema] = Field(default_factory=list) + + @model_validator(mode="before") + @classmethod + def _normalize_per_item(cls, data: object, info: ValidationInfo) -> object: + if not isinstance(data, dict): + return data + raw = data.get("per_item") or [] + if not isinstance(raw, list): + raw = [] + expected = (info.context or {}).get("expected_ids") if info.context else None + data["per_item"] = _normalize_id_covered_list( + raw, + expected_ids=expected, + default_item_for_id=lambda i: {"id": i, "score": 0.5, "reason": None}, + ) + return data @model_validator(mode="after") def _check_coverage(self, info: ValidationInfo) -> QACompareResultsSchema: expected_ids = (info.context or {}).get("expected_ids") if expected_ids is not None: - _validate_id_coverage(expected_ids, [a.id for a in self.per_item], "compare") + try: + _validate_id_coverage(expected_ids, [a.id for a in self.per_item], "compare") + except ValueError as e: + import logging + + logging.getLogger(__name__).warning("QACompareResultsSchema post-normalization coverage warning: %s", e) return self diff --git a/src/anonymizer/engine/schemas/shared.py b/src/anonymizer/engine/schemas/shared.py index f2eae732..7086c6bd 100644 --- a/src/anonymizer/engine/schemas/shared.py +++ b/src/anonymizer/engine/schemas/shared.py @@ -4,12 +4,15 @@ from __future__ import annotations import json +import logging from typing import TypeVar from pydantic import BaseModel, ValidationError T = TypeVar("T", bound=BaseModel) +_logger = logging.getLogger(__name__) + def _parse_raw_wrapper( model_cls: type[T], @@ -53,3 +56,66 @@ def _safe_validate(candidate_list: list[object]) -> T: if isinstance(as_list, list): return _safe_validate(as_list) return model_cls() + + +# --------------------------------------------------------------------------- +# Loose-list-wrapper helpers +# +# Used by wire-shape schemas whose LLM emitters sometimes drop the wrapper +# key, returning a bare list at the top level (observed: nemotron-3-nano:4b +# on rewrite-mode disposition; qwen3.5:4b on legal-court meaning units). +# Two helpers, one for the JSON Schema widening (consumed by DataDesigner's +# pre-validate gate) and one for the runtime before-validator. +# --------------------------------------------------------------------------- + + +def loose_list_wrapper_json_schema(handler, schema, *, list_field: str) -> dict: + """Widen a wrapper-style pydantic schema to ``oneOf({wrapper}, {array})``. + + DataDesigner runs ``jsonschema.validate()`` on raw LLM output BEFORE + pydantic's before-validators run. If a small model returns + ``[item, ...]`` instead of ``{list_field: [item, ...]}``, the strict + ``type: object`` pre-check rejects the row and the record is dropped. + Widening to a ``oneOf`` of the wrapper-object and the bare-array shape + lets both pass the pre-check; the runtime ``accept_bare_list`` + validator then normalizes to the canonical wrapper form. + + Falls back gracefully to the unwidened wrapper if pydantic ever + restructures so the inline schema for ``list_field`` is not directly + accessible (e.g. moves behind a ``$ref``). Logs a warning so future + regressions are visible at runtime. + """ + wrapped = handler(schema) + items = wrapped.get("properties", {}).get(list_field) + # Degrade gracefully if the property is missing entirely OR if it's + # only a $ref pointer (a future pydantic refactor moving the inline + # array schema behind $defs would do this — we can't safely use a + # ref pointer as the standalone bare-list branch of oneOf because + # DD's jsonschema gate would resolve it against the wrapper's $defs + # and the semantics get murky). + if not isinstance(items, dict) or set(items.keys()) <= {"$ref"}: + _logger.warning( + "loose_list_wrapper_json_schema: inline schema for %r unavailable in %r " + "(items=%r); skipping oneOf widening (DD pre-validate may reject bare-list shape)", + list_field, + wrapped.get("title"), + items, + ) + return wrapped + return {"oneOf": [wrapped, items]} + + +def accept_bare_list(*, list_field: str): + """Build a ``mode="before"`` validator that wraps a top-level bare list. + + Returns a ``classmethod`` suitable as a ``model_validator(mode="before")`` + callable: maps ``[item, ...]`` -> ``{list_field: [item, ...]}`` and + passes anything else through unchanged. + """ + + def _wrap(cls, data): + if isinstance(data, list): + return {list_field: data} + return data + + return classmethod(_wrap) diff --git a/tests/engine/test_disposition_reconstructor.py b/tests/engine/test_disposition_reconstructor.py new file mode 100644 index 00000000..3534f536 --- /dev/null +++ b/tests/engine/test_disposition_reconstructor.py @@ -0,0 +1,473 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the server-side disposition reconstructor. + +Covers: + * ``_normalize_category`` — display variants, merged enums, entity-label + confusion, fallback to quasi_identifier. + * ``_normalize_method`` — case drift, substring match priority, omission. + * ``derive_combined_risk_level`` — invariant-passing pick from method. + * ``reconstruct_full_disposition`` — pairing, orphan skip, dedupe, empty + fallback contract. + * ``_pessimistic_fallback_disposition`` — direct=replace, others=generalize. + * ``_reconstruct_full_disposition_column`` — workflow-column glue + empty + + ValidationError fallbacks. + +These tests pin the contracts the rewrite + QA + repair pipelines depend on +when small-model drift collapses the LLM disposition into a sparse +``SimpleDispositionResult``. +""" + +from __future__ import annotations + +import pandas as pd + +from anonymizer.engine.constants import ( + COL_ENTITIES_BY_VALUE, + COL_LATENT_ENTITIES, + COL_SENSITIVITY_DISPOSITION, + COL_SIMPLE_DISPOSITION, +) +from anonymizer.engine.rewrite.disposition_derivation import ( + _normalize_category, + _normalize_method, + derive_combined_risk_level, + reconstruct_full_disposition, + template_protection_reason, +) +from anonymizer.engine.rewrite.sensitivity_disposition import ( + _pessimistic_fallback_disposition, + _reconstruct_full_disposition_column, +) +from anonymizer.engine.schemas import SimpleDispositionResult + +# --------------------------------------------------------------------------- +# _normalize_category +# --------------------------------------------------------------------------- + + +class TestNormalizeCategory: + def test_canonical_passthrough(self) -> None: + assert _normalize_category("direct_identifier") == "direct_identifier" + assert _normalize_category("quasi_identifier") == "quasi_identifier" + assert _normalize_category("latent_identifier") == "latent_identifier" + + def test_display_variants(self) -> None: + assert _normalize_category("Direct-Identifier") == "direct_identifier" + assert _normalize_category("DIRECT IDENTIFIER") == "direct_identifier" + assert _normalize_category("DIRECT IDENTIFIERS") == "direct_identifier" + + def test_merged_enum_strongest_protection_wins(self) -> None: + """Nemotron emits ``"latent_direct_identifier"`` — direct should win + because re-id risk > inference risk.""" + assert _normalize_category("latent_direct_identifier") == "direct_identifier" + + def test_sensitive_substring_folds_into_quasi(self) -> None: + """``sensitive_attribute`` was removed from ``EntityCategory`` (#150); + the substring branch folds it to quasi_identifier as the conservative + protect-cautiously bucket.""" + assert _normalize_category("latent_sensitive_attribute") == "quasi_identifier" + assert _normalize_category("sensitive_attribute") == "quasi_identifier" + + def test_entity_label_in_category_slot_resolves(self) -> None: + assert _normalize_category("first_name") == "direct_identifier" + assert _normalize_category("date_of_birth") == "quasi_identifier" + assert _normalize_category("gender") == "quasi_identifier" # post-#150 fold + + def test_entity_label_echo_falls_back_to_quasi(self) -> None: + """When the model echoes the same entity_label into both fields, + quasi_identifier is the conservative fallback.""" + assert _normalize_category("zzz_unknown", entity_label="zzz_unknown") == "quasi_identifier" + + def test_blank_or_none_falls_back(self) -> None: + assert _normalize_category(None) == "quasi_identifier" + assert _normalize_category("") == "quasi_identifier" + assert _normalize_category(" ") == "quasi_identifier" + assert _normalize_category(123) == "quasi_identifier" + + def test_truly_unknown_falls_back_to_quasi(self) -> None: + assert _normalize_category("xyz_zzz_random_999") == "quasi_identifier" + + +# --------------------------------------------------------------------------- +# _normalize_method +# --------------------------------------------------------------------------- + + +class TestNormalizeMethod: + def test_canonical_passthrough(self) -> None: + for choice in ("replace", "generalize", "remove", "suppress_inference", "leave_as_is"): + assert _normalize_method(choice) == choice + + def test_uppercase_drift(self) -> None: + assert _normalize_method("REPLACE") == "replace" + assert _normalize_method("Generalize") == "generalize" + + def test_substring_match_priority(self) -> None: + """``"replace_with_surrogate"`` -> replace; ``"leave_as_is_for_now"`` + -> leave_as_is. Substring match is on the canonical underscored form + so it sees the wire-typical ``"suppress_inference_of_the_value"``; + space-form variants fall through to the empty-string return (which + the reconstructor handles via its pessimistic-default path).""" + assert _normalize_method("replace_with_surrogate") == "replace" + assert _normalize_method("leave_as_is_for_now") == "leave_as_is" + assert _normalize_method("suppress_inference_of_the_value") == "suppress_inference" + + def test_unknown_returns_empty_string(self) -> None: + """Empty signals to the caller to apply a pessimistic default.""" + assert _normalize_method("totally novel method") == "" + assert _normalize_method("") == "" + assert _normalize_method(None) == "" # type: ignore[arg-type] + + +# --------------------------------------------------------------------------- +# derive_combined_risk_level +# --------------------------------------------------------------------------- + + +class TestDeriveCombinedRiskLevel: + def test_leave_as_is_yields_low(self) -> None: + """``low + leave_as_is`` is the only combination the + ``_validate_protection_consistency`` invariant accepts for low; the + reconstructor mirrors that.""" + for cat in ("direct_identifier", "quasi_identifier", "latent_identifier"): + for sens in ("low", "medium", "high"): + assert derive_combined_risk_level(cat, "leave_as_is", sens) == "low" + + def test_direct_identifier_with_protection_yields_high(self) -> None: + assert derive_combined_risk_level("direct_identifier", "replace", "low") == "high" + assert derive_combined_risk_level("direct_identifier", "remove", "medium") == "high" + + def test_high_sensitivity_yields_high(self) -> None: + assert derive_combined_risk_level("quasi_identifier", "replace", "high") == "high" + assert derive_combined_risk_level("latent_identifier", "remove", "high") == "high" + + def test_otherwise_yields_medium(self) -> None: + assert derive_combined_risk_level("quasi_identifier", "generalize", "medium") == "medium" + assert derive_combined_risk_level("latent_identifier", "suppress_inference", "low") == "medium" + + def test_picks_pass_validate_protection_consistency(self) -> None: + """Spot-check: reconstructor's pick is always accepted by the strict + schema's ``_validate_protection_consistency``. Any combination of + (category, method, sensitivity) the reconstructor sees should + produce a valid (combined_risk_level, method) pair.""" + from anonymizer.engine.schemas.rewrite import EntityDispositionSchema + + for cat in ("direct_identifier", "quasi_identifier", "latent_identifier"): + for method in ("replace", "generalize", "remove", "suppress_inference", "leave_as_is"): + for sens in ("low", "medium", "high"): + risk = derive_combined_risk_level(cat, method, sens) + EntityDispositionSchema( + id=1, + source="tagged", + category=cat, + sensitivity=sens, + entity_label="x", + entity_value="y", + protection_reason=template_protection_reason(cat, method, sens), + protection_method_suggestion=method, + combined_risk_level=risk, + ) + + +# --------------------------------------------------------------------------- +# reconstruct_full_disposition +# --------------------------------------------------------------------------- + + +class TestReconstructFullDisposition: + def test_id_indexed_pairing_uses_context_over_echo(self) -> None: + """Belt-and-braces: when context has the entity at id-1, trust it + over the echoed labels (gemma4-e2b emits garbage in the echoes).""" + ebv = [{"value": "Alice", "labels": ["first_name"]}] + simple = SimpleDispositionResult.model_validate( + [ + { + "id": 1, + "source": "garbage", + "entity_label": "wrong_label", + "entity_value": "wrong_value", + "category": "direct_identifier", + "sensitivity": "high", + "protection_method_suggestion": "replace", + } + ] + ) + result = reconstruct_full_disposition(simple, ebv, []) + item = result.sensitivity_disposition[0] + assert item.entity_label == "first_name" + assert item.entity_value == "Alice" + assert item.source == "tagged" + + def test_orphan_skipped_with_warning(self) -> None: + """Items with id outside the context range AND missing/garbage + echoes are skipped — better to return a smaller valid schema than + to drop the row.""" + ebv = [{"value": "Alice", "labels": ["first_name"]}] + simple = SimpleDispositionResult.model_validate( + [ + {"id": 1, "category": "direct_identifier", "sensitivity": "high"}, + {"id": 99, "category": "direct_identifier"}, # orphan + ] + ) + result = reconstruct_full_disposition(simple, ebv, []) + assert [item.id for item in result.sensitivity_disposition] == [1] + + def test_duplicate_ids_first_wins(self) -> None: + ebv = [ + {"value": "Alice", "labels": ["first_name"]}, + {"value": "Bob", "labels": ["first_name"]}, + ] + simple = SimpleDispositionResult.model_validate( + [ + {"id": 1, "category": "direct_identifier", "sensitivity": "high"}, + {"id": 1, "category": "quasi_identifier", "sensitivity": "low"}, # duplicate + {"id": 2, "category": "direct_identifier", "sensitivity": "high"}, + ] + ) + result = reconstruct_full_disposition(simple, ebv, []) + assert [item.id for item in result.sensitivity_disposition] == [1, 2] + assert result.sensitivity_disposition[0].category == "direct_identifier" + + def test_short_reason_replaced_by_template(self) -> None: + ebv = [{"value": "Alice", "labels": ["first_name"]}] + simple = SimpleDispositionResult.model_validate( + [ + { + "id": 1, + "category": "direct_identifier", + "sensitivity": "high", + "protection_method_suggestion": "replace", + "protection_reason": "ok", + } + ] + ) + result = reconstruct_full_disposition(simple, ebv, []) + # Templated reason is >=10 chars and reflects (high, replace, direct). + assert len(result.sensitivity_disposition[0].protection_reason) >= 10 + assert "direct identifier" in result.sensitivity_disposition[0].protection_reason.lower() + + def test_long_llm_reason_kept_verbatim(self) -> None: + ebv = [{"value": "Alice", "labels": ["first_name"]}] + long_reason = "A document-specific judgement that the model should preserve" + simple = SimpleDispositionResult.model_validate( + [ + { + "id": 1, + "category": "direct_identifier", + "sensitivity": "high", + "protection_method_suggestion": "replace", + "protection_reason": long_reason, + } + ] + ) + result = reconstruct_full_disposition(simple, ebv, []) + assert result.sensitivity_disposition[0].protection_reason == long_reason + + def test_rambling_reason_is_capped(self) -> None: + """The schema no longer enforces max_length on protection_reason, so + the reconstructor must cap a runaway small-model reason itself to keep + rewrite prompts and parquet bounded (silent truncate, never a drop).""" + ebv = [{"value": "Alice", "labels": ["first_name"]}] + rambling = "This entity is sensitive because " + ("blah " * 300) + simple = SimpleDispositionResult.model_validate( + [ + { + "id": 1, + "category": "direct_identifier", + "sensitivity": "high", + "protection_method_suggestion": "replace", + "protection_reason": rambling, + } + ] + ) + result = reconstruct_full_disposition(simple, ebv, []) + capped = result.sensitivity_disposition[0].protection_reason + assert len(capped) <= 500 + assert capped.endswith("...") + + def test_omitted_method_pessimistic_default(self) -> None: + """When the LLM omits ``protection_method_suggestion``, default to + ``replace`` for direct/medium-or-high sensitivity, else + ``leave_as_is``.""" + ebv = [{"value": "Alice", "labels": ["first_name"]}] + simple = SimpleDispositionResult.model_validate( + [{"id": 1, "category": "direct_identifier", "sensitivity": "high"}] + ) + result = reconstruct_full_disposition(simple, ebv, []) + assert result.sensitivity_disposition[0].protection_method_suggestion == "replace" + + def test_combined_risk_level_derived_from_method(self) -> None: + ebv = [{"value": "Alice", "labels": ["first_name"]}] + simple = SimpleDispositionResult.model_validate( + [ + { + "id": 1, + "category": "direct_identifier", + "sensitivity": "high", + "protection_method_suggestion": "replace", + } + ] + ) + result = reconstruct_full_disposition(simple, ebv, []) + assert result.sensitivity_disposition[0].combined_risk_level == "high" + + +# --------------------------------------------------------------------------- +# _pessimistic_fallback_disposition +# --------------------------------------------------------------------------- + + +class TestPessimisticFallback: + def test_direct_identifiers_get_replace(self) -> None: + ebv = [{"value": "Alice", "labels": ["first_name"]}] + result = _pessimistic_fallback_disposition(ebv, []) + item = result.sensitivity_disposition[0] + assert item.protection_method_suggestion == "replace" + assert item.combined_risk_level == "high" + + def test_quasi_identifiers_get_generalize(self) -> None: + ebv = [{"value": "40", "labels": ["age"]}] + result = _pessimistic_fallback_disposition(ebv, []) + item = result.sensitivity_disposition[0] + assert item.protection_method_suggestion == "generalize" + + def test_latent_entities_marked_latent_identifier(self) -> None: + result = _pessimistic_fallback_disposition([], [{"label": "occupation", "value": "doctor"}]) + item = result.sensitivity_disposition[0] + assert item.source == "latent" + assert item.category == "latent_identifier" + assert item.protection_method_suggestion == "generalize" + + def test_unmapped_label_falls_back_to_quasi(self) -> None: + ebv = [{"value": "x", "labels": ["wholly_novel_label"]}] + result = _pessimistic_fallback_disposition(ebv, []) + item = result.sensitivity_disposition[0] + assert item.category == "quasi_identifier" + + def test_empty_context_returns_valid_noop_instead_of_raising(self) -> None: + """Empty context (a pipeline-invariant violation) must not raise on the + SensitivityDispositionSchema min_length=1 tripwire and drop the row; + the fallback emits a single no-op (leave_as_is/low) disposition.""" + result = _pessimistic_fallback_disposition([], []) + assert len(result.sensitivity_disposition) == 1 + item = result.sensitivity_disposition[0] + assert item.protection_method_suggestion == "leave_as_is" + assert item.combined_risk_level == "low" + assert result.protected_entities == [] # no-op never reaches the rewrite + + def test_all_blank_slots_returns_valid_noop(self) -> None: + """Slots whose label/value strip to empty are skipped; if that empties + the disposition, the no-op guarantee still holds (no raise/no drop).""" + result = _pessimistic_fallback_disposition([{"value": "", "labels": [""]}], []) + assert len(result.sensitivity_disposition) == 1 + assert result.sensitivity_disposition[0].protection_method_suggestion == "leave_as_is" + + +# --------------------------------------------------------------------------- +# _reconstruct_full_disposition_column (workflow glue) +# --------------------------------------------------------------------------- + + +class TestReconstructionColumn: + def _row(self, simple_payload, ebv=None, latent=None) -> dict: + return { + COL_SIMPLE_DISPOSITION: simple_payload, + COL_ENTITIES_BY_VALUE: ebv or [], + COL_LATENT_ENTITIES: latent or [], + } + + def test_dict_payload(self) -> None: + ebv = [{"value": "Alice", "labels": ["first_name"]}] + row = self._row( + {"sensitivity_disposition": [{"id": 1, "category": "direct_identifier", "sensitivity": "high"}]}, + ebv=ebv, + ) + out = _reconstruct_full_disposition_column(row) + assert out[COL_SENSITIVITY_DISPOSITION]["sensitivity_disposition"][0]["entity_label"] == "first_name" + + def test_json_string_payload(self) -> None: + import json + + ebv = [{"value": "Alice", "labels": ["first_name"]}] + row = self._row( + json.dumps( + {"sensitivity_disposition": [{"id": 1, "category": "direct_identifier", "sensitivity": "high"}]} + ), + ebv=ebv, + ) + out = _reconstruct_full_disposition_column(row) + assert out[COL_SENSITIVITY_DISPOSITION]["sensitivity_disposition"][0]["entity_label"] == "first_name" + + def test_empty_simple_falls_back_to_pessimistic(self) -> None: + """Lipika/Andre's review concern: empty reconstruction must NOT + emit ``{"sensitivity_disposition": []}`` (would fail downstream + ``parse_sensitivity_disposition`` on min_length=1). Instead build + a pessimistic disposition from the entity context.""" + ebv = [{"value": "Alice", "labels": ["first_name"]}] + row = self._row({"sensitivity_disposition": []}, ebv=ebv) + out = _reconstruct_full_disposition_column(row) + items = out[COL_SENSITIVITY_DISPOSITION]["sensitivity_disposition"] + assert len(items) == 1 + assert items[0]["entity_label"] == "first_name" + assert items[0]["protection_method_suggestion"] == "replace" + + def test_empty_simple_and_empty_context_does_not_drop_row(self) -> None: + """Both unguarded fallback call-sites: empty simple output AND empty + entity context must still yield a valid row (the column generator must + never raise out and drop the record).""" + row = self._row({"sensitivity_disposition": []}, ebv=[], latent=[]) + out = _reconstruct_full_disposition_column(row) + items = out[COL_SENSITIVITY_DISPOSITION]["sensitivity_disposition"] + assert len(items) == 1 + assert items[0]["protection_method_suggestion"] == "leave_as_is" + + def test_invalid_simple_payload_falls_back(self) -> None: + """If ``SimpleDispositionResult.model_validate`` raises, fall back + to pessimistic disposition rather than dropping the row.""" + ebv = [{"value": "Alice", "labels": ["first_name"]}] + # A list with a non-dict item is not a valid SimpleDispositionItem; + # validation would raise — fallback should kick in. + row = self._row(["not a dict at all"], ebv=ebv) + out = _reconstruct_full_disposition_column(row) + assert len(out[COL_SENSITIVITY_DISPOSITION]["sensitivity_disposition"]) == 1 + + def test_garbage_string_payload_falls_back(self) -> None: + ebv = [{"value": "Alice", "labels": ["first_name"]}] + row = self._row("not json {}{ at all", ebv=ebv) + out = _reconstruct_full_disposition_column(row) + assert len(out[COL_SENSITIVITY_DISPOSITION]["sensitivity_disposition"]) == 1 + + +# --------------------------------------------------------------------------- +# Sanity: the reconstructor closes the loop via parse_sensitivity_disposition +# --------------------------------------------------------------------------- + + +def test_round_trip_reconstructed_disposition_parses() -> None: + """End-to-end: reconstructor output -> parse_sensitivity_disposition + must succeed. This is the contract that broke in PR #130's original + empty-list fallback (Lipika+Andre review).""" + from anonymizer.engine.rewrite.parsers import parse_sensitivity_disposition + + ebv = [{"value": "Alice", "labels": ["first_name"]}] + simple = SimpleDispositionResult.model_validate([{"id": 1, "category": "direct_identifier", "sensitivity": "high"}]) + result = reconstruct_full_disposition(simple, ebv, []) + parsed = parse_sensitivity_disposition(result.model_dump()) + assert len(parsed.sensitivity_disposition) == 1 + + +def test_round_trip_pessimistic_fallback_parses() -> None: + from anonymizer.engine.rewrite.parsers import parse_sensitivity_disposition + + ebv = [{"value": "Alice", "labels": ["first_name"]}] + result = _pessimistic_fallback_disposition(ebv, []) + parsed = parse_sensitivity_disposition(result.model_dump()) + assert len(parsed.sensitivity_disposition) == 1 + + +# Suppress unused-import warning for pandas in pyright/strict modes; pandas +# may be needed by future fixtures and pre-importing it gives a more +# reliable failure surface than a lazy import inside a test that fires only +# on schema regression. +_ = pd diff --git a/tests/engine/test_entity_label_category_map.py b/tests/engine/test_entity_label_category_map.py new file mode 100644 index 00000000..92d6261b --- /dev/null +++ b/tests/engine/test_entity_label_category_map.py @@ -0,0 +1,49 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""CI guard: every label in ``DEFAULT_ENTITY_LABELS`` has a category mapping. + +The disposition reconstructor (``engine/rewrite/disposition_derivation.py``) +falls back to the entity-label-to-category map when small models stuff an +entity label into the ``category`` slot, and the pessimistic fallback +disposition (``sensitivity_disposition._pessimistic_fallback_disposition``) +uses it to pick ``replace`` vs ``generalize``. A new ``DEFAULT_ENTITY_LABELS`` +entry without a corresponding category mapping silently degrades both paths +to the conservative ``quasi_identifier`` -> generalize bucket — fine for +attributes, wrong for any new direct identifier (which should be ``replace``). + +This regression makes the gap visible: a contributor adding a new label +must extend the category map at the same time, or this test fails with a +diff. +""" + +from __future__ import annotations + +from anonymizer.engine.constants import DEFAULT_ENTITY_LABELS +from anonymizer.engine.schemas.rewrite import _ENTITY_LABEL_TO_CATEGORY + + +def test_entity_label_to_category_covers_default_labels() -> None: + missing = sorted(set(DEFAULT_ENTITY_LABELS) - set(_ENTITY_LABEL_TO_CATEGORY)) + assert not missing, ( + "Labels in DEFAULT_ENTITY_LABELS without a category mapping in " + "anonymizer.engine.schemas.rewrite._ENTITY_LABEL_TO_CATEGORY:\n " + + "\n ".join(missing) + + "\n\nAdd each label to one of _DIRECT_ID_LABELS or _QUASI_ID_LABELS in " + "src/anonymizer/engine/schemas/rewrite.py. Direct identifiers (names, " + "ids, contact info) -> _DIRECT_ID_LABELS; everything else -> " + "_QUASI_ID_LABELS (the conservative protect-cautiously bucket)." + ) + + +def test_entity_label_to_category_values_are_valid_categories() -> None: + """Every value in the map must be a member of the current ``EntityCategory`` + enum. Catches the case where ``EntityCategory`` is shrunk (cf. #150 + removing ``sensitive_attribute``) without updating the map.""" + from anonymizer.engine.schemas.rewrite import EntityCategory + + valid = {c.value for c in EntityCategory} + invalid = {label: cat for label, cat in _ENTITY_LABEL_TO_CATEGORY.items() if cat not in valid} + assert not invalid, ( + f"_ENTITY_LABEL_TO_CATEGORY has values not in EntityCategory: {invalid}. " + "Update src/anonymizer/engine/schemas/rewrite.py to drop these or remap them." + ) diff --git a/tests/engine/test_schemas.py b/tests/engine/test_schemas.py index 3ceada56..8e0a7df8 100644 --- a/tests/engine/test_schemas.py +++ b/tests/engine/test_schemas.py @@ -381,12 +381,16 @@ def test_qa_compare_results_use_integer_ids() -> None: # Context-validated answer coverage -def test_quality_answers_reject_missing_ids_with_context() -> None: - with pytest.raises(ValidationError, match="Missing answer IDs"): - QualityAnswersSchema.model_validate( - {"answers": [{"id": 1, "answer": "yes"}]}, - context={"expected_ids": [1, 2]}, - ) +def test_quality_answers_pad_missing_ids_with_context() -> None: + """Post-#130: missing ids are padded with placeholder ``"missing"`` instead + of raising. Coverage is now best-effort (small-model drift would otherwise + drop the row); the after-validator logs a warning if normalization fails.""" + result = QualityAnswersSchema.model_validate( + {"answers": [{"id": 1, "answer": "yes"}]}, + context={"expected_ids": [1, 2]}, + ) + by_id = {a.id: a.answer for a in result.answers} + assert by_id == {1: "yes", 2: "missing"} def test_quality_answers_accept_complete_with_context() -> None: @@ -402,36 +406,49 @@ def test_quality_answers_no_enforcement_without_context() -> None: assert len(result.answers) == 1 -def test_privacy_answers_reject_missing_ids_with_context() -> None: - with pytest.raises(ValidationError, match="Missing answer IDs"): - PrivacyAnswersSchema.model_validate( - {"answers": [{"id": 1, "answer": "no", "confidence": 0.0, "reason": "not inferable"}]}, - context={"expected_ids": [1, 2]}, - ) +def test_privacy_answers_pad_missing_ids_pessimistically() -> None: + """Post-#130: missing ids get ``answer="yes"`` (pessimistic — flag for + human review) rather than raising. Padded rows are visible to downstream + QA so they do not silently bypass leak detection.""" + result = PrivacyAnswersSchema.model_validate( + {"answers": [{"id": 1, "answer": "no", "confidence": 0.0, "reason": "not inferable"}]}, + context={"expected_ids": [1, 2]}, + ) + by_id = {a.id: a for a in result.answers} + assert by_id[1].answer == "no" + assert by_id[2].answer == "yes" # pessimistic default + assert "missing" in by_id[2].reason.lower() -def test_qa_compare_reject_missing_ids_with_context() -> None: - with pytest.raises(ValidationError, match="Missing compare IDs"): - QACompareResultsSchema.model_validate( - {"per_item": [{"id": 1, "score": 0.9}]}, - context={"expected_ids": [1, 2]}, - ) +def test_qa_compare_pad_missing_ids_with_neutral_score() -> None: + """Post-#130: missing ids get a neutral 0.5 score rather than raising.""" + result = QACompareResultsSchema.model_validate( + {"per_item": [{"id": 1, "score": 0.9}]}, + context={"expected_ids": [1, 2]}, + ) + by_id = {item.id: item.score for item in result.per_item} + assert by_id == {1: 0.9, 2: 0.5} -def test_quality_answers_reject_duplicate_ids() -> None: - with pytest.raises(ValidationError, match="Duplicate answer IDs"): - QualityAnswersSchema.model_validate( - {"answers": [{"id": 1, "answer": "yes"}, {"id": 1, "answer": "no"}, {"id": 2, "answer": "yes"}]}, - context={"expected_ids": [1, 2]}, - ) +def test_quality_answers_dedupe_duplicate_ids_first_wins() -> None: + """Post-#130: duplicate ids are de-duplicated (first occurrence wins) + instead of raising. Drifted small-model output that emits the same id + twice no longer drops the row.""" + result = QualityAnswersSchema.model_validate( + {"answers": [{"id": 1, "answer": "yes"}, {"id": 1, "answer": "no"}, {"id": 2, "answer": "yes"}]}, + context={"expected_ids": [1, 2]}, + ) + by_id = {a.id: a.answer for a in result.answers} + assert by_id == {1: "yes", 2: "yes"} # first-wins for id=1 -def test_quality_answers_reject_extra_ids() -> None: - with pytest.raises(ValidationError, match="Extra answer IDs"): - QualityAnswersSchema.model_validate( - {"answers": [{"id": 1, "answer": "yes"}, {"id": 2, "answer": "no"}, {"id": 99, "answer": "yes"}]}, - context={"expected_ids": [1, 2]}, - ) +def test_quality_answers_drop_extra_ids() -> None: + """Post-#130: ids not in ``expected_ids`` are dropped instead of raising.""" + result = QualityAnswersSchema.model_validate( + {"answers": [{"id": 1, "answer": "yes"}, {"id": 2, "answer": "no"}, {"id": 99, "answer": "yes"}]}, + context={"expected_ids": [1, 2]}, + ) + assert {a.id for a in result.answers} == {1, 2} # --------------------------------------------------------------------------- diff --git a/tests/engine/test_sensitivity_disposition.py b/tests/engine/test_sensitivity_disposition.py index 8b8a77ff..a152d245 100644 --- a/tests/engine/test_sensitivity_disposition.py +++ b/tests/engine/test_sensitivity_disposition.py @@ -3,7 +3,7 @@ from __future__ import annotations -from data_designer.config.column_configs import LLMStructuredColumnConfig +from data_designer.config.column_configs import CustomColumnConfig, LLMStructuredColumnConfig from anonymizer.config.models import RewriteModelSelection from anonymizer.config.rewrite import PrivacyGoal @@ -12,6 +12,7 @@ COL_ENTITIES_BY_VALUE, COL_LATENT_ENTITIES, COL_SENSITIVITY_DISPOSITION, + COL_SIMPLE_DISPOSITION, COL_TAGGED_TEXT, _jinja, ) @@ -26,17 +27,35 @@ ) -def test_columns_uses_disposition_analyzer_alias( +def test_columns_emits_two_step_pipeline( stub_rewrite_model_selection: RewriteModelSelection, ) -> None: + """Post-#130: the disposition workflow emits a 2-step pipeline — a + ``SimpleDispositionResult`` LLM column (loose wire, dropped from preview) + + a deterministic CustomColumnConfig that reconstructs the strict + ``SensitivityDispositionSchema`` server-side.""" cols = SensitivityDispositionWorkflow().columns( selected_models=stub_rewrite_model_selection, privacy_goal=_STUB_PRIVACY_GOAL, ) - assert len(cols) == 1 - assert isinstance(cols[0], LLMStructuredColumnConfig) - assert cols[0].model_alias == stub_rewrite_model_selection.disposition_analyzer - assert cols[0].name == COL_SENSITIVITY_DISPOSITION + assert len(cols) == 2 + + llm_col, recon_col = cols + assert isinstance(llm_col, LLMStructuredColumnConfig) + assert llm_col.name == COL_SIMPLE_DISPOSITION + assert llm_col.model_alias == stub_rewrite_model_selection.disposition_analyzer + # DataDesigner serializes ``output_format`` to its JSON schema at + # construction time, so we assert on the schema's ``$defs`` rather than + # identity. Looking for ``SimpleDispositionItem`` (the loose wire item) + # is a tighter check than just "some schema is set" — it confirms we + # are passing the loose wrapper and not the strict + # ``SensitivityDispositionSchema``. + assert isinstance(llm_col.output_format, dict) + assert "SimpleDispositionItem" in llm_col.output_format.get("$defs", {}) + assert llm_col.drop is True + + assert isinstance(recon_col, CustomColumnConfig) + assert recon_col.name == COL_SENSITIVITY_DISPOSITION def test_privacy_goal_interpolated_into_prompt() -> None: diff --git a/tests/engine/test_small_model_drift.py b/tests/engine/test_small_model_drift.py index c59ef66b..4a930698 100644 --- a/tests/engine/test_small_model_drift.py +++ b/tests/engine/test_small_model_drift.py @@ -1,8 +1,8 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -"""Tests covering small-model output drift on detection schemas. +"""Tests covering small-model output drift on rewrite + detection schemas. -These regressions cover the drift modes observed during small-model +These regressions cover the drift modes observed during PR #130 small-model benchmarks (gemma4-e2b, gemma4-e4b, nemotron-3-nano:4b, qwen3.5:4b on legal court / medical visit / employee notes datasets). Each test pins one drift class so a future schema change that re-tightens the wire contract surfaces @@ -11,11 +11,192 @@ class so a future schema change that re-tightens the wire contract surfaces from __future__ import annotations +import pytest +from pydantic import ValidationError + from anonymizer.engine.schemas.detection import ( LatentEntitySchema, RawValidationDecisionSchema, ValidationDecisionSchema, ) +from anonymizer.engine.schemas.rewrite import ( + Domain, + DomainClassificationSchema, + MeaningUnitImportance, + MeaningUnitsSchema, + PrivacyAnswerItemSchema, + SimpleDispositionResult, +) + +# --------------------------------------------------------------------------- +# DomainClassificationSchema — wire-loose domain + confidence coercion +# --------------------------------------------------------------------------- + + +class TestDomainClassificationDrift: + def test_lowercase_domain_normalizes_to_canonical(self) -> None: + result = DomainClassificationSchema.model_validate({"domain": "biography_profile", "domain_confidence": 0.9}) + assert result.domain == Domain.BIOGRAPHY_PROFILE.value + + def test_substring_match_to_canonical_value(self) -> None: + result = DomainClassificationSchema.model_validate( + {"domain": "this is medical clinical content", "domain_confidence": 0.7} + ) + assert result.domain == Domain.MEDICAL_CLINICAL.value + + def test_unknown_domain_falls_back_to_other(self) -> None: + result = DomainClassificationSchema.model_validate( + {"domain": "completely-novel-bucket", "domain_confidence": 0.5} + ) + assert result.domain == Domain.OTHER.value + + def test_missing_or_blank_domain_falls_back_to_other(self) -> None: + for raw in [None, "", " ", 123]: + result = DomainClassificationSchema.model_validate({"domain": raw, "domain_confidence": 0.5}) + assert result.domain == Domain.OTHER.value + + def test_string_confidence_coerces_to_float(self) -> None: + result = DomainClassificationSchema.model_validate({"domain": "LEGAL", "domain_confidence": "0.85"}) + assert result.domain_confidence == pytest.approx(0.85) + + def test_percent_confidence_coerces_to_fractional_float(self) -> None: + result = DomainClassificationSchema.model_validate({"domain": "LEGAL", "domain_confidence": "85%"}) + assert result.domain_confidence == pytest.approx(0.85) + + def test_unparseable_confidence_falls_back_to_default(self) -> None: + result = DomainClassificationSchema.model_validate({"domain": "LEGAL", "domain_confidence": "very high"}) + assert result.domain_confidence == 0.5 + + def test_out_of_range_confidence_clamps(self) -> None: + high = DomainClassificationSchema.model_validate({"domain": "LEGAL", "domain_confidence": 1.5}) + assert high.domain_confidence == 1.0 + low = DomainClassificationSchema.model_validate({"domain": "LEGAL", "domain_confidence": -0.2}) + assert low.domain_confidence == 0.0 + + +# --------------------------------------------------------------------------- +# SimpleDispositionResult — bare-list tolerance + scalar coercion +# --------------------------------------------------------------------------- + + +class TestSimpleDispositionDrift: + def test_canonical_wrapper_validates(self) -> None: + result = SimpleDispositionResult.model_validate( + {"sensitivity_disposition": [{"id": 1, "category": "direct_identifier"}]} + ) + assert len(result.sensitivity_disposition) == 1 + + def test_bare_list_at_top_level_wraps_into_canonical(self) -> None: + """nemotron-3-nano:4b consistently emits the bare list shape on + rewrite-mode disposition; without the ``oneOf`` widening + + ``accept_bare_list`` validator this would fail DD's pre-validate.""" + result = SimpleDispositionResult.model_validate([{"id": 1, "category": "direct_identifier"}]) + assert len(result.sensitivity_disposition) == 1 + assert result.sensitivity_disposition[0].id == 1 + + def test_json_schema_widens_to_oneof_for_dd_pre_validate(self) -> None: + """DataDesigner runs ``jsonschema.validate`` BEFORE pydantic's + before-validators, so the emitted schema must accept both shapes.""" + schema = SimpleDispositionResult.model_json_schema() + assert "oneOf" in schema, schema + oneof_types = {variant.get("type") for variant in schema["oneOf"]} + assert oneof_types == {"object", "array"} + + def test_int_values_in_str_fields_coerce(self) -> None: + """gemma4-e4b observed echoing ints in entity_value when the value + is numeric (age, ssn). Loose wire coerces to str so the row + survives; reconstructor uses the trusted context anyway.""" + result = SimpleDispositionResult.model_validate( + [{"id": 1, "entity_value": 42, "entity_label": "age", "category": "quasi_identifier"}] + ) + assert result.sensitivity_disposition[0].entity_value == "42" + + def test_none_in_str_fields_coerces_to_empty(self) -> None: + result = SimpleDispositionResult.model_validate( + [{"id": 1, "entity_value": None, "entity_label": None, "category": None}] + ) + item = result.sensitivity_disposition[0] + assert item.entity_value == "" + assert item.entity_label == "" + assert item.category == "" + + def test_container_values_in_str_fields_coerce_to_empty(self) -> None: + """A model emitting a list/dict for a scalar str field must not fail the + whole item (which would discard every disposition for the row); coerce + to "" and let the reconstructor recover from trusted context.""" + result = SimpleDispositionResult.model_validate( + [{"id": 1, "entity_label": ["first", "name"], "category": {"x": 1}, "entity_value": "Alice"}] + ) + item = result.sensitivity_disposition[0] + assert item.entity_label == "" + assert item.category == "" + assert item.entity_value == "Alice" + + +# --------------------------------------------------------------------------- +# MeaningUnits — bare list, aspect normalize, importance default, id renumber +# --------------------------------------------------------------------------- + + +class TestMeaningUnitsDrift: + def test_bare_list_top_level_wraps(self) -> None: + """qwen3.5:4b on legal-court emits the bare-list shape.""" + result = MeaningUnitsSchema.model_validate([{"id": 1, "aspect": "role", "unit": "judge"}]) + assert len(result.units) == 1 + + def test_aspect_case_normalizes(self) -> None: + result = MeaningUnitsSchema.model_validate({"units": [{"id": 1, "aspect": "ROLE", "unit": "lawyer"}]}) + assert result.units[0].aspect == "role" + + def test_aspect_substring_match(self) -> None: + result = MeaningUnitsSchema.model_validate( + {"units": [{"id": 1, "aspect": "the procedural status of the case", "unit": "pending"}]} + ) + assert result.units[0].aspect == "procedural_status" + + def test_unknown_aspect_falls_through_to_empty(self) -> None: + result = MeaningUnitsSchema.model_validate({"units": [{"id": 1, "aspect": "xyz", "unit": "u"}]}) + assert result.units[0].aspect == "" + + def test_missing_importance_defaults_to_important(self) -> None: + """Pre-#130 ``MeaningUnitSchema.importance`` was a required enum; + small models that omit it would drop the row. Default to + ``important`` (the safer of the two values for downstream QA).""" + result = MeaningUnitsSchema.model_validate({"units": [{"id": 1, "aspect": "role", "unit": "lawyer"}]}) + assert result.units[0].importance == MeaningUnitImportance.important.value + + def test_drift_on_importance_normalizes(self) -> None: + result = MeaningUnitsSchema.model_validate( + {"units": [{"id": 1, "aspect": "role", "unit": "lawyer", "importance": "Critical"}]} + ) + assert result.units[0].importance == MeaningUnitImportance.critical.value + + def test_duplicate_ids_renumber_sequentially(self) -> None: + """Every-unit-has-id=1 collapse mode (model omitted ids and the + ``id=1`` default kicked in for all of them).""" + result = MeaningUnitsSchema.model_validate( + { + "units": [ + {"aspect": "role", "unit": "a"}, + {"aspect": "role", "unit": "b"}, + {"aspect": "role", "unit": "c"}, + ] + } + ) + assert [u.id for u in result.units] == [1, 2, 3] + + def test_explicit_unique_ids_preserved(self) -> None: + result = MeaningUnitsSchema.model_validate( + {"units": [{"id": 5, "aspect": "role", "unit": "a"}, {"id": 7, "aspect": "role", "unit": "b"}]} + ) + assert [u.id for u in result.units] == [5, 7] + + def test_empty_list_does_not_drop_record(self) -> None: + """Outer min_length=1 was relaxed to default_factory=list — empty + units should validate so downstream can decide what to do.""" + result = MeaningUnitsSchema.model_validate({"units": []}) + assert result.units == [] + # --------------------------------------------------------------------------- # RawValidationDecisionSchema — chunked-validation drift @@ -107,7 +288,7 @@ def test_overlong_rationale_truncates(self) -> None: assert result.rationale.endswith("...") def test_empty_required_fields_default_to_empty_string(self) -> None: - """Pre-loosening these were required ``min_length=1`` and would drop + """Pre-#130 these were required ``min_length=1`` and would drop the row; loose wire allows empty so the parquet-pad sentinel path can build a placeholder row when needed.""" result = LatentEntitySchema() @@ -132,3 +313,63 @@ def test_sensitive_category_drift_normalizes_to_latent_identifier(self) -> None: def test_unknown_category_drift_normalizes_to_latent_identifier(self) -> None: result = LatentEntitySchema.model_validate({"label": "x", "value": "y", "category": "some-novel-bucket"}) assert result.category == "latent_identifier" + + +# --------------------------------------------------------------------------- +# PrivacyAnswerItemSchema — reason field coercion +# --------------------------------------------------------------------------- + + +class TestPrivacyAnswerReasonCoercion: + def _base(self, **overrides: object) -> dict: + return { + "id": 1, + "answer": "no", + "confidence": 0.9, + "reason": "no leak observed in rewrite", + **overrides, + } + + def test_overlong_reason_truncates(self) -> None: + long = "x" * 400 + result = PrivacyAnswerItemSchema.model_validate(self._base(reason=long)) + assert len(result.reason) <= 200 + assert result.reason.endswith("...") + + def test_none_reason_replaced_with_placeholder(self) -> None: + result = PrivacyAnswerItemSchema.model_validate(self._base(reason=None)) + assert "no reason provided" in result.reason + + def test_blank_reason_replaced_with_placeholder(self) -> None: + result = PrivacyAnswerItemSchema.model_validate(self._base(reason=" ")) + assert "no reason provided" in result.reason + + def test_short_valid_reason_passthrough(self) -> None: + result = PrivacyAnswerItemSchema.model_validate(self._base(reason="ok")) + assert result.reason == "ok" + + +# --------------------------------------------------------------------------- +# Catch the regression: the strict wire contract used to drop these rows +# --------------------------------------------------------------------------- + + +def test_simple_disposition_does_not_validate_under_strict_disposition() -> None: + """Sanity check: wire-loose ``SimpleDispositionResult`` accepts inputs + that the strict ``SensitivityDispositionSchema`` would reject. This + is the whole point of PR #130's two-step pipeline; if a future schema + change makes them equivalent we want this test to surface that.""" + from anonymizer.engine.schemas.rewrite import SensitivityDispositionSchema + + drifted = [ + { + "id": 1, + "category": "DIRECT IDENTIFIER", # display-variant drift + "sensitivity": "HIGH", + "protection_method_suggestion": "Replace_With_Surrogate", + # missing entity_label, entity_value, protection_reason, combined_risk_level + } + ] + SimpleDispositionResult.model_validate(drifted) # OK + with pytest.raises(ValidationError): + SensitivityDispositionSchema.model_validate({"sensitivity_disposition": drifted})