6868from sklearn .base import BaseEstimator
6969from sklearn .pipeline import FeatureUnion , Pipeline
7070from sklearn .tree import BaseDecisionTree
71+ from sklearn .utils .metaestimators import _IffHasAttrDescriptor
7172
7273from opentelemetry .instrumentation .instrumentor import BaseInstrumentor
7374from opentelemetry .instrumentation .sklearn .version import __version__
@@ -104,6 +105,55 @@ def wrapper(*args, **kwargs):
104105 return wrapper
105106
106107
108+ def implement_spans_delegator (obj : _IffHasAttrDescriptor ):
109+ """Wrap the descriptor's fn with a span.
110+
111+ Args:
112+ obj: An instance of _IffHasAttrDescriptor
113+ """
114+ # Don't instrument inherited delegators
115+ if hasattr (obj , "_otel_original_fn" ):
116+ logger .debug ("Already instrumented: %s" , obj .fn .__qualname__ )
117+ return
118+
119+ def implement_spans_get (func : Callable ):
120+ @wraps (func )
121+ def wrapper (* args , ** kwargs ):
122+ with get_tracer (__name__ , __version__ ).start_as_current_span (
123+ name = func .__qualname__
124+ ):
125+ return func (* args , ** kwargs )
126+
127+ return wrapper
128+
129+ logger .debug ("Instrumenting: %s" , obj .fn .__qualname__ )
130+
131+ setattr (obj , "_otel_original_fn" , getattr (obj , "fn" ))
132+ setattr (obj , "fn" , implement_spans_get (obj .fn ))
133+
134+
135+ def get_delegator (
136+ estimator : Type [BaseEstimator ], method_name : str
137+ ) -> Union [_IffHasAttrDescriptor , None ]:
138+ """Get the delegator from a class method or None.
139+
140+ Args:
141+ estimator (BaseEstimator): A class derived from ``sklearn``'s
142+ ``BaseEstimator``.
143+ method_name (str): The method name of the estimator on which to
144+ check for delegation.
145+
146+ Returns:
147+ The delegator, if one exists, otherwise None.
148+ """
149+ class_attr = getattr (estimator , method_name )
150+ if getattr (class_attr , "__closure__" , None ) is not None :
151+ for cell in class_attr .__closure__ :
152+ if isinstance (cell .cell_contents , _IffHasAttrDescriptor ):
153+ return cell .cell_contents
154+ return None
155+
156+
107157def get_base_estimators (packages : List [str ]) -> Dict [str , Type [BaseEstimator ]]:
108158 """Walk package hierarchies to get BaseEstimator-derived classes.
109159
@@ -389,7 +439,7 @@ def _check_instrumented(
389439 method_name (str): The method name of the estimator on which to
390440 check for instrumentation.
391441 """
392- orig_method_name = "_original_ " + method_name
442+ orig_method_name = "_otel_original_ " + method_name
393443 has_original = hasattr (estimator , orig_method_name )
394444 orig_class , orig_method = getattr (
395445 estimator , orig_method_name , (None , None )
@@ -419,11 +469,12 @@ def _uninstrument_class_method(
419469 method_name (str): The method name of the estimator on which to
420470 apply a span.
421471 """
422- orig_method_name = "_original_ " + method_name
472+ orig_method_name = "_otel_original_ " + method_name
423473 if isclass (estimator ):
424474 qualname = estimator .__qualname__
425475 else :
426476 qualname = estimator .__class__ .__qualname__
477+ delegator = get_delegator (estimator , method_name )
427478 if self ._check_instrumented (estimator , method_name ):
428479 logger .debug (
429480 "Uninstrumenting: %s.%s" , qualname , method_name ,
@@ -433,6 +484,16 @@ def _uninstrument_class_method(
433484 estimator , method_name , orig_method ,
434485 )
435486 delattr (estimator , orig_method_name )
487+ elif delegator is not None :
488+ if not hasattr (delegator , "_otel_original_fn" ):
489+ logger .debug (
490+ "Already uninstrumented: %s.%s" , qualname , method_name ,
491+ )
492+ return
493+ setattr (
494+ delegator , "fn" , getattr (delegator , "_otel_original_fn" ),
495+ )
496+ delattr (delegator , "_otel_original_fn" )
436497 else :
437498 logger .debug (
438499 "Already uninstrumented: %s.%s" , qualname , method_name ,
@@ -452,7 +513,7 @@ def _uninstrument_instance_method(
452513 method_name (str): The method name of the estimator on which to
453514 apply a span.
454515 """
455- orig_method_name = "_original_ " + method_name
516+ orig_method_name = "_otel_original_ " + method_name
456517 if isclass (estimator ):
457518 qualname = estimator .__qualname__
458519 else :
@@ -496,37 +557,25 @@ def _instrument_class_method(
496557 )
497558 return
498559 class_attr = getattr (estimator , method_name )
560+ delegator = get_delegator (estimator , method_name )
499561 if isinstance (class_attr , property ):
500562 logger .debug (
501563 "Not instrumenting found property: %s.%s" ,
502564 estimator .__qualname__ ,
503565 method_name ,
504566 )
567+ elif delegator is not None :
568+ implement_spans_delegator (delegator )
505569 else :
506570 setattr (
507- estimator , "_original_" + method_name , (estimator , class_attr ),
571+ estimator ,
572+ "_otel_original_" + method_name ,
573+ (estimator , class_attr ),
508574 )
509575 setattr (
510576 estimator , method_name , self .spanner (class_attr , estimator ),
511577 )
512578
513- def _function_wrapper (self , function ):
514- """Get the inner-most decorator of a function."""
515- if hasattr (function , "__wrapped__" ):
516- if hasattr (function .__wrapped__ , "__wrapped__" ):
517- return self ._function_wrapper (function .__wrapped__ )
518- return function
519- return None
520-
521- def _function_wrapper_wrapper (self , function ):
522- """Get the second inner-most decorator of a function"""
523- if hasattr (function , "__wrapped__" ):
524- if hasattr (function .__wrapped__ , "__wrapped__" ):
525- if hasattr (function .__wrapped__ .__wrapped__ , "__wrapped__" ):
526- return self ._function_wrapper_wrapper (function .__wrapped__ )
527- return function
528- return None
529-
530579 def _unwrap_function (self , function ):
531580 """Fetch the function underlying any decorators"""
532581 if hasattr (function , "__wrapped__" ):
@@ -564,7 +613,9 @@ def _instrument_instance_method(
564613 )
565614 else :
566615 method = getattr (estimator , method_name )
567- setattr (estimator , "_original_" + method_name , (estimator , method ))
616+ setattr (
617+ estimator , "_otel_original_" + method_name , (estimator , method )
618+ )
568619 setattr (
569620 estimator , method_name , self .spanner (method , estimator ),
570621 )
0 commit comments