Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Make propagatros conform to spec
* do not modify / set an invalid span in the passed context in case
  a propagator did not manage to extract
* in case no context is passed to propagator.extract assume the root
  context as default so that a new trace is started instead of continung
  the current active trace in case extraction fails
* fix also jaeger propagator which compared int with str trace/span ids
  when checking for validity in extract
  • Loading branch information
mariojonke committed May 4, 2021
commit 373e0e0c7218eac363d5d2c51ed513d5395906f5
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def extract(
used to construct a Context. This object
must be paired with an appropriate getter
which understands how to extract a value from it.
context: an optional Context to use. Defaults to current
context: an optional Context to use. Defaults to root
context if not set.
"""
return get_global_textmap().extract(carrier, context, getter=getter)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def extract(
used to construct a Context. This object
must be paired with an appropriate getter
which understands how to extract a value from it.
context: an optional Context to use. Defaults to current
context: an optional Context to use. Defaults to root
context if not set.
Returns:
A Context with configuration found in the carrier.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,28 +43,31 @@ def extract(

See `opentelemetry.propagators.textmap.TextMapPropagator.extract`
"""
if context is None:
context = Context()

header = getter.get(carrier, self._TRACEPARENT_HEADER_NAME)

if not header:
return trace.set_span_in_context(trace.INVALID_SPAN, context)
return context

match = re.search(self._TRACEPARENT_HEADER_FORMAT_RE, header[0])
if not match:
return trace.set_span_in_context(trace.INVALID_SPAN, context)
return context

version = match.group(1)
trace_id = match.group(2)
span_id = match.group(3)
trace_flags = match.group(4)

if trace_id == "0" * 32 or span_id == "0" * 16:
return trace.set_span_in_context(trace.INVALID_SPAN, context)
return context

if version == "00":
if match.group(5):
return trace.set_span_in_context(trace.INVALID_SPAN, context)
return context
if version == "ff":
return trace.set_span_in_context(trace.INVALID_SPAN, context)
return context

tracestate_headers = getter.get(carrier, self._TRACESTATE_HEADER_NAME)
if tracestate_headers is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from unittest.mock import Mock, patch

from opentelemetry import trace
from opentelemetry.context import Context
from opentelemetry.trace.propagation import tracecontext
from opentelemetry.trace.span import TraceState

Expand Down Expand Up @@ -270,3 +271,51 @@ def test_fields(self, mock_get_current_span, mock_invalid_span_context):
inject_fields.add(mock_call[1][1])

self.assertEqual(inject_fields, FORMAT.fields)

def test_extract_no_trace_parent_to_explicit_ctx(self):
carrier = {"tracestate": ["foo=1"]}
orig_ctx = Context({"k1": "v1"})

ctx = FORMAT.extract(carrier, orig_ctx)
self.assertDictEqual(orig_ctx, ctx)

def test_extract_no_trace_parent_to_implicit_ctx(self):
carrier = {"tracestate": ["foo=1"]}

ctx = FORMAT.extract(carrier)
self.assertDictEqual(Context(), ctx)

def test_extract_invalid_trace_parent_to_explicit_ctx(self):
trace_parent_headers = [
"invalid",
"00-00000000000000000000000000000000-1234567890123456-00",
"00-12345678901234567890123456789012-0000000000000000-00",
"00-12345678901234567890123456789012-1234567890123456-00-residue",
]
for trace_parent in trace_parent_headers:
with self.subTest(trace_parent=trace_parent):
carrier = {
"traceparent": [trace_parent],
"tracestate": ["foo=1"],
}
orig_ctx = Context({"k1": "v1"})

ctx = FORMAT.extract(carrier, orig_ctx)
self.assertDictEqual(orig_ctx, ctx)

def test_extract_invalid_trace_parent_to_implicit_ctx(self):
trace_parent_headers = [
"invalid",
"00-00000000000000000000000000000000-1234567890123456-00",
"00-12345678901234567890123456789012-0000000000000000-00",
"00-12345678901234567890123456789012-1234567890123456-00-residue",
]
for trace_parent in trace_parent_headers:
with self.subTest(trace_parent=trace_parent):
carrier = {
"traceparent": [trace_parent],
"tracestate": ["foo=1"],
}

ctx = FORMAT.extract(carrier)
self.assertDictEqual(Context(), ctx)
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def extract(
context: typing.Optional[Context] = None,
getter: Getter = default_getter,
) -> Context:
if context is None:
context = Context()
trace_id = trace.INVALID_TRACE_ID
span_id = trace.INVALID_SPAN_ID
sampled = "0"
Expand Down Expand Up @@ -97,8 +99,6 @@ def extract(
or self._trace_id_regex.fullmatch(trace_id) is None
or self._span_id_regex.fullmatch(span_id) is None
):
if context is None:
return trace.set_span_in_context(trace.INVALID_SPAN, context)
return context

trace_id = int(trace_id, 16)
Expand Down
87 changes: 68 additions & 19 deletions propagator/opentelemetry-propagator-b3/tests/test_b3_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import opentelemetry.sdk.trace as trace
import opentelemetry.sdk.trace.id_generator as id_generator
import opentelemetry.trace as trace_api
from opentelemetry.context import get_current
from opentelemetry.context import Context, get_current
from opentelemetry.propagators.textmap import DefaultGetter

FORMAT = b3_format.B3Format()
Expand Down Expand Up @@ -219,7 +219,7 @@ def test_flags_and_sampling(self):

def test_derived_ctx_is_returned_for_success(self):
"""Ensure returned context is derived from the given context."""
old_ctx = {"k1": "v1"}
old_ctx = Context({"k1": "v1"})
new_ctx = FORMAT.extract(
{
FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
Expand All @@ -229,17 +229,19 @@ def test_derived_ctx_is_returned_for_success(self):
old_ctx,
)
self.assertIn("current-span", new_ctx)
for key, value in old_ctx.items():
for key, value in old_ctx.items(): # pylint:disable=no-member
self.assertIn(key, new_ctx)
# pylint:disable=unsubscriptable-object
self.assertEqual(new_ctx[key], value)

def test_derived_ctx_is_returned_for_failure(self):
"""Ensure returned context is derived from the given context."""
old_ctx = {"k2": "v2"}
old_ctx = Context({"k2": "v2"})
new_ctx = FORMAT.extract({}, old_ctx)
self.assertNotIn("current-span", new_ctx)
for key, value in old_ctx.items():
for key, value in old_ctx.items(): # pylint:disable=no-member
self.assertIn(key, new_ctx)
# pylint:disable=unsubscriptable-object
self.assertEqual(new_ctx[key], value)

def test_64bit_trace_id(self):
Expand All @@ -258,18 +260,24 @@ def test_64bit_trace_id(self):
new_carrier[FORMAT.TRACE_ID_KEY], "0" * 16 + trace_id_64_bit
)

def test_extract_invalid_single_header(self):
def test_extract_invalid_single_header_to_explicit_ctx(self):
"""Given unparsable header, do not modify context"""
old_ctx = {}
old_ctx = Context({"k1": "v1"})

carrier = {FORMAT.SINGLE_HEADER_KEY: "0-1-2-3-4-5-6-7"}
new_ctx = FORMAT.extract(carrier, old_ctx)

self.assertDictEqual(new_ctx, old_ctx)

def test_extract_missing_trace_id(self):
def test_extract_invalid_single_header_to_implicit_ctx(self):
carrier = {FORMAT.SINGLE_HEADER_KEY: "0-1-2-3-4-5-6-7"}
new_ctx = FORMAT.extract(carrier)

self.assertDictEqual(Context(), new_ctx)

def test_extract_missing_trace_id_to_explicit_ctx(self):
"""Given no trace ID, do not modify context"""
old_ctx = {}
old_ctx = Context({"k1": "v1"})

carrier = {
FORMAT.SPAN_ID_KEY: self.serialized_span_id,
Expand All @@ -279,9 +287,18 @@ def test_extract_missing_trace_id(self):

self.assertDictEqual(new_ctx, old_ctx)

def test_extract_invalid_trace_id(self):
def test_extract_missing_trace_id_to_implicit_ctx(self):
carrier = {
FORMAT.SPAN_ID_KEY: self.serialized_span_id,
FORMAT.FLAGS_KEY: "1",
}
new_ctx = FORMAT.extract(carrier)

self.assertDictEqual(Context(), new_ctx)

def test_extract_invalid_trace_id_to_explicit_ctx(self):
"""Given invalid trace ID, do not modify context"""
old_ctx = {}
old_ctx = Context({"k1": "v1"})

carrier = {
FORMAT.TRACE_ID_KEY: "abc123",
Expand All @@ -292,9 +309,19 @@ def test_extract_invalid_trace_id(self):

self.assertDictEqual(new_ctx, old_ctx)

def test_extract_invalid_span_id(self):
def test_extract_invalid_trace_id_to_implicit_ctx(self):
carrier = {
FORMAT.TRACE_ID_KEY: "abc123",
FORMAT.SPAN_ID_KEY: self.serialized_span_id,
FORMAT.FLAGS_KEY: "1",
}
new_ctx = FORMAT.extract(carrier)

self.assertDictEqual(Context(), new_ctx)

def test_extract_invalid_span_id_to_explicit_ctx(self):
"""Given invalid span ID, do not modify context"""
old_ctx = {}
old_ctx = Context({"k1": "v1"})

carrier = {
FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
Expand All @@ -305,9 +332,19 @@ def test_extract_invalid_span_id(self):

self.assertDictEqual(new_ctx, old_ctx)

def test_extract_missing_span_id(self):
def test_extract_invalid_span_id_to_implicit_ctx(self):
carrier = {
FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
FORMAT.SPAN_ID_KEY: "abc123",
FORMAT.FLAGS_KEY: "1",
}
new_ctx = FORMAT.extract(carrier)

self.assertDictEqual(Context(), new_ctx)

def test_extract_missing_span_id_to_explicit_ctx(self):
"""Given no span ID, do not modify context"""
old_ctx = {}
old_ctx = Context({"k1": "v1"})

carrier = {
FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
Expand All @@ -317,15 +354,28 @@ def test_extract_missing_span_id(self):

self.assertDictEqual(new_ctx, old_ctx)

def test_extract_empty_carrier(self):
def test_extract_missing_span_id_to_implicit_ctx(self):
carrier = {
FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
FORMAT.FLAGS_KEY: "1",
}
new_ctx = FORMAT.extract(carrier)

self.assertDictEqual(Context(), new_ctx)

def test_extract_empty_carrier_to_explicit_ctx(self):
"""Given no headers at all, do not modify context"""
old_ctx = {}
old_ctx = Context({"k1": "v1"})

carrier = {}
new_ctx = FORMAT.extract(carrier, old_ctx)

self.assertDictEqual(new_ctx, old_ctx)

def test_extract_empty_carrier_to_implicit_ctx(self):
new_ctx = FORMAT.extract({})
self.assertDictEqual(Context(), new_ctx)

@staticmethod
def test_inject_empty_context():
"""If the current context has no span, don't add headers"""
Expand Down Expand Up @@ -368,5 +418,4 @@ def test_extract_none_context(self):

carrier = {}
new_ctx = FORMAT.extract(carrier, old_ctx)
self.assertIsNotNone(new_ctx)
self.assertEqual(new_ctx["current-span"], trace_api.INVALID_SPAN)
self.assertDictEqual(Context(), new_ctx)
Original file line number Diff line number Diff line change
Expand Up @@ -47,31 +47,26 @@ def extract(
) -> Context:

if context is None:
context = get_current()
context = Context()
header = getter.get(carrier, self.TRACE_ID_KEY)
if not header:
return trace.set_span_in_context(trace.INVALID_SPAN, context)
fields = _extract_first_element(header).split(":")
return context

context = self._extract_baggage(getter, carrier, context)
if len(fields) != 4:
return trace.set_span_in_context(trace.INVALID_SPAN, context)

trace_id, span_id, _parent_id, flags = fields
trace_id, span_id, flags = _parse_trace_id_header(header)
if (
trace_id == trace.INVALID_TRACE_ID
or span_id == trace.INVALID_SPAN_ID
):
return trace.set_span_in_context(trace.INVALID_SPAN, context)
return context

span = trace.NonRecordingSpan(
trace.SpanContext(
trace_id=int(trace_id, 16),
span_id=int(span_id, 16),
trace_id=trace_id,
span_id=span_id,
is_remote=True,
trace_flags=trace.TraceFlags(
int(flags, 16) & trace.TraceFlags.SAMPLED
),
trace_flags=trace.TraceFlags(flags & trace.TraceFlags.SAMPLED),
)
)
return trace.set_span_in_context(span, context)
Expand Down Expand Up @@ -147,3 +142,35 @@ def _extract_first_element(
if items is None:
return None
return next(iter(items), None)


def _parse_trace_id_header(
items: typing.Iterable[CarrierT],
) -> typing.Tuple[int]:
invalid_header_result = (trace.INVALID_TRACE_ID, trace.INVALID_SPAN_ID, 0)

header = _extract_first_element(items)
if header is None:
return invalid_header_result

fields = header.split(":")
if len(fields) != 4:
return invalid_header_result

trace_id_str, span_id_str, _parent_id_str, flags_str = fields
flags = _int_from_hex_str(flags_str, None)
if flags is None:
return invalid_header_result

trace_id = _int_from_hex_str(trace_id_str, trace.INVALID_TRACE_ID)
span_id = _int_from_hex_str(span_id_str, trace.INVALID_SPAN_ID)
return trace_id, span_id, flags


def _int_from_hex_str(
identifier: str, default: typing.Optional[int]
) -> typing.Optional[int]:
try:
return int(identifier, 16)
except ValueError:
return default
Loading