Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions backend/app/services/llm/guardrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,20 @@ def run_guardrails_validation(
project_id: int | None,
organization_id: int | None,
suppress_pass_logs: bool = True,
output_text: str | None = None,
) -> dict[str, Any]:
"""
Call the Kaapi guardrails service to validate and process input text.

Args:
input_text: Text to validate and process.
input_text: User query text, maps to payload["input"].
guardrail_config: List of validator configurations to apply.
job_id: Unique identifier for the request.
project_id: Project identifier expected by guardrails API.
organization_id: Organization identifier expected by guardrails API.
suppress_pass_logs: Whether to suppress successful validation logs in guardrails service.
output_text: LLM response text, maps to payload["output"]. Required for validators
that evaluate input/output pairs.

Returns:
JSON response from the guardrails service with validation results.
Expand All @@ -39,14 +42,17 @@ def run_guardrails_validation(
for validator in guardrail_config
]

payload = {
payload: dict[str, Any] = {
"request_id": str(job_id),
"project_id": project_id,
"organization_id": organization_id,
"input": input_text,
"validators": validators,
}

if output_text is not None:
payload["output"] = output_text

headers = {
"accept": "application/json",
"Authorization": f"Bearer {settings.KAAPI_GUARDRAILS_AUTH}",
Expand Down
7 changes: 5 additions & 2 deletions backend/app/services/llm/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,7 @@ def apply_output_guardrails(
job_id: UUID,
project_id: int,
organization_id: int,
input_text: str | None = None,
) -> tuple[BlockResult, str | None]:
"""Apply output guardrails from a config_blob. Shared by /llm/call and /llm/chain.

Expand All @@ -451,14 +452,15 @@ def apply_output_guardrails(
if not output_guardrails:
return result, None

output_text = result.response.response.output.content.value
llm_output = result.response.response.output.content.value
safe = run_guardrails_validation(
output_text,
input_text or "",
output_guardrails,
job_id,
project_id,
organization_id,
suppress_pass_logs=True,
output_text=llm_output,
)

logger.info(
Expand Down Expand Up @@ -956,6 +958,7 @@ def execute_llm_call(
job_id=job_id,
project_id=project_id,
organization_id=organization_id,
input_text=original_input_value,
)
if output_error:
out_guard_span.set_status(
Expand Down
48 changes: 48 additions & 0 deletions backend/app/tests/services/llm/test_guardrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,54 @@ def test_run_guardrails_validation_serializes_validator_models(mock_client_cls)
assert kwargs["json"]["validators"] == [{"validator_config_id": str(vid)}]


@patch("app.services.llm.guardrails.httpx.Client")
def test_run_guardrails_validation_includes_output_in_payload(mock_client_cls) -> None:
mock_response = MagicMock()
mock_response.raise_for_status.return_value = None
mock_response.json.return_value = {"success": True}

mock_client = MagicMock()
mock_client.post.return_value = mock_response
mock_client_cls.return_value.__enter__.return_value = mock_client

run_guardrails_validation(
TEST_TEXT,
TEST_CONFIG,
TEST_JOB_ID,
TEST_PROJECT_ID,
TEST_ORGANIZATION_ID,
output_text="some llm response",
)

_, kwargs = mock_client.post.call_args
assert kwargs["json"]["input"] == TEST_TEXT
assert kwargs["json"]["output"] == "some llm response"


@patch("app.services.llm.guardrails.httpx.Client")
def test_run_guardrails_validation_omits_output_from_payload_by_default(
mock_client_cls,
) -> None:
mock_response = MagicMock()
mock_response.raise_for_status.return_value = None
mock_response.json.return_value = {"success": True}

mock_client = MagicMock()
mock_client.post.return_value = mock_response
mock_client_cls.return_value.__enter__.return_value = mock_client

run_guardrails_validation(
TEST_TEXT,
TEST_CONFIG,
TEST_JOB_ID,
TEST_PROJECT_ID,
TEST_ORGANIZATION_ID,
)

_, kwargs = mock_client.post.call_args
assert "output" not in kwargs["json"]


@patch("app.services.llm.guardrails.httpx.Client")
def test_run_guardrails_validation_allows_disable_suppress_pass_logs(
mock_client_cls,
Expand Down
97 changes: 97 additions & 0 deletions backend/app/tests/services/llm/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,6 +995,103 @@ def test_guardrails_sanitize_output_after_provider(

assert "REDACTED" in result["data"]["response"]["output"]["content"]["value"]

def test_guardrails_output_validation_sends_input_output_pair(
self, db, job_env, job_for_execution
):
env = job_env
user_query = "What is my Aadhar number?"
llm_output = "Your Aadhar number is 1234-5678-9012"

env["mock_llm_response"].response.output.content.value = llm_output
env["provider"].execute.return_value = (env["mock_llm_response"], None)

with (
patch("app.services.llm.jobs.run_guardrails_validation") as mock_guardrails,
patch("app.services.llm.jobs.list_validators_config") as mock_fetch_configs,
):
mock_guardrails.return_value = {
"success": True,
"bypassed": False,
"data": {
"safe_text": "Your Aadhar number is [REDACTED]",
"rephrase_needed": False,
},
}
mock_fetch_configs.return_value = (
[],
[{"type": "pii_remover", "stage": "output"}],
)

request_data = {
"query": {"input": user_query},
"config": {
"blob": {
"completion": {
"provider": "openai-native",
"type": "text",
"params": {"model": "gpt-4o"},
},
"input_guardrails": [],
"output_guardrails": [
{"validator_config_id": VALIDATOR_CONFIG_ID_2}
],
}
},
}
self._execute_job(job_for_execution, db, request_data)

mock_guardrails.assert_called_once()
_, kwargs = mock_guardrails.call_args
assert kwargs.get("output_text") == llm_output
assert mock_guardrails.call_args[0][0] == user_query

def test_guardrails_bypass_does_not_modify_output(
self, db, job_env, job_for_execution
):
env = job_env
original_llm_output = "Your Aadhar number is 1234-5678-9012"

env["mock_llm_response"].response.output.content.value = original_llm_output
env["provider"].execute.return_value = (env["mock_llm_response"], None)

with (
patch("app.services.llm.jobs.run_guardrails_validation") as mock_guardrails,
patch("app.services.llm.jobs.list_validators_config") as mock_fetch_configs,
):
mock_guardrails.return_value = {
"success": False,
"bypassed": True,
"data": {"safe_text": original_llm_output, "rephrase_needed": False},
}
mock_fetch_configs.return_value = (
[],
[{"type": "pii_remover", "stage": "output"}],
)

request_data = {
"query": {"input": "some question"},
"config": {
"blob": {
"completion": {
"provider": "openai-native",
"type": "text",
"params": {"model": "gpt-4o"},
},
"input_guardrails": [],
"output_guardrails": [
{"validator_config_id": VALIDATOR_CONFIG_ID_2}
],
}
},
}
result = self._execute_job(job_for_execution, db, request_data)

assert result["success"] is True
assert (
result["data"]["response"]["output"]["content"]["value"]
== original_llm_output
)

def test_guardrails_skip_output_validation_for_audio_output(
self, db, job_env, job_for_execution
):
Expand Down
Loading