Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,65 +1,83 @@
from typing import Any
from typing import Any, Dict, List, Optional

from ldai import log
from ldai.providers import AgentResult, AgentRunner
from ldai.providers.types import LDAIMetrics
from ldai.providers.types import LDAIMetrics, RunnerResult

from ldai_langchain.langchain_helper import (
extract_last_message_content,
get_tool_calls_from_response,
sum_token_usage_from_messages,
)


class LangChainAgentRunner(AgentRunner):
class LangChainAgentRunner:
"""
CAUTION:
This feature is experimental and should NOT be considered ready for production use.
It may change or be removed without notice and is not subject to backwards
compatibility guarantees.

AgentRunner implementation for LangChain.
Runner implementation for a single LangChain agent.

Wraps a compiled LangChain agent graph (from ``langchain.agents.create_agent``)
and delegates execution to it. Tool calling and loop management are handled
internally by the graph.
Returned by LangChainRunnerFactory.create_agent(config, tools).
Returned by ``LangChainRunnerFactory.create_agent(config, tools)``.

Implements the unified :class:`~ldai.providers.runner.Runner` protocol.
"""

def __init__(self, agent: Any):
self._agent = agent

async def run(self, input: Any) -> AgentResult:
async def run(
self,
input: Any,
output_type: Optional[Dict[str, Any]] = None,
) -> RunnerResult:
"""
Run the agent with the given input string.
Run the agent with the given input.

Delegates to the compiled LangChain agent, which handles
the tool-calling loop internally.

:param input: The user prompt or input to the agent
:return: AgentResult with output, raw response, and aggregated metrics
:param output_type: Reserved for future structured output support;
currently ignored.
:return: :class:`RunnerResult` with ``content``, ``raw`` response, and
metrics including aggregated token usage and observed ``tool_calls``.
"""
try:
result = await self._agent.ainvoke({
"messages": [{"role": "user", "content": str(input)}]
})
messages = result.get("messages", [])
output = extract_last_message_content(messages)
return AgentResult(
output=output,
raw=result,
messages: List[Any] = result.get("messages", [])
content = extract_last_message_content(messages)
tool_calls = self._extract_tool_calls(messages)
return RunnerResult(
content=content,
metrics=LDAIMetrics(
success=True,
usage=sum_token_usage_from_messages(messages),
tool_calls=tool_calls if tool_calls else None,
),
raw=result,
)
except Exception as error:
log.warning(f"LangChain agent run failed: {error}")
return AgentResult(
output="",
raw=None,
return RunnerResult(
content="",
metrics=LDAIMetrics(success=False, usage=None),
)

@staticmethod
def _extract_tool_calls(messages: List[Any]) -> List[str]:
"""Collect tool call names from all messages in the agent output."""
names: List[str] = []
for msg in messages:
names.extend(get_tool_calls_from_response(msg))
return names

def get_agent(self) -> Any:
"""Return the underlying compiled LangChain agent."""
return self._agent
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import BaseMessage
from ldai import LDMessage, log
from ldai.providers.model_runner import ModelRunner
from ldai.providers.types import LDAIMetrics, ModelResponse, StructuredResponse
from ldai.providers.types import LDAIMetrics, RunnerResult

from ldai_langchain.langchain_helper import (
convert_messages_to_langchain,
Expand All @@ -13,12 +12,15 @@
)


class LangChainModelRunner(ModelRunner):
class LangChainModelRunner:
"""
ModelRunner implementation for LangChain.
Runner implementation for LangChain chat models.

Holds a fully-configured BaseChatModel.
Returned by LangChainConnector.create_model(config).
Returned by ``LangChainRunnerFactory.create_model(config)``.

Implements the unified :class:`~ldai.providers.runner.Runner` protocol via
:meth:`run`.
"""

def __init__(self, llm: BaseChatModel):
Expand All @@ -32,13 +34,37 @@ def get_llm(self) -> BaseChatModel:
"""
return self._llm

async def invoke_model(self, messages: List[LDMessage]) -> ModelResponse:
async def run(
self,
input: Any,
output_type: Optional[Dict[str, Any]] = None,
) -> RunnerResult:
"""
Invoke the LangChain model with an array of messages.

:param messages: Array of LDMessage objects representing the conversation
:return: ModelResponse containing the model's response and metrics
Run the LangChain model with the given input.

:param input: A string prompt or a list of :class:`LDMessage` objects
:param output_type: Optional JSON schema dict requesting structured output.
When provided, ``parsed`` on the returned :class:`RunnerResult` is
populated with the structured data.
:return: :class:`RunnerResult` containing ``content``, ``metrics``,
``raw`` and (when ``output_type`` is set) ``parsed``.
"""
messages = self._coerce_input(input)
if output_type is not None:
return await self._run_structured(messages, output_type)
return await self._run_completion(messages)

@staticmethod
def _coerce_input(input: Any) -> List[LDMessage]:
if isinstance(input, str):
return [LDMessage(role='user', content=input)]
if isinstance(input, list):
return input
raise TypeError(
f"Unsupported input type for LangChainModelRunner.run: {type(input).__name__}"
)

async def _run_completion(self, messages: List[LDMessage]) -> RunnerResult:
try:
langchain_messages = convert_messages_to_langchain(messages)
response: BaseMessage = await self._llm.ainvoke(langchain_messages)
Expand All @@ -52,58 +78,58 @@ async def invoke_model(self, messages: List[LDMessage]) -> ModelResponse:
f'Multimodal response not supported, expecting a string. '
f'Content type: {type(response.content)}, Content: {response.content}'
)
metrics = LDAIMetrics(success=False, usage=metrics.usage)
return RunnerResult(
content='',
metrics=LDAIMetrics(success=False, usage=metrics.usage),
raw=response,
)

return ModelResponse(
message=LDMessage(role='assistant', content=content),
metrics=metrics,
)
return RunnerResult(content=content, metrics=metrics, raw=response)
except Exception as error:
log.warning(f'LangChain model invocation failed: {error}')
return ModelResponse(
message=LDMessage(role='assistant', content=''),
return RunnerResult(
content='',
metrics=LDAIMetrics(success=False, usage=None),
)

async def invoke_structured_model(
self,
messages: List[LDMessage],
response_structure: Dict[str, Any],
) -> StructuredResponse:
"""
Invoke the LangChain model with structured output support.

:param messages: Array of LDMessage objects representing the conversation
:param response_structure: Dictionary defining the output structure
:return: StructuredResponse containing the structured data
"""
structured_response = StructuredResponse(
data={},
raw_response='',
metrics=LDAIMetrics(success=False, usage=None),
)
async def _run_structured(
self, messages: List[LDMessage], response_structure: Dict[str, Any]
) -> RunnerResult:
try:
langchain_messages = convert_messages_to_langchain(messages)
structured_llm = self._llm.with_structured_output(response_structure, include_raw=True)
response = await structured_llm.ainvoke(langchain_messages)

if not isinstance(response, dict):
log.warning(f'Structured output did not return a dict. Got: {type(response)}')
return structured_response
return RunnerResult(
content='',
metrics=LDAIMetrics(success=False, usage=None),
)

raw_response = response.get('raw')
if raw_response is not None:
if hasattr(raw_response, 'content'):
structured_response.raw_response = raw_response.content
structured_response.metrics.usage = get_ai_usage_from_response(raw_response)
usage = get_ai_usage_from_response(raw_response) if raw_response is not None else None
raw_content = raw_response.content if raw_response is not None and hasattr(raw_response, 'content') else ''

if response.get('parsing_error'):
log.warning('LangChain structured model invocation had a parsing error')
return structured_response
return RunnerResult(
content=raw_content,
metrics=LDAIMetrics(success=False, usage=usage),
raw=raw_response,
)

structured_response.metrics.success = True
structured_response.data = response.get('parsed') or {}
return structured_response
parsed = response.get('parsed') or {}
return RunnerResult(
content=raw_content,
metrics=LDAIMetrics(success=True, usage=usage),
raw=raw_response,
parsed=parsed,
)
except Exception as error:
log.warning(f'LangChain structured model invocation failed: {error}')
return structured_response
return RunnerResult(
content='',
metrics=LDAIMetrics(success=False, usage=None),
)

Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,10 @@ async def run(self, input: Any) -> AgentGraphResult:
messages = result.get('messages', [])
output = extract_last_message_content(messages)

# Flush per-node metrics to LD trackers
all_eval_results = await handler.flush(self._graph, pending_eval_tasks)
# Flush per-node metrics to LD trackers; eval results are tracked
# internally and intentionally not exposed on AgentGraphResult here
# — judge dispatch is the managed layer's responsibility.
await handler.flush(self._graph, pending_eval_tasks)

tracker.track_path(handler.path)
tracker.track_duration(duration)
Expand All @@ -341,7 +343,6 @@ async def run(self, input: Any) -> AgentGraphResult:
output=output,
raw=result,
metrics=LDAIMetrics(success=True),
evaluations=all_eval_results,
)

except Exception as exc:
Expand Down
Loading
Loading