|
15 | 15 | # limitations under the License. |
16 | 16 | # |
17 | 17 |
|
18 | | -import threading |
19 | 18 | from typing import List |
20 | 19 |
|
21 | 20 | from skywalking import agent, config |
@@ -111,12 +110,13 @@ def start(self, span: Span): |
111 | 110 | self.spans.append(span) |
112 | 111 |
|
113 | 112 | def stop(self, span: Span) -> bool: |
114 | | - assert span is self.spans[-1] |
| 113 | + idx = self.spans.index(span) # span SHOULD always be at end but in async-world it MAY not be, so don't crash |
115 | 114 |
|
116 | | - span.finish(self.segment) and self.spans.pop() |
| 115 | + if span.finish(self.segment): |
| 116 | + del self.spans[idx] |
117 | 117 |
|
118 | 118 | if len(self.spans) == 0: |
119 | | - _thread_local.context = None |
| 119 | + _local().context = None |
120 | 120 | agent.archive(self.segment) |
121 | 121 |
|
122 | 122 | return len(self.spans) == 0 |
@@ -212,13 +212,38 @@ def continued(self, snapshot: 'Snapshot'): |
212 | 212 | self._correlation.update(snapshot.correlation) |
213 | 213 |
|
214 | 214 |
|
215 | | -_thread_local = threading.local() |
216 | | -_thread_local.context = None |
| 215 | +try: # attempt to use async-local instead of thread-local context |
| 216 | + import contextvars |
| 217 | + |
| 218 | + __local = contextvars.ContextVar('local') |
| 219 | + |
| 220 | + class AsyncLocal: |
| 221 | + pass |
| 222 | + |
| 223 | + def _local(): |
| 224 | + try: |
| 225 | + return __local.get() |
| 226 | + |
| 227 | + except LookupError: |
| 228 | + local = AsyncLocal() |
| 229 | + __local.set(local) |
| 230 | + |
| 231 | + return local |
| 232 | + |
| 233 | +except ImportError: |
| 234 | + import threading |
| 235 | + |
| 236 | + __local = threading.local() |
| 237 | + |
| 238 | + def _local(): |
| 239 | + return __local |
217 | 240 |
|
218 | 241 |
|
219 | 242 | def get_context() -> SpanContext: |
220 | | - if not hasattr(_thread_local, 'context'): |
221 | | - _thread_local.context = None |
222 | | - _thread_local.context = _thread_local.context or (SpanContext() if agent.connected() else NoopContext()) |
| 243 | + local = _local() |
| 244 | + context = getattr(local, 'context', False) |
| 245 | + |
| 246 | + if not context: |
| 247 | + context = local.context = (SpanContext() if agent.connected() else NoopContext()) |
223 | 248 |
|
224 | | - return _thread_local.context |
| 249 | + return context |
0 commit comments