From 373e0e0c7218eac363d5d2c51ed513d5395906f5 Mon Sep 17 00:00:00 2001 From: Mario Jonke Date: Tue, 4 May 2021 16:34:30 +0200 Subject: [PATCH 1/2] Make propagatros conform to spec * do not modify / set an invalid span in the passed context in case a propagator did not manage to extract * in case no context is passed to propagator.extract assume the root context as default so that a new trace is started instead of continung the current active trace in case extraction fails * fix also jaeger propagator which compared int with str trace/span ids when checking for validity in extract --- .../src/opentelemetry/propagate/__init__.py | 2 +- .../src/opentelemetry/propagators/textmap.py | 2 +- .../trace/propagation/tracecontext.py | 13 +-- .../test_tracecontexthttptextformat.py | 49 +++++++++++ .../opentelemetry/propagators/b3/__init__.py | 4 +- .../tests/test_b3_format.py | 87 +++++++++++++++---- .../propagators/jaeger/__init__.py | 51 ++++++++--- .../tests/test_jaeger_propagator.py | 43 +++++++++ .../src/opentelemetry/test/mock_textmap.py | 11 ++- 9 files changed, 218 insertions(+), 44 deletions(-) diff --git a/opentelemetry-api/src/opentelemetry/propagate/__init__.py b/opentelemetry-api/src/opentelemetry/propagate/__init__.py index 6c63edec3cb..615de99ce33 100644 --- a/opentelemetry-api/src/opentelemetry/propagate/__init__.py +++ b/opentelemetry-api/src/opentelemetry/propagate/__init__.py @@ -96,7 +96,7 @@ def extract( used to construct a Context. This object must be paired with an appropriate getter which understands how to extract a value from it. - context: an optional Context to use. Defaults to current + context: an optional Context to use. Defaults to root context if not set. """ return get_global_textmap().extract(carrier, context, getter=getter) diff --git a/opentelemetry-api/src/opentelemetry/propagators/textmap.py b/opentelemetry-api/src/opentelemetry/propagators/textmap.py index 45c2308f661..0011315cf21 100644 --- a/opentelemetry-api/src/opentelemetry/propagators/textmap.py +++ b/opentelemetry-api/src/opentelemetry/propagators/textmap.py @@ -150,7 +150,7 @@ def extract( used to construct a Context. This object must be paired with an appropriate getter which understands how to extract a value from it. - context: an optional Context to use. Defaults to current + context: an optional Context to use. Defaults to root context if not set. Returns: A Context with configuration found in the carrier. diff --git a/opentelemetry-api/src/opentelemetry/trace/propagation/tracecontext.py b/opentelemetry-api/src/opentelemetry/trace/propagation/tracecontext.py index 9fc5cfed242..001db5c7293 100644 --- a/opentelemetry-api/src/opentelemetry/trace/propagation/tracecontext.py +++ b/opentelemetry-api/src/opentelemetry/trace/propagation/tracecontext.py @@ -43,14 +43,17 @@ def extract( See `opentelemetry.propagators.textmap.TextMapPropagator.extract` """ + if context is None: + context = Context() + header = getter.get(carrier, self._TRACEPARENT_HEADER_NAME) if not header: - return trace.set_span_in_context(trace.INVALID_SPAN, context) + return context match = re.search(self._TRACEPARENT_HEADER_FORMAT_RE, header[0]) if not match: - return trace.set_span_in_context(trace.INVALID_SPAN, context) + return context version = match.group(1) trace_id = match.group(2) @@ -58,13 +61,13 @@ def extract( trace_flags = match.group(4) if trace_id == "0" * 32 or span_id == "0" * 16: - return trace.set_span_in_context(trace.INVALID_SPAN, context) + return context if version == "00": if match.group(5): - return trace.set_span_in_context(trace.INVALID_SPAN, context) + return context if version == "ff": - return trace.set_span_in_context(trace.INVALID_SPAN, context) + return context tracestate_headers = getter.get(carrier, self._TRACESTATE_HEADER_NAME) if tracestate_headers is None: diff --git a/opentelemetry-api/tests/trace/propagation/test_tracecontexthttptextformat.py b/opentelemetry-api/tests/trace/propagation/test_tracecontexthttptextformat.py index 98ca50610b9..9d9561a4a5e 100644 --- a/opentelemetry-api/tests/trace/propagation/test_tracecontexthttptextformat.py +++ b/opentelemetry-api/tests/trace/propagation/test_tracecontexthttptextformat.py @@ -19,6 +19,7 @@ from unittest.mock import Mock, patch from opentelemetry import trace +from opentelemetry.context import Context from opentelemetry.trace.propagation import tracecontext from opentelemetry.trace.span import TraceState @@ -270,3 +271,51 @@ def test_fields(self, mock_get_current_span, mock_invalid_span_context): inject_fields.add(mock_call[1][1]) self.assertEqual(inject_fields, FORMAT.fields) + + def test_extract_no_trace_parent_to_explicit_ctx(self): + carrier = {"tracestate": ["foo=1"]} + orig_ctx = Context({"k1": "v1"}) + + ctx = FORMAT.extract(carrier, orig_ctx) + self.assertDictEqual(orig_ctx, ctx) + + def test_extract_no_trace_parent_to_implicit_ctx(self): + carrier = {"tracestate": ["foo=1"]} + + ctx = FORMAT.extract(carrier) + self.assertDictEqual(Context(), ctx) + + def test_extract_invalid_trace_parent_to_explicit_ctx(self): + trace_parent_headers = [ + "invalid", + "00-00000000000000000000000000000000-1234567890123456-00", + "00-12345678901234567890123456789012-0000000000000000-00", + "00-12345678901234567890123456789012-1234567890123456-00-residue", + ] + for trace_parent in trace_parent_headers: + with self.subTest(trace_parent=trace_parent): + carrier = { + "traceparent": [trace_parent], + "tracestate": ["foo=1"], + } + orig_ctx = Context({"k1": "v1"}) + + ctx = FORMAT.extract(carrier, orig_ctx) + self.assertDictEqual(orig_ctx, ctx) + + def test_extract_invalid_trace_parent_to_implicit_ctx(self): + trace_parent_headers = [ + "invalid", + "00-00000000000000000000000000000000-1234567890123456-00", + "00-12345678901234567890123456789012-0000000000000000-00", + "00-12345678901234567890123456789012-1234567890123456-00-residue", + ] + for trace_parent in trace_parent_headers: + with self.subTest(trace_parent=trace_parent): + carrier = { + "traceparent": [trace_parent], + "tracestate": ["foo=1"], + } + + ctx = FORMAT.extract(carrier) + self.assertDictEqual(Context(), ctx) diff --git a/propagator/opentelemetry-propagator-b3/src/opentelemetry/propagators/b3/__init__.py b/propagator/opentelemetry-propagator-b3/src/opentelemetry/propagators/b3/__init__.py index 6977bc32c64..d0beec401a8 100644 --- a/propagator/opentelemetry-propagator-b3/src/opentelemetry/propagators/b3/__init__.py +++ b/propagator/opentelemetry-propagator-b3/src/opentelemetry/propagators/b3/__init__.py @@ -50,6 +50,8 @@ def extract( context: typing.Optional[Context] = None, getter: Getter = default_getter, ) -> Context: + if context is None: + context = Context() trace_id = trace.INVALID_TRACE_ID span_id = trace.INVALID_SPAN_ID sampled = "0" @@ -97,8 +99,6 @@ def extract( or self._trace_id_regex.fullmatch(trace_id) is None or self._span_id_regex.fullmatch(span_id) is None ): - if context is None: - return trace.set_span_in_context(trace.INVALID_SPAN, context) return context trace_id = int(trace_id, 16) diff --git a/propagator/opentelemetry-propagator-b3/tests/test_b3_format.py b/propagator/opentelemetry-propagator-b3/tests/test_b3_format.py index 6ee0be2ce1c..29d8a472ea3 100644 --- a/propagator/opentelemetry-propagator-b3/tests/test_b3_format.py +++ b/propagator/opentelemetry-propagator-b3/tests/test_b3_format.py @@ -19,7 +19,7 @@ import opentelemetry.sdk.trace as trace import opentelemetry.sdk.trace.id_generator as id_generator import opentelemetry.trace as trace_api -from opentelemetry.context import get_current +from opentelemetry.context import Context, get_current from opentelemetry.propagators.textmap import DefaultGetter FORMAT = b3_format.B3Format() @@ -219,7 +219,7 @@ def test_flags_and_sampling(self): def test_derived_ctx_is_returned_for_success(self): """Ensure returned context is derived from the given context.""" - old_ctx = {"k1": "v1"} + old_ctx = Context({"k1": "v1"}) new_ctx = FORMAT.extract( { FORMAT.TRACE_ID_KEY: self.serialized_trace_id, @@ -229,17 +229,19 @@ def test_derived_ctx_is_returned_for_success(self): old_ctx, ) self.assertIn("current-span", new_ctx) - for key, value in old_ctx.items(): + for key, value in old_ctx.items(): # pylint:disable=no-member self.assertIn(key, new_ctx) + # pylint:disable=unsubscriptable-object self.assertEqual(new_ctx[key], value) def test_derived_ctx_is_returned_for_failure(self): """Ensure returned context is derived from the given context.""" - old_ctx = {"k2": "v2"} + old_ctx = Context({"k2": "v2"}) new_ctx = FORMAT.extract({}, old_ctx) self.assertNotIn("current-span", new_ctx) - for key, value in old_ctx.items(): + for key, value in old_ctx.items(): # pylint:disable=no-member self.assertIn(key, new_ctx) + # pylint:disable=unsubscriptable-object self.assertEqual(new_ctx[key], value) def test_64bit_trace_id(self): @@ -258,18 +260,24 @@ def test_64bit_trace_id(self): new_carrier[FORMAT.TRACE_ID_KEY], "0" * 16 + trace_id_64_bit ) - def test_extract_invalid_single_header(self): + def test_extract_invalid_single_header_to_explicit_ctx(self): """Given unparsable header, do not modify context""" - old_ctx = {} + old_ctx = Context({"k1": "v1"}) carrier = {FORMAT.SINGLE_HEADER_KEY: "0-1-2-3-4-5-6-7"} new_ctx = FORMAT.extract(carrier, old_ctx) self.assertDictEqual(new_ctx, old_ctx) - def test_extract_missing_trace_id(self): + def test_extract_invalid_single_header_to_implicit_ctx(self): + carrier = {FORMAT.SINGLE_HEADER_KEY: "0-1-2-3-4-5-6-7"} + new_ctx = FORMAT.extract(carrier) + + self.assertDictEqual(Context(), new_ctx) + + def test_extract_missing_trace_id_to_explicit_ctx(self): """Given no trace ID, do not modify context""" - old_ctx = {} + old_ctx = Context({"k1": "v1"}) carrier = { FORMAT.SPAN_ID_KEY: self.serialized_span_id, @@ -279,9 +287,18 @@ def test_extract_missing_trace_id(self): self.assertDictEqual(new_ctx, old_ctx) - def test_extract_invalid_trace_id(self): + def test_extract_missing_trace_id_to_implicit_ctx(self): + carrier = { + FORMAT.SPAN_ID_KEY: self.serialized_span_id, + FORMAT.FLAGS_KEY: "1", + } + new_ctx = FORMAT.extract(carrier) + + self.assertDictEqual(Context(), new_ctx) + + def test_extract_invalid_trace_id_to_explicit_ctx(self): """Given invalid trace ID, do not modify context""" - old_ctx = {} + old_ctx = Context({"k1": "v1"}) carrier = { FORMAT.TRACE_ID_KEY: "abc123", @@ -292,9 +309,19 @@ def test_extract_invalid_trace_id(self): self.assertDictEqual(new_ctx, old_ctx) - def test_extract_invalid_span_id(self): + def test_extract_invalid_trace_id_to_implicit_ctx(self): + carrier = { + FORMAT.TRACE_ID_KEY: "abc123", + FORMAT.SPAN_ID_KEY: self.serialized_span_id, + FORMAT.FLAGS_KEY: "1", + } + new_ctx = FORMAT.extract(carrier) + + self.assertDictEqual(Context(), new_ctx) + + def test_extract_invalid_span_id_to_explicit_ctx(self): """Given invalid span ID, do not modify context""" - old_ctx = {} + old_ctx = Context({"k1": "v1"}) carrier = { FORMAT.TRACE_ID_KEY: self.serialized_trace_id, @@ -305,9 +332,19 @@ def test_extract_invalid_span_id(self): self.assertDictEqual(new_ctx, old_ctx) - def test_extract_missing_span_id(self): + def test_extract_invalid_span_id_to_implicit_ctx(self): + carrier = { + FORMAT.TRACE_ID_KEY: self.serialized_trace_id, + FORMAT.SPAN_ID_KEY: "abc123", + FORMAT.FLAGS_KEY: "1", + } + new_ctx = FORMAT.extract(carrier) + + self.assertDictEqual(Context(), new_ctx) + + def test_extract_missing_span_id_to_explicit_ctx(self): """Given no span ID, do not modify context""" - old_ctx = {} + old_ctx = Context({"k1": "v1"}) carrier = { FORMAT.TRACE_ID_KEY: self.serialized_trace_id, @@ -317,15 +354,28 @@ def test_extract_missing_span_id(self): self.assertDictEqual(new_ctx, old_ctx) - def test_extract_empty_carrier(self): + def test_extract_missing_span_id_to_implicit_ctx(self): + carrier = { + FORMAT.TRACE_ID_KEY: self.serialized_trace_id, + FORMAT.FLAGS_KEY: "1", + } + new_ctx = FORMAT.extract(carrier) + + self.assertDictEqual(Context(), new_ctx) + + def test_extract_empty_carrier_to_explicit_ctx(self): """Given no headers at all, do not modify context""" - old_ctx = {} + old_ctx = Context({"k1": "v1"}) carrier = {} new_ctx = FORMAT.extract(carrier, old_ctx) self.assertDictEqual(new_ctx, old_ctx) + def test_extract_empty_carrier_to_implicit_ctx(self): + new_ctx = FORMAT.extract({}) + self.assertDictEqual(Context(), new_ctx) + @staticmethod def test_inject_empty_context(): """If the current context has no span, don't add headers""" @@ -368,5 +418,4 @@ def test_extract_none_context(self): carrier = {} new_ctx = FORMAT.extract(carrier, old_ctx) - self.assertIsNotNone(new_ctx) - self.assertEqual(new_ctx["current-span"], trace_api.INVALID_SPAN) + self.assertDictEqual(Context(), new_ctx) diff --git a/propagator/opentelemetry-propagator-jaeger/src/opentelemetry/propagators/jaeger/__init__.py b/propagator/opentelemetry-propagator-jaeger/src/opentelemetry/propagators/jaeger/__init__.py index 47f438531fb..974b9143a5a 100644 --- a/propagator/opentelemetry-propagator-jaeger/src/opentelemetry/propagators/jaeger/__init__.py +++ b/propagator/opentelemetry-propagator-jaeger/src/opentelemetry/propagators/jaeger/__init__.py @@ -47,31 +47,26 @@ def extract( ) -> Context: if context is None: - context = get_current() + context = Context() header = getter.get(carrier, self.TRACE_ID_KEY) if not header: - return trace.set_span_in_context(trace.INVALID_SPAN, context) - fields = _extract_first_element(header).split(":") + return context context = self._extract_baggage(getter, carrier, context) - if len(fields) != 4: - return trace.set_span_in_context(trace.INVALID_SPAN, context) - trace_id, span_id, _parent_id, flags = fields + trace_id, span_id, flags = _parse_trace_id_header(header) if ( trace_id == trace.INVALID_TRACE_ID or span_id == trace.INVALID_SPAN_ID ): - return trace.set_span_in_context(trace.INVALID_SPAN, context) + return context span = trace.NonRecordingSpan( trace.SpanContext( - trace_id=int(trace_id, 16), - span_id=int(span_id, 16), + trace_id=trace_id, + span_id=span_id, is_remote=True, - trace_flags=trace.TraceFlags( - int(flags, 16) & trace.TraceFlags.SAMPLED - ), + trace_flags=trace.TraceFlags(flags & trace.TraceFlags.SAMPLED), ) ) return trace.set_span_in_context(span, context) @@ -147,3 +142,35 @@ def _extract_first_element( if items is None: return None return next(iter(items), None) + + +def _parse_trace_id_header( + items: typing.Iterable[CarrierT], +) -> typing.Tuple[int]: + invalid_header_result = (trace.INVALID_TRACE_ID, trace.INVALID_SPAN_ID, 0) + + header = _extract_first_element(items) + if header is None: + return invalid_header_result + + fields = header.split(":") + if len(fields) != 4: + return invalid_header_result + + trace_id_str, span_id_str, _parent_id_str, flags_str = fields + flags = _int_from_hex_str(flags_str, None) + if flags is None: + return invalid_header_result + + trace_id = _int_from_hex_str(trace_id_str, trace.INVALID_TRACE_ID) + span_id = _int_from_hex_str(span_id_str, trace.INVALID_SPAN_ID) + return trace_id, span_id, flags + + +def _int_from_hex_str( + identifier: str, default: typing.Optional[int] +) -> typing.Optional[int]: + try: + return int(identifier, 16) + except ValueError: + return default diff --git a/propagator/opentelemetry-propagator-jaeger/tests/test_jaeger_propagator.py b/propagator/opentelemetry-propagator-jaeger/tests/test_jaeger_propagator.py index 12a0a028ddf..55e096b0954 100644 --- a/propagator/opentelemetry-propagator-jaeger/tests/test_jaeger_propagator.py +++ b/propagator/opentelemetry-propagator-jaeger/tests/test_jaeger_propagator.py @@ -19,6 +19,7 @@ import opentelemetry.sdk.trace.id_generator as id_generator import opentelemetry.trace as trace_api from opentelemetry import baggage +from opentelemetry.context import Context from opentelemetry.propagators import ( # pylint: disable=no-name-in-module jaeger, ) @@ -186,3 +187,45 @@ def test_fields(self): for call in mock_setter.mock_calls: inject_fields.add(call[1][1]) self.assertEqual(FORMAT.fields, inject_fields) + + def test_extract_no_trace_id_to_explicit_ctx(self): + carrier = {} + orig_ctx = Context({"k1": "v1"}) + + ctx = FORMAT.extract(carrier, orig_ctx) + self.assertDictEqual(orig_ctx, ctx) + + def test_extract_no_trace_id_to_implicit_ctx(self): + carrier = {} + + ctx = FORMAT.extract(carrier) + self.assertDictEqual(Context(), ctx) + + def test_extract_invalid_uber_trace_id_header_to_explicit_ctx(self): + trace_id_headers = [ + "000000000000000000000000deadbeef:00000000deadbef0:00", + "00000000000000000000000000000000:00000000deadbef0:00:00", + "000000000000000000000000deadbeef:0000000000000000:00:00", + "000000000000000000000000deadbeef:0000000000000000:00:xyz", + ] + for trace_id_header in trace_id_headers: + with self.subTest(trace_id_header=trace_id_header): + carrier = {"uber-trace-id": trace_id_header} + orig_ctx = Context({"k1": "v1"}) + + ctx = FORMAT.extract(carrier, orig_ctx) + self.assertDictEqual(orig_ctx, ctx) + + def test_extract_invalid_uber_trace_id_header_to_implicit_ctx(self): + trace_id_headers = [ + "000000000000000000000000deadbeef:00000000deadbef0:00", + "00000000000000000000000000000000:00000000deadbef0:00:00", + "000000000000000000000000deadbeef:0000000000000000:00:00", + "000000000000000000000000deadbeef:0000000000000000:00:xyz", + ] + for trace_id_header in trace_id_headers: + with self.subTest(trace_id_header=trace_id_header): + carrier = {"uber-trace-id": trace_id_header} + + ctx = FORMAT.extract(carrier) + self.assertDictEqual(Context(), ctx) diff --git a/tests/util/src/opentelemetry/test/mock_textmap.py b/tests/util/src/opentelemetry/test/mock_textmap.py index 4cdef447d6a..c3e901ee287 100644 --- a/tests/util/src/opentelemetry/test/mock_textmap.py +++ b/tests/util/src/opentelemetry/test/mock_textmap.py @@ -15,7 +15,7 @@ import typing from opentelemetry import trace -from opentelemetry.context import Context, get_current +from opentelemetry.context import Context from opentelemetry.propagators.textmap import ( CarrierT, Getter, @@ -39,7 +39,7 @@ def extract( context: typing.Optional[Context] = None, getter: Getter = default_getter, ) -> Context: - return get_current() + return Context() def inject( self, @@ -66,11 +66,13 @@ def extract( context: typing.Optional[Context] = None, getter: Getter = default_getter, ) -> Context: + if context is None: + context = Context() trace_id_list = getter.get(carrier, self.TRACE_ID_KEY) span_id_list = getter.get(carrier, self.SPAN_ID_KEY) if not trace_id_list or not span_id_list: - return trace.set_span_in_context(trace.INVALID_SPAN) + return context return trace.set_span_in_context( trace.NonRecordingSpan( @@ -79,7 +81,8 @@ def extract( span_id=int(span_id_list[0]), is_remote=True, ) - ) + ), + context, ) def inject( From 0d36115ef9e9509513e2a482fcd68d8fa3230708 Mon Sep 17 00:00:00 2001 From: Mario Jonke Date: Tue, 4 May 2021 17:02:58 +0200 Subject: [PATCH 2/2] changelog --- CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ae8eb424a67..fef2a064535 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added example for running Django with auto instrumentation ([#1803](https://github.com/open-telemetry/opentelemetry-python/pull/1803)) +### Changed +- Propagators use the root context as default for `extract` and do not modify + the context if extracting from carrier does not work. + ([#1811](https://github.com/open-telemetry/opentelemetry-python/pull/1811)) + ### Removed - Moved `opentelemetry-instrumentation` to contrib repository ([#1797](https://github.com/open-telemetry/opentelemetry-python/pull/1797))