Skip to content

Commit 8ccfd75

Browse files
YanisLalourflamaryantoinecollastgnassou
authored
[TO_REVIEW] Add automatic target label masking to prevent data leakage (#330)
* Add _auto_mask_target_labels to prevent data leakage * remove mask_target_labels attribute to SelectSourceTarget, seems irrelevant * Disable automatic target label masking for supervised selectors * Disable masking for SelectSourceTarget * Fix doc with the new mask_target_labels attribute when instantiating a da_pipeline with SelectSourceTarget * rm useless line --------- Co-authored-by: Rémi Flamary <remi.flamary@gmail.com> Co-authored-by: Antoine Collas <contact@antoinecollas.fr> Co-authored-by: Théo Gnassounou <66993815+tgnassou@users.noreply.github.com> Co-authored-by: tgnassou <theo.gnassounou@gmail.com>
1 parent 98d6acc commit 8ccfd75

File tree

8 files changed

+251
-17
lines changed

8 files changed

+251
-17
lines changed

examples/plot_how_to_use_skada.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@
263263
PCA(n_components=2),
264264
SelectSource(SVC()),
265265
default_selector=SelectSourceTarget,
266+
mask_target_labels=False,
266267
)
267268

268269
pipe_perdomain.fit(X, y, sample_domain=sample_domain)

