Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

from ldai import log
from ldai.agent_graph import AgentGraphDefinition, AgentGraphNode
from ldai.providers import AgentGraphResult, AgentGraphRunner, ToolRegistry
from ldai.providers.types import LDAIMetrics
from ldai.providers import AgentGraphRunner, ToolRegistry
from ldai.providers.types import AgentGraphRunnerResult, GraphMetrics, LDAIMetrics
from ldai.tracker import TokenUsage

from ldai_openai.openai_helper import (
Expand All @@ -22,6 +22,34 @@ def _sanitize_agent_name(key: str) -> str:
return re.sub(r'[^a-zA-Z0-9_]', '_', key)


class _NodeMetricsAccumulator:
"""Mutable per-node metrics collected during a run (replaces LDAIConfigTracker)."""

def __init__(self) -> None:
self.usage: Optional[TokenUsage] = None
self.duration_ms: Optional[int] = None
self.tool_calls: List[str] = []
self.success: bool = True

def set_usage(self, usage: Optional[TokenUsage]) -> None:
if usage is not None:
self.usage = usage

def set_duration_ms(self, duration_ms: int) -> None:
self.duration_ms = duration_ms

def add_tool_call(self, tool_name: str) -> None:
self.tool_calls.append(tool_name)

def to_ldai_metrics(self) -> LDAIMetrics:
return LDAIMetrics(
success=self.success,
usage=self.usage,
duration_ms=self.duration_ms,
tool_calls=self.tool_calls if self.tool_calls else None,
)


class _RunState:
"""Mutable state shared across handoff and tool callbacks during a single run."""

Expand All @@ -39,9 +67,10 @@ class OpenAIAgentGraphRunner(AgentGraphRunner):

AgentGraphRunner implementation for the OpenAI Agents SDK.

Runs the agent graph with the OpenAI Agents SDK and automatically records
graph- and node-level AI metric data to the LaunchDarkly trackers on the
graph definition and each node.
Runs the agent graph with the OpenAI Agents SDK and collects graph- and
node-level metrics. Tracking events are emitted by the managed layer
(:class:`~ldai.ManagedAgentGraph`) from the returned
:class:`~ldai.providers.types.AgentGraphRunnerResult`.

Requires ``openai-agents`` to be installed.
"""
Expand All @@ -61,20 +90,19 @@ def __init__(
self._tools = tools
self._agent_name_map: Dict[str, str] = {}
self._tool_name_map: Dict[str, str] = {}
self._node_trackers: Dict[str, Any] = {}
self._node_accumulators: Dict[str, _NodeMetricsAccumulator] = {}

async def run(self, input: Any) -> AgentGraphResult:
async def run(self, input: Any) -> AgentGraphRunnerResult:
"""
Run the agent graph with the given input.

Builds the agent tree via reverse_traverse, then invokes the root
agent with Runner.run(). Tracks path, latency, and invocation
success/failure.
agent with Runner.run(). Collects path, latency, and per-node metrics.
Graph-level tracking events are emitted by the managed layer.

:param input: The string prompt to send to the agent graph
:return: AgentGraphResult with the final output and metrics
:return: AgentGraphRunnerResult with the final content and GraphMetrics
"""
tracker = self._graph.create_tracker()
path: List[str] = []
root_node = self._graph.root()
root_key = root_node.get_key() if root_node else ''
Expand All @@ -86,24 +114,29 @@ async def run(self, input: Any) -> AgentGraphResult:
state = _RunState(last_handoff_ns=start_ns, last_node_key=root_key)
try:
from agents import Runner
root_agent = self._build_agents(path, state, tracker)
root_agent = self._build_agents(path, state)
result = await Runner.run(root_agent, input_str)
self._flush_final_segment(state, result)
self._track_tool_calls(result)
self._collect_tool_calls(result)

duration = (time.perf_counter_ns() - start_ns) // 1_000_000
duration_ms = (time.perf_counter_ns() - start_ns) // 1_000_000
token_usage = get_ai_usage_from_response(result)

tracker.track_path(path)
tracker.track_duration(duration)
tracker.track_invocation_success()
if token_usage is not None:
tracker.track_total_tokens(token_usage)
node_metrics = {
key: acc.to_ldai_metrics()
for key, acc in self._node_accumulators.items()
}

return AgentGraphResult(
output=str(result.final_output),
return AgentGraphRunnerResult(
content=str(result.final_output),
raw=result,
metrics=LDAIMetrics(success=True, usage=token_usage),
metrics=GraphMetrics(
success=True,
path=path,
duration_ms=duration_ms,
usage=token_usage,
node_metrics=node_metrics,
),
)
except Exception as exc:
if isinstance(exc, ImportError):
Expand All @@ -113,17 +146,19 @@ async def run(self, input: Any) -> AgentGraphResult:
)
else:
log.warning(f'OpenAIAgentGraphRunner run failed: {exc}')
duration = (time.perf_counter_ns() - start_ns) // 1_000_000
tracker.track_duration(duration)
tracker.track_invocation_failure()
return AgentGraphResult(
output='',
duration_ms = (time.perf_counter_ns() - start_ns) // 1_000_000
return AgentGraphRunnerResult(
content='',
raw=None,
metrics=LDAIMetrics(success=False),
metrics=GraphMetrics(
success=False,
path=path,
duration_ms=duration_ms,
),
)

def _build_agents(
self, path: List[str], state: _RunState, tracker: Any
self, path: List[str], state: _RunState
) -> Any:
"""
Build the agent tree from the graph definition via reverse_traverse.
Expand All @@ -133,7 +168,6 @@ def _build_agents(

:param path: Mutable list to accumulate the execution path
:param state: Shared run state for tracking handoff timing and last node
:param tracker: Graph-level tracker shared across the entire run
:return: The root Agent instance
"""
try:
Expand All @@ -151,12 +185,12 @@ def _build_agents(

name_map: Dict[str, str] = {}
tool_name_map: Dict[str, str] = {}
node_trackers: Dict[str, Any] = {}
node_accumulators: Dict[str, _NodeMetricsAccumulator] = {}

def build_node(node: AgentGraphNode, ctx: dict) -> Any:
node_config = node.get_config()
config_tracker = node_config.create_tracker()
node_trackers[node_config.key] = config_tracker
acc = _NodeMetricsAccumulator()
node_accumulators[node_config.key] = acc
model = node_config.model

if not model:
Expand All @@ -177,8 +211,7 @@ def build_node(node: AgentGraphNode, ctx: dict) -> Any:
node_config.key,
target_key,
path,
tracker,
config_tracker,
acc,
state,
),
)
Expand Down Expand Up @@ -212,20 +245,19 @@ def build_node(node: AgentGraphNode, ctx: dict) -> Any:
root = self._graph.reverse_traverse(fn=build_node)
self._agent_name_map = name_map
self._tool_name_map = tool_name_map
self._node_trackers = node_trackers
self._node_accumulators = node_accumulators
return root

def _make_on_handoff(
self,
src: str,
tgt: str,
path: List[str],
tracker: Any,
config_tracker: Any,
acc: _NodeMetricsAccumulator,
state: _RunState,
):
def on_handoff(run_ctx: Any) -> None:
self._handle_handoff(run_ctx, src, tgt, path, tracker, config_tracker, state)
self._handle_handoff(run_ctx, src, tgt, path, acc, state)
return on_handoff

def _handle_handoff(
Expand All @@ -234,13 +266,11 @@ def _handle_handoff(
src: str,
tgt: str,
path: List[str],
tracker: Any,
config_tracker: Any,
acc: _NodeMetricsAccumulator,
state: _RunState,
) -> None:
path.append(tgt)
state.last_node_key = tgt
tracker.track_handoff_success(src, tgt)

now_ns = time.perf_counter_ns()
duration_ms = (now_ns - state.last_handoff_ns) // 1_000_000
Expand All @@ -254,19 +284,15 @@ def _handle_handoff(
except Exception:
pass

if config_tracker is not None:
if usage is not None:
config_tracker.track_tokens(usage)
if duration_ms is not None:
config_tracker.track_duration(int(duration_ms))
config_tracker.track_success()
acc.set_usage(usage)
acc.set_duration_ms(int(duration_ms))

def _flush_final_segment(self, state: _RunState, result: Any) -> None:
"""Record duration/tokens for the last active agent (no handoff after it)."""
if not state.last_node_key:
return
config_tracker = self._node_trackers.get(state.last_node_key)
if config_tracker is None:
acc = self._node_accumulators.get(state.last_node_key)
if acc is None:
return

now_ns = time.perf_counter_ns()
Expand All @@ -280,18 +306,16 @@ def _flush_final_segment(self, state: _RunState, result: Any) -> None:
except Exception:
pass

if usage is not None:
config_tracker.track_tokens(usage)
config_tracker.track_duration(int(duration_ms))
config_tracker.track_success()
acc.set_usage(usage)
acc.set_duration_ms(int(duration_ms))

def _track_tool_calls(self, result: Any) -> None:
"""Track all tool calls from the run result, attributed to the node that called them."""
def _collect_tool_calls(self, result: Any) -> None:
"""Collect all tool calls from the run result, attributed to the node that called them."""
for agent_name, tool_fn_name in get_tool_calls_from_run_items(result.new_items):
agent_key = self._agent_name_map.get(agent_name, agent_name)
tool_name = self._tool_name_map.get(tool_fn_name)
if tool_name is None:
continue
config_tracker = self._node_trackers.get(agent_key)
if config_tracker is not None:
config_tracker.track_tool_call(tool_name)
acc = self._node_accumulators.get(agent_key)
if acc is not None:
acc.add_tool_call(tool_name)
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,17 @@

from ldai.agent_graph import AgentGraphDefinition
from ldai.models import AIAgentGraphConfig, AIAgentConfig, Edge, ModelConfig, ProviderConfig
from ldai.providers import AgentGraphResult, ToolRegistry
from ldai.providers import ToolRegistry
from ldai.providers.types import AgentGraphRunnerResult, GraphMetrics
from ldai_openai.openai_agent_graph_runner import OpenAIAgentGraphRunner
from ldai_openai.openai_runner_factory import OpenAIRunnerFactory
from ldai.evaluator import Evaluator


def _make_graph(enabled: bool = True) -> AgentGraphDefinition:
"""Build a minimal single-node AgentGraphDefinition for testing."""
node_tracker = MagicMock()
graph_tracker = MagicMock()
node_factory = MagicMock(return_value=node_tracker)
graph_factory = MagicMock(return_value=graph_tracker)
node_factory = MagicMock()
graph_factory = MagicMock()
root_config = AIAgentConfig(
key='root-agent',
enabled=enabled,
Expand Down Expand Up @@ -73,41 +72,44 @@ def test_openai_agent_graph_runner_stores_graph_and_tools():

@pytest.mark.asyncio
async def test_openai_agent_graph_runner_run_raises_when_agents_not_installed():
"""Import failure returns AgentGraphRunnerResult with success=False."""
graph = _make_graph()
runner = OpenAIAgentGraphRunner(graph, {})

with patch.dict('sys.modules', {'agents': None}):
# The import inside run() will fail — runner should return failure result
# rather than propagate the ImportError, since it's caught by the except block
result = await runner.run("test input")
assert isinstance(result, AgentGraphResult)
assert isinstance(result, AgentGraphRunnerResult)
assert result.metrics.success is False


@pytest.mark.asyncio
async def test_openai_agent_graph_runner_run_tracks_invocation_failure_on_exception():
async def test_openai_agent_graph_runner_run_failure_returns_metrics():
"""On import failure, returned GraphMetrics has success=False (no tracker needed)."""
graph = _make_graph()
tracker = graph.create_tracker.return_value
runner = OpenAIAgentGraphRunner(graph, {})

with patch.dict('sys.modules', {'agents': None}):
result = await runner.run("fail")

assert isinstance(result, AgentGraphRunnerResult)
assert result.metrics.success is False
tracker.track_invocation_failure.assert_called_once()
tracker.track_duration.assert_called_once()
assert result.metrics.duration_ms is not None
# Runner no longer calls graph tracker — graph.create_tracker should NOT be called
graph.create_tracker.assert_not_called()


@pytest.mark.asyncio
async def test_openai_agent_graph_runner_run_success():
"""Successful run returns AgentGraphRunnerResult with populated GraphMetrics."""
graph = _make_graph()
tracker = graph.create_tracker.return_value

mock_result = MagicMock()
mock_result.final_output = "agent answer"
mock_result.context_wrapper.usage.total_tokens = 0
mock_result.context_wrapper.usage.input_tokens = 0
mock_result.context_wrapper.usage.output_tokens = 0
mock_result.new_items = []
mock_result.context_wrapper.usage.total_tokens = 10
mock_result.context_wrapper.usage.input_tokens = 5
mock_result.context_wrapper.usage.output_tokens = 5
mock_result.context_wrapper.usage.request_usage_entries = []

mock_runner_module = MagicMock()
mock_runner_module.run = AsyncMock(return_value=mock_result)
Expand Down Expand Up @@ -135,28 +137,19 @@ async def test_openai_agent_graph_runner_run_success():
runner = OpenAIAgentGraphRunner(graph, {})
result = await runner.run("find restaurants")

assert isinstance(result, AgentGraphResult)
assert result.output == "agent answer"
assert isinstance(result, AgentGraphRunnerResult)
assert result.content == "agent answer"
assert isinstance(result.metrics, GraphMetrics)
assert result.metrics.success is True
tracker.track_invocation_success.assert_called_once()
tracker.track_path.assert_called_once()
tracker.track_duration.assert_called_once()
assert result.metrics.duration_ms is not None
assert 'root-agent' in result.metrics.path

# The runner caches one tracker per node — verify it is the same instance
# returned by create_tracker() and that all tracking calls hit it.
node_factory = graph.get_node('root-agent').get_config().create_tracker

# The runner caches one tracker per node — verify it is the same instance
# returned by create_tracker and that all tracking calls hit it.
cached = runner._node_trackers['root-agent']
assert cached is node_factory.return_value
cached.track_duration.assert_called_once()
cached.track_tokens.assert_called_once()
cached.track_success.assert_called_once()
# Runner no longer creates or calls the graph tracker
graph.create_tracker.assert_not_called()

# Graph-level create_tracker is called exactly once per run (not twice)
# so that handoff callbacks and run() share the same tracker instance.
graph.create_tracker.assert_called_once()
# Runner no longer creates per-node LDAIConfigTracker instances
node_factory = graph.get_node('root-agent').get_config().create_tracker
node_factory.assert_not_called()

# Node-level create_tracker is called exactly once per node.
node_factory.assert_called_once()
# Runner accumulates per-node metrics in _node_accumulators
assert 'root-agent' in runner._node_accumulators
Loading
Loading