Skip to content

Commit fef9518

Browse files
committed
fix tests?
1 parent 44ae2e9 commit fef9518

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

skada/metrics.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,12 @@
1515
from sklearn.metrics import balanced_accuracy_score, check_scoring
1616
from sklearn.model_selection import train_test_split
1717
from sklearn.neighbors import KernelDensity
18+
from sklearn.pipeline import Pipeline
1819
from sklearn.preprocessing import LabelEncoder, Normalizer
1920
from sklearn.utils import check_random_state
2021
from sklearn.utils.extmath import softmax
2122
from sklearn.utils.metadata_routing import _MetadataRequester, get_routing_for_object
2223

23-
from skada.deep.base import DomainAwareNet
24-
2524
from ._utils import (
2625
_DEFAULT_MASKED_TARGET_CLASSIFICATION_LABEL,
2726
_DEFAULT_MASKED_TARGET_REGRESSION_LABEL,
@@ -397,7 +396,7 @@ def _score(self, estimator, X, y, sample_domain=None, **kwargs):
397396
)
398397

399398
has_transform_method = False
400-
if isinstance(estimator, DomainAwareNet):
399+
if not isinstance(estimator, Pipeline):
401400
# The estimator is a deep model
402401
if estimator.module_.layer_name is None:
403402
raise ValueError("The layer_name of the estimator is not set.")

0 commit comments

Comments
 (0)