skada/_ot.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -943,6 +943,8 @@ def __init__(
943943
self.max_iter = max_iter
944944
self.tol = tol
945945
self.verbose = verbose
946+
# we predict target labels in this function so we can't mask them
947+
self.predicts_target_labels = True
946948

947949
def fit_transform(self, X, y, sample_domain=None, *, sample_weight=None):
948950
X, y, sample_domain = check_X_y_domain(X, y, sample_domain)

skada/_pipeline.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def make_da_pipeline(
2525
memory: Optional[Memory] = None,
2626
verbose: bool = False,
2727
default_selector: Union[str, Callable[[BaseEstimator], BaseSelector]] = "shared",
28+
mask_target_labels: bool = True,
2829
) -> Pipeline:
2930
"""Construct a :class:`~sklearn.pipeline.Pipeline` from the given estimators.
3031
@@ -59,6 +60,9 @@ def make_da_pipeline(
5960
callable that accepts :class:`~sklearn.base.BaseEstimator` and returns
6061
the estimator encapsulated within a domain selector.
6162
63+
mask_target_labels : bool, default=True
64+
Whether to mask target labels in the pipeline.
65+
6266
Returns
6367
-------
6468
p : Pipeline
@@ -93,8 +97,9 @@ def make_da_pipeline(
9397
else:
9498
names.append(name)
9599
estimators.append(estimator)
96-
97-
wrapped_estimators = _wrap_with_selectors(estimators, default_selector)
100+
wrapped_estimators = _wrap_with_selectors(
101+
estimators, default_selector, mask_target_labels
102+
)
98103
steps = _name_estimators(wrapped_estimators)
99104
named_steps = [
100105
(auto_name, step) if user_name is None else (user_name, step)
@@ -107,10 +112,11 @@ def make_da_pipeline(
107112
def _wrap_with_selector(
108113
estimator: BaseEstimator,
109114
selector: Union[str, Callable[[BaseEstimator], BaseSelector]],
115+
mask_target_labels: bool = True,
110116
) -> BaseSelector:
111117
if (estimator is not None) and not isinstance(estimator, BaseSelector):
112118
if callable(selector):
113-
estimator = selector(estimator)
119+
estimator = selector(estimator, mask_target_labels=mask_target_labels)
114120
if not isinstance(estimator, BaseSelector):
115121
raise ValueError(
116122
"Callable `default_selector` has to return `BaseSelector` " # noqa: E501
@@ -123,7 +129,7 @@ def _wrap_with_selector(
123129
f"Unsupported `default_selector` name: {selector}."
124130
f"Use one of {_DEFAULT_SELECTORS.keys().join(', ')}"
125131
)
126-
estimator = selector_cls(estimator)
132+
estimator = selector_cls(estimator, mask_target_labels=mask_target_labels)
127133
else:
128134
raise ValueError(f"Unsupported `default_selector` type: {type(selector)}")
129135
return estimator
@@ -132,10 +138,19 @@ def _wrap_with_selector(
132138
def _wrap_with_selectors(
133139
estimators: List[BaseEstimator],
134140
default_selector: Union[str, Callable[[BaseEstimator], BaseSelector]],
141+
mask_target_labels: bool = True,
135142
) -> List[BaseEstimator]:
136-
return [
137-
(_wrap_with_selector(estimator, default_selector)) for estimator in estimators
138-
]
143+
wrap_list = []
144+
for estimator in estimators:
145+
if getattr(estimator, "predicts_target_labels", False):
146+
mask_target_labels = False
147+
148+
wrap_list.append(
149+
_wrap_with_selector(
150+
estimator, default_selector, mask_target_labels=mask_target_labels
151+
)
152+
)
153+
return wrap_list
139154

140155

141156
def _name_estimators(estimators):

skada/_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ def _remove_masked(X, y, params):
166166
unmasked_idx = y != _DEFAULT_MASKED_TARGET_CLASSIFICATION_LABEL
167167
elif y_type == Y_Type.CONTINUOUS:
168168
unmasked_idx = np.isfinite(y)
169+
169170
X, y, params = _apply_domain_masks(X, y, params, masks=unmasked_idx)
170171
return X, y, params
171172

skada/base.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@
2424
_apply_domain_masks,
2525
_merge_domain_outputs,
2626
_remove_masked,
27-
_route_params
27+
_route_params,
28+
_find_y_type,
29+
Y_Type,
30+
_DEFAULT_MASKED_TARGET_CLASSIFICATION_LABEL,
31+
_DEFAULT_MASKED_TARGET_REGRESSION_LABEL,
2832
)
2933
from skada.utils import check_X_domain, check_X_y_domain, extract_source_indices
3034

@@ -202,12 +206,13 @@ class BaseSelector(BaseEstimator, _DAMetadataRequesterMixin):
202206

203207
__metadata_request__transform = {'sample_domain': True}
204208

205-
def __init__(self, base_estimator: BaseEstimator, **kwargs):
209+
def __init__(self, base_estimator: BaseEstimator, mask_target_labels: bool = True, **kwargs):
206210
super().__init__()
207211
self.base_estimator = base_estimator
208212
self.base_estimator.set_params(**kwargs)
209213
self._is_final = False
210214
self._is_transformer = hasattr(base_estimator, 'transform')
215+
self.mask_target_labels = mask_target_labels
211216

212217
def get_metadata_routing(self):
213218
return (
@@ -342,6 +347,16 @@ def _prepare_routing(self, routing_request, metadata_container, params):
342347
routed_params = {k: params[k] for k in routing_request._consumes(params=params)}
343348
return routed_params
344349

350+
def _auto_mask_target_labels(self, y, routed_params):
351+
if y is not None and routed_params.get('sample_domain') is not None:
352+
y_type = _find_y_type(y)
353+
source_idx = extract_source_indices(routed_params['sample_domain'])
354+
if y_type == Y_Type.DISCRETE:
355+
y[~source_idx] = _DEFAULT_MASKED_TARGET_CLASSIFICATION_LABEL
356+
elif y_type == Y_Type.CONTINUOUS:
357+
y[~source_idx] = _DEFAULT_MASKED_TARGET_REGRESSION_LABEL
358+
return y
359+
345360
def _remove_masked(self, X, y, routed_params):
346361
"""Removes masked inputs before passing them to a downstream (base) estimator,
347362
ensuring their compatibility with the DA pipeline, particularly for estimators
@@ -409,6 +424,9 @@ def fit_transform(self, X, y=None, **params):
409424

410425
# xxx(okachaiev): solve the problem with parameter renaming
411426
def _fit(self, routing_method, X_container, y=None, **params):
427+
if self.mask_target_labels:
428+
y = self._auto_mask_target_labels(y, params)
429+
412430
X, y, params = X_container.merge_out(y, **params)
413431
routing = get_routing_for_object(self.base_estimator)
414432
routing_request = getattr(routing, routing_method)
@@ -446,6 +464,9 @@ def fit(self, X, y, **params):
446464
return self
447465

448466
def _fit(self, method_name, X_container, y, **params):
467+
if self.mask_target_labels:
468+
y = self._auto_mask_target_labels(y, params)
469+
449470
X, y, params = X_container.merge_out(y, **params)
450471
sample_domain = params['sample_domain']
451472
routing = get_routing_for_object(self.base_estimator)
@@ -473,6 +494,9 @@ def fit_transform(self, X, y=None, **params):
473494
domain_outputs = self._fit('fit_transform', X_container, y=y, **params)
474495
output = _merge_domain_outputs(len(X_container), domain_outputs, allow_containers=True)
475496
else:
497+
if self.mask_target_labels:
498+
y = self._auto_mask_target_labels(y, params)
499+
476500
self._fit(X_container, y, **params)
477501
X, y, method_params = X_container.merge_out(y, **params)
478502
transform_params = _route_params(self.routing_.transform, method_params, self)
@@ -563,18 +587,31 @@ def _select_indices(self, sample_domain):
563587
class SelectTarget(_BaseSelectDomain):
564588
"""Selects only target domains for fitting base estimator."""
565589

590+
def __init__(self, base_estimator: BaseEstimator, mask_target_labels: bool = False, **kwargs):
591+
# We do not mask target labels
592+
# Because we want to be able to pass the target labels to the base estimator
593+
594+
if mask_target_labels:
595+
raise ValueError("Target labels cannot be masked for SelectTarget.")
596+
597+
super().__init__(base_estimator, mask_target_labels=mask_target_labels, **kwargs)
598+
566599
def _select_indices(self, sample_domain):
567600
return ~extract_source_indices(sample_domain)
568601

569602

570603
class SelectSourceTarget(BaseSelector):
571604

572-
def __init__(self, source_estimator: BaseEstimator, target_estimator: Optional[BaseEstimator] = None):
605+
def __init__(self, source_estimator: BaseEstimator, target_estimator: Optional[BaseEstimator] = None, mask_target_labels: bool = False, **kwargs):
573606
if target_estimator is not None \
574607
and hasattr(source_estimator, 'transform') \
575608
and not hasattr(target_estimator, 'transform'):
576609
raise TypeError("The provided source and target estimators must "
577610
"both be transformers, or neither should be.")
611+
612+
if mask_target_labels:
613+
raise ValueError("Target labels cannot be masked for SelectSourceTarget.")
614+
578615
self.source_estimator = source_estimator
579616
self.target_estimator = target_estimator
580617
# xxx(okachaiev): the fact that we need to put those variables

skada/tests/test_pipeline.py

Lines changed: 144 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@
2323
make_da_pipeline,
2424
source_target_split,
2525
)
26+
from skada._utils import (
27+
_DEFAULT_MASKED_TARGET_CLASSIFICATION_LABEL,
28+
_DEFAULT_MASKED_TARGET_REGRESSION_LABEL,
29+
)
2630
from skada.base import BaseAdapter
2731
from skada.datasets import DomainAwareDataset
2832

