diff --git a/sentry_sdk/integrations/litellm.py b/sentry_sdk/integrations/litellm.py index 28bcc34d3e..6e3d444954 100644 --- a/sentry_sdk/integrations/litellm.py +++ b/sentry_sdk/integrations/litellm.py @@ -230,8 +230,11 @@ def _success_callback( ) finally: - # Always finish the span and clean up - span.__exit__(None, None, None) + is_streaming = kwargs.get("stream") + # Callback is fired multiple times when streaming a response. + # Streaming flag checked at https://github.com/BerriAI/litellm/blob/33c3f13443eaf990ac8c6e3da78bddbc2b7d0e7a/litellm/litellm_core_utils/litellm_logging.py#L1603 + if is_streaming is not True or "complete_streaming_response" in kwargs: + span.__exit__(None, None, None) def _failure_callback( diff --git a/tests/conftest.py b/tests/conftest.py index 71f2431aac..d1cd95fb77 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1063,6 +1063,120 @@ def inner(response_content, serialize_pydantic=False, request_headers=None): return inner +@pytest.fixture +def streaming_chat_completions_model_response(): + return [ + openai.types.chat.ChatCompletionChunk( + id="chatcmpl-test", + object="chat.completion.chunk", + created=10000000, + model="gpt-3.5-turbo", + choices=[ + openai.types.chat.chat_completion_chunk.Choice( + index=0, + delta=openai.types.chat.chat_completion_chunk.ChoiceDelta( + role="assistant" + ), + finish_reason=None, + ), + ], + ), + openai.types.chat.ChatCompletionChunk( + id="chatcmpl-test", + object="chat.completion.chunk", + created=10000000, + model="gpt-3.5-turbo", + choices=[ + openai.types.chat.chat_completion_chunk.Choice( + index=0, + delta=openai.types.chat.chat_completion_chunk.ChoiceDelta( + content="Tes" + ), + finish_reason=None, + ), + ], + ), + openai.types.chat.ChatCompletionChunk( + id="chatcmpl-test", + object="chat.completion.chunk", + created=10000000, + model="gpt-3.5-turbo", + choices=[ + openai.types.chat.chat_completion_chunk.Choice( + index=0, + delta=openai.types.chat.chat_completion_chunk.ChoiceDelta( + content="t r" + ), + finish_reason=None, + ), + ], + ), + openai.types.chat.ChatCompletionChunk( + id="chatcmpl-test", + object="chat.completion.chunk", + created=10000000, + model="gpt-3.5-turbo", + choices=[ + openai.types.chat.chat_completion_chunk.Choice( + index=0, + delta=openai.types.chat.chat_completion_chunk.ChoiceDelta( + content="esp" + ), + finish_reason=None, + ), + ], + ), + openai.types.chat.ChatCompletionChunk( + id="chatcmpl-test", + object="chat.completion.chunk", + created=10000000, + model="gpt-3.5-turbo", + choices=[ + openai.types.chat.chat_completion_chunk.Choice( + index=0, + delta=openai.types.chat.chat_completion_chunk.ChoiceDelta( + content="ons" + ), + finish_reason=None, + ), + ], + ), + openai.types.chat.ChatCompletionChunk( + id="chatcmpl-test", + object="chat.completion.chunk", + created=10000000, + model="gpt-3.5-turbo", + choices=[ + openai.types.chat.chat_completion_chunk.Choice( + index=0, + delta=openai.types.chat.chat_completion_chunk.ChoiceDelta( + content="e" + ), + finish_reason=None, + ), + ], + ), + openai.types.chat.ChatCompletionChunk( + id="chatcmpl-test", + object="chat.completion.chunk", + created=10000000, + model="gpt-3.5-turbo", + choices=[ + openai.types.chat.chat_completion_chunk.Choice( + index=0, + delta=openai.types.chat.chat_completion_chunk.ChoiceDelta(), + finish_reason="stop", + ), + ], + usage=openai.types.CompletionUsage( + prompt_tokens=10, + completion_tokens=20, + total_tokens=30, + ), + ), + ] + + @pytest.fixture def nonstreaming_responses_model_response(): return openai.types.responses.Response( diff --git a/tests/integrations/litellm/test_litellm.py b/tests/integrations/litellm/test_litellm.py index 9022093fa3..3403c2e5a0 100644 --- a/tests/integrations/litellm/test_litellm.py +++ b/tests/integrations/litellm/test_litellm.py @@ -31,10 +31,26 @@ async def __call__(self, *args, **kwargs): ) from sentry_sdk.utils import package_version +from openai import OpenAI + +from concurrent.futures import ThreadPoolExecutor + +import litellm.utils as litellm_utils +from litellm.litellm_core_utils import streaming_handler +from litellm.litellm_core_utils import thread_pool_executor + LITELLM_VERSION = package_version("litellm") +@pytest.fixture() +def reset_litellm_executor(): + yield + thread_pool_executor.executor = ThreadPoolExecutor(max_workers=100) + litellm_utils.executor = thread_pool_executor.executor + streaming_handler.executor = thread_pool_executor.executor + + @pytest.fixture def clear_litellm_cache(): """ @@ -212,7 +228,14 @@ def test_nonstreaming_chat_completion( ], ) def test_streaming_chat_completion( - sentry_init, capture_events, send_default_pii, include_prompts + reset_litellm_executor, + sentry_init, + capture_events, + send_default_pii, + include_prompts, + get_model_response, + server_side_event_chunks, + streaming_chat_completions_model_response, ): sentry_init( integrations=[LiteLLMIntegration(include_prompts=include_prompts)], @@ -222,29 +245,45 @@ def test_streaming_chat_completion( events = capture_events() messages = [{"role": "user", "content": "Hello!"}] - mock_response = MockCompletionResponse() - with start_transaction(name="litellm test"): - kwargs = { - "model": "gpt-3.5-turbo", - "messages": messages, - "stream": True, - } + client = OpenAI(api_key="z") - _input_callback(kwargs) - _success_callback( - kwargs, - mock_response, - datetime.now(), - datetime.now(), - ) + model_response = get_model_response( + server_side_event_chunks( + streaming_chat_completions_model_response, + include_event_type=False, + ), + request_headers={"X-Stainless-Raw-Response": "True"}, + ) + + with mock.patch.object( + client.completions._client._client, + "send", + return_value=model_response, + ): + with start_transaction(name="litellm test"): + response = litellm.completion( + model="gpt-3.5-turbo", + messages=messages, + client=client, + stream=True, + ) + for _ in response: + pass + + streaming_handler.executor.shutdown(wait=True) assert len(events) == 1 (event,) = events assert event["type"] == "transaction" - assert len(event["spans"]) == 1 - (span,) = event["spans"] + chat_spans = list( + x + for x in event["spans"] + if x["op"] == OP.GEN_AI_CHAT and x["origin"] == "auto.ai.litellm" + ) + assert len(chat_spans) == 1 + span = chat_spans[0] assert span["op"] == OP.GEN_AI_CHAT assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True