Skip to content

Commit d90763f

Browse files
committed
handle instrumentation of delegators
1 parent 9c55e99 commit d90763f

File tree

2 files changed

+116
-48
lines changed

2 files changed

+116
-48
lines changed

instrumentation/opentelemetry-instrumentation-sklearn/src/opentelemetry/instrumentation/sklearn/__init__.py

Lines changed: 73 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
from sklearn.base import BaseEstimator
6969
from sklearn.pipeline import FeatureUnion, Pipeline
7070
from sklearn.tree import BaseDecisionTree
71+
from sklearn.utils.metaestimators import _IffHasAttrDescriptor
7172

7273
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
7374
from opentelemetry.instrumentation.sklearn.version import __version__
@@ -104,6 +105,55 @@ def wrapper(*args, **kwargs):
104105
return wrapper
105106

106107

108+
def implement_spans_delegator(obj: _IffHasAttrDescriptor):
109+
"""Wrap the descriptor's fn with a span.
110+
111+
Args:
112+
obj: An instance of _IffHasAttrDescriptor
113+
"""
114+
# Don't instrument inherited delegators
115+
if hasattr(obj, "_otel_original_fn"):
116+
logger.debug("Already instrumented: %s", obj.fn.__qualname__)
117+
return
118+
119+
def implement_spans_get(func: Callable):
120+
@wraps(func)
121+
def wrapper(*args, **kwargs):
122+
with get_tracer(__name__, __version__).start_as_current_span(
123+
name=func.__qualname__
124+
):
125+
return func(*args, **kwargs)
126+
127+
return wrapper
128+
129+
logger.debug("Instrumenting: %s", obj.fn.__qualname__)
130+
131+
setattr(obj, "_otel_original_fn", getattr(obj, "fn"))
132+
setattr(obj, "fn", implement_spans_get(obj.fn))
133+
134+
135+
def get_delegator(
136+
estimator: Type[BaseEstimator], method_name: str
137+
) -> Union[_IffHasAttrDescriptor, None]:
138+
"""Get the delegator from a class method or None.
139+
140+
Args:
141+
estimator (BaseEstimator): A class derived from ``sklearn``'s
142+
``BaseEstimator``.
143+
method_name (str): The method name of the estimator on which to
144+
check for delegation.
145+
146+
Returns:
147+
The delegator, if one exists, otherwise None.
148+
"""
149+
class_attr = getattr(estimator, method_name)
150+
if getattr(class_attr, "__closure__", None) is not None:
151+
for cell in class_attr.__closure__:
152+
if isinstance(cell.cell_contents, _IffHasAttrDescriptor):
153+
return cell.cell_contents
154+
return None
155+
156+
107157
def get_base_estimators(packages: List[str]) -> Dict[str, Type[BaseEstimator]]:
108158
"""Walk package hierarchies to get BaseEstimator-derived classes.
109159
@@ -389,7 +439,7 @@ def _check_instrumented(
389439
method_name (str): The method name of the estimator on which to
390440
check for instrumentation.
391441
"""
392-
orig_method_name = "_original_" + method_name
442+
orig_method_name = "_otel_original_" + method_name
393443
has_original = hasattr(estimator, orig_method_name)
394444
orig_class, orig_method = getattr(
395445
estimator, orig_method_name, (None, None)
@@ -419,11 +469,12 @@ def _uninstrument_class_method(
419469
method_name (str): The method name of the estimator on which to
420470
apply a span.
421471
"""
422-
orig_method_name = "_original_" + method_name
472+
orig_method_name = "_otel_original_" + method_name
423473
if isclass(estimator):
424474
qualname = estimator.__qualname__
425475
else:
426476
qualname = estimator.__class__.__qualname__
477+
delegator = get_delegator(estimator, method_name)
427478
if self._check_instrumented(estimator, method_name):
428479
logger.debug(
429480
"Uninstrumenting: %s.%s", qualname, method_name,
@@ -433,6 +484,16 @@ def _uninstrument_class_method(
433484
estimator, method_name, orig_method,
434485
)
435486
delattr(estimator, orig_method_name)
487+
elif delegator is not None:
488+
if not hasattr(delegator, "_otel_original_fn"):
489+
logger.debug(
490+
"Already uninstrumented: %s.%s", qualname, method_name,
491+
)
492+
return
493+
setattr(
494+
delegator, "fn", getattr(delegator, "_otel_original_fn"),
495+
)
496+
delattr(delegator, "_otel_original_fn")
436497
else:
437498
logger.debug(
438499
"Already uninstrumented: %s.%s", qualname, method_name,
@@ -452,7 +513,7 @@ def _uninstrument_instance_method(
452513
method_name (str): The method name of the estimator on which to
453514
apply a span.
454515
"""
455-
orig_method_name = "_original_" + method_name
516+
orig_method_name = "_otel_original_" + method_name
456517
if isclass(estimator):
457518
qualname = estimator.__qualname__
458519
else:
@@ -496,37 +557,25 @@ def _instrument_class_method(
496557
)
497558
return
498559
class_attr = getattr(estimator, method_name)
560+
delegator = get_delegator(estimator, method_name)
499561
if isinstance(class_attr, property):
500562
logger.debug(
501563
"Not instrumenting found property: %s.%s",
502564
estimator.__qualname__,
503565
method_name,
504566
)
567+
elif delegator is not None:
568+
implement_spans_delegator(delegator)
505569
else:
506570
setattr(
507-
estimator, "_original_" + method_name, (estimator, class_attr),
571+
estimator,
572+
"_otel_original_" + method_name,
573+
(estimator, class_attr),
508574
)
509575
setattr(
510576
estimator, method_name, self.spanner(class_attr, estimator),
511577
)
512578

513-
def _function_wrapper(self, function):
514-
"""Get the inner-most decorator of a function."""
515-
if hasattr(function, "__wrapped__"):
516-
if hasattr(function.__wrapped__, "__wrapped__"):
517-
return self._function_wrapper(function.__wrapped__)
518-
return function
519-
return None
520-
521-
def _function_wrapper_wrapper(self, function):
522-
"""Get the second inner-most decorator of a function"""
523-
if hasattr(function, "__wrapped__"):
524-
if hasattr(function.__wrapped__, "__wrapped__"):
525-
if hasattr(function.__wrapped__.__wrapped__, "__wrapped__"):
526-
return self._function_wrapper_wrapper(function.__wrapped__)
527-
return function
528-
return None
529-
530579
def _unwrap_function(self, function):
531580
"""Fetch the function underlying any decorators"""
532581
if hasattr(function, "__wrapped__"):
@@ -564,7 +613,9 @@ def _instrument_instance_method(
564613
)
565614
else:
566615
method = getattr(estimator, method_name)
567-
setattr(estimator, "_original_" + method_name, (estimator, method))
616+
setattr(
617+
estimator, "_otel_original_" + method_name, (estimator, method)
618+
)
568619
setattr(
569620
estimator, method_name, self.spanner(method, estimator),
570621
)

instrumentation/opentelemetry-instrumentation-sklearn/tests/test_sklearn.py

Lines changed: 43 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,54 @@
55
DEFAULT_METHODS,
66
SklearnInstrumentor,
77
get_base_estimators,
8+
get_delegator,
89
)
910
from opentelemetry.test.test_base import TestBase
1011
from opentelemetry.trace import SpanKind
1112

1213
from .fixtures import pipeline, random_input
1314

1415

16+
def assert_instrumented(base_estimators):
17+
for _, estimator in base_estimators.items():
18+
for method_name in DEFAULT_METHODS:
19+
original_method_name = "_otel_original_" + method_name
20+
if issubclass(estimator, tuple(DEFAULT_EXCLUDE_CLASSES)):
21+
assert not hasattr(estimator, original_method_name)
22+
continue
23+
class_attr = getattr(estimator, method_name, None)
24+
if isinstance(class_attr, property):
25+
assert not hasattr(estimator, original_method_name)
26+
continue
27+
delegator = None
28+
if hasattr(estimator, method_name):
29+
delegator = get_delegator(estimator, method_name)
30+
if delegator is not None:
31+
assert hasattr(delegator, "_otel_original_fn")
32+
elif hasattr(estimator, method_name):
33+
assert hasattr(estimator, original_method_name)
34+
35+
36+
def assert_uninstrumented(base_estimators):
37+
for _, estimator in base_estimators.items():
38+
for method_name in DEFAULT_METHODS:
39+
original_method_name = "_otel_original_" + method_name
40+
if issubclass(estimator, tuple(DEFAULT_EXCLUDE_CLASSES)):
41+
assert not hasattr(estimator, original_method_name)
42+
continue
43+
class_attr = getattr(estimator, method_name, None)
44+
if isinstance(class_attr, property):
45+
assert not hasattr(estimator, original_method_name)
46+
continue
47+
delegator = None
48+
if hasattr(estimator, method_name):
49+
delegator = get_delegator(estimator, method_name)
50+
if delegator is not None:
51+
assert not hasattr(delegator, "_otel_original_fn")
52+
elif hasattr(estimator, method_name):
53+
assert not hasattr(estimator, original_method_name)
54+
55+
1556
class TestSklearn(TestBase):
1657
def test_package_instrumentation(self):
1758
ski = SklearnInstrumentor()
@@ -21,42 +62,18 @@ def test_package_instrumentation(self):
2162
model = pipeline()
2263

2364
ski.instrument()
24-
# assert instrumented
25-
for _, estimator in base_estimators.items():
26-
for method_name in DEFAULT_METHODS:
27-
if issubclass(estimator, tuple(DEFAULT_EXCLUDE_CLASSES)):
28-
assert not hasattr(estimator, "_original_" + method_name)
29-
continue
30-
class_attr = getattr(estimator, method_name, None)
31-
if isinstance(class_attr, property):
32-
assert not hasattr(estimator, "_original_" + method_name)
33-
continue
34-
if hasattr(estimator, method_name):
35-
assert hasattr(estimator, "_original_" + method_name)
65+
assert_instrumented(base_estimators)
3666

3767
x_test = random_input()
3868

3969
model.predict(x_test)
4070

4171
spans = self.memory_exporter.get_finished_spans()
42-
for span in spans:
43-
print(span)
4472
self.assertEqual(len(spans), 8)
4573
self.memory_exporter.clear()
4674

4775
ski.uninstrument()
48-
# assert uninstrumented
49-
for _, estimator in base_estimators.items():
50-
for method_name in DEFAULT_METHODS:
51-
if issubclass(estimator, tuple(DEFAULT_EXCLUDE_CLASSES)):
52-
assert not hasattr(estimator, "_original_" + method_name)
53-
continue
54-
class_attr = getattr(estimator, method_name, None)
55-
if isinstance(class_attr, property):
56-
assert not hasattr(estimator, "_original_" + method_name)
57-
continue
58-
if hasattr(estimator, method_name):
59-
assert not hasattr(estimator, "_original_" + method_name)
76+
assert_uninstrumented(base_estimators)
6077

6178
model = pipeline()
6279
x_test = random_input()

0 commit comments

Comments
 (0)