diff --git a/instrumentation/opentelemetry-instrumentation-fastapi/src/opentelemetry/instrumentation/fastapi/__init__.py b/instrumentation/opentelemetry-instrumentation-fastapi/src/opentelemetry/instrumentation/fastapi/__init__.py index 7de11cab8d..6cdfda1eba 100644 --- a/instrumentation/opentelemetry-instrumentation-fastapi/src/opentelemetry/instrumentation/fastapi/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-fastapi/src/opentelemetry/instrumentation/fastapi/__init__.py @@ -190,6 +190,7 @@ def client_response_hook(span: Span, scope: dict[str, Any], message: dict[str, A import fastapi from starlette.applications import Starlette +from starlette.background import BackgroundTask from starlette.middleware.errors import ServerErrorMiddleware from starlette.routing import Match, Route from starlette.types import ASGIApp, Receive, Scope, Send @@ -399,6 +400,16 @@ async def __call__( app, ) + if not hasattr(BackgroundTask, "_otel_original_call"): + BackgroundTask._otel_original_call = BackgroundTask.__call__ + + async def traced_call(self): + span_name = f"BackgroundTask {getattr(self.func, '__name__', self.func.__class__.__name__)}" + with tracer.start_as_current_span(span_name): + return await BackgroundTask._otel_original_call(self) + + BackgroundTask.__call__ = traced_call + app._is_instrumented_by_opentelemetry = True if app not in _InstrumentedFastAPI._instrumented_fastapi_apps: _InstrumentedFastAPI._instrumented_fastapi_apps.add(app) @@ -416,6 +427,11 @@ def uninstrument_app(app: fastapi.FastAPI): app.build_middleware_stack = original_build_middleware_stack del app._original_build_middleware_stack app.middleware_stack = app.build_middleware_stack() + + if hasattr(BackgroundTask, "_otel_original_call"): + BackgroundTask.__call__ = BackgroundTask._otel_original_call + del BackgroundTask._otel_original_call + app._is_instrumented_by_opentelemetry = False # Remove the app from the set of instrumented apps to avoid calling uninstrument twice diff --git a/instrumentation/opentelemetry-instrumentation-fastapi/tests/test_fastapi_instrumentation.py b/instrumentation/opentelemetry-instrumentation-fastapi/tests/test_fastapi_instrumentation.py index aa6189a60e..86362633a5 100644 --- a/instrumentation/opentelemetry-instrumentation-fastapi/tests/test_fastapi_instrumentation.py +++ b/instrumentation/opentelemetry-instrumentation-fastapi/tests/test_fastapi_instrumentation.py @@ -25,11 +25,13 @@ import fastapi import pytest +from fastapi.background import BackgroundTasks from fastapi.middleware.asyncexitstack import AsyncExitStackMiddleware from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware from fastapi.responses import JSONResponse, PlainTextResponse from fastapi.routing import APIRoute from fastapi.testclient import TestClient +from starlette.background import BackgroundTask from starlette.routing import Match from starlette.types import Receive, Scope, Send @@ -493,6 +495,51 @@ def test_basic_fastapi_call(self): for span in spans: self.assertIn("GET /foobar", span.name) + def test_background_task_span_parents_inner_spans(self): + """Regression test for #4251: spans created inside a FastAPI + BackgroundTask must be children of a dedicated background-task span + instead of the already-closed request span.""" + self.memory_exporter.clear() + app = fastapi.FastAPI() + self._instrumentor.instrument_app(app) + tracer = self.tracer_provider.get_tracer(__name__) + + async def background_notify(): + with tracer.start_as_current_span("inside-background-task"): + pass + + @app.post("/checkout") + async def checkout(background_tasks: BackgroundTasks): + background_tasks.add_task(background_notify) + return {"status": "processing"} + + with TestClient(app) as client: + response = client.post("/checkout") + self.assertEqual(200, response.status_code) + spans = self.memory_exporter.get_finished_spans() + request_span = next( + span for span in spans if span.name == "POST /checkout" + ) + background_span = next( + span + for span in spans + if span.name == "BackgroundTask background_notify" + ) + inner_span = next( + span for span in spans if span.name == "inside-background-task" + ) + self.assertIsNotNone(background_span.parent) + self.assertEqual( + background_span.parent.span_id, + request_span.context.span_id, + ) + self.assertIsNotNone(inner_span.parent) + self.assertEqual( + inner_span.parent.span_id, + background_span.context.span_id, + ) + otel_fastapi.FastAPIInstrumentor().uninstrument_app(app) + def test_fastapi_route_attribute_added(self): """Ensure that fastapi routes are used as the span name.""" self._client.get("/user/123") @@ -988,6 +1035,49 @@ def test_basic_post_request_metric_success_both_semconv(self): if isinstance(point, NumberDataPoint): self.assertEqual(point.value, 0) + def test_uninstrument_app_restores_background_task_call(self): + """Regression test for #4251: uninstrumentation must restore the + original BackgroundTask.__call__ after FastAPI patches it.""" + self.assertTrue(hasattr(BackgroundTask, "_otel_original_call")) + self._instrumentor.uninstrument_app(self._app) + self.assertFalse(hasattr(BackgroundTask, "_otel_original_call")) + + def test_background_task_span_not_duplicated_on_double_instrument_app( + self, + ): + """Regression test for #4251: repeated instrument_app calls must not + wrap BackgroundTask.__call__ multiple times or duplicate spans.""" + self.memory_exporter.clear() + app = fastapi.FastAPI() + self._instrumentor.instrument_app(app) + self._instrumentor.instrument_app(app) + tracer = self.tracer_provider.get_tracer(__name__) + + async def background_notify(): + with tracer.start_as_current_span("inside-background-task"): + pass + + @app.post("/checkout") + async def checkout(background_tasks: BackgroundTasks): + background_tasks.add_task(background_notify) + return {"status": "processing"} + + with TestClient(app) as client: + response = client.post("/checkout") + self.assertEqual(200, response.status_code) + spans = self.memory_exporter.get_finished_spans() + background_spans = [ + span + for span in spans + if span.name == "BackgroundTask background_notify" + ] + inner_spans = [ + span for span in spans if span.name == "inside-background-task" + ] + self.assertEqual(len(background_spans), 1) + self.assertEqual(len(inner_spans), 1) + otel_fastapi.FastAPIInstrumentor().uninstrument_app(app) + def test_metric_uninstrument_app(self): self._client.get("/foobar") self._instrumentor.uninstrument_app(self._app)