Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion sentry_sdk/integrations/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sentry_sdk.scope import should_send_default_pii
from sentry_sdk.traces import NoOpStreamedSpan, StreamedSpan
from sentry_sdk.tracing import SOURCE_FOR_STYLE, TransactionSource
from sentry_sdk.tracing_utils import has_span_streaming_enabled
from sentry_sdk.utils import transaction_from_function

from typing import TYPE_CHECKING
Expand All @@ -19,6 +20,7 @@
from sentry_sdk.integrations.starlette import (
StarletteIntegration,
StarletteRequestExtractor,
_set_request_body_data_on_streaming_segment,
)
except DidNotEnable:
raise DidNotEnable("Starlette is not installed")
Expand Down Expand Up @@ -109,7 +111,8 @@ def _sentry_call(*args: "Any", **kwargs: "Any") -> "Any":
old_app = old_get_request_handler(*args, **kwargs)

async def _sentry_app(*args: "Any", **kwargs: "Any") -> "Any":
integration = sentry_sdk.get_client().get_integration(FastApiIntegration)
client = sentry_sdk.get_client()
integration = client.get_integration(FastApiIntegration)
if integration is None:
return await old_app(*args, **kwargs)

Expand Down Expand Up @@ -144,6 +147,9 @@ def event_processor(event: "Event", hint: "Dict[str, Any]") -> "Event":
_make_request_event_processor(request, integration)
)

if has_span_streaming_enabled(client.options):
_set_request_body_data_on_streaming_segment(info)

return await old_app(*args, **kwargs)

return _sentry_app
Expand Down
38 changes: 20 additions & 18 deletions sentry_sdk/integrations/starlette.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ async def _sentry_send(*args: "Any", **kwargs: "Any") -> "Any":
return middleware_class


def _serialize_body_data(data: "Any") -> str:
def _serialize_request_body_data(data: "Any") -> str:
# data may be a JSON-serializable value, an AnnotatedValue, or a dict with AnnotatedValue values
def _default(value: "Any") -> "Any":
if isinstance(value, AnnotatedValue):
Expand All @@ -251,6 +251,23 @@ def _default(value: "Any") -> "Any":
return json.dumps(data, default=_default)


def _set_request_body_data_on_streaming_segment(
info: "Optional[Dict[str, Any]]",
) -> None:
current_span = sentry_sdk.get_current_span()
if (
info
and "data" in info
and isinstance(current_span, StreamedSpan)
and not isinstance(current_span, NoOpStreamedSpan)
):
with capture_internal_exceptions():
current_span._segment.set_attribute(
"http.request.body.data",
_serialize_request_body_data(info["data"]),
)


