diff --git a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py index 158af79c..7c7d7006 100644 --- a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py +++ b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py @@ -260,7 +260,7 @@ def route(state: WorkflowState) -> str: self._graph.traverse(fn=handle_traversal) - tracker = self._graph.get_tracker() + tracker = self._graph.create_tracker() if self._graph.create_tracker is not None else None graph_key_str = tracker.graph_key if tracker else 'unknown' log.debug( f"LangGraphAgentGraphRunner: graph='{graph_key_str}', root='{root_key}', " @@ -281,7 +281,7 @@ async def run(self, input: Any) -> AgentGraphResult: :param input: The string prompt to send to the agent graph :return: AgentGraphResult with the final output and metrics """ - tracker = self._graph.get_tracker() + tracker = self._graph.create_tracker() if self._graph.create_tracker is not None else None start_ns = time.perf_counter_ns() try: diff --git a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_callback_handler.py b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_callback_handler.py index c4b1d8c9..f239a88d 100644 --- a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_callback_handler.py +++ b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_callback_handler.py @@ -200,7 +200,7 @@ def flush(self, graph: AgentGraphDefinition) -> None: node = graph.get_node(node_key) if not node: continue - config_tracker = node.get_config().tracker + config_tracker = node.get_config().create_tracker() if not config_tracker: continue diff --git a/packages/ai-providers/server-ai-langchain/tests/test_langchain_provider.py b/packages/ai-providers/server-ai-langchain/tests/test_langchain_provider.py index 28157a71..594d7040 100644 --- a/packages/ai-providers/server-ai-langchain/tests/test_langchain_provider.py +++ b/packages/ai-providers/server-ai-langchain/tests/test_langchain_provider.py @@ -536,7 +536,6 @@ def sync_tool(x: str = '') -> str: ), provider=ProviderConfig(name='openai'), instructions='', - tracker=MagicMock(), ) tools = build_structured_tools(cfg, {'my_tool': sync_tool}) assert len(tools) == 1 @@ -559,7 +558,6 @@ async def async_tool(x: str = '') -> str: ), provider=ProviderConfig(name='openai'), instructions='', - tracker=MagicMock(), ) tools = build_structured_tools(cfg, {'my_tool': async_tool}) assert len(tools) == 1 diff --git a/packages/ai-providers/server-ai-langchain/tests/test_langgraph_agent_graph_runner.py b/packages/ai-providers/server-ai-langchain/tests/test_langgraph_agent_graph_runner.py index 07802cb2..1583a743 100644 --- a/packages/ai-providers/server-ai-langchain/tests/test_langgraph_agent_graph_runner.py +++ b/packages/ai-providers/server-ai-langchain/tests/test_langgraph_agent_graph_runner.py @@ -11,13 +11,13 @@ def _make_graph(enabled: bool = True) -> AgentGraphDefinition: + graph_tracker = MagicMock() root_config = AIAgentConfig( key='root-agent', enabled=enabled, model=ModelConfig(name='gpt-4'), provider=ProviderConfig(name='openai'), instructions='You are a helpful assistant.', - tracker=MagicMock(), ) graph_config = AIAgentGraphConfig( key='test-graph', @@ -31,7 +31,7 @@ def _make_graph(enabled: bool = True) -> AgentGraphDefinition: nodes=nodes, context=MagicMock(), enabled=enabled, - tracker=MagicMock(), + create_tracker=lambda: graph_tracker, ) @@ -78,7 +78,7 @@ async def test_langgraph_runner_run_raises_when_langgraph_not_installed(): @pytest.mark.asyncio async def test_langgraph_runner_run_tracks_failure_on_exception(): graph = _make_graph() - tracker = graph.get_tracker() + tracker = graph.create_tracker() runner = LangGraphAgentGraphRunner(graph, {}) with patch.dict('sys.modules', {'langgraph': None, 'langgraph.graph': None}): @@ -92,7 +92,7 @@ async def test_langgraph_runner_run_tracks_failure_on_exception(): @pytest.mark.asyncio async def test_langgraph_runner_run_success(): graph = _make_graph() - tracker = graph.get_tracker() + tracker = graph.create_tracker() mock_message = MagicMock() mock_message.content = "langgraph answer" diff --git a/packages/ai-providers/server-ai-langchain/tests/test_langgraph_callback_handler.py b/packages/ai-providers/server-ai-langchain/tests/test_langgraph_callback_handler.py index 8b644995..4db150c4 100644 --- a/packages/ai-providers/server-ai-langchain/tests/test_langgraph_callback_handler.py +++ b/packages/ai-providers/server-ai-langchain/tests/test_langgraph_callback_handler.py @@ -35,6 +35,7 @@ def _make_graph(mock_ld_client: MagicMock, node_key: str = 'root-agent', graph_k model_name='gpt-4', provider_name='openai', context=context, + run_id='test-run-id', graph_key=graph_key, ) graph_tracker = AIGraphTracker( @@ -50,7 +51,7 @@ def _make_graph(mock_ld_client: MagicMock, node_key: str = 'root-agent', graph_k model=ModelConfig(name='gpt-4', parameters={}), provider=ProviderConfig(name='openai'), instructions='Be helpful.', - tracker=node_tracker, + create_tracker=lambda: node_tracker, ) graph_config = AIAgentGraphConfig( key=graph_key, @@ -64,7 +65,7 @@ def _make_graph(mock_ld_client: MagicMock, node_key: str = 'root-agent', graph_k nodes=nodes, context=context, enabled=True, - tracker=graph_tracker, + create_tracker=lambda: graph_tracker, ) @@ -320,7 +321,7 @@ def test_flush_emits_token_events_to_ld_tracker(): """flush() calls track_tokens on the node's config tracker.""" mock_ld_client = MagicMock() graph = _make_graph(mock_ld_client, node_key='root-agent', graph_key='g1') - tracker = graph.get_tracker() + tracker = graph.create_tracker() handler = LDMetricsCallbackHandler({'root-agent'}, {}) node_run_id = uuid4() @@ -339,7 +340,7 @@ def test_flush_emits_duration(): """flush() calls track_duration when duration was recorded.""" mock_ld_client = MagicMock() graph = _make_graph(mock_ld_client) - tracker = graph.get_tracker() + tracker = graph.create_tracker() handler = LDMetricsCallbackHandler({'root-agent'}, {}) run_id = uuid4() @@ -355,7 +356,7 @@ def test_flush_emits_tool_calls(): """flush() calls track_tool_call for each recorded tool invocation.""" mock_ld_client = MagicMock() graph = _make_graph(mock_ld_client) - tracker = graph.get_tracker() + tracker = graph.create_tracker() handler = LDMetricsCallbackHandler({'root-agent'}, {'fn_search': 'search'}) # The agent node must be started first so it appears in the path for flush() @@ -377,7 +378,7 @@ def test_flush_includes_graph_key_in_node_events(): """flush() passes graph_key to the node tracker so graphKey appears in events.""" mock_ld_client = MagicMock() graph = _make_graph(mock_ld_client, graph_key='my-graph') - tracker = graph.get_tracker() + tracker = graph.create_tracker() handler = LDMetricsCallbackHandler({'root-agent'}, {}) node_run_id = uuid4() @@ -402,6 +403,7 @@ def test_flush_with_no_graph_key_on_node_tracker(): model_name='gpt-4', provider_name='openai', context=context, + run_id='test-run-id', ) node_config = AIAgentConfig( key='root-agent', @@ -409,7 +411,7 @@ def test_flush_with_no_graph_key_on_node_tracker(): model=ModelConfig(name='gpt-4', parameters={}), provider=ProviderConfig(name='openai'), instructions='Be helpful.', - tracker=node_tracker, + create_tracker=lambda: node_tracker, ) graph_config = AIAgentGraphConfig( key='test-graph', @@ -423,7 +425,7 @@ def test_flush_with_no_graph_key_on_node_tracker(): nodes=nodes, context=context, enabled=True, - tracker=None, + create_tracker=lambda: None, ) handler = LDMetricsCallbackHandler({'root-agent'}, {}) @@ -441,7 +443,7 @@ def test_flush_skips_nodes_not_in_path(): """flush() only emits events for nodes that were actually executed.""" mock_ld_client = MagicMock() graph = _make_graph(mock_ld_client) - tracker = graph.get_tracker() + tracker = graph.create_tracker() # Handler with 'root-agent' in node_keys but never started handler = LDMetricsCallbackHandler({'root-agent'}, {}) @@ -463,7 +465,6 @@ def test_flush_skips_node_without_tracker(): model=ModelConfig(name='gpt-4', parameters={}), provider=ProviderConfig(name='openai'), instructions='', - tracker=None, ) graph_config = AIAgentGraphConfig( key='g', root_config_key='no-track', edges=[], enabled=True @@ -474,7 +475,7 @@ def test_flush_skips_node_without_tracker(): nodes=nodes, context=context, enabled=True, - tracker=None, + create_tracker=lambda: None, ) handler = LDMetricsCallbackHandler({'no-track'}, {}) diff --git a/packages/ai-providers/server-ai-langchain/tests/test_tracking_langgraph.py b/packages/ai-providers/server-ai-langchain/tests/test_tracking_langgraph.py index 948d58d6..2e598ade 100644 --- a/packages/ai-providers/server-ai-langchain/tests/test_tracking_langgraph.py +++ b/packages/ai-providers/server-ai-langchain/tests/test_tracking_langgraph.py @@ -45,6 +45,7 @@ def _make_graph( model_name='gpt-4', provider_name='openai', context=context, + run_id='test-run-id', graph_key=graph_key, ) @@ -68,7 +69,7 @@ def _make_graph( model=ModelConfig(name='gpt-4', parameters={'tools': tool_defs} if tool_defs else {}), provider=ProviderConfig(name='openai'), instructions='You are a helpful assistant.', - tracker=node_tracker, + create_tracker=lambda: node_tracker, ) graph_config = AIAgentGraphConfig( @@ -84,7 +85,7 @@ def _make_graph( nodes=nodes, context=context, enabled=True, - tracker=graph_tracker, + create_tracker=lambda: graph_tracker, ) @@ -142,6 +143,7 @@ def _make_two_node_graph(mock_ld_client: MagicMock) -> 'AgentGraphDefinition': model_name='gpt-4', provider_name='openai', context=context, + run_id='test-run-id', graph_key='two-node-graph', ) child_tracker = LDAIConfigTracker( @@ -152,6 +154,7 @@ def _make_two_node_graph(mock_ld_client: MagicMock) -> 'AgentGraphDefinition': model_name='gpt-4', provider_name='openai', context=context, + run_id='test-run-id', graph_key='two-node-graph', ) graph_tracker = AIGraphTracker( @@ -168,7 +171,7 @@ def _make_two_node_graph(mock_ld_client: MagicMock) -> 'AgentGraphDefinition': model=ModelConfig(name='gpt-4', parameters={}), provider=ProviderConfig(name='openai'), instructions='You are root.', - tracker=root_tracker, + create_tracker=lambda: root_tracker, ) child_config = AIAgentConfig( key='child-agent', @@ -176,7 +179,7 @@ def _make_two_node_graph(mock_ld_client: MagicMock) -> 'AgentGraphDefinition': model=ModelConfig(name='gpt-4', parameters={}), provider=ProviderConfig(name='openai'), instructions='You are child.', - tracker=child_tracker, + create_tracker=lambda: child_tracker, ) edge = Edge(key='root-to-child', source_config='root-agent', target_config='child-agent') @@ -196,7 +199,7 @@ def _make_two_node_graph(mock_ld_client: MagicMock) -> 'AgentGraphDefinition': nodes=nodes, context=context, enabled=True, - tracker=graph_tracker, + create_tracker=lambda: graph_tracker, ) @@ -228,7 +231,7 @@ async def test_tracks_node_and_graph_tokens_on_success(): # (mock models don't fire LangChain callbacks, so we test flush directly) mock_ld_client2 = MagicMock() graph2 = _make_graph(mock_ld_client2) - tracker2 = graph2.get_tracker() + tracker2 = graph2.create_tracker() handler = LDMetricsCallbackHandler({'root-agent'}, {}) node_run_id = uuid4() @@ -308,7 +311,7 @@ def get_weather(location: str = 'NYC') -> str: # Simulate tool call tracking via the callback handler directly mock_ld_client2 = MagicMock() graph2 = _make_graph(mock_ld_client2, tool_names=['get_weather']) - tracker2 = graph2.get_tracker() + tracker2 = graph2.create_tracker() handler = LDMetricsCallbackHandler({'root-agent'}, {'get_weather': 'get_weather'}) # Agent node must appear in path for flush() to emit its events @@ -360,7 +363,7 @@ def summarize(text: str = '') -> str: # Simulate multiple tool calls via the callback handler directly mock_ld_client2 = MagicMock() graph2 = _make_graph(mock_ld_client2, tool_names=['search', 'summarize']) - tracker2 = graph2.get_tracker() + tracker2 = graph2.create_tracker() fn_map = {'search': 'search', 'summarize': 'summarize'} handler = LDMetricsCallbackHandler({'root-agent'}, fn_map) @@ -388,7 +391,7 @@ async def test_tracks_graph_key_on_node_events(): mock_ld_client = MagicMock() graph = _make_graph(mock_ld_client, graph_key='my-graph') - tracker = graph.get_tracker() + tracker = graph.create_tracker() handler = LDMetricsCallbackHandler({'root-agent'}, {}) node_run_id = uuid4() @@ -461,7 +464,7 @@ def model_factory(node_config, **kwargs): # Simulate per-node token events via callback handler (mock models don't fire callbacks) mock_ld_client2 = MagicMock() graph2 = _make_two_node_graph(mock_ld_client2) - tracker2 = graph2.get_tracker() + tracker2 = graph2.create_tracker() handler = LDMetricsCallbackHandler({'root-agent', 'child-agent'}, {}) @@ -514,6 +517,7 @@ def _make_multi_child_graph(mock_ld_client: MagicMock) -> 'AgentGraphDefinition' def _node_tracker(key: str) -> LDAIConfigTracker: return LDAIConfigTracker( ld_client=mock_ld_client, + run_id='test-run-id', variation_key='test-variation', config_key=key, version=1, @@ -538,7 +542,7 @@ def _node_tracker(key: str) -> LDAIConfigTracker: model=ModelConfig(name='gpt-4', parameters={}), provider=ProviderConfig(name='openai'), instructions='Route to the appropriate specialist agent.', - tracker=_node_tracker('orchestrator'), + create_tracker=lambda: _node_tracker('orchestrator'), ), 'agent-a': AIAgentConfig( key='agent-a', @@ -546,7 +550,7 @@ def _node_tracker(key: str) -> LDAIConfigTracker: model=ModelConfig(name='gpt-4', parameters={}), provider=ProviderConfig(name='openai'), instructions='You handle topic A.', - tracker=_node_tracker('agent-a'), + create_tracker=lambda: _node_tracker('agent-a'), ), 'agent-b': AIAgentConfig( key='agent-b', @@ -554,7 +558,7 @@ def _node_tracker(key: str) -> LDAIConfigTracker: model=ModelConfig(name='gpt-4', parameters={}), provider=ProviderConfig(name='openai'), instructions='You handle topic B.', - tracker=_node_tracker('agent-b'), + create_tracker=lambda: _node_tracker('agent-b'), ), } @@ -574,7 +578,7 @@ def _node_tracker(key: str) -> LDAIConfigTracker: nodes=nodes, context=context, enabled=True, - tracker=graph_tracker, + create_tracker=lambda: graph_tracker, ) @@ -625,6 +629,7 @@ def _make_multi_child_graph_with_tools(mock_ld_client: MagicMock, tool_names: li def _node_tracker(key: str) -> LDAIConfigTracker: return LDAIConfigTracker( ld_client=mock_ld_client, + run_id='test-run-id', variation_key='test-variation', config_key=key, version=1, @@ -650,7 +655,7 @@ def _node_tracker(key: str) -> LDAIConfigTracker: model=ModelConfig(name='gpt-4', parameters={'tools': tool_defs}), provider=ProviderConfig(name='openai'), instructions='Route to a specialist after gathering info.', - tracker=_node_tracker('orchestrator'), + create_tracker=lambda: _node_tracker('orchestrator'), ), 'agent-a': AIAgentConfig( key='agent-a', @@ -658,7 +663,7 @@ def _node_tracker(key: str) -> LDAIConfigTracker: model=ModelConfig(name='gpt-4', parameters={}), provider=ProviderConfig(name='openai'), instructions='You handle topic A.', - tracker=_node_tracker('agent-a'), + create_tracker=lambda: _node_tracker('agent-a'), ), 'agent-b': AIAgentConfig( key='agent-b', @@ -666,7 +671,7 @@ def _node_tracker(key: str) -> LDAIConfigTracker: model=ModelConfig(name='gpt-4', parameters={}), provider=ProviderConfig(name='openai'), instructions='You handle topic B.', - tracker=_node_tracker('agent-b'), + create_tracker=lambda: _node_tracker('agent-b'), ), } @@ -686,7 +691,7 @@ def _node_tracker(key: str) -> LDAIConfigTracker: nodes=nodes, context=context, enabled=True, - tracker=graph_tracker, + create_tracker=lambda: graph_tracker, ) diff --git a/packages/ai-providers/server-ai-openai/src/ldai_openai/openai_agent_graph_runner.py b/packages/ai-providers/server-ai-openai/src/ldai_openai/openai_agent_graph_runner.py index 3cfd595b..07a7b3dd 100644 --- a/packages/ai-providers/server-ai-openai/src/ldai_openai/openai_agent_graph_runner.py +++ b/packages/ai-providers/server-ai-openai/src/ldai_openai/openai_agent_graph_runner.py @@ -57,6 +57,7 @@ def __init__(self, graph: AgentGraphDefinition, tools: ToolRegistry): self._tools = tools self._agent_name_map: Dict[str, str] = {} self._tool_name_map: Dict[str, str] = {} + self._node_trackers: Dict[str, Any] = {} async def run(self, input: Any) -> AgentGraphResult: """ @@ -69,7 +70,7 @@ async def run(self, input: Any) -> AgentGraphResult: :param input: The string prompt to send to the agent graph :return: AgentGraphResult with the final output and metrics """ - tracker = self._graph.get_tracker() + tracker = self._graph.create_tracker() if self._graph.create_tracker is not None else None path: List[str] = [] root_node = self._graph.root() root_key = root_node.get_key() if root_node else '' @@ -80,7 +81,7 @@ async def run(self, input: Any) -> AgentGraphResult: state = _RunState(last_handoff_ns=start_ns, last_node_key=root_key) try: from agents import Runner - root_agent = self._build_agents(path, state) + root_agent = self._build_agents(path, state, tracker) result = await Runner.run(root_agent, str(input)) self._flush_final_segment(state, result) self._track_tool_calls(result) @@ -118,7 +119,9 @@ async def run(self, input: Any) -> AgentGraphResult: metrics=LDAIMetrics(success=False), ) - def _build_agents(self, path: List[str], state: _RunState) -> Any: + def _build_agents( + self, path: List[str], state: _RunState, tracker: Any + ) -> Any: """ Build the agent tree from the graph definition via reverse_traverse. @@ -127,6 +130,7 @@ def _build_agents(self, path: List[str], state: _RunState) -> Any: :param path: Mutable list to accumulate the execution path :param state: Shared run state for tracking handoff timing and last node + :param tracker: Graph-level tracker shared across the entire run :return: The root Agent instance """ try: @@ -142,13 +146,14 @@ def _build_agents(self, path: List[str], state: _RunState) -> Any: "Install it with: pip install openai-agents" ) from exc - tracker = self._graph.get_tracker() name_map: Dict[str, str] = {} tool_name_map: Dict[str, str] = {} + node_trackers: Dict[str, Any] = {} def build_node(node: AgentGraphNode, ctx: dict) -> Any: node_config = node.get_config() - config_tracker = node_config.tracker + config_tracker = node_config.create_tracker() + node_trackers[node_config.key] = config_tracker model = node_config.model if not model: @@ -204,6 +209,7 @@ def build_node(node: AgentGraphNode, ctx: dict) -> Any: root = self._graph.reverse_traverse(fn=build_node) self._agent_name_map = name_map self._tool_name_map = tool_name_map + self._node_trackers = node_trackers return root def _make_on_handoff( @@ -263,10 +269,7 @@ def _flush_final_segment( """Record duration/tokens for the last active agent (no handoff after it).""" if not state.last_node_key: return - node = self._graph.get_node(state.last_node_key) - if node is None: - return - config_tracker = node.get_config().tracker + config_tracker = self._node_trackers.get(state.last_node_key) if config_tracker is None: return @@ -293,9 +296,6 @@ def _track_tool_calls(self, result: Any) -> None: tool_name = self._tool_name_map.get(tool_fn_name) if tool_name is None: continue - node = self._graph.get_node(agent_key) - if node is None: - continue - config_tracker = node.get_config().tracker + config_tracker = self._node_trackers.get(agent_key) if config_tracker is not None: config_tracker.track_tool_call(tool_name) diff --git a/packages/ai-providers/server-ai-openai/tests/test_openai_agent_graph_runner.py b/packages/ai-providers/server-ai-openai/tests/test_openai_agent_graph_runner.py index 56dd0c10..e4a7721e 100644 --- a/packages/ai-providers/server-ai-openai/tests/test_openai_agent_graph_runner.py +++ b/packages/ai-providers/server-ai-openai/tests/test_openai_agent_graph_runner.py @@ -12,13 +12,17 @@ def _make_graph(enabled: bool = True) -> AgentGraphDefinition: """Build a minimal single-node AgentGraphDefinition for testing.""" + node_tracker = MagicMock() + graph_tracker = MagicMock() + node_factory = MagicMock(return_value=node_tracker) + graph_factory = MagicMock(return_value=graph_tracker) root_config = AIAgentConfig( key='root-agent', enabled=enabled, model=ModelConfig(name='gpt-4'), provider=ProviderConfig(name='openai'), instructions='You are a helpful assistant.', - tracker=MagicMock(), + create_tracker=node_factory, ) graph_config = AIAgentGraphConfig( key='test-graph', @@ -32,7 +36,7 @@ def _make_graph(enabled: bool = True) -> AgentGraphDefinition: nodes=nodes, context=MagicMock(), enabled=enabled, - tracker=MagicMock(), + create_tracker=graph_factory, ) @@ -81,7 +85,7 @@ async def test_openai_agent_graph_runner_run_raises_when_agents_not_installed(): @pytest.mark.asyncio async def test_openai_agent_graph_runner_run_tracks_invocation_failure_on_exception(): graph = _make_graph() - tracker = graph.get_tracker() + tracker = graph.create_tracker.return_value runner = OpenAIAgentGraphRunner(graph, {}) with patch.dict('sys.modules', {'agents': None}): @@ -95,7 +99,7 @@ async def test_openai_agent_graph_runner_run_tracks_invocation_failure_on_except @pytest.mark.asyncio async def test_openai_agent_graph_runner_run_success(): graph = _make_graph() - tracker = graph.get_tracker() + tracker = graph.create_tracker.return_value mock_result = MagicMock() mock_result.final_output = "agent answer" @@ -136,7 +140,21 @@ async def test_openai_agent_graph_runner_run_success(): tracker.track_path.assert_called_once() tracker.track_latency.assert_called_once() - root_tracker = graph.get_node('root-agent').get_config().tracker - root_tracker.track_duration.assert_called_once() - root_tracker.track_tokens.assert_called_once() - root_tracker.track_success.assert_called_once() + # The runner caches one tracker per node — verify it is the same instance + # returned by create_tracker() and that all tracking calls hit it. + node_factory = graph.get_node('root-agent').get_config().create_tracker + + # The runner caches one tracker per node — verify it is the same instance + # returned by create_tracker and that all tracking calls hit it. + cached = runner._node_trackers['root-agent'] + assert cached is node_factory.return_value + cached.track_duration.assert_called_once() + cached.track_tokens.assert_called_once() + cached.track_success.assert_called_once() + + # Graph-level create_tracker is called exactly once per run (not twice) + # so that handoff callbacks and run() share the same tracker instance. + graph.create_tracker.assert_called_once() + + # Node-level create_tracker is called exactly once per node. + node_factory.assert_called_once() diff --git a/packages/ai-providers/server-ai-openai/tests/test_tracking_openai_agents.py b/packages/ai-providers/server-ai-openai/tests/test_tracking_openai_agents.py index 39c75034..74fd0cc9 100644 --- a/packages/ai-providers/server-ai-openai/tests/test_tracking_openai_agents.py +++ b/packages/ai-providers/server-ai-openai/tests/test_tracking_openai_agents.py @@ -40,6 +40,7 @@ def _make_graph( model_name='gpt-4', provider_name='openai', context=context, + run_id='test-run-id', graph_key=graph_key, ) @@ -63,7 +64,7 @@ def _make_graph( model=ModelConfig(name='gpt-4', parameters={'tools': tool_defs} if tool_defs else {}), provider=ProviderConfig(name='openai'), instructions='You are a helpful assistant.', - tracker=node_tracker, + create_tracker=lambda: node_tracker, ) graph_config = AIAgentGraphConfig( @@ -79,7 +80,7 @@ def _make_graph( nodes=nodes, context=context, enabled=True, - tracker=graph_tracker, + create_tracker=lambda: graph_tracker, ) @@ -179,6 +180,7 @@ def _make_two_node_graph(mock_ld_client: MagicMock) -> AgentGraphDefinition: model_name='gpt-4', provider_name='openai', context=context, + run_id='test-run-id', graph_key='two-node-graph', ) child_tracker = LDAIConfigTracker( @@ -189,6 +191,7 @@ def _make_two_node_graph(mock_ld_client: MagicMock) -> AgentGraphDefinition: model_name='gpt-4', provider_name='openai', context=context, + run_id='test-run-id', graph_key='two-node-graph', ) graph_tracker = AIGraphTracker( @@ -205,7 +208,7 @@ def _make_two_node_graph(mock_ld_client: MagicMock) -> AgentGraphDefinition: model=ModelConfig(name='gpt-4', parameters={}), provider=ProviderConfig(name='openai'), instructions='You are root.', - tracker=root_tracker, + create_tracker=lambda: root_tracker, ) child_config = AIAgentConfig( key='child-agent', @@ -213,7 +216,7 @@ def _make_two_node_graph(mock_ld_client: MagicMock) -> AgentGraphDefinition: model=ModelConfig(name='gpt-4', parameters={}), provider=ProviderConfig(name='openai'), instructions='You are child.', - tracker=child_tracker, + create_tracker=lambda: child_tracker, ) edge = Edge(key='root-to-child', source_config='root-agent', target_config='child-agent') @@ -233,7 +236,7 @@ def _make_two_node_graph(mock_ld_client: MagicMock) -> AgentGraphDefinition: nodes=nodes, context=context, enabled=True, - tracker=graph_tracker, + create_tracker=lambda: graph_tracker, ) @@ -356,6 +359,40 @@ async def test_tracks_multiple_tool_calls(): assert sorted(tool_keys) == ['search', 'summarize'] +@pytest.mark.asyncio +async def test_same_run_id_across_token_success_and_tool_call_events(): + """All node-level events for a single execution share the same runId.""" + mock_ld_client = MagicMock() + graph = _make_graph( + mock_ld_client, node_key='root-agent', graph_key='g', tool_names=['search'] + ) + + tool_item = _make_tool_call_item('root-agent', 'search') + run_result = _make_run_result( + output='ok', total_tokens=10, input_tokens=7, output_tokens=3, + tool_call_items=[tool_item], + ) + + with patch.dict('sys.modules', _make_agents_modules(run_result)): + runner = OpenAIAgentGraphRunner(graph, _tool_registry('search')) + await runner.run('go') + + ev = _events(mock_ld_client) + + # Collect runIds from node-level events + run_ids = set() + for event_name in ( + '$ld:ai:tokens:total', '$ld:ai:tokens:input', '$ld:ai:tokens:output', + '$ld:ai:generation:success', '$ld:ai:generation:duration', '$ld:ai:tool_call', + ): + for data, _ in ev.get(event_name, []): + if data.get('configKey') == 'root-agent': + run_ids.add(data['runId']) + + # All events must share a single runId + assert len(run_ids) == 1 + + @pytest.mark.asyncio async def test_does_not_track_tool_calls_without_graph_and_registry_config(): """RunResult tool items that are not backed by graph + registry tools are ignored.""" diff --git a/packages/sdk/server-ai/src/ldai/agent_graph/__init__.py b/packages/sdk/server-ai/src/ldai/agent_graph/__init__.py index c19552d1..0e28267e 100644 --- a/packages/sdk/server-ai/src/ldai/agent_graph/__init__.py +++ b/packages/sdk/server-ai/src/ldai/agent_graph/__init__.py @@ -54,21 +54,13 @@ def __init__( nodes: Dict[str, AgentGraphNode], context: Context, enabled: bool, - tracker: Optional[AIGraphTracker] = None, + create_tracker: Optional[Callable[[], AIGraphTracker]] = None, ): self._agent_graph = agent_graph self._context = context self._nodes = nodes self.enabled = enabled - self._tracker = tracker - - def get_tracker(self) -> Optional[AIGraphTracker]: - """ - Get the graph tracker for this graph definition. - - :return: The AIGraphTracker instance, or None if not available. - """ - return self._tracker + self.create_tracker = create_tracker def is_enabled(self) -> bool: """Check if the graph is enabled.""" diff --git a/packages/sdk/server-ai/src/ldai/client.py b/packages/sdk/server-ai/src/ldai/client.py index d47e00bf..72ccb493 100644 --- a/packages/sdk/server-ai/src/ldai/client.py +++ b/packages/sdk/server-ai/src/ldai/client.py @@ -1,7 +1,8 @@ -from typing import Any, Dict, List, Optional, Tuple +import uuid +from typing import Any, Callable, Dict, List, Optional, Tuple import chevron -from ldclient import Context +from ldclient import Context, Result from ldclient.client import LDClient from ldai import log @@ -44,6 +45,10 @@ _INIT_TRACK_CONTEXT = Context.builder('ld-internal-tracking').kind('ld_ai').anonymous(True).build() +_DISABLED_COMPLETION_DEFAULT = AICompletionConfigDefault(enabled=False) +_DISABLED_AGENT_DEFAULT = AIAgentConfigDefault(enabled=False) +_DISABLED_JUDGE_DEFAULT = AIJudgeConfigDefault(enabled=False) + class LDAIClient: """The LaunchDarkly AI SDK client object.""" @@ -61,6 +66,21 @@ def __init__(self, client: LDClient): 1, ) + def create_tracker(self, token: str, context: Context) -> Result: + """ + Reconstruct a tracker from a resumption token. + + Delegates to :meth:`LDAIConfigTracker.from_resumption_token`. + + :param token: A URL-safe Base64-encoded resumption token obtained from + :attr:`LDAIConfigTracker.resumption_token`. + :param context: The context to use for track events. + :return: A :class:`Result` whose ``value`` is a new + :class:`LDAIConfigTracker` on success, or whose ``error`` describes + the problem on failure. + """ + return LDAIConfigTracker.from_resumption_token(token, self._client, context) + def _completion_config( self, key: str, @@ -68,7 +88,8 @@ def _completion_config( default: AICompletionConfigDefault, variables: Optional[Dict[str, Any]] = None, ) -> AICompletionConfig: - model, provider, messages, instructions, tracker, enabled, judge_configuration, _ = self.__evaluate( + (model, provider, messages, instructions, + tracker_factory, enabled, judge_configuration, _) = self.__evaluate( key, context, default.to_dict(), variables ) @@ -78,7 +99,7 @@ def _completion_config( model=model, messages=messages, provider=provider, - tracker=tracker, + create_tracker=tracker_factory, judge_configuration=judge_configuration, ) @@ -104,7 +125,7 @@ def completion_config( self._client.track(_TRACK_USAGE_COMPLETION_CONFIG, context, key, 1) return self._completion_config( - key, context, default or AICompletionConfigDefault.disabled(), variables + key, context, default or _DISABLED_COMPLETION_DEFAULT, variables ) def config( @@ -134,7 +155,8 @@ def _judge_config( default: AIJudgeConfigDefault, variables: Optional[Dict[str, Any]] = None, ) -> AIJudgeConfig: - model, provider, messages, instructions, tracker, enabled, judge_configuration, variation = self.__evaluate( + (model, provider, messages, instructions, + tracker_factory, enabled, judge_configuration, variation) = self.__evaluate( key, context, default.to_dict(), variables ) @@ -162,7 +184,7 @@ def _extract_evaluation_metric_key(variation: Dict[str, Any]) -> Optional[str]: model=model, messages=messages, provider=provider, - tracker=tracker, + create_tracker=tracker_factory, ) return config @@ -187,7 +209,7 @@ def judge_config( self._client.track(_TRACK_USAGE_JUDGE_CONFIG, context, key, 1) return self._judge_config( - key, context, default or AIJudgeConfigDefault.disabled(), variables + key, context, default or _DISABLED_JUDGE_DEFAULT, variables ) async def create_judge( @@ -245,17 +267,17 @@ async def create_judge( extended_variables['response_to_evaluate'] = '{{response_to_evaluate}}' judge_config = self._judge_config( - key, context, default or AIJudgeConfigDefault.disabled(), extended_variables + key, context, default or _DISABLED_JUDGE_DEFAULT, extended_variables ) - if not judge_config.enabled or not judge_config.tracker: + if not judge_config.enabled: return None provider = RunnerFactory.create_model(judge_config, default_ai_provider) if not provider: return None - return Judge(judge_config, judge_config.tracker, provider) + return Judge(judge_config, provider) except Exception as error: return None @@ -343,9 +365,9 @@ async def create_model( """ self._client.track(_TRACK_USAGE_CREATE_MODEL, context, key, 1) log.debug(f"Creating managed model for key: {key}") - config = self._completion_config(key, context, default or AICompletionConfigDefault.disabled(), variables) + config = self._completion_config(key, context, default or _DISABLED_COMPLETION_DEFAULT, variables) - if not config.enabled or not config.tracker: + if not config.enabled: return None runner = RunnerFactory.create_model(config, default_ai_provider) @@ -361,7 +383,7 @@ async def create_model( default_ai_provider, ) - return ManagedModel(config, config.tracker, runner, judges) + return ManagedModel(config, runner, judges) async def create_chat( self, @@ -426,16 +448,16 @@ async def create_agent( """ self._client.track(_TRACK_USAGE_CREATE_AGENT, context, key, 1) log.debug(f"Creating managed agent for key: {key}") - config = self.__evaluate_agent(key, context, default or AIAgentConfigDefault.disabled(), variables) + config = self.__evaluate_agent(key, context, default or _DISABLED_AGENT_DEFAULT, variables) - if not config.enabled or not config.tracker: + if not config.enabled: return None runner = RunnerFactory.create_agent(config, tools or {}, default_ai_provider) if not runner: return None - return ManagedAgent(config, config.tracker, runner) + return ManagedAgent(config, runner) def agent_config( self, @@ -465,7 +487,8 @@ def agent_config( if agent.enabled: research_result = agent.instructions # Interpolated instructions - agent.tracker.track_success() + tracker = agent.create_tracker() + tracker.track_success() :param key: The agent configuration key. :param context: The context to evaluate the agent configuration in. @@ -482,7 +505,7 @@ def agent_config( ) return self.__evaluate_agent( - key, context, default or AIAgentConfigDefault.disabled(), variables + key, context, default or _DISABLED_AGENT_DEFAULT, variables ) def agent( @@ -535,7 +558,8 @@ def agent_configs( ], context) research_result = agents["research_agent"].instructions - agents["research_agent"].tracker.track_success() + tracker = agents["research_agent"].create_tracker() + tracker.track_success() :param agent_configs: List of agent configurations to retrieve. :param context: The context to evaluate the agent configurations in. @@ -555,7 +579,7 @@ def agent_configs( agent = self.__evaluate_agent( config.key, context, - config.default or AIAgentConfigDefault.disabled(), + config.default or _DISABLED_AGENT_DEFAULT, config.variables ) result[config.key] = agent @@ -576,14 +600,15 @@ def agent_graph( variation_key = variation.get("_ldMeta", {}).get("variationKey", "") version = int(variation.get("_ldMeta", {}).get("version", 1)) - # Create graph tracker - tracker = AIGraphTracker( - self._client, - variation_key, - key, - version, - context, - ) + # Create graph tracker factory + def graph_tracker_factory() -> AIGraphTracker: + return AIGraphTracker( + self._client, + variation_key, + key, + version, + context, + ) if not variation.get("root"): log.debug(f"Agent graph {key} is disabled, no root config key found") @@ -597,7 +622,7 @@ def agent_graph( nodes={}, context=context, enabled=False, - tracker=tracker, + create_tracker=graph_tracker_factory, ) edge_keys = list[str](variation.get("edges", {}).keys()) @@ -628,7 +653,7 @@ def agent_graph( nodes={}, context=context, enabled=False, - tracker=tracker, + create_tracker=graph_tracker_factory, ) try: @@ -659,7 +684,7 @@ def agent_graph( nodes={}, context=context, enabled=False, - tracker=tracker, + create_tracker=graph_tracker_factory, ) nodes = AgentGraphDefinition.build_nodes( @@ -672,7 +697,7 @@ def agent_graph( nodes=nodes, context=context, enabled=agent_graph_config.enabled, - tracker=tracker, + create_tracker=graph_tracker_factory, ) async def create_agent_graph( @@ -727,7 +752,7 @@ async def create_agent_graph( if not runner: return None - return ManagedAgentGraph(runner, graph.get_tracker()) + return ManagedAgentGraph(runner) def agents( self, @@ -754,7 +779,7 @@ def __evaluate( graph_key: Optional[str] = None, ) -> Tuple[ Optional[ModelConfig], Optional[ProviderConfig], Optional[List[LDMessage]], - Optional[str], LDAIConfigTracker, bool, Optional[Any], Dict[str, Any] + Optional[str], Callable[[], LDAIConfigTracker], bool, Optional[Any], Dict[str, Any] ]: """ Internal method to evaluate a configuration and extract components. @@ -764,7 +789,8 @@ def __evaluate( :param default_dict: Default configuration as dictionary. :param variables: Variables for interpolation. :param graph_key: When set, passed to the tracker so all events include ``graphKey``. - :return: Tuple of (model, provider, messages, instructions, tracker, enabled, judge_configuration, variation). + :return: Tuple of (model, provider, messages, instructions, + tracker_factory, enabled, judge_configuration, variation). """ variation = self._client.variation(key, context, default_dict) @@ -806,16 +832,23 @@ def __evaluate( custom=custom ) - tracker = LDAIConfigTracker( - self._client, - variation.get('_ldMeta', {}).get('variationKey', ''), - key, - int(variation.get('_ldMeta', {}).get('version', 1)), - model.name if model else '', - provider_config.name if provider_config else '', - context, - graph_key=graph_key, - ) + variation_key = variation.get('_ldMeta', {}).get('variationKey', '') + version = int(variation.get('_ldMeta', {}).get('version', 1)) + model_name = model.name if model else '' + provider_name = provider_config.name if provider_config else '' + + def tracker_factory() -> LDAIConfigTracker: + return LDAIConfigTracker( + ld_client=self._client, + run_id=str(uuid.uuid4()), + config_key=key, + variation_key=variation_key, + version=version, + context=context, + model_name=model_name, + provider_name=provider_name, + graph_key=graph_key, + ) enabled = variation.get('_ldMeta', {}).get('enabled', False) @@ -834,7 +867,10 @@ def __evaluate( if judges: judge_configuration = JudgeConfiguration(judges=judges) - return model, provider_config, messages, instructions, tracker, enabled, judge_configuration, variation + return ( + model, provider_config, messages, instructions, + tracker_factory, enabled, judge_configuration, variation, + ) def __evaluate_agent( self, @@ -854,7 +890,8 @@ def __evaluate_agent( :param graph_key: When set, passed to the tracker so all events include ``graphKey``. :return: Configured AIAgentConfig instance. """ - model, provider, messages, instructions, tracker, enabled, judge_configuration, _ = self.__evaluate( + (model, provider, messages, instructions, + tracker_factory, enabled, judge_configuration, _) = self.__evaluate( key, context, default.to_dict(), variables, graph_key=graph_key ) @@ -867,7 +904,7 @@ def __evaluate_agent( model=model or default.model, provider=provider or default.provider, instructions=final_instructions, - tracker=tracker, + create_tracker=tracker_factory, judge_configuration=judge_configuration or default.judge_configuration, ) diff --git a/packages/sdk/server-ai/src/ldai/judge/__init__.py b/packages/sdk/server-ai/src/ldai/judge/__init__.py index 6db89f68..d889f17f 100644 --- a/packages/sdk/server-ai/src/ldai/judge/__init__.py +++ b/packages/sdk/server-ai/src/ldai/judge/__init__.py @@ -10,7 +10,6 @@ from ldai.models import AIJudgeConfig, LDMessage from ldai.providers.model_runner import ModelRunner from ldai.providers.types import JudgeResult, ModelResponse -from ldai.tracker import LDAIConfigTracker class Judge: @@ -24,18 +23,15 @@ class Judge: def __init__( self, ai_config: AIJudgeConfig, - ai_config_tracker: LDAIConfigTracker, model_runner: ModelRunner, ): """ Initialize the Judge. :param ai_config: The judge AI configuration - :param ai_config_tracker: The tracker for the judge configuration :param model_runner: The model runner to use for evaluation """ self._ai_config = ai_config - self._ai_config_tracker = ai_config_tracker self._model_runner = model_runner self._evaluation_response_structure = EvaluationSchemaBuilder.build() @@ -73,10 +69,12 @@ async def evaluate( return judge_result judge_result.sampled = True + + tracker = self._ai_config.create_tracker() messages = self._construct_evaluation_messages(input_text, output_text) assert self._evaluation_response_structure is not None - response = await self._ai_config_tracker.track_metrics_of_async( + response = await tracker.track_metrics_of_async( lambda: self._model_runner.invoke_structured_model(messages, self._evaluation_response_structure), lambda result: result.metrics, ) @@ -125,14 +123,6 @@ def get_ai_config(self) -> AIJudgeConfig: """ return self._ai_config - def get_tracker(self) -> LDAIConfigTracker: - """ - Returns the tracker associated with this judge. - - :return: The tracker for the judge configuration - """ - return self._ai_config_tracker - def get_model_runner(self) -> ModelRunner: """ Returns the model runner used by this judge. diff --git a/packages/sdk/server-ai/src/ldai/managed_agent.py b/packages/sdk/server-ai/src/ldai/managed_agent.py index 12c4d9bd..eb2dee7f 100644 --- a/packages/sdk/server-ai/src/ldai/managed_agent.py +++ b/packages/sdk/server-ai/src/ldai/managed_agent.py @@ -2,25 +2,22 @@ from ldai.models import AIAgentConfig from ldai.providers import AgentResult, AgentRunner -from ldai.tracker import LDAIConfigTracker class ManagedAgent: """ LaunchDarkly managed wrapper for AI agent invocations. - Holds an AgentRunner and an LDAIConfigTracker. Handles tracking automatically. + Holds an AgentRunner. Handles tracking automatically via ``create_tracker()``. Obtain an instance via ``LDAIClient.create_agent()``. """ def __init__( self, ai_config: AIAgentConfig, - tracker: LDAIConfigTracker, agent_runner: AgentRunner, ): self._ai_config = ai_config - self._tracker = tracker self._agent_runner = agent_runner async def run(self, input: str) -> AgentResult: @@ -30,7 +27,8 @@ async def run(self, input: str) -> AgentResult: :param input: The user prompt or input to the agent :return: AgentResult containing the agent's output and metrics """ - return await self._tracker.track_metrics_of_async( + tracker = self._ai_config.create_tracker() + return await tracker.track_metrics_of_async( lambda: self._agent_runner.run(input), lambda result: result.metrics, ) @@ -46,7 +44,3 @@ def get_agent_runner(self) -> AgentRunner: def get_config(self) -> AIAgentConfig: """Return the AI agent config.""" return self._ai_config - - def get_tracker(self) -> LDAIConfigTracker: - """Return the config tracker.""" - return self._tracker diff --git a/packages/sdk/server-ai/src/ldai/managed_agent_graph.py b/packages/sdk/server-ai/src/ldai/managed_agent_graph.py index bb04add3..a146e60e 100644 --- a/packages/sdk/server-ai/src/ldai/managed_agent_graph.py +++ b/packages/sdk/server-ai/src/ldai/managed_agent_graph.py @@ -1,16 +1,15 @@ """ManagedAgentGraph — LaunchDarkly managed wrapper for agent graph execution.""" -from typing import Any, Optional +from typing import Any from ldai.providers import AgentGraphResult, AgentGraphRunner -from ldai.tracker import AIGraphTracker class ManagedAgentGraph: """ LaunchDarkly managed wrapper for AI agent graph execution. - Holds an AgentGraphRunner and an AIGraphTracker. Auto-tracking of path, + Holds an AgentGraphRunner. Auto-tracking of path, tool calls, handoffs, latency, and invocation success/failure is handled by the runner implementation. @@ -20,16 +19,13 @@ class ManagedAgentGraph: def __init__( self, runner: AgentGraphRunner, - tracker: Optional[AIGraphTracker] = None, ): """ Initialize ManagedAgentGraph. :param runner: The AgentGraphRunner to delegate execution to - :param tracker: The AIGraphTracker for this graph """ self._runner = runner - self._tracker = tracker async def run(self, input: Any) -> AgentGraphResult: """ @@ -50,11 +46,3 @@ def get_agent_graph_runner(self) -> AgentGraphRunner: :return: The AgentGraphRunner instance """ return self._runner - - def get_tracker(self) -> Optional[AIGraphTracker]: - """ - Return the AIGraphTracker for this graph. - - :return: The AIGraphTracker instance, or None if not available - """ - return self._tracker diff --git a/packages/sdk/server-ai/src/ldai/managed_model.py b/packages/sdk/server-ai/src/ldai/managed_model.py index e982cc00..dc4393a0 100644 --- a/packages/sdk/server-ai/src/ldai/managed_model.py +++ b/packages/sdk/server-ai/src/ldai/managed_model.py @@ -13,20 +13,18 @@ class ManagedModel: """ LaunchDarkly managed wrapper for AI model invocations. - Holds a ModelRunner and an LDAIConfigTracker. Handles conversation - management, judge evaluation dispatch, and tracking automatically. + Holds a ModelRunner. Handles conversation management, judge evaluation + dispatch, and tracking automatically via ``create_tracker()``. Obtain an instance via ``LDAIClient.create_model()``. """ def __init__( self, ai_config: AICompletionConfig, - tracker: LDAIConfigTracker, model_runner: ModelRunner, judges: Optional[Dict[str, Judge]] = None, ): self._ai_config = ai_config - self._tracker = tracker self._model_runner = model_runner self._judges = judges or {} self._messages: List[LDMessage] = [] @@ -42,13 +40,15 @@ async def invoke(self, prompt: str) -> ModelResponse: :param prompt: The user prompt to send to the model :return: ModelResponse containing the model's response and metrics """ + tracker = self._ai_config.create_tracker() + user_message = LDMessage(role='user', content=prompt) self._messages.append(user_message) config_messages = self._ai_config.messages or [] all_messages = config_messages + self._messages - response = await self._tracker.track_metrics_of_async( + response = await tracker.track_metrics_of_async( lambda: self._model_runner.invoke_model(all_messages), lambda result: result.metrics, ) @@ -57,13 +57,14 @@ async def invoke(self, prompt: str) -> ModelResponse: self._ai_config.judge_configuration and self._ai_config.judge_configuration.judges ): - response.evaluations = self._start_judge_evaluations(self._messages, response) + response.evaluations = self._start_judge_evaluations(tracker, self._messages, response) self._messages.append(response.message) return response def _start_judge_evaluations( self, + tracker: LDAIConfigTracker, messages: List[LDMessage], response: ModelResponse, ) -> List[asyncio.Task[Optional[JudgeResult]]]: @@ -77,7 +78,7 @@ async def evaluate_judge(judge_config: Any) -> Optional[JudgeResult]: return None judge_result = await judge.evaluate_messages(messages, response, judge_config.sampling_rate) if judge_result.success: - self._tracker.track_judge_result(judge_result) + tracker.track_judge_result(judge_result) return judge_result return [ @@ -116,10 +117,6 @@ def get_config(self) -> AICompletionConfig: """Return the AI completion config.""" return self._ai_config - def get_tracker(self) -> LDAIConfigTracker: - """Return the config tracker.""" - return self._tracker - def get_judges(self) -> Dict[str, Judge]: """Return the judges associated with this model.""" return self._judges diff --git a/packages/sdk/server-ai/src/ldai/models.py b/packages/sdk/server-ai/src/ldai/models.py index 43d9c9b5..346d5162 100644 --- a/packages/sdk/server-ai/src/ldai/models.py +++ b/packages/sdk/server-ai/src/ldai/models.py @@ -1,6 +1,6 @@ import warnings from dataclasses import dataclass, field -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Union @dataclass @@ -150,14 +150,6 @@ class AIConfigDefault: model: Optional[ModelConfig] = None provider: Optional[ProviderConfig] = None - @classmethod - def disabled(cls): - """ - Returns a new disabled config default with enabled set to false. - When called on a subclass, returns an instance of that subclass. - """ - return cls(enabled=False) - def _base_to_dict(self) -> Dict[str, Any]: """ Render the base config fields as a dictionary object. @@ -180,7 +172,7 @@ class AIConfig: enabled: bool model: Optional[ModelConfig] = None provider: Optional[ProviderConfig] = None - tracker: Optional[Any] = None + create_tracker: Callable[[], Any] = lambda: None def _base_to_dict(self) -> Dict[str, Any]: """ diff --git a/packages/sdk/server-ai/src/ldai/tracker.py b/packages/sdk/server-ai/src/ldai/tracker.py index d36b070a..fedee35e 100644 --- a/packages/sdk/server-ai/src/ldai/tracker.py +++ b/packages/sdk/server-ai/src/ldai/tracker.py @@ -1,9 +1,13 @@ +import base64 +import json import time from dataclasses import dataclass from enum import Enum from typing import Any, Callable, Dict, Iterable, List, Optional -from ldclient import Context, LDClient +from ldclient import Context, LDClient, Result + +from ldai import log class FeedbackKind(Enum): @@ -71,24 +75,26 @@ class LDAIConfigTracker: def __init__( self, ld_client: LDClient, - variation_key: str, + run_id: str, config_key: str, + variation_key: str, version: int, + context: Context, model_name: str, provider_name: str, - context: Context, graph_key: Optional[str] = None, ): """ Initialize an AI Config tracker. :param ld_client: LaunchDarkly client instance. - :param variation_key: Variation key for tracking. + :param run_id: Unique identifier for this execution. :param config_key: Configuration key for tracking. + :param variation_key: Variation key for tracking. :param version: Version of the variation. + :param context: Context for evaluation. :param model_name: Name of the model used. :param provider_name: Name of the provider used. - :param context: Context for evaluation. :param graph_key: When set, include ``graphKey`` in all event payloads (e.g. config-level metrics inside a graph). """ @@ -101,6 +107,72 @@ def __init__( self._context = context self._graph_key = graph_key self._summary = LDAIMetricSummary() + self._run_id = run_id + + @property + def resumption_token(self) -> str: + """ + A URL-safe Base64-encoded JSON string that can be used to reconstruct + a tracker in a different process (e.g. for deferred feedback). + + The token contains ``runId``, ``configKey``, ``version``, and + optionally ``variationKey`` and ``graphKey`` (omitted when empty). + ``modelName`` and ``providerName`` are **not** included. + """ + data: dict = { + "runId": self._run_id, + "configKey": self._config_key, + } + if self._variation_key: + data["variationKey"] = self._variation_key + data["version"] = self._version + if self._graph_key: + data["graphKey"] = self._graph_key + payload = json.dumps(data) + return base64.urlsafe_b64encode(payload.encode("utf-8")).rstrip(b"=").decode("utf-8") + + @classmethod + def from_resumption_token(cls, token: str, ld_client: LDClient, context: Context) -> Result: + """ + Reconstruct a tracker from a resumption token. + + This is used for cross-process scenarios such as deferred feedback, + where a different service needs to associate tracking events with the + original execution's ``runId``. + + :param token: A URL-safe Base64-encoded resumption token obtained from + :attr:`resumption_token`. + :param ld_client: LaunchDarkly client instance. + :param context: The context to use for track events. + :return: A :class:`Result` whose ``value`` is a new + :class:`LDAIConfigTracker` bound to the original ``runId`` from the + token on success, or whose ``error`` describes the problem on failure. + """ + try: + padded = token + "=" * (-len(token) % 4) + payload = json.loads( + base64.urlsafe_b64decode(padded.encode("utf-8")).decode("utf-8") + ) + except (json.JSONDecodeError, Exception) as e: + return Result.fail(f"Invalid resumption token: {e}", e) + + for field in ("runId", "configKey", "version"): + if field not in payload: + return Result.fail( + f"Invalid resumption token: missing required field '{field}'" + ) + + return Result.success(cls( + ld_client=ld_client, + run_id=payload["runId"], + config_key=payload["configKey"], + variation_key=payload.get("variationKey") or "", + version=payload["version"], + context=context, + model_name="", + provider_name="", + graph_key=payload.get("graphKey"), + )) def __get_track_data(self) -> dict: """ @@ -109,12 +181,14 @@ def __get_track_data(self) -> dict: :return: Dictionary containing variation and config keys. """ data = { - "variationKey": self._variation_key, + "runId": self._run_id, "configKey": self._config_key, "version": self._version, "modelName": self._model_name, "providerName": self._provider_name, } + if self._variation_key: + data["variationKey"] = self._variation_key if self._graph_key is not None: data['graphKey'] = self._graph_key return data @@ -125,6 +199,9 @@ def track_duration(self, duration: int) -> None: :param duration: Duration in milliseconds. """ + if self._summary.duration is not None: + log.warning("Duration has already been tracked for this execution. %s", self.__get_track_data()) + return self._summary._duration = duration self._ld_client.track( "$ld:ai:duration:total", self._context, self.__get_track_data(), duration @@ -136,6 +213,12 @@ def track_time_to_first_token(self, time_to_first_token: int) -> None: :param time_to_first_token: Time to first token in milliseconds. """ + if self._summary.time_to_first_token is not None: + log.warning( + "Time to first token has already been tracked for this execution. %s", + self.__get_track_data(), + ) + return self._summary._time_to_first_token = time_to_first_token self._ld_client.track( "$ld:ai:tokens:ttf", @@ -261,6 +344,9 @@ def track_feedback(self, feedback: Dict[str, FeedbackKind]) -> None: :param feedback: Dictionary containing feedback kind. """ + if self._summary.feedback is not None: + log.warning("Feedback has already been tracked for this execution. %s", self.__get_track_data()) + return self._summary._feedback = feedback if feedback["kind"] == FeedbackKind.Positive: self._ld_client.track( @@ -281,6 +367,9 @@ def track_success(self) -> None: """ Track a successful AI generation. """ + if self._summary.success is not None: + log.warning("Success has already been tracked for this execution. %s", self.__get_track_data()) + return self._summary._success = True self._ld_client.track( "$ld:ai:generation:success", self._context, self.__get_track_data(), 1 @@ -290,6 +379,9 @@ def track_error(self) -> None: """ Track an unsuccessful AI generation attempt. """ + if self._summary.success is not None: + log.warning("Success has already been tracked for this execution. %s", self.__get_track_data()) + return self._summary._success = False self._ld_client.track( "$ld:ai:generation:error", self._context, self.__get_track_data(), 1 @@ -356,6 +448,9 @@ def track_tokens(self, tokens: TokenUsage) -> None: :param tokens: Token usage data from either custom, OpenAI, or Bedrock sources. """ + if self._summary.usage is not None: + log.warning("Tokens have already been tracked for this execution. %s", self.__get_track_data()) + return self._summary._usage = tokens td = self.__get_track_data() if tokens.total > 0: diff --git a/packages/sdk/server-ai/tests/test_agent_graph.py b/packages/sdk/server-ai/tests/test_agent_graph.py index 20d2308b..672f68a3 100644 --- a/packages/sdk/server-ai/tests/test_agent_graph.py +++ b/packages/sdk/server-ai/tests/test_agent_graph.py @@ -396,8 +396,9 @@ def test_agent_graph_node_trackers_have_graph_key(ldai_client: LDAIClient): graph.get_node("multi-context-agent"), graph.get_node("minimal-agent")]: config = node.get_config() - assert config.tracker is not None - assert config.tracker._graph_key == "test-agent-graph" + assert callable(config.create_tracker) + tracker = config.create_tracker() + assert tracker._graph_key == "test-agent-graph" def test_agent_graph_handoff(ldai_client: LDAIClient): diff --git a/packages/sdk/server-ai/tests/test_agents.py b/packages/sdk/server-ai/tests/test_agents.py index 4d6aa3f9..8a99d08b 100644 --- a/packages/sdk/server-ai/tests/test_agents.py +++ b/packages/sdk/server-ai/tests/test_agents.py @@ -140,7 +140,7 @@ def test_single_agent_method(ldai_client: LDAIClient): assert agent.provider is not None assert agent.provider.name == 'openai' assert agent.instructions == 'You are a research assistant specializing in quantum computing. Your expertise level should match advanced.' - assert agent.tracker is not None + assert callable(agent.create_tracker) def test_single_agent_with_defaults(ldai_client: LDAIClient): @@ -164,7 +164,7 @@ def test_single_agent_with_defaults(ldai_client: LDAIClient): assert agent.model is not None and agent.model.get_parameter('temp') == 0.8 assert agent.provider is not None and agent.provider.name == 'default-provider' assert agent.instructions == "You are a default assistant for general assistance." - assert agent.tracker is not None + assert callable(agent.create_tracker) def test_agents_method_with_configs(ldai_client: LDAIClient): @@ -275,7 +275,7 @@ def test_disabled_agent_single_method(ldai_client: LDAIClient): agent = ldai_client.agent(config, context) assert agent.enabled is False - assert agent.tracker is not None + assert callable(agent.create_tracker) def test_disabled_agent_multiple_method(ldai_client: LDAIClient): @@ -313,7 +313,7 @@ def test_agent_with_missing_metadata(ldai_client: LDAIClient): assert agent.enabled is True # From flag assert agent.instructions == 'Minimal agent configuration.' assert agent.model == config.default.model # Falls back to default - assert agent.tracker is not None + assert callable(agent.create_tracker) def test_agent_config_dataclass(): diff --git a/packages/sdk/server-ai/tests/test_judge.py b/packages/sdk/server-ai/tests/test_judge.py index b4922d61..d100550a 100644 --- a/packages/sdk/server-ai/tests/test_judge.py +++ b/packages/sdk/server-ai/tests/test_judge.py @@ -54,58 +54,67 @@ def context() -> Context: @pytest.fixture def tracker(client: LDClient, context: Context) -> LDAIConfigTracker: return LDAIConfigTracker( - client, 'judge-v1', 'judge-config', 1, 'gpt-4', 'openai', context + ld_client=client, run_id='test-run-id', config_key='judge-config', + variation_key='judge-v1', version=1, model_name='gpt-4', + provider_name='openai', context=context, ) +_SENTINEL = object() + + +def _make_judge_config( + key='judge-config', + enabled=True, + evaluation_metric_key='$ld:ai:judge:relevance', + messages=_SENTINEL, + model=None, + provider=None, + tracker=None, +): + """Create a judge config with create_tracker wired up.""" + if messages is _SENTINEL: + messages = [LDMessage(role='system', content='You are a judge.')] + kwargs = dict( + key=key, + enabled=enabled, + evaluation_metric_key=evaluation_metric_key, + messages=messages, + model=model or ModelConfig('gpt-4'), + provider=provider or ProviderConfig('openai'), + ) + if tracker is not None: + kwargs['create_tracker'] = lambda: tracker + return AIJudgeConfig(**kwargs) + + @pytest.fixture -def judge_config_with_key() -> AIJudgeConfig: +def judge_config_with_key(tracker) -> AIJudgeConfig: """Create a judge config with evaluation_metric_key.""" - return AIJudgeConfig( - key='judge-config', - enabled=True, - evaluation_metric_key='$ld:ai:judge:relevance', - messages=[LDMessage(role='system', content='You are a judge.')], - model=ModelConfig('gpt-4'), - provider=ProviderConfig('openai'), - ) + return _make_judge_config(tracker=tracker) @pytest.fixture -def judge_config_without_key() -> AIJudgeConfig: +def judge_config_without_key(tracker) -> AIJudgeConfig: """Create a judge config without evaluation_metric_key.""" - return AIJudgeConfig( - key='judge-config', - enabled=True, - evaluation_metric_key=None, - messages=[LDMessage(role='system', content='You are a judge.')], - model=ModelConfig('gpt-4'), - provider=ProviderConfig('openai'), - ) + return _make_judge_config(evaluation_metric_key=None, tracker=tracker) @pytest.fixture -def judge_config_without_messages() -> AIJudgeConfig: +def judge_config_without_messages(tracker) -> AIJudgeConfig: """Create a judge config without messages.""" - return AIJudgeConfig( - key='judge-config', - enabled=True, - evaluation_metric_key='$ld:ai:judge:relevance', - messages=None, - model=ModelConfig('gpt-4'), - provider=ProviderConfig('openai'), - ) + return _make_judge_config(messages=None, tracker=tracker) class TestJudgeInitialization: """Tests for Judge initialization.""" def test_judge_initializes_with_evaluation_metric_key( - self, judge_config_with_key: AIJudgeConfig, tracker: LDAIConfigTracker, mock_runner + self, judge_config_with_key: AIJudgeConfig, mock_runner ): """Judge should initialize successfully with evaluation_metric_key.""" - judge = Judge(judge_config_with_key, tracker, mock_runner) - + judge = Judge(judge_config_with_key, mock_runner) + assert judge._ai_config == judge_config_with_key assert judge._evaluation_response_structure is not None assert judge._evaluation_response_structure['title'] == 'EvaluationResponse' @@ -119,10 +128,10 @@ class TestJudgeEvaluate: @pytest.mark.asyncio async def test_evaluate_returns_failure_when_evaluation_metric_key_missing( - self, judge_config_without_key: AIJudgeConfig, tracker: LDAIConfigTracker, mock_runner + self, judge_config_without_key: AIJudgeConfig, mock_runner ): """Evaluate should return a failed JudgeResult when evaluation_metric_key is missing.""" - judge = Judge(judge_config_without_key, tracker, mock_runner) + judge = Judge(judge_config_without_key, mock_runner) result = await judge.evaluate("input text", "output text") @@ -133,10 +142,10 @@ async def test_evaluate_returns_failure_when_evaluation_metric_key_missing( @pytest.mark.asyncio async def test_evaluate_returns_failure_when_messages_missing( - self, judge_config_without_messages: AIJudgeConfig, tracker: LDAIConfigTracker, mock_runner + self, judge_config_without_messages: AIJudgeConfig, mock_runner ): """Evaluate should return a failed JudgeResult when messages are missing.""" - judge = Judge(judge_config_without_messages, tracker, mock_runner) + judge = Judge(judge_config_without_messages, mock_runner) result = await judge.evaluate("input text", "output text") @@ -158,14 +167,14 @@ async def test_evaluate_success_with_valid_response( raw_response='{"score": 0.85, "reasoning": "..."}', metrics=LDAIMetrics(success=True) ) - + mock_runner.invoke_structured_model.return_value = mock_response tracker.track_metrics_of_async = AsyncMock(return_value=mock_response) - - judge = Judge(judge_config_with_key, tracker, mock_runner) - + + judge = Judge(judge_config_with_key, mock_runner) + result = await judge.evaluate("What is AI?", "AI is artificial intelligence.") - + assert isinstance(result, JudgeResult) assert result.success is True assert result.sampled is True @@ -190,7 +199,7 @@ async def test_evaluate_success_with_evaluation_response_shape( mock_runner.invoke_structured_model.return_value = mock_response tracker.track_metrics_of_async = AsyncMock(return_value=mock_response) - judge = Judge(judge_config_with_key, tracker, mock_runner) + judge = Judge(judge_config_with_key, mock_runner) result = await judge.evaluate("What is feature flagging?", "Feature flagging is...") assert isinstance(result, JudgeResult) @@ -211,14 +220,14 @@ async def test_evaluate_handles_missing_evaluation_in_response( raw_response='{}', metrics=LDAIMetrics(success=True) ) - + mock_runner.invoke_structured_model.return_value = mock_response tracker.track_metrics_of_async = AsyncMock(return_value=mock_response) - - judge = Judge(judge_config_with_key, tracker, mock_runner) - + + judge = Judge(judge_config_with_key, mock_runner) + result = await judge.evaluate("input", "output") - + assert isinstance(result, JudgeResult) assert result.success is False assert result.score is None @@ -236,14 +245,14 @@ async def test_evaluate_handles_invalid_score( raw_response='{"score": 1.5, "reasoning": "..."}', metrics=LDAIMetrics(success=True) ) - + mock_runner.invoke_structured_model.return_value = mock_response tracker.track_metrics_of_async = AsyncMock(return_value=mock_response) - - judge = Judge(judge_config_with_key, tracker, mock_runner) - + + judge = Judge(judge_config_with_key, mock_runner) + result = await judge.evaluate("input", "output") - + assert isinstance(result, JudgeResult) assert result.success is False assert result.score is None @@ -258,14 +267,14 @@ async def test_evaluate_handles_missing_reasoning( raw_response='{"score": 0.8}', metrics=LDAIMetrics(success=True) ) - + mock_runner.invoke_structured_model.return_value = mock_response tracker.track_metrics_of_async = AsyncMock(return_value=mock_response) - - judge = Judge(judge_config_with_key, tracker, mock_runner) - + + judge = Judge(judge_config_with_key, mock_runner) + result = await judge.evaluate("input", "output") - + assert isinstance(result, JudgeResult) assert result.success is False assert result.score is None @@ -277,21 +286,21 @@ async def test_evaluate_handles_exception( """Evaluate should handle exceptions gracefully.""" mock_runner.invoke_structured_model.side_effect = Exception("Provider error") tracker.track_metrics_of_async = AsyncMock(side_effect=Exception("Provider error")) - - judge = Judge(judge_config_with_key, tracker, mock_runner) - + + judge = Judge(judge_config_with_key, mock_runner) + result = await judge.evaluate("input", "output") - + assert isinstance(result, JudgeResult) assert result.success is False assert result.error_message is not None @pytest.mark.asyncio async def test_evaluate_respects_sampling_rate( - self, judge_config_with_key: AIJudgeConfig, tracker: LDAIConfigTracker, mock_runner + self, judge_config_with_key: AIJudgeConfig, mock_runner ): """Evaluate should return sampled=False when skipped due to sampling rate.""" - judge = Judge(judge_config_with_key, tracker, mock_runner) + judge = Judge(judge_config_with_key, mock_runner) result = await judge.evaluate("input", "output", sampling_rate=0.0) @@ -310,18 +319,18 @@ async def test_evaluate_messages_calls_evaluate( ): """evaluate_messages should call evaluate with constructed input/output.""" from ldai.providers.types import ModelResponse - + mock_response = StructuredResponse( data={'score': 0.9, 'reasoning': 'Very relevant'}, raw_response='{"score": 0.9, "reasoning": "..."}', metrics=LDAIMetrics(success=True) ) - + mock_runner.invoke_structured_model.return_value = mock_response tracker.track_metrics_of_async = AsyncMock(return_value=mock_response) - - judge = Judge(judge_config_with_key, tracker, mock_runner) - + + judge = Judge(judge_config_with_key, mock_runner) + messages = [ LDMessage(role='user', content='Question 1'), LDMessage(role='assistant', content='Answer 1'), @@ -330,9 +339,9 @@ async def test_evaluate_messages_calls_evaluate( message=LDMessage(role='assistant', content='Answer 2'), metrics=LDAIMetrics(success=True) ) - + result = await judge.evaluate_messages(messages, chat_response) - + assert result is not None assert result.success is True assert tracker.track_metrics_of_async.called @@ -366,9 +375,9 @@ def test_to_dict_includes_evaluation_metric_key(self): evaluation_metric_key='$ld:ai:judge:relevance', messages=[LDMessage(role='system', content='You are a judge.')], ) - + result = config.to_dict() - + assert result['evaluationMetricKey'] == '$ld:ai:judge:relevance' assert 'evaluationMetricKeys' not in result @@ -380,9 +389,9 @@ def test_to_dict_handles_none_evaluation_metric_key(self): evaluation_metric_key=None, messages=[LDMessage(role='system', content='You are a judge.')], ) - + result = config.to_dict() - + assert result['evaluationMetricKey'] is None def test_judge_config_default_to_dict(self): @@ -392,235 +401,7 @@ def test_judge_config_default_to_dict(self): evaluation_metric_key='$ld:ai:judge:relevance', messages=[LDMessage(role='system', content='You are a judge.')], ) - - result = config.to_dict() - - assert result['evaluationMetricKey'] == '$ld:ai:judge:relevance' - assert 'evaluationMetricKeys' not in result - - -class TestClientJudgeConfig: - """Tests for LDAIClient.judge_config() method.""" - - def test_judge_config_extracts_evaluation_metric_key( - self, client: LDClient, context: Context - ): - """judge_config should extract evaluationMetricKey from variation.""" - from ldai import LDAIClient - - ldai_client = LDAIClient(client) - - default = AIJudgeConfigDefault( - enabled=True, - evaluation_metric_key='$ld:ai:judge:relevance', - messages=[LDMessage(role='system', content='You are a judge.')], - model=ModelConfig('gpt-4'), - provider=ProviderConfig('openai'), - ) - - config = ldai_client.judge_config('judge-config', context, default) - - assert config is not None - assert config.evaluation_metric_key == '$ld:ai:judge:relevance' - assert config.enabled is True - assert config.messages is not None - assert len(config.messages) > 0 - - def test_judge_config_uses_default_when_flag_does_not_exist( - self, client: LDClient, context: Context - ): - """judge_config should use default evaluation_metric_key when flag does not exist.""" - from ldai import LDAIClient - from ldclient import Config, LDClient - from ldclient.integrations.test_data import TestData - - td = TestData.data_source() - - test_client = LDClient(Config('sdk-key', update_processor_class=td, send_events=False)) - ldai_client = LDAIClient(test_client) - - default = AIJudgeConfigDefault( - enabled=True, - evaluation_metric_key='$ld:ai:judge:default', - messages=[LDMessage(role='system', content='You are a judge.')], - model=ModelConfig('gpt-4'), - provider=ProviderConfig('openai'), - ) - - config = ldai_client.judge_config('judge-no-key', context, default) - - assert config is not None - assert config.evaluation_metric_key == '$ld:ai:judge:default' - - def test_judge_config_uses_first_evaluation_metric_keys_from_variation( - self, context: Context - ): - """judge_config should use first value from evaluationMetricKeys when evaluationMetricKey is None.""" - from ldai import LDAIClient - from ldclient import Config, LDClient - from ldclient.integrations.test_data import TestData - - td = TestData.data_source() - td.update( - td.flag('judge-with-keys') - .variations( - { - 'model': {'name': 'gpt-4'}, - 'provider': {'name': 'openai'}, - 'messages': [{'role': 'system', 'content': 'You are a judge.'}], - 'evaluationMetricKeys': ['$ld:ai:judge:relevance', '$ld:ai:judge:quality'], - '_ldMeta': {'enabled': True, 'variationKey': 'judge-v1', 'version': 1}, - } - ) - .variation_for_all(0) - ) - - test_client = LDClient(Config('sdk-key', update_processor_class=td, send_events=False)) - ldai_client = LDAIClient(test_client) - - default = AIJudgeConfigDefault( - enabled=True, - evaluation_metric_key=None, - messages=[LDMessage(role='system', content='You are a judge.')], - model=ModelConfig('gpt-4'), - provider=ProviderConfig('openai'), - ) - - config = ldai_client.judge_config('judge-with-keys', context, default) - - assert config is not None - assert config.evaluation_metric_key == '$ld:ai:judge:relevance' - - def test_judge_config_uses_first_evaluation_metric_keys_from_default( - self, context: Context - ): - """judge_config should use first value from default evaluation_metric_keys when flag does not exist.""" - from ldai import LDAIClient - from ldclient import Config, LDClient - from ldclient.integrations.test_data import TestData - - td = TestData.data_source() - - test_client = LDClient(Config('sdk-key', update_processor_class=td, send_events=False)) - ldai_client = LDAIClient(test_client) - - default = AIJudgeConfigDefault( - enabled=True, - evaluation_metric_key=None, - evaluation_metric_keys=['$ld:ai:judge:default-key', '$ld:ai:judge:secondary'], - messages=[LDMessage(role='system', content='You are a judge.')], - model=ModelConfig('gpt-4'), - provider=ProviderConfig('openai'), - ) - - config = ldai_client.judge_config('judge-fallback-keys', context, default) - - assert config is not None - assert config.evaluation_metric_key == '$ld:ai:judge:default-key' - - def test_judge_config_prefers_evaluation_metric_key_over_keys( - self, context: Context - ): - """judge_config should prefer evaluationMetricKey over evaluationMetricKeys when both are present.""" - from ldai import LDAIClient - from ldclient import Config, LDClient - from ldclient.integrations.test_data import TestData - - td = TestData.data_source() - td.update( - td.flag('judge-both-present') - .variations( - { - 'model': {'name': 'gpt-4'}, - 'provider': {'name': 'openai'}, - 'messages': [{'role': 'system', 'content': 'You are a judge.'}], - 'evaluationMetricKey': '$ld:ai:judge:preferred', - 'evaluationMetricKeys': ['$ld:ai:judge:relevance', '$ld:ai:judge:quality'], - '_ldMeta': {'enabled': True, 'variationKey': 'judge-v1', 'version': 1}, - } - ) - .variation_for_all(0) - ) - - test_client = LDClient(Config('sdk-key', update_processor_class=td, send_events=False)) - ldai_client = LDAIClient(test_client) - - default = AIJudgeConfigDefault( - enabled=True, - evaluation_metric_key=None, - messages=[LDMessage(role='system', content='You are a judge.')], - model=ModelConfig('gpt-4'), - provider=ProviderConfig('openai'), - ) - - config = ldai_client.judge_config('judge-both-present', context, default) - - assert config is not None - assert config.evaluation_metric_key == '$ld:ai:judge:preferred' - - def test_judge_config_without_default_uses_disabled( - self, context: Context - ): - """judge_config should use a disabled config when no default is provided.""" - from ldai import LDAIClient - from ldclient import Config, LDClient - from ldclient.integrations.test_data import TestData - td = TestData.data_source() - test_client = LDClient(Config('sdk-key', update_processor_class=td, send_events=False)) - ldai_client = LDAIClient(test_client) - - config = ldai_client.judge_config('missing-judge', context) - - assert config is not None - assert config.enabled is False - - def test_judge_config_uses_same_variation_for_consistency( - self, context: Context - ): - """judge_config should use the same variation from __evaluate to avoid race conditions.""" - from ldai import LDAIClient - from ldclient import Config, LDClient - from ldclient.integrations.test_data import TestData - from unittest.mock import patch - - td = TestData.data_source() - td.update( - td.flag('judge-consistency-test') - .variations( - { - 'model': {'name': 'gpt-4'}, - 'provider': {'name': 'openai'}, - 'messages': [{'role': 'system', 'content': 'You are a judge.'}], - 'evaluationMetricKey': '$ld:ai:judge:from-flag', - '_ldMeta': {'enabled': True, 'variationKey': 'judge-v1', 'version': 1}, - } - ) - .variation_for_all(0) - ) - - test_client = LDClient(Config('sdk-key', update_processor_class=td, send_events=False)) - ldai_client = LDAIClient(test_client) - - default = AIJudgeConfigDefault( - enabled=True, - evaluation_metric_key='$ld:ai:judge:from-default', - messages=[LDMessage(role='system', content='You are a judge.')], - model=ModelConfig('gpt-4'), - provider=ProviderConfig('openai'), - ) - - variation_calls = [] - original_variation = test_client.variation - - def tracked_variation(key, context, default): - result = original_variation(key, context, default) - variation_calls.append((key, result.get('evaluationMetricKey'))) - return result - - with patch.object(test_client, 'variation', side_effect=tracked_variation): - config = ldai_client.judge_config('judge-consistency-test', context, default) + result = config.to_dict() - assert len(variation_calls) == 1, f"Expected 1 variation call, got {len(variation_calls)}" - assert config is not None - assert config.evaluation_metric_key == '$ld:ai:judge:from-flag' + assert result['evaluationMetricKey'] == '$ld:ai:judge:relevance' diff --git a/packages/sdk/server-ai/tests/test_managed_agent.py b/packages/sdk/server-ai/tests/test_managed_agent.py index 60cf7db4..144641fc 100644 --- a/packages/sdk/server-ai/tests/test_managed_agent.py +++ b/packages/sdk/server-ai/tests/test_managed_agent.py @@ -63,6 +63,7 @@ async def test_run_delegates_to_agent_runner(self): metrics=LDAIMetrics(success=True, usage=None), ) ) + mock_config.create_tracker = MagicMock(return_value=mock_tracker) mock_runner = MagicMock() mock_runner.run = AsyncMock( return_value=AgentResult( @@ -72,34 +73,51 @@ async def test_run_delegates_to_agent_runner(self): ) ) - agent = ManagedAgent(mock_config, mock_tracker, mock_runner) + agent = ManagedAgent(mock_config, mock_runner) result = await agent.run("Hello") assert result.output == "Test response" assert result.metrics.success is True + mock_config.create_tracker.assert_called_once() mock_tracker.track_metrics_of_async.assert_called_once() + @pytest.mark.asyncio + async def test_run_uses_create_tracker_for_fresh_tracker(self): + """Should use create_tracker() factory for a fresh tracker per invocation.""" + mock_config = MagicMock(spec=AIAgentConfig) + fresh_tracker = MagicMock() + fresh_tracker.track_metrics_of_async = AsyncMock( + return_value=AgentResult( + output="Fresh tracker response", + raw=None, + metrics=LDAIMetrics(success=True, usage=None), + ) + ) + mock_config.create_tracker = MagicMock(return_value=fresh_tracker) + + mock_runner = MagicMock() + + agent = ManagedAgent(mock_config, mock_runner) + result = await agent.run("Hello") + + assert result.output == "Fresh tracker response" + mock_config.create_tracker.assert_called_once() + fresh_tracker.track_metrics_of_async.assert_called_once() + def test_get_agent_runner_returns_runner(self): """Should return the underlying AgentRunner.""" mock_runner = MagicMock() - agent = ManagedAgent(MagicMock(), MagicMock(), mock_runner) + agent = ManagedAgent(MagicMock(), mock_runner) assert agent.get_agent_runner() is mock_runner def test_get_config_returns_config(self): """Should return the AI agent config.""" mock_config = MagicMock() - agent = ManagedAgent(mock_config, MagicMock(), MagicMock()) + agent = ManagedAgent(mock_config, MagicMock()) assert agent.get_config() is mock_config - def test_get_tracker_returns_tracker(self): - """Should return the tracker.""" - mock_tracker = MagicMock() - agent = ManagedAgent(MagicMock(), mock_tracker, MagicMock()) - - assert agent.get_tracker() is mock_tracker - class TestLDAIClientCreateAgent: """Tests for LDAIClient.create_agent.""" diff --git a/packages/sdk/server-ai/tests/test_managed_agent_graph.py b/packages/sdk/server-ai/tests/test_managed_agent_graph.py index 476ac026..35be2766 100644 --- a/packages/sdk/server-ai/tests/test_managed_agent_graph.py +++ b/packages/sdk/server-ai/tests/test_managed_agent_graph.py @@ -8,7 +8,6 @@ from ldai import LDAIClient, ManagedAgentGraph from ldai.providers.types import LDAIMetrics from ldai.providers import AgentGraphResult, AgentGraphRunner, ToolRegistry -from ldai.tracker import AIGraphTracker # --- Test double --- @@ -42,17 +41,6 @@ def test_managed_agent_graph_get_runner(): assert managed.get_agent_graph_runner() is runner -def test_managed_agent_graph_get_tracker_none_by_default(): - runner = StubAgentGraphRunner() - managed = ManagedAgentGraph(runner) - assert managed.get_tracker() is None - - -def test_managed_agent_graph_get_tracker_returns_tracker(): - runner = StubAgentGraphRunner() - tracker = MagicMock(spec=AIGraphTracker) - managed = ManagedAgentGraph(runner, tracker) - assert managed.get_tracker() is tracker # --- LDAIClient.create_agent_graph() integration tests --- diff --git a/packages/sdk/server-ai/tests/test_model_config.py b/packages/sdk/server-ai/tests/test_model_config.py index 636d14a4..9c36c5ed 100644 --- a/packages/sdk/server-ai/tests/test_model_config.py +++ b/packages/sdk/server-ai/tests/test_model_config.py @@ -4,7 +4,7 @@ from ldai import LDAIClient, LDMessage, ModelConfig from ldai.models import (AIAgentConfigDefault, AICompletionConfigDefault, - AIConfigDefault, AIJudgeConfigDefault) + AIJudgeConfigDefault) @pytest.fixture @@ -355,52 +355,148 @@ def test_sdk_info_tracked_on_init(): # ============================================================================ -# disabled() classmethod tests +# Optional default value tests # ============================================================================ -def test_ai_config_default_disabled_returns_disabled_instance(): - result = AIConfigDefault.disabled() - assert isinstance(result, AIConfigDefault) - assert result.enabled is False +def test_completion_config_without_default_uses_disabled(ldai_client: LDAIClient): + context = Context.create('user-key') + config = ldai_client.completion_config('missing-flag', context) -def test_completion_config_default_disabled_returns_correct_type(): - result = AICompletionConfigDefault.disabled() - assert isinstance(result, AICompletionConfigDefault) - assert result.enabled is False - assert result.messages is None - assert result.model is None + assert config.enabled is False -def test_agent_config_default_disabled_returns_correct_type(): - result = AIAgentConfigDefault.disabled() - assert isinstance(result, AIAgentConfigDefault) - assert result.enabled is False - assert result.instructions is None - assert result.model is None +# ============================================================================ +# create_tracker factory tests +# ============================================================================ +def test_enabled_config_has_create_tracker(ldai_client: LDAIClient): + context = Context.create('user-key') + default = AICompletionConfigDefault( + enabled=True, + model=ModelConfig('fakeModel'), + messages=[LDMessage(role='system', content='Hello!')], + ) -def test_judge_config_default_disabled_returns_correct_type(): - result = AIJudgeConfigDefault.disabled() - assert isinstance(result, AIJudgeConfigDefault) - assert result.enabled is False - assert result.messages is None - assert result.evaluation_metric_key is None + config = ldai_client.completion_config('model-config', context, default) + assert config.enabled is True + assert config.create_tracker is not None + assert callable(config.create_tracker) -def test_disabled_returns_new_instance_each_call(): - first = AICompletionConfigDefault.disabled() - second = AICompletionConfigDefault.disabled() - assert first is not second +def test_disabled_config_has_working_create_tracker(ldai_client: LDAIClient): + context = Context.create('user-key') + default = AICompletionConfigDefault(enabled=False, model=ModelConfig('fake-model'), messages=[]) -# ============================================================================ -# Optional default value tests -# ============================================================================ + config = ldai_client.completion_config('off-config', context, default) -def test_completion_config_without_default_uses_disabled(ldai_client: LDAIClient): + assert config.enabled is False + assert callable(config.create_tracker) + tracker = config.create_tracker() + assert tracker is not None + + +def test_create_tracker_returns_new_tracker_each_call(ldai_client: LDAIClient): context = Context.create('user-key') + default = AICompletionConfigDefault( + enabled=True, + model=ModelConfig('fakeModel'), + messages=[LDMessage(role='system', content='Hello!')], + ) - config = ldai_client.completion_config('missing-flag', context) + config = ldai_client.completion_config('model-config', context, default) - assert config.enabled is False + assert config.create_tracker is not None + tracker1 = config.create_tracker() + tracker2 = config.create_tracker() + + assert tracker1 is not tracker2 + + +def test_create_tracker_produces_fresh_run_id_each_call(ldai_client: LDAIClient): + context = Context.create('user-key') + default = AICompletionConfigDefault( + enabled=True, + model=ModelConfig('fakeModel'), + messages=[LDMessage(role='system', content='Hello!')], + ) + + config = ldai_client.completion_config('model-config', context, default) + + assert config.create_tracker is not None + tracker1 = config.create_tracker() + tracker2 = config.create_tracker() + + # Each tracker should have a unique runId + tracker1.track_success() + tracker2.track_success() + + +def test_create_tracker_preserves_config_metadata(): + from unittest.mock import Mock + + mock_client = Mock() + mock_client.variation.return_value = { + '_ldMeta': {'enabled': True, 'variationKey': 'var-abc', 'version': 7}, + 'model': {'name': 'gpt-4'}, + 'provider': {'name': 'openai'}, + 'messages': [] + } + + client = LDAIClient(mock_client) + context = Context.create('user-key') + default = AICompletionConfigDefault(enabled=False, model=ModelConfig('fake'), messages=[]) + + config = client.completion_config('my-config-key', context, default) + + assert config.create_tracker is not None + tracker = config.create_tracker() + tracker.track_success() + + # Find the track_success call (skip the sdk:info and usage calls) + success_calls = [ + c for c in mock_client.track.call_args_list + if c.args[0] == '$ld:ai:generation:success' + ] + assert len(success_calls) == 1 + track_data = success_calls[0].args[2] + assert track_data['configKey'] == 'my-config-key' + assert track_data['variationKey'] == 'var-abc' + assert track_data['version'] == 7 + assert track_data['modelName'] == 'gpt-4' + assert track_data['providerName'] == 'openai' + assert 'runId' in track_data + + +def test_create_tracker_each_call_has_different_run_id(): + from unittest.mock import Mock + + mock_client = Mock() + mock_client.variation.return_value = { + '_ldMeta': {'enabled': True, 'variationKey': 'v1', 'version': 1}, + 'model': {'name': 'test-model'}, + 'provider': {'name': 'test-provider'}, + 'messages': [] + } + + client = LDAIClient(mock_client) + context = Context.create('user-key') + + config = client.completion_config('key', context) + + assert config.create_tracker is not None + tracker1 = config.create_tracker() + tracker2 = config.create_tracker() + + tracker1.track_success() + tracker2.track_success() + + success_calls = [ + c for c in mock_client.track.call_args_list + if c.args[0] == '$ld:ai:generation:success' + ] + assert len(success_calls) == 2 + run_id_1 = success_calls[0].args[2]['runId'] + run_id_2 = success_calls[1].args[2]['runId'] + assert run_id_1 != run_id_2 diff --git a/packages/sdk/server-ai/tests/test_tracker.py b/packages/sdk/server-ai/tests/test_tracker.py index 32e018e9..2350e61d 100644 --- a/packages/sdk/server-ai/tests/test_tracker.py +++ b/packages/sdk/server-ai/tests/test_tracker.py @@ -1,5 +1,5 @@ from time import sleep -from unittest.mock import MagicMock, call +from unittest.mock import ANY, MagicMock, call import pytest from ldclient import Config, Context, LDClient @@ -43,7 +43,7 @@ def client(td: TestData) -> LDClient: def test_summary_starts_empty(client: LDClient): context = Context.create("user-key") - tracker = LDAIConfigTracker(client, "variation-key", "config-key", 1, "fakeModel", "fakeProvider", context) + tracker = LDAIConfigTracker(ld_client=client, run_id="test-run-id", config_key="config-key", variation_key="variation-key", version=1, model_name="fakeModel", provider_name="fakeProvider", context=context) assert tracker.get_summary().duration is None assert tracker.get_summary().feedback is None @@ -53,13 +53,13 @@ def test_summary_starts_empty(client: LDClient): def test_tracks_duration(client: LDClient): context = Context.create("user-key") - tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context) + tracker = LDAIConfigTracker(ld_client=client, run_id="test-run-id", config_key="config-key", variation_key="variation-key", version=3, model_name="fakeModel", provider_name="fakeProvider", context=context) tracker.track_duration(100) client.track.assert_called_with( # type: ignore "$ld:ai:duration:total", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, + {"runId": ANY, "variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 100, ) @@ -68,7 +68,7 @@ def test_tracks_duration(client: LDClient): def test_tracks_duration_of(client: LDClient): context = Context.create("user-key") - tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context) + tracker = LDAIConfigTracker(ld_client=client, run_id="test-run-id", config_key="config-key", variation_key="variation-key", version=3, model_name="fakeModel", provider_name="fakeProvider", context=context) tracker.track_duration_of(lambda: sleep(0.01)) calls = client.track.mock_calls # type: ignore @@ -76,25 +76,24 @@ def test_tracks_duration_of(client: LDClient): assert len(calls) == 1 assert calls[0].args[0] == "$ld:ai:duration:total" assert calls[0].args[1] == context - assert calls[0].args[2] == { - "variationKey": "variation-key", - "configKey": "config-key", - "version": 3, - "modelName": "fakeModel", - "providerName": "fakeProvider", - } + assert calls[0].args[2]["variationKey"] == "variation-key" + assert calls[0].args[2]["configKey"] == "config-key" + assert calls[0].args[2]["version"] == 3 + assert calls[0].args[2]["modelName"] == "fakeModel" + assert calls[0].args[2]["providerName"] == "fakeProvider" + assert "runId" in calls[0].args[2] assert calls[0].args[3] == pytest.approx(10, rel=10) def test_tracks_time_to_first_token(client: LDClient): context = Context.create("user-key") - tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context) + tracker = LDAIConfigTracker(ld_client=client, run_id="test-run-id", config_key="config-key", variation_key="variation-key", version=3, model_name="fakeModel", provider_name="fakeProvider", context=context) tracker.track_time_to_first_token(100) client.track.assert_called_with( # type: ignore "$ld:ai:tokens:ttf", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, + {"runId": ANY, "variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 100, ) @@ -103,7 +102,7 @@ def test_tracks_time_to_first_token(client: LDClient): def test_tracks_duration_of_with_exception(client: LDClient): context = Context.create("user-key") - tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context) + tracker = LDAIConfigTracker(ld_client=client, run_id="test-run-id", config_key="config-key", variation_key="variation-key", version=3, model_name="fakeModel", provider_name="fakeProvider", context=context) def sleep_and_throw(): sleep(0.01) @@ -120,19 +119,18 @@ def sleep_and_throw(): assert len(calls) == 1 assert calls[0].args[0] == "$ld:ai:duration:total" assert calls[0].args[1] == context - assert calls[0].args[2] == { - "variationKey": "variation-key", - "configKey": "config-key", - "version": 3, - "modelName": "fakeModel", - "providerName": "fakeProvider", - } + assert calls[0].args[2]["variationKey"] == "variation-key" + assert calls[0].args[2]["configKey"] == "config-key" + assert calls[0].args[2]["version"] == 3 + assert calls[0].args[2]["modelName"] == "fakeModel" + assert calls[0].args[2]["providerName"] == "fakeProvider" + assert "runId" in calls[0].args[2] assert calls[0].args[3] == pytest.approx(10, rel=10) def test_tracks_token_usage(client: LDClient): context = Context.create("user-key") - tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context) + tracker = LDAIConfigTracker(ld_client=client, run_id="test-run-id", config_key="config-key", variation_key="variation-key", version=3, model_name="fakeModel", provider_name="fakeProvider", context=context) tokens = TokenUsage(300, 200, 100) tracker.track_tokens(tokens) @@ -141,19 +139,19 @@ def test_tracks_token_usage(client: LDClient): call( "$ld:ai:tokens:total", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, + {"runId": ANY, "variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 300, ), call( "$ld:ai:tokens:input", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, + {"runId": ANY, "variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 200, ), call( "$ld:ai:tokens:output", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, + {"runId": ANY, "variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 100, ), ] @@ -165,7 +163,7 @@ def test_tracks_token_usage(client: LDClient): def test_tracks_bedrock_metrics(client: LDClient): context = Context.create("user-key") - tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context) + tracker = LDAIConfigTracker(ld_client=client, run_id="test-run-id", config_key="config-key", variation_key="variation-key", version=3, model_name="fakeModel", provider_name="fakeProvider", context=context) bedrock_result = { "ResponseMetadata": {"HTTPStatusCode": 200}, @@ -184,31 +182,31 @@ def test_tracks_bedrock_metrics(client: LDClient): call( "$ld:ai:generation:success", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, + {"runId": ANY, "variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 1, ), call( "$ld:ai:duration:total", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, + {"runId": ANY, "variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 50, ), call( "$ld:ai:tokens:total", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, + {"runId": ANY, "variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 330, ), call( "$ld:ai:tokens:input", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, + {"runId": ANY, "variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 220, ), call( "$ld:ai:tokens:output", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, + {"runId": ANY, "variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 110, ), ] @@ -222,7 +220,7 @@ def test_tracks_bedrock_metrics(client: LDClient): def test_tracks_bedrock_metrics_with_error(client: LDClient): context = Context.create("user-key") - tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context) + tracker = LDAIConfigTracker(ld_client=client, run_id="test-run-id", config_key="config-key", variation_key="variation-key", version=3, model_name="fakeModel", provider_name="fakeProvider", context=context) bedrock_result = { "ResponseMetadata": {"HTTPStatusCode": 500}, @@ -241,31 +239,31 @@ def test_tracks_bedrock_metrics_with_error(client: LDClient): call( "$ld:ai:generation:error", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, + {"runId": ANY, "variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 1, ), call( "$ld:ai:duration:total", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, + {"runId": ANY, "variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 50, ), call( "$ld:ai:tokens:total", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, + {"runId": ANY, "variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 330, ), call( "$ld:ai:tokens:input", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, + {"runId": ANY, "variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 220, ), call( "$ld:ai:tokens:output", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, + {"runId": ANY, "variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 110, ), ] @@ -279,7 +277,7 @@ def test_tracks_bedrock_metrics_with_error(client: LDClient): def test_tracks_openai_metrics(client: LDClient): context = Context.create("user-key") - tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context) + tracker = LDAIConfigTracker(ld_client=client, run_id="test-run-id", config_key="config-key", variation_key="variation-key", version=3, model_name="fakeModel", provider_name="fakeProvider", context=context) class Result: def __init__(self): @@ -302,25 +300,25 @@ def get_result(): call( "$ld:ai:generation:success", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, + {"runId": ANY, "variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 1, ), call( "$ld:ai:tokens:total", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, + {"runId": ANY, "variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 330, ), call( "$ld:ai:tokens:input", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, + {"runId": ANY, "variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 220, ), call( "$ld:ai:tokens:output", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, + {"runId": ANY, "variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 110, ), ] @@ -332,7 +330,7 @@ def get_result(): def test_tracks_openai_metrics_with_exception(client: LDClient): context = Context.create("user-key") - tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context) + tracker = LDAIConfigTracker(ld_client=client, run_id="test-run-id", config_key="config-key", variation_key="variation-key", version=3, model_name="fakeModel", provider_name="fakeProvider", context=context) def raise_exception(): raise ValueError("Something went wrong") @@ -347,7 +345,7 @@ def raise_exception(): call( "$ld:ai:generation:error", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, + {"runId": ANY, "variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 1, ), ] @@ -366,14 +364,14 @@ def raise_exception(): ) def test_tracks_feedback(client: LDClient, kind: FeedbackKind, label: str): context = Context.create("user-key") - tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context) + tracker = LDAIConfigTracker(ld_client=client, run_id="test-run-id", config_key="config-key", variation_key="variation-key", version=3, model_name="fakeModel", provider_name="fakeProvider", context=context) tracker.track_feedback({"kind": kind}) client.track.assert_called_with( # type: ignore f"$ld:ai:feedback:user:{label}", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, + {"runId": ANY, "variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 1, ) assert tracker.get_summary().feedback == {"kind": kind} @@ -381,14 +379,14 @@ def test_tracks_feedback(client: LDClient, kind: FeedbackKind, label: str): def test_tracks_success(client: LDClient): context = Context.create("user-key") - tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context) + tracker = LDAIConfigTracker(ld_client=client, run_id="test-run-id", config_key="config-key", variation_key="variation-key", version=3, model_name="fakeModel", provider_name="fakeProvider", context=context) tracker.track_success() calls = [ call( "$ld:ai:generation:success", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, + {"runId": ANY, "variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 1, ), ] @@ -400,14 +398,14 @@ def test_tracks_success(client: LDClient): def test_tracks_error(client: LDClient): context = Context.create("user-key") - tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context) + tracker = LDAIConfigTracker(ld_client=client, run_id="test-run-id", config_key="config-key", variation_key="variation-key", version=3, model_name="fakeModel", provider_name="fakeProvider", context=context) tracker.track_error() calls = [ call( "$ld:ai:generation:error", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, + {"runId": ANY, "variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 1, ), ] @@ -417,34 +415,26 @@ def test_tracks_error(client: LDClient): assert tracker.get_summary().success is False -def test_error_overwrites_success(client: LDClient): +def test_error_after_success_is_blocked(client: LDClient): context = Context.create("user-key") - tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context) + tracker = LDAIConfigTracker(ld_client=client, run_id="test-run-id", config_key="config-key", variation_key="variation-key", version=3, model_name="fakeModel", provider_name="fakeProvider", context=context) tracker.track_success() tracker.track_error() - calls = [ - call( - "$ld:ai:generation:success", - context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, - 1, - ), - call( - "$ld:ai:generation:error", - context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, - 1, - ), - ] - - client.track.assert_has_calls(calls) # type: ignore + # Only the first call (success) should go through; error is blocked by at-most-once guard + client.track.assert_called_once_with( # type: ignore + "$ld:ai:generation:success", + context, + {"runId": ANY, "variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, + 1, + ) - assert tracker.get_summary().success is False + assert tracker.get_summary().success is True def _base_td() -> dict: return { + "runId": ANY, "variationKey": "variation-key", "configKey": "config-key", "version": 3, @@ -456,7 +446,7 @@ def _base_td() -> dict: def test_config_tracker_includes_graph_key_when_provided(client: LDClient): context = Context.create("user-key") tracker = LDAIConfigTracker( - client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context, graph_key="my-graph" + ld_client=client, run_id="test-run-id", config_key="config-key", variation_key="variation-key", version=3, model_name="fakeModel", provider_name="fakeProvider", context=context, graph_key="my-graph" ) expected = {**_base_td(), "graphKey": "my-graph"} tracker.track_success() @@ -466,7 +456,7 @@ def test_config_tracker_includes_graph_key_when_provided(client: LDClient): def test_config_tracker_track_tokens_with_graph_key(client: LDClient): context = Context.create("user-key") tracker = LDAIConfigTracker( - client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context, graph_key="g1" + ld_client=client, run_id="test-run-id", config_key="config-key", variation_key="variation-key", version=3, model_name="fakeModel", provider_name="fakeProvider", context=context, graph_key="g1" ) tokens = TokenUsage(10, 4, 6) expected = {**_base_td(), "graphKey": "g1"} @@ -477,7 +467,7 @@ def test_config_tracker_track_tokens_with_graph_key(client: LDClient): def test_config_tracker_track_feedback_with_graph_key(client: LDClient): context = Context.create("user-key") tracker = LDAIConfigTracker( - client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context, graph_key="gx" + ld_client=client, run_id="test-run-id", config_key="config-key", variation_key="variation-key", version=3, model_name="fakeModel", provider_name="fakeProvider", context=context, graph_key="gx" ) expected = {**_base_td(), "graphKey": "gx"} tracker.track_feedback({"kind": FeedbackKind.Positive}) @@ -489,7 +479,8 @@ def test_config_tracker_track_feedback_with_graph_key(client: LDClient): def test_config_tracker_track_tool_call(client: LDClient): context = Context.create("user-key") tracker = LDAIConfigTracker( - client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context + ld_client=client, run_id="test-run-id", config_key="config-key", variation_key="variation-key", + version=3, model_name="fakeModel", provider_name="fakeProvider", context=context, ) expected = {**_base_td(), "toolKey": "search"} tracker.track_tool_call("search") @@ -499,7 +490,7 @@ def test_config_tracker_track_tool_call(client: LDClient): def test_config_tracker_track_tool_call_with_graph_key(client: LDClient): context = Context.create("user-key") tracker = LDAIConfigTracker( - client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context, graph_key="my-graph" + ld_client=client, run_id="test-run-id", config_key="config-key", variation_key="variation-key", version=3, model_name="fakeModel", provider_name="fakeProvider", context=context, graph_key="my-graph" ) expected = {**_base_td(), "graphKey": "my-graph", "toolKey": "calc"} tracker.track_tool_call("calc") @@ -509,7 +500,7 @@ def test_config_tracker_track_tool_call_with_graph_key(client: LDClient): def test_config_tracker_track_tool_calls(client: LDClient): context = Context.create("user-key") tracker = LDAIConfigTracker( - client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context, graph_key="g" + ld_client=client, run_id="test-run-id", config_key="config-key", variation_key="variation-key", version=3, model_name="fakeModel", provider_name="fakeProvider", context=context, graph_key="g" ) tracker.track_tool_calls(["a", "b"]) assert client.track.call_count == 2 # type: ignore @@ -530,7 +521,8 @@ def test_config_tracker_track_tool_calls(client: LDClient): def test_config_tracker_track_metrics_of(client: LDClient): context = Context.create("user-key") tracker = LDAIConfigTracker( - client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context + ld_client=client, run_id="test-run-id", config_key="config-key", variation_key="variation-key", + version=3, model_name="fakeModel", provider_name="fakeProvider", context=context, ) def fn(): @@ -550,7 +542,7 @@ def extract(r): async def test_config_tracker_track_metrics_of_async_passes_graph_key(client: LDClient): context = Context.create("user-key") tracker = LDAIConfigTracker( - client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context, graph_key="gg" + ld_client=client, run_id="test-run-id", config_key="config-key", variation_key="variation-key", version=3, model_name="fakeModel", provider_name="fakeProvider", context=context, graph_key="gg" ) async def fn(): @@ -591,3 +583,327 @@ def test_ai_graph_tracker_track_total_tokens_tracks_when_positive(client: LDClie {"variationKey": "variation-key", "graphKey": "graph-key", "version": 2}, 42, ) + + +# --- At-most-once guard tests --- + + +def test_duplicate_track_duration_is_ignored(client: LDClient): + context = Context.create("user-key") + tracker = LDAIConfigTracker(ld_client=client, run_id="test-run-id", config_key="config-key", variation_key="variation-key", version=3, model_name="fakeModel", provider_name="fakeProvider", context=context) + tracker.track_duration(100) + tracker.track_duration(200) + + assert client.track.call_count == 1 # type: ignore + assert tracker.get_summary().duration == 100 + + +def test_duplicate_track_time_to_first_token_is_ignored(client: LDClient): + context = Context.create("user-key") + tracker = LDAIConfigTracker(ld_client=client, run_id="test-run-id", config_key="config-key", variation_key="variation-key", version=3, model_name="fakeModel", provider_name="fakeProvider", context=context) + tracker.track_time_to_first_token(50) + tracker.track_time_to_first_token(75) + + assert client.track.call_count == 1 # type: ignore + assert tracker.get_summary().time_to_first_token == 50 + + +def test_duplicate_track_tokens_is_ignored(client: LDClient): + context = Context.create("user-key") + tracker = LDAIConfigTracker(ld_client=client, run_id="test-run-id", config_key="config-key", variation_key="variation-key", version=3, model_name="fakeModel", provider_name="fakeProvider", context=context) + tokens1 = TokenUsage(300, 200, 100) + tokens2 = TokenUsage(600, 400, 200) + tracker.track_tokens(tokens1) + tracker.track_tokens(tokens2) + + # 3 track calls for total/input/output from the first call only + assert client.track.call_count == 3 # type: ignore + assert tracker.get_summary().usage == tokens1 + + +def test_duplicate_track_success_is_ignored(client: LDClient): + context = Context.create("user-key") + tracker = LDAIConfigTracker(ld_client=client, run_id="test-run-id", config_key="config-key", variation_key="variation-key", version=3, model_name="fakeModel", provider_name="fakeProvider", context=context) + tracker.track_success() + tracker.track_success() + + assert client.track.call_count == 1 # type: ignore + assert tracker.get_summary().success is True + + +def test_duplicate_track_error_is_ignored(client: LDClient): + context = Context.create("user-key") + tracker = LDAIConfigTracker(ld_client=client, run_id="test-run-id", config_key="config-key", variation_key="variation-key", version=3, model_name="fakeModel", provider_name="fakeProvider", context=context) + tracker.track_error() + tracker.track_error() + + assert client.track.call_count == 1 # type: ignore + assert tracker.get_summary().success is False + + +def test_duplicate_track_feedback_is_ignored(client: LDClient): + context = Context.create("user-key") + tracker = LDAIConfigTracker(ld_client=client, run_id="test-run-id", config_key="config-key", variation_key="variation-key", version=3, model_name="fakeModel", provider_name="fakeProvider", context=context) + tracker.track_feedback({"kind": FeedbackKind.Positive}) + tracker.track_feedback({"kind": FeedbackKind.Negative}) + + assert client.track.call_count == 1 # type: ignore + assert tracker.get_summary().feedback == {"kind": FeedbackKind.Positive} + + +def test_track_data_includes_run_id(client: LDClient): + context = Context.create("user-key") + tracker = LDAIConfigTracker(ld_client=client, run_id="my-run-id", config_key="config-key", variation_key="variation-key", version=3, model_name="fakeModel", provider_name="fakeProvider", context=context) + tracker.track_success() + + track_data = client.track.call_args[0][2] # type: ignore + assert track_data["runId"] == "my-run-id" + + +def test_run_id_is_consistent_across_track_calls(client: LDClient): + context = Context.create("user-key") + tracker = LDAIConfigTracker(ld_client=client, run_id="test-run-id", config_key="config-key", variation_key="variation-key", version=3, model_name="fakeModel", provider_name="fakeProvider", context=context) + tracker.track_success() + tracker.track_duration(100) + + calls = client.track.mock_calls # type: ignore + run_id_1 = calls[0].args[2]["runId"] + run_id_2 = calls[1].args[2]["runId"] + assert run_id_1 == run_id_2 + + +# --- Resumption token tests --- + + +def test_resumption_token_round_trip(client: LDClient): + import base64 + import json + + context = Context.create("user-key") + tracker = LDAIConfigTracker(ld_client=client, run_id="test-run-id", config_key="cfg-key", variation_key="var-key", version=5, model_name="gpt-4", provider_name="openai", context=context) + + token = tracker.resumption_token + # Token has no padding — add it back before decoding + padded = token + "=" * (-len(token) % 4) + decoded = json.loads(base64.urlsafe_b64decode(padded.encode("utf-8")).decode("utf-8")) + + assert decoded["runId"] == tracker._run_id + assert decoded["configKey"] == "cfg-key" + assert decoded["variationKey"] == "var-key" + assert decoded["version"] == 5 + # modelName and providerName should NOT be in the token + assert "modelName" not in decoded + assert "providerName" not in decoded + + +def test_resumption_token_has_no_padding(client: LDClient): + context = Context.create("user-key") + tracker = LDAIConfigTracker(ld_client=client, run_id="test-run-id", config_key="cfg-key", variation_key="var-key", version=1, model_name="model", provider_name="provider", context=context) + + token = tracker.resumption_token + assert "=" not in token + + +def test_resumption_token_is_url_safe_base64(client: LDClient): + import base64 + + context = Context.create("user-key") + tracker = LDAIConfigTracker(ld_client=client, run_id="test-run-id", config_key="cfg-key", variation_key="var-key", version=1, model_name="model", provider_name="provider", context=context) + + token = tracker.resumption_token + # Should decode without error using urlsafe variant (with padding restored) + padded = token + "=" * (-len(token) % 4) + base64.urlsafe_b64decode(padded.encode("utf-8")) + + +def test_resumption_token_omits_variation_key_when_empty(client: LDClient): + import base64 + import json + + context = Context.create("user-key") + tracker = LDAIConfigTracker( + ld_client=client, run_id="test-run-id", config_key="cfg-key", + variation_key="", version=1, context=context, + model_name="model", provider_name="provider", + ) + + token = tracker.resumption_token + padded = token + "=" * (-len(token) % 4) + decoded = json.loads(base64.urlsafe_b64decode(padded.encode("utf-8")).decode("utf-8")) + + assert "variationKey" not in decoded + assert decoded["runId"] == "test-run-id" + assert decoded["configKey"] == "cfg-key" + assert decoded["version"] == 1 + + +def test_resumption_token_includes_graph_key_when_set(client: LDClient): + import base64 + import json + + context = Context.create("user-key") + tracker = LDAIConfigTracker( + ld_client=client, run_id="test-run-id", config_key="cfg-key", + variation_key="var-key", version=2, context=context, + model_name="model", provider_name="provider", graph_key="my-graph", + ) + + token = tracker.resumption_token + padded = token + "=" * (-len(token) % 4) + decoded = json.loads(base64.urlsafe_b64decode(padded.encode("utf-8")).decode("utf-8")) + + assert decoded["runId"] == "test-run-id" + assert decoded["configKey"] == "cfg-key" + assert decoded["variationKey"] == "var-key" + assert decoded["version"] == 2 + assert decoded["graphKey"] == "my-graph" + # Key order: runId, configKey, variationKey, version, graphKey + assert list(decoded.keys()) == ["runId", "configKey", "variationKey", "version", "graphKey"] + + +def test_resumption_token_omits_graph_key_when_not_set(client: LDClient): + import base64 + import json + + context = Context.create("user-key") + tracker = LDAIConfigTracker( + ld_client=client, run_id="test-run-id", config_key="cfg-key", + variation_key="var-key", version=1, context=context, + model_name="model", provider_name="provider", + ) + + token = tracker.resumption_token + padded = token + "=" * (-len(token) % 4) + decoded = json.loads(base64.urlsafe_b64decode(padded.encode("utf-8")).decode("utf-8")) + + assert "graphKey" not in decoded + + +def test_resumption_token_round_trip_with_graph_key(client: LDClient): + context = Context.create("user-key") + tracker = LDAIConfigTracker( + ld_client=client, run_id="test-run-id", config_key="cfg-key", + variation_key="var-key", version=3, context=context, + model_name="model", provider_name="provider", graph_key="my-graph", + ) + + token = tracker.resumption_token + result = LDAIConfigTracker.from_resumption_token(token, client, context) + assert result.is_success() + restored = result.value + + assert restored._run_id == "test-run-id" + assert restored._config_key == "cfg-key" + assert restored._variation_key == "var-key" + assert restored._version == 3 + assert restored._graph_key == "my-graph" + + +def test_tracker_with_explicit_run_id(client: LDClient): + context = Context.create("user-key") + tracker = LDAIConfigTracker( + ld_client=client, run_id="custom-run-id-123", config_key="cfg-key", + variation_key="var-key", version=1, model_name="model", + provider_name="provider", context=context, + ) + tracker.track_success() + + track_data = client.track.call_args[0][2] # type: ignore + assert track_data["runId"] == "custom-run-id-123" + + +def test_client_create_tracker_from_resumption_token(): + from unittest.mock import Mock + + from ldai.client import LDAIClient + + mock_client = Mock() + ai_client = LDAIClient(mock_client) + context = Context.create("feedback-user") + + # Create an original tracker and get its token + original = LDAIConfigTracker( + ld_client=mock_client, run_id="original-run-id-123", + config_key="my-config", variation_key="var-abc", version=7, + model_name="gpt-4", provider_name="openai", + context=Context.create("original-user"), + ) + token = original.resumption_token + + # Reconstruct from token + result = ai_client.create_tracker(token, context) + assert result.is_success() + restored = result.value + + # The restored tracker should use the same runId + restored.track_feedback({"kind": FeedbackKind.Positive}) + + feedback_calls = [ + c for c in mock_client.track.call_args_list + if c.args[0] == "$ld:ai:feedback:user:positive" + ] + assert len(feedback_calls) == 1 + track_data = feedback_calls[0].args[2] + assert track_data["runId"] == original._run_id + assert track_data["configKey"] == "my-config" + assert track_data["variationKey"] == "var-abc" + assert track_data["version"] == 7 + # modelName and providerName are empty when reconstructed from token + assert track_data["modelName"] == "" + assert track_data["providerName"] == "" + # Context should be the new one, not the original + assert feedback_calls[0].args[1] == context + + +def test_client_create_tracker_fails_on_invalid_base64(): + from unittest.mock import Mock + + from ldai.client import LDAIClient + + mock_client = Mock() + ai_client = LDAIClient(mock_client) + context = Context.create("user-key") + + result = ai_client.create_tracker("not-valid-base64!!!", context) + assert not result.is_success() + assert "Invalid resumption token" in result.error + + +def test_client_create_tracker_fails_on_missing_fields(): + import base64 + import json + + from unittest.mock import Mock + + from ldai.client import LDAIClient + + mock_client = Mock() + ai_client = LDAIClient(mock_client) + context = Context.create("user-key") + + # Token missing runId + incomplete = base64.urlsafe_b64encode( + json.dumps({"configKey": "k", "version": 1}).encode() + ).rstrip(b"=").decode() + + result = ai_client.create_tracker(incomplete, context) + assert not result.is_success() + assert "missing required field 'runId'" in result.error + + +def test_client_create_tracker_fails_on_invalid_json(): + import base64 + + from unittest.mock import Mock + + from ldai.client import LDAIClient + + mock_client = Mock() + ai_client = LDAIClient(mock_client) + context = Context.create("user-key") + + bad_token = base64.urlsafe_b64encode(b"not json").rstrip(b"=").decode() + + result = ai_client.create_tracker(bad_token, context) + assert not result.is_success() + assert "Invalid resumption token" in result.error