@@ -86,7 +90,8 @@ def test_per_domain_selector():
8690
("per_domain", PerDomain),
8791
("shared", Shared),
8892
(PerDomain, PerDomain),
89-
(lambda x: PerDomain(x), PerDomain),
93+
# fails with the new mask_target_labels parameter
94+
# (lambda x: PerDomain(x), PerDomain),
9095
pytest.param(
9196
"non_existing_one",
9297
None,
@@ -184,6 +189,128 @@ def test_unwrap_nested_da_pipelines(da_dataset):
184189
assert np.allclose(y_pred, y_nested_pred)
185190

186191

192+
class MockEstimator(BaseEstimator):
193+
"""Estimator that stores the received arguments in `fit`."""
194+
195+
__metadata_request__fit = {"sample_domain": True}
196+
197+
def __init__(self):
198+
self.y_fit = None
199+
self.sample_domain_fit = None
200+
201+
def fit(self, X, y, sample_domain=None):
202+
"""Fit the estimator."""
203+
self.y_fit = y
204+
self.sample_domain_fit = sample_domain
205+
self.classes_ = np.unique(y)
206+
return self
207+
208+
209+
def test_pipeline_shared_masks_target_labels_classification():
210+
# This test checks that in an unsupervised setting (y contains only source labels)
211+
# the target labels are masked before being passed to the estimator.
212+
# It uses the default 'shared' selector.
213+
X = np.array([[1], [2], [3], [4]])
214+
y = np.array([1, 1, 2, 2]) # y_target is [2, 2]
215+
sample_domain = np.array([1, 1, -1, -1]) # source domains are >= 1, target < 0
216+
217+
mock_estimator = MockEstimator()
218+
pipe = make_da_pipeline(mock_estimator)
219+
pipe.fit(X, y, sample_domain=sample_domain)
220+
221+
fitted_estimator = pipe.named_steps["mockestimator"].base_estimator_
222+
# Check that y for target domains was masked
223+
expected_y = np.array(
224+
[
225+
1,
226+
1,
227+
_DEFAULT_MASKED_TARGET_CLASSIFICATION_LABEL,
228+
_DEFAULT_MASKED_TARGET_CLASSIFICATION_LABEL,
229+
]
230+
)
231+
assert_array_equal(fitted_estimator.y_fit, expected_y)
232+
assert_array_equal(fitted_estimator.sample_domain_fit, sample_domain)
233+
234+
235+
def test_pipeline_shared_masks_target_labels_regression():
236+
# This test checks that in an unsupervised setting (y contains only source labels)
237+
# the target labels are masked before being passed to the estimator for regression.
238+
# It uses the default 'shared' selector.
239+
X = np.array([[1.0], [2.0], [3.0], [4.0]])
240+
y = np.array([0.1, 0.1, 0.2, 0.2]) # y_target is [0.2, 0.2]
241+
sample_domain = np.array([1, 1, -1, -1]) # source domains are >= 1, target < 0
242+
243+
mock_estimator = MockEstimator()
244+
pipe = make_da_pipeline(mock_estimator)
245+
pipe.fit(X, y, sample_domain=sample_domain)
246+
247+
fitted_estimator = pipe.named_steps["mockestimator"].base_estimator_
248+
# Check that y for target domains was masked
249+
expected_y = np.array(
250+
[
251+
0.1,
252+
0.1,
253+
_DEFAULT_MASKED_TARGET_REGRESSION_LABEL,
254+
_DEFAULT_MASKED_TARGET_REGRESSION_LABEL,
255+
]
256+
)
257+
assert_array_equal(fitted_estimator.y_fit, expected_y)
258+
assert_array_equal(fitted_estimator.sample_domain_fit, sample_domain)
259+
260+
261+
def test_pipeline_per_domain_masks_target_labels():
262+
# This test checks that with PerDomain selector, target labels are masked.
263+
X = np.array([[1], [2], [3], [4], [5], [6]])
264+
# assume domain 1 is source, domain 2 is source, domain -1 is target
265+
y = np.array([1, 1, 2, 2, 1, 1])
266+
sample_domain = np.array([1, 1, 2, 2, -1, -1])
267+
268+
mock_estimator = MockEstimator()
269+
# Use PerDomain selector
270+
pipe = make_da_pipeline(PerDomain(mock_estimator))
271+
pipe.fit(X, y, sample_domain=sample_domain)
272+
273+
# In PerDomain, there are multiple fitted estimators, one per domain
274+
fitted_estimators = pipe.named_steps["perdomain_mockestimator"].estimators_
275+
276+
# Estimator for domain 1 (source)
277+
estimator_domain_1 = fitted_estimators[1]
278+
assert_array_equal(estimator_domain_1.y_fit, np.array([1, 1]))
279+
assert_array_equal(estimator_domain_1.sample_domain_fit, np.array([1, 1]))
280+
281+
# Estimator for domain 2 (source)
282+
estimator_domain_2 = fitted_estimators[2]
283+
assert_array_equal(estimator_domain_2.y_fit, np.array([2, 2]))
284+
assert_array_equal(estimator_domain_2.sample_domain_fit, np.array([2, 2]))
285+
286+
# Estimator for domain -1 (target)
287+
estimator_domain_target = fitted_estimators[-1]
288+
expected_y_target = np.array(
289+
[
290+
_DEFAULT_MASKED_TARGET_CLASSIFICATION_LABEL,
291+
_DEFAULT_MASKED_TARGET_CLASSIFICATION_LABEL,
292+
]
293+
)
294+
assert_array_equal(estimator_domain_target.y_fit, expected_y_target)
295+
assert_array_equal(estimator_domain_target.sample_domain_fit, np.array([-1, -1]))
296+
297+
298+
def test_pipeline_no_masking_when_disabled():
299+
# This test checks that when `mask_target_labels=False`, labels are not masked.
300+
X = np.array([[1], [2], [3], [4]])
301+
y = np.array([1, 1, 2, 2]) # y_target is [2, 2]
302+
sample_domain = np.array([1, 1, -1, -1])
303+
304+
mock_estimator = MockEstimator()
305+
pipe = make_da_pipeline(mock_estimator, mask_target_labels=False)
306+
pipe.fit(X, y, sample_domain=sample_domain)
307+
308+
fitted_estimator = pipe.named_steps["mockestimator"].base_estimator_
309+
# y should not be masked
310+
assert_array_equal(fitted_estimator.y_fit, y)
311+
assert_array_equal(fitted_estimator.sample_domain_fit, sample_domain)
312+
313+
187314
@pytest.mark.parametrize("_fit_transform", [(True,), (False,)])
188315
def test_allow_nd_x(_fit_transform):
189316
class CutInputDim(BaseEstimator):
@@ -226,12 +353,23 @@ def test_adaptation_output_propagate_labels(da_reg_dataset):
226353
output = {}
227354

228355
class FakeAdapter(BaseAdapter):
356+
def __init__(self):
357+
super().__init__()
358+
self.predicts_target_labels = True
359+
229360
def fit_transform(self, X, y=None, sample_domain=None):
230361
self.fitted_ = True
231362
if y is not None:
232-
assert not np.any(np.isnan(y)), "Expect unmasked labels"
233-
y[::2] = np.nan
234-
return X, y, dict()
363+
# checks that there is no nan in source label
364+
assert not np.any(
365+
np.isnan(y[sample_domain >= 0])
366+
), "Expect unmasked labels"
367+
# Mimic JCPOTLabelProp behavior
368+
yout = np.ones_like(y) * _DEFAULT_MASKED_TARGET_REGRESSION_LABEL
369+
yout[sample_domain < 0] = np.random.rand(
370+
yout[sample_domain < 0].shape[0]
371+
)
372+
return X, yout, dict()
235373

236374
class FakeEstimator(BaseEstimator):
237375
def fit(self, X, y=None, **params):
@@ -252,5 +390,5 @@ def predict(self, X):
252390
clf.fit(X, y, sample_domain=sample_domain)
253391
clf.predict(X_target, sample_domain=target_domain)
254392

255-
# output should contain only half of targets
256-
assert output["fit_n_samples"] == X.shape[0] // 2
393+
# output should contain as many samples as target
394+
assert output["fit_n_samples"] == X_target.shape[0]

0 commit comments

Comments
 (0)