Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/plot_how_to_use_skada.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@
PCA(n_components=2),
SelectSource(SVC()),
default_selector=SelectSourceTarget,
mask_target_labels=False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it false here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get why the use of SelectSourceTarget?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when you do that you have one PCA for source anc one for target but SVC is traine donly on source

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we should be able to mask target label with SelectSourceTarget no ? We don't want data leakage even if we have one PCA for source and one for target ?

)

pipe_perdomain.fit(X, y, sample_domain=sample_domain)
Expand Down
2 changes: 2 additions & 0 deletions skada/_ot.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,8 @@ def __init__(
self.max_iter = max_iter
self.tol = tol
self.verbose = verbose
# we predict target labels in this function so we can't mask them
self.predicts_target_labels = True

def fit_transform(self, X, y, sample_domain=None, *, sample_weight=None):
X, y, sample_domain = check_X_y_domain(X, y, sample_domain)
Expand Down
29 changes: 22 additions & 7 deletions skada/_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def make_da_pipeline(
memory: Optional[Memory] = None,
verbose: bool = False,
default_selector: Union[str, Callable[[BaseEstimator], BaseSelector]] = "shared",
mask_target_labels: bool = True,
) -> Pipeline:
"""Construct a :class:`~sklearn.pipeline.Pipeline` from the given estimators.

Expand Down Expand Up @@ -59,6 +60,9 @@ def make_da_pipeline(
callable that accepts :class:`~sklearn.base.BaseEstimator` and returns
the estimator encapsulated within a domain selector.

mask_target_labels : bool, default=True
Whether to mask target labels in the pipeline.

Returns
-------
p : Pipeline
Expand Down Expand Up @@ -93,8 +97,9 @@ def make_da_pipeline(
else:
names.append(name)
estimators.append(estimator)

wrapped_estimators = _wrap_with_selectors(estimators, default_selector)
wrapped_estimators = _wrap_with_selectors(
estimators, default_selector, mask_target_labels
)
steps = _name_estimators(wrapped_estimators)
named_steps = [
(auto_name, step) if user_name is None else (user_name, step)
Expand All @@ -107,10 +112,11 @@ def make_da_pipeline(
def _wrap_with_selector(
estimator: BaseEstimator,
selector: Union[str, Callable[[BaseEstimator], BaseSelector]],
mask_target_labels: bool = True,
) -> BaseSelector:
if (estimator is not None) and not isinstance(estimator, BaseSelector):
if callable(selector):
estimator = selector(estimator)
estimator = selector(estimator, mask_target_labels=mask_target_labels)
if not isinstance(estimator, BaseSelector):
raise ValueError(
"Callable `default_selector` has to return `BaseSelector` " # noqa: E501
Expand All @@ -123,7 +129,7 @@ def _wrap_with_selector(
f"Unsupported `default_selector` name: {selector}."
f"Use one of {_DEFAULT_SELECTORS.keys().join(', ')}"
)
estimator = selector_cls(estimator)
estimator = selector_cls(estimator, mask_target_labels=mask_target_labels)
else:
raise ValueError(f"Unsupported `default_selector` type: {type(selector)}")
return estimator
Expand All @@ -132,10 +138,19 @@ def _wrap_with_selector(
def _wrap_with_selectors(
estimators: List[BaseEstimator],
default_selector: Union[str, Callable[[BaseEstimator], BaseSelector]],
mask_target_labels: bool = True,
) -> List[BaseEstimator]:
return [
(_wrap_with_selector(estimator, default_selector)) for estimator in estimators
]
wrap_list = []
for estimator in estimators:
if getattr(estimator, "predicts_target_labels", False):
mask_target_labels = False

wrap_list.append(
_wrap_with_selector(
estimator, default_selector, mask_target_labels=mask_target_labels
)
)
return wrap_list


def _name_estimators(estimators):
Expand Down
14 changes: 9 additions & 5 deletions skada/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,15 @@
params : dict
Additional parameters declared in the routing
"""
y_type = _find_y_type(y)
if y_type == Y_Type.DISCRETE:
unmasked_idx = y != _DEFAULT_MASKED_TARGET_CLASSIFICATION_LABEL
elif y_type == Y_Type.CONTINUOUS:
unmasked_idx = np.isfinite(y)
if "sample_domain" in params:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With these two lines, we avoid semi-supervised DA. I think it's a residue of before no?

unmasked_idx = params["sample_domain"] >= 0

Check warning on line 165 in skada/_utils.py

View check run for this annotation

Codecov / codecov/patch

skada/_utils.py#L165

Added line #L165 was not covered by tests
else:
y_type = _find_y_type(y)
if y_type == Y_Type.DISCRETE:
unmasked_idx = y != _DEFAULT_MASKED_TARGET_CLASSIFICATION_LABEL
elif y_type == Y_Type.CONTINUOUS:
unmasked_idx = np.isfinite(y)

X, y, params = _apply_domain_masks(X, y, params, masks=unmasked_idx)
return X, y, params

Expand Down
43 changes: 40 additions & 3 deletions skada/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
_apply_domain_masks,
_merge_domain_outputs,
_remove_masked,
_route_params
_route_params,
_find_y_type,
Y_Type,
_DEFAULT_MASKED_TARGET_CLASSIFICATION_LABEL,
_DEFAULT_MASKED_TARGET_REGRESSION_LABEL,
)
from skada.utils import check_X_domain, check_X_y_domain, extract_source_indices

