Skip to content

Commit 28cfe9b

Browse files
authored
[Enhancement] Async task context support (#88)
1 parent 1f7f2ab commit 28cfe9b

2 files changed

Lines changed: 37 additions & 11 deletions

File tree

skywalking/agent/protocol/grpc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# limitations under the License.
1616
#
1717

18+
import logging
1819
from skywalking.loggings import logger
1920
import traceback
2021
from queue import Queue
@@ -62,7 +63,7 @@ def connected(self):
6263
return self.state == grpc.ChannelConnectivity.READY
6364

6465
def on_error(self):
65-
traceback.print_exc()
66+
traceback.print_exc() if logger.isEnabledFor(logging.DEBUG) else None
6667
self.channel.unsubscribe(self._cb)
6768
self.channel.subscribe(self._cb, try_to_connect=True)
6869

skywalking/trace/context.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# limitations under the License.
1616
#
1717

18-
import threading
1918
from typing import List
2019

2120
from skywalking import agent, config
@@ -111,12 +110,13 @@ def start(self, span: Span):
111110
self.spans.append(span)
112111

113112
def stop(self, span: Span) -> bool:
114-
assert span is self.spans[-1]
113+
idx = self.spans.index(span) # span SHOULD always be at end but in async-world it MAY not be, so don't crash
115114

116-
span.finish(self.segment) and self.spans.pop()
115+
if span.finish(self.segment):
116+
del self.spans[idx]
117117

118118
if len(self.spans) == 0:
119-
_thread_local.context = None
119+
_local().context = None
120120
agent.archive(self.segment)
121121

122122
return len(self.spans) == 0
@@ -212,13 +212,38 @@ def continued(self, snapshot: 'Snapshot'):
212212
self._correlation.update(snapshot.correlation)
213213

214214

215-
_thread_local = threading.local()
216-
_thread_local.context = None
215+
try: # attempt to use async-local instead of thread-local context
216+
import contextvars
217+
218+
__local = contextvars.ContextVar('local')
219+
220+
class AsyncLocal:
221+
pass
222+
223+
def _local():
224+
try:
225+
return __local.get()
226+
227+
except LookupError:
228+
local = AsyncLocal()
229+
__local.set(local)
230+
231+
return local
232+
233+
except ImportError:
234+
import threading
235+
236+
__local = threading.local()
237+
238+
def _local():
239+
return __local
217240

218241

219242
def get_context() -> SpanContext:
220-
if not hasattr(_thread_local, 'context'):
221-
_thread_local.context = None
222-
_thread_local.context = _thread_local.context or (SpanContext() if agent.connected() else NoopContext())
243+
local = _local()
244+
context = getattr(local, 'context', False)
245+
246+
if not context:
247+
context = local.context = (SpanContext() if agent.connected() else NoopContext())
223248

224-
return _thread_local.context
249+
return context

0 commit comments

Comments
 (0)