From 0cb02001a17ae41e0c06b247fd955ab97f2e17b7 Mon Sep 17 00:00:00 2001 From: Tomasz Pytel Date: Tue, 24 Nov 2020 18:50:44 -0300 Subject: [PATCH] [Enhancement] Async tasks should work 100% --- skywalking/trace/context.py | 114 +++++++++++++++++++++--------------- 1 file changed, 68 insertions(+), 46 deletions(-) diff --git a/skywalking/trace/context.py b/skywalking/trace/context.py index 9f73bf78..814fca5a 100644 --- a/skywalking/trace/context.py +++ b/skywalking/trace/context.py @@ -15,8 +15,6 @@ # limitations under the License. # -from typing import List - from skywalking import agent, config from skywalking.trace import ID from skywalking.trace.carrier import Carrier @@ -27,9 +25,52 @@ from skywalking.utils.counter import Counter +try: # attempt to use async-local instead of thread-local context and spans + import contextvars + + __local = contextvars.ContextVar('local') + __spans = contextvars.ContextVar('spans') # this needs to be a per-task variable, can't be part of __local + _spans = __spans.get + _spans_set = __spans.set # pyre-ignore + + class AsyncLocal: + pass + + def _local(): + try: + return __local.get() + + except LookupError: + local = AsyncLocal() + __local.set(local) + + return local + + def _spans_dup(): + spans = __spans.get()[:] + __spans.set(spans) + + return spans + +except ImportError: + import threading + + __local = threading.local() + + def _local(): + return __local + + def _spans(): + return __local.spans + + def _spans_set(spans): + __local.spans = spans + + _spans_dup = _spans + + class SpanContext(object): def __init__(self): - self.spans = [] # type: List[Span] self.segment = Segment() # type: Segment self._sid = Counter() self._correlation = {} # type: dict @@ -39,7 +80,8 @@ def new_local_span(self, op: str) -> Span: if span is not None: return span - parent = self.spans[-1] if self.spans else None # type: Span + spans = _spans_dup() + parent = spans[-1] if spans else None # type: Span return Span( context=self, @@ -54,7 +96,8 @@ def new_entry_span(self, op: str, carrier: 'Carrier' = None) -> Span: if span is not None: return span - parent = self.spans[-1] if self.spans else None # type: Span + spans = _spans_dup() + parent = spans[-1] if spans else None # type: Span span = parent if parent is not None and parent.kind.is_entry else EntrySpan( context=self, @@ -73,7 +116,8 @@ def new_exit_span(self, op: str, peer: str, carrier: 'Carrier' = None) -> Span: if span is not None: return span - parent = self.spans[-1] if self.spans else None # type: Span + spans = _spans_dup() + parent = spans[-1] if spans else None # type: Span span = parent if parent is not None and parent.kind.is_exit else ExitSpan( context=self, @@ -106,24 +150,27 @@ def ignore_check(self, op: str, kind: Kind): return None def start(self, span: Span): - if span not in self.spans: - self.spans.append(span) + spans = _spans() + if span not in spans: + spans.append(span) def stop(self, span: Span) -> bool: - idx = self.spans.index(span) # span SHOULD always be at end but in async-world it MAY not be, so don't crash + spans = _spans() + idx = spans.index(span) # span SHOULD now always be at end even in async-world, but just in case if span.finish(self.segment): - del self.spans[idx] + del spans[idx] - if len(self.spans) == 0: + if len(spans) == 0: _local().context = None agent.archive(self.segment) - return len(self.spans) == 0 + return len(spans) == 0 def active_span(self): - if self.spans: - return self.spans[len(self.spans) - 1] + spans = _spans() + if spans: + return spans[len(spans) - 1] return None @@ -146,14 +193,15 @@ def put_correlation(self, key, value): self._correlation[key] = value def capture(self): - if len(self.spans) == 0: + spans = _spans() + if len(spans) == 0: return None return Snapshot( segment_id=str(self.segment.segment_id), span_id=self.active_span().sid, trace_id=self.segment.related_traces[0], - endpoint=self.spans[0].op, + endpoint=spans[0].op, correlation=self._correlation, ) @@ -176,22 +224,22 @@ def __init__(self): self.correlation = {} # type: dict def new_local_span(self, op: str) -> Span: - self._depth += 1 return self._noop_span def new_entry_span(self, op: str, carrier: 'Carrier' = None) -> Span: - self._depth += 1 if carrier is not None: self._noop_span.extract(carrier) return self._noop_span def new_exit_span(self, op: str, peer: str, carrier: 'Carrier' = None) -> Span: - self._depth += 1 if carrier is not None: self._noop_span.inject(carrier) return self._noop_span + def start(self, span: Span): + self._depth += 1 + def stop(self, span: Span) -> bool: self._depth -= 1 return self._depth == 0 @@ -212,38 +260,12 @@ def continued(self, snapshot: 'Snapshot'): self._correlation.update(snapshot.correlation) -try: # attempt to use async-local instead of thread-local context - import contextvars - - __local = contextvars.ContextVar('local') - - class AsyncLocal: - pass - - def _local(): - try: - return __local.get() - - except LookupError: - local = AsyncLocal() - __local.set(local) - - return local - -except ImportError: - import threading - - __local = threading.local() - - def _local(): - return __local - - def get_context() -> SpanContext: local = _local() context = getattr(local, 'context', False) if not context: context = local.context = (SpanContext() if agent.connected() else NoopContext()) + _spans_set([]) # XXX would be better in SpanContext.__init__() but for some reason doesn't work there return context