@ensure_integration_enabled(StarletteIntegration)
def _capture_exception(exception: BaseException, handled: "Any" = False) -> None:
event, hint = event_from_exception(
Expand Down Expand Up @@ -517,23 +534,8 @@ def event_processor(
_make_request_event_processor(request, integration)
)

is_span_streaming_enabled = has_span_streaming_enabled(client.options)
if is_span_streaming_enabled:
current_span = sentry_sdk.get_current_span()

if (
info
and "data" in info
and isinstance(current_span, StreamedSpan)
and not isinstance(current_span, NoOpStreamedSpan)
):
data = info["data"]

with capture_internal_exceptions():
current_span._segment.set_attribute(
"http.request.body.data",
_serialize_body_data(data),
)
if has_span_streaming_enabled(client.options):
_set_request_body_data_on_streaming_segment(info)

return await old_func(*args, **kwargs)

Expand Down
138 changes: 138 additions & 0 deletions tests/integrations/fastapi/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from unittest import mock

import fastapi
import starlette
from fastapi import FastAPI, HTTPException, Request
from fastapi.testclient import TestClient
from fastapi.middleware.trustedhost import TrustedHostMiddleware
Expand All @@ -20,6 +21,7 @@


FASTAPI_VERSION = parse_version(fastapi.__version__)
STARLETTE_VERSION = parse_version(starlette.__version__)

from tests.integrations.conftest import parametrize_test_configurable_status_codes
from tests.integrations.starlette import test_starlette
Expand Down Expand Up @@ -245,6 +247,142 @@ def test_active_thread_id_span_streaming(sentry_init, capture_items, endpoint):
assert str(data["active"]) == segments[0]["attributes"]["thread.id"]


def _post_body_fastapi_app(handler_awaitable):
app = FastAPI()

@app.post("/body")
async def _route(request: Request):
await handler_awaitable(request)
return {"ok": True}

return app


@pytest.mark.parametrize("middleware_spans", [False, True])
def test_request_body_data_does_not_scrub_pii_span_streaming(
sentry_init, capture_items, middleware_spans
):
sentry_init(
auto_enabling_integrations=False,
integrations=[
StarletteIntegration(middleware_spans=middleware_spans),
FastApiIntegration(middleware_spans=middleware_spans),
],
traces_sample_rate=1.0,
_experiments={"trace_lifecycle": "stream"},
)

async def _read_json(request):
await request.json()

items = capture_items("span")

client = TestClient(_post_body_fastapi_app(_read_json))
response = client.post(
"/body",
json={
"password": "ohno",
"authorization": "Bearer token",
"message": "hello",
},
)
assert response.status_code == 200

sentry_sdk.flush()

segments = [item.payload for item in items if item.payload.get("is_segment")]
assert len(segments) == 1
attr = segments[0]["attributes"]["http.request.body.data"]

# Going forward, the sanitization of data will need to happen within the `before_send_span` hooks
# See https://sentry.slack.com/archives/C09RR0KD2N7/p1776951331206129?thread_ts=1776951227.440659&cid=C09RR0KD2N7
assert "ohno" in attr
assert "Bearer token" in attr
assert "hello" in attr


@pytest.mark.skipif(
STARLETTE_VERSION < (0, 21),
reason="Requires Starlette >= 0.21, because earlier versions use a requests-based TestClient which does not support the 'content' kwarg",
)
@pytest.mark.parametrize("middleware_spans", [False, True])
def test_request_body_data_annotated_value_top_level_span_streaming(
sentry_init, capture_items, middleware_spans
):
sentry_init(
auto_enabling_integrations=False,
integrations=[
StarletteIntegration(middleware_spans=middleware_spans),
FastApiIntegration(middleware_spans=middleware_spans),
],
traces_sample_rate=1.0,
_experiments={"trace_lifecycle": "stream"},
)

async def _read_body(request):
await request.body()

items = capture_items("span")

client = TestClient(_post_body_fastapi_app(_read_body))
response = client.post(
"/body",
content=b"not json and not form",
headers={"content-type": "application/octet-stream"},
)
assert response.status_code == 200

sentry_sdk.flush()

segments = [item.payload for item in items if item.payload.get("is_segment")]
assert len(segments) == 1
attr = segments[0]["attributes"]["http.request.body.data"]

assert isinstance(attr, str)
assert attr == '""'


@pytest.mark.parametrize("middleware_spans", [False, True])
def test_request_body_data_annotated_value_nested_span_streaming(
sentry_init, capture_items, middleware_spans
):
pytest.importorskip("multipart")

sentry_init(
auto_enabling_integrations=False,
integrations=[
StarletteIntegration(middleware_spans=middleware_spans),
FastApiIntegration(middleware_spans=middleware_spans),
],
traces_sample_rate=1.0,
_experiments={"trace_lifecycle": "stream"},
)

async def _read_form(request):
await request.form()

items = capture_items("span")

client = TestClient(_post_body_fastapi_app(_read_form))
response = client.post(
"/body",
data={"name": "erica"},
files={"avatar": ("photo.jpg", b"fake-bytes", "image/jpeg")},
)
assert response.status_code == 200

sentry_sdk.flush()

segments = [item.payload for item in items if item.payload.get("is_segment")]
assert len(segments) == 1
attr = segments[0]["attributes"]["http.request.body.data"]

assert isinstance(attr, str)
parsed = json.loads(attr)
assert parsed["name"] == "erica"
assert "fake-bytes" not in attr


@pytest.mark.parametrize("span_streaming", [True, False])
@pytest.mark.asyncio
async def test_original_request_not_scrubbed(
Expand Down
Loading