Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions packages/optimization/src/ldai_optimization/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
AIJudgeCallConfig,
GroundTruthOptimizationOptions,
GroundTruthSample,
HandleJudgeCall,
JudgeResult,
OptimizationContext,
OptimizationFromConfigOptions,
Expand Down Expand Up @@ -228,6 +229,11 @@ def _create_optimization_context(
iteration=iteration,
)

@property
def _judge_call(self) -> HandleJudgeCall:
"""Return the judge callable, falling back to handle_agent_call when not set."""
return self._options.handle_judge_call or self._options.handle_agent_call

def _safe_status_update(
self,
status: Literal[
Expand Down Expand Up @@ -569,10 +575,9 @@ async def _evaluate_config_judge(
LDMessage(role="user", content=judge_user_input),
]

# Collect model parameters from the judge config, separating out any existing tools
model_name = (
judge_config.model.name if judge_config.model else self._options.judge_model
)
# Always use the global judge_model; model parameters (temperature, etc.) from
# the judge flag are still forwarded, but the model name is never overridden.
model_name = self._options.judge_model
model_params: Dict[str, Any] = {}
tools: List[ToolDefinition] = []
if judge_config.model and judge_config.model._parameters:
Expand Down Expand Up @@ -615,8 +620,8 @@ async def _evaluate_config_judge(
)

_judge_start = time.monotonic()
result = self._options.handle_judge_call(
judge_key, judge_call_config, judge_ctx
result = self._judge_call(
judge_key, judge_call_config, judge_ctx, True
)
judge_response: OptimizationResponse = await await_if_needed(result)
judge_duration_ms = (time.monotonic() - _judge_start) * 1000
Expand Down Expand Up @@ -776,8 +781,8 @@ async def _evaluate_acceptance_judge(
)

_judge_start = time.monotonic()
result = self._options.handle_judge_call(
judge_key, judge_call_config, judge_ctx
result = self._judge_call(
judge_key, judge_call_config, judge_ctx, True
)
judge_response: OptimizationResponse = await await_if_needed(result)
judge_duration_ms = (time.monotonic() - _judge_start) * 1000
Expand Down Expand Up @@ -1318,6 +1323,7 @@ async def _generate_new_variation(
self._agent_key,
agent_config,
variation_ctx,
False,
)
variation_response: OptimizationResponse = await await_if_needed(result)
response_str = variation_response.output
Expand Down Expand Up @@ -1717,6 +1723,7 @@ async def _execute_agent_turn(
self._agent_key,
self._build_agent_config_for_context(optimize_context),
optimize_context,
False,
)
agent_response: OptimizationResponse = await await_if_needed(result)
agent_duration_ms = (time.monotonic() - _agent_start) * 1000
Expand Down
33 changes: 21 additions & 12 deletions packages/optimization/src/ldai_optimization/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ async def handle_llm_call(
key: str,
config: LLMCallConfig,
context: LLMCallContext,
is_evaluation: bool,
) -> OptimizationResponse:
model_name = config.model.name if config.model else "gpt-4o"
instructions = config.instructions or ""
Expand All @@ -132,9 +133,12 @@ async def handle_llm_call(
)
"""

key: str
model: Optional[ModelConfig]
instructions: Optional[str]
@property
def key(self) -> str: ...
@property
def model(self) -> Optional[ModelConfig]: ...
@property
def instructions(self) -> Optional[str]: ...


class LLMCallContext(Protocol):
Expand All @@ -144,8 +148,10 @@ class LLMCallContext(Protocol):
``handle_agent_call`` and ``handle_judge_call``.
"""

user_input: Optional[str]
current_variables: Dict[str, Any]
@property
def user_input(self) -> Optional[str]: ...
@property
def current_variables(self) -> Dict[str, Any]: ...


@dataclass
Expand Down Expand Up @@ -282,12 +288,12 @@ class OptimizationJudgeContext:
# the concrete types (AIAgentConfig / AIJudgeCallConfig) continue to work
# because those types structurally satisfy the Protocols.
HandleAgentCall = Union[
Callable[[str, LLMCallConfig, LLMCallContext], OptimizationResponse],
Callable[[str, LLMCallConfig, LLMCallContext], Awaitable[OptimizationResponse]],
Callable[[str, LLMCallConfig, LLMCallContext, bool], OptimizationResponse],
Callable[[str, LLMCallConfig, LLMCallContext, bool], Awaitable[OptimizationResponse]],
]
HandleJudgeCall = Union[
Callable[[str, LLMCallConfig, LLMCallContext], OptimizationResponse],
Callable[[str, LLMCallConfig, LLMCallContext], Awaitable[OptimizationResponse]],
Callable[[str, LLMCallConfig, LLMCallContext, bool], OptimizationResponse],
Callable[[str, LLMCallConfig, LLMCallContext, bool], Awaitable[OptimizationResponse]],
]

_StatusLiteral = Literal[
Expand Down Expand Up @@ -315,7 +321,8 @@ class OptimizationOptions:
] # choices of interpolated variables to be chosen at random per turn, 1 min required
# Actual agent/completion (judge) calls - Required
handle_agent_call: HandleAgentCall
handle_judge_call: HandleJudgeCall
# Optional; falls back to handle_agent_call when omitted (both share the same signature)
handle_judge_call: Optional[HandleJudgeCall] = None
# Criteria for pass/fail - Optional
user_input_options: Optional[List[str]] = (
None # optional list of user input messages to randomly select from
Expand Down Expand Up @@ -401,7 +408,8 @@ class GroundTruthOptimizationOptions:
model_choices: List[str]
judge_model: str
handle_agent_call: HandleAgentCall
handle_judge_call: HandleJudgeCall
# Optional; falls back to handle_agent_call when omitted (both share the same signature)
handle_judge_call: Optional[HandleJudgeCall] = None
judges: Optional[Dict[str, OptimizationJudge]] = None
on_turn: Optional[Callable[[OptimizationContext], bool]] = None
on_sample_result: Optional[Callable[[OptimizationContext], None]] = None
Expand Down Expand Up @@ -461,7 +469,8 @@ class OptimizationFromConfigOptions:

project_key: str
handle_agent_call: HandleAgentCall
handle_judge_call: HandleJudgeCall
# Optional; falls back to handle_agent_call when omitted (both share the same signature)
handle_judge_call: Optional[HandleJudgeCall] = None
on_turn: Optional[Callable[["OptimizationContext"], bool]] = None
on_sample_result: Optional[Callable[["OptimizationContext"], None]] = None
on_passing_result: Optional[Callable[["OptimizationContext"], None]] = None
Expand Down
19 changes: 10 additions & 9 deletions packages/optimization/src/ldai_optimization/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
import logging
import re
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
from typing import Any, Awaitable, Dict, List, Optional, Tuple, TypeVar, Union

from ldai_optimization.dataclasses import ToolDefinition

Expand Down Expand Up @@ -156,18 +156,19 @@ def restore_variable_placeholders(
return text, warnings


async def await_if_needed(
result: Union[str, Awaitable[str]]
) -> str:
_T = TypeVar("_T")


async def await_if_needed(result: Union[_T, Awaitable[_T]]) -> _T:
"""
Handle both sync and async callable results.

:param result: Either a string or an awaitable that returns a string
:return: The string result
:param result: Either a value or an awaitable that returns a value
:return: The resolved value
"""
if isinstance(result, str):
return result
return await result
if inspect.isawaitable(result):
return await result # type: ignore[return-value]
return result # type: ignore[return-value]


def create_evaluation_tool() -> ToolDefinition:
Expand Down
Loading
Loading