Expand Down Expand Up @@ -202,12 +206,13 @@

__metadata_request__transform = {'sample_domain': True}

def __init__(self, base_estimator: BaseEstimator, **kwargs):
def __init__(self, base_estimator: BaseEstimator, mask_target_labels: bool = True, **kwargs):
super().__init__()
self.base_estimator = base_estimator
self.base_estimator.set_params(**kwargs)
self._is_final = False
self._is_transformer = hasattr(base_estimator, 'transform')
self.mask_target_labels = mask_target_labels

def get_metadata_routing(self):
return (
Expand Down Expand Up @@ -342,6 +347,16 @@
routed_params = {k: params[k] for k in routing_request._consumes(params=params)}
return routed_params

def _auto_mask_target_labels(self, y, routed_params):
if y is not None and routed_params.get('sample_domain') is not None:
y_type = _find_y_type(y)
source_idx = extract_source_indices(routed_params['sample_domain'])
if y_type == Y_Type.DISCRETE:
y[~source_idx] = _DEFAULT_MASKED_TARGET_CLASSIFICATION_LABEL
elif y_type == Y_Type.CONTINUOUS:
y[~source_idx] = _DEFAULT_MASKED_TARGET_REGRESSION_LABEL
return y

def _remove_masked(self, X, y, routed_params):
"""Removes masked inputs before passing them to a downstream (base) estimator,
ensuring their compatibility with the DA pipeline, particularly for estimators
Expand Down Expand Up @@ -409,6 +424,9 @@

# xxx(okachaiev): solve the problem with parameter renaming
def _fit(self, routing_method, X_container, y=None, **params):
if self.mask_target_labels:
y = self._auto_mask_target_labels(y, params)

X, y, params = X_container.merge_out(y, **params)
routing = get_routing_for_object(self.base_estimator)
routing_request = getattr(routing, routing_method)
Expand Down Expand Up @@ -446,6 +464,9 @@
return self

def _fit(self, method_name, X_container, y, **params):
if self.mask_target_labels:
y = self._auto_mask_target_labels(y, params)

X, y, params = X_container.merge_out(y, **params)
sample_domain = params['sample_domain']
routing = get_routing_for_object(self.base_estimator)
Expand Down Expand Up @@ -473,6 +494,9 @@
domain_outputs = self._fit('fit_transform', X_container, y=y, **params)
output = _merge_domain_outputs(len(X_container), domain_outputs, allow_containers=True)
else:
if self.mask_target_labels:
y = self._auto_mask_target_labels(y, params)

Check warning on line 498 in skada/base.py

View check run for this annotation

Codecov / codecov/patch

skada/base.py#L497-L498

Added lines #L497 - L498 were not covered by tests

self._fit(X_container, y, **params)
X, y, method_params = X_container.merge_out(y, **params)
transform_params = _route_params(self.routing_.transform, method_params, self)
Expand Down Expand Up @@ -563,18 +587,31 @@
class SelectTarget(_BaseSelectDomain):
"""Selects only target domains for fitting base estimator."""

def __init__(self, base_estimator: BaseEstimator, mask_target_labels: bool = False, **kwargs):
# We do not mask target labels
# Because we want to be able to pass the target labels to the base estimator

if mask_target_labels:
raise ValueError("Target labels cannot be masked for SelectTarget.")

super().__init__(base_estimator, mask_target_labels=mask_target_labels, **kwargs)

def _select_indices(self, sample_domain):
return ~extract_source_indices(sample_domain)


class SelectSourceTarget(BaseSelector):

def __init__(self, source_estimator: BaseEstimator, target_estimator: Optional[BaseEstimator] = None):
def __init__(self, source_estimator: BaseEstimator, target_estimator: Optional[BaseEstimator] = None, mask_target_labels: bool = False, **kwargs):
if target_estimator is not None \
and hasattr(source_estimator, 'transform') \
and not hasattr(target_estimator, 'transform'):
raise TypeError("The provided source and target estimators must "
"both be transformers, or neither should be.")

if mask_target_labels:
raise ValueError("Target labels cannot be masked for SelectSourceTarget.")

self.source_estimator = source_estimator
self.target_estimator = target_estimator
# xxx(okachaiev): the fact that we need to put those variables
Expand Down
Loading