Skip to content
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
a480e48
empty as_sources for train=False in pack
MellotApolline Jun 24, 2025
6f0eba5
deprecate train parameter in pack
MellotApolline Jun 24, 2025
9266481
IMP: add deprecated class into utils and wrap pack_train and pack_test
tom-yneuro Jun 24, 2025
3742ab0
update docstring
MellotApolline Jun 24, 2025
77c772e
Merge branch 'main' into pack
MellotApolline Jun 24, 2025
a192fab
FIX: refactor pack_train and pack_test to pack and mask_target_labels…
tom-yneuro Jun 24, 2025
eca96c3
Merge branch 'pack' of https://github.com/vloison/skada into pack
tom-yneuro Jun 24, 2025
4966ffc
FIX: use mask_target_labels instead of train arguement in pack in plo…
tom-yneuro Jun 24, 2025
aa71480
make_dataset_from_moons_distribution
MellotApolline Jun 24, 2025
861e592
FIX: change test after change of default value of train/mask_target_l…
tom-yneuro Jun 24, 2025
bf8e8d0
Merge branch 'pack' of https://github.com/vloison/skada into pack
tom-yneuro Jun 24, 2025
86f5352
Merge branch 'main' into pack
MellotApolline Jun 24, 2025
4a30f91
fix
MellotApolline Jun 24, 2025
09838f7
fix test deep
MellotApolline Jun 24, 2025
2c27b14
FIX: add missing mask_target_labels argument to pack function
tom-yneuro Jun 24, 2025
d1fce7e
Merge branch 'pack' of https://github.com/vloison/skada into pack
tom-yneuro Jun 24, 2025
069fdff
Merge branch 'main' into pack
MellotApolline Jun 24, 2025
c75bc57
pre-commit
MellotApolline Jun 24, 2025
5cc1a9d
Merge branch 'pack' of github.com:vloison/skada into pack
MellotApolline Jun 24, 2025
dfdacf8
FIX: inside DomainAwareDataset use mask_target_labels in pack instead…
tom-yneuro Jun 24, 2025
83cc0dc
Merge branch 'main' into pack
antoinecollas Jun 24, 2025
9cb5430
Merge branch 'main' into pack
antoinecollas Jun 25, 2025
73fa54d
Merge branch 'main' into pack
tgnassou Jun 25, 2025
7b78159
FIX: remove _is_deprecated because already existing in sklearn
tom-yneuro Jun 25, 2025
fdb5cac
FIX: remove unused import
tom-yneuro Jun 25, 2025
dec2cd7
Improve docs for domain aware dataset
MellotApolline Jun 25, 2025
afd556e
FIX: use sklearn deprecated
tom-yneuro Jun 25, 2025
c9e5bb2
Merge branch 'main' into pack
antoinecollas Jun 25, 2025
3b712a9
update doc with antoine`s comments
MellotApolline Jun 25, 2025
d8c9241
remove empty line
MellotApolline Jun 25, 2025
1ccc1e2
make as_sources and as_targets mandatory args in pack
MellotApolline Jun 25, 2025
d5e82f3
Merge branch 'main' into pack
antoinecollas Jun 25, 2025
123489b
remove last pack_train and pack_test
MellotApolline Jun 25, 2025
1deb412
Merge branch 'pack' of github.com:vloison/skada into pack
MellotApolline Jun 25, 2025
185a0a6
warning deprecated
MellotApolline Jun 25, 2025
d26741d
Merge branch 'main' into pack
antoinecollas Jun 25, 2025
1cf5fe9
empty commit
antoinecollas Jun 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 33 additions & 25 deletions docs/source/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,50 +48,39 @@ model.score(X_test, y_test, sample_domain=sample_domain_test)

## Dataset

The new skada.datasets.DomainAwareDataset class acts as a container for all domains. Its API is built around the add_domain and pack methods, see example below.:
The new skada.datasets.DomainAwareDataset class acts as a dataset container for all domains. Its API is built around two main methods: `add_domain` and `pack`.

The class is initially empty. Data and labels of a new domain can be added to the dataset with `add_domain`:
```python
datasets = DomainAwareDataset()
datasets.add_domain(X_subj1, y_subj1, domain_name="subj_1")
datasets.add_domain(X_subj3, y_subj3, domain_name="subj_3")
datasets.add_domain(X_subj12, y_subj12, domain_name="subj_12")
X, y, sample_domain = datasets.pack(as_sources=['subj_12', 'subj_1'], as_targets=['subj_3'])
```
A domain label (int) is assigned to each domain in the order they were provided. For example, here the domain "subj_1" will have the domain label 1 and "subj_12" will have the domain label 3.

should be also compatible for fetchers, like
Once all the desired domains have been included in the dataset, the `pack` method is used to aggregate the selected domains and create the associated `sample_domain`, depending on whether the domains is called as `source` or as `target`:

```python
office31 = fetch_office31_surf_all()
X, y, sample_domain = office31.pack(as_sources=['amazon', 'dslr'], as_targets=['webcam'])
```

Method `pack` also accepts optional `return_X_y` argument (defaults to `True`). When this argument is set to `False`, the method returns `Bunch` object with the following set of keys:

```python
>>> office31 = fetch_all_office31_surf()
>>> data = office31.pack(as_sources=['amazon', 'dslr'], as_targets=['webcam'], return_X_y=False)
>>> data.keys()
dict_keys(['X', 'y', 'sample_domain', 'domain_names'])
X, y, sample_domain = datasets.pack(as_sources=['subj_12', 'subj_1'], as_targets=['subj_3'], mask_target_labels=False)
```
The `sample_domain` values are generated by taking the domain labels and changing their sign, according to the convention that source gets non-negative integer (1,2,..) and target always gets negative (-1,-2,...). In the previous example, the sample domain values for 'subj_12' and 'subj_3' will be 3 and -2, respectively.

This is mostly to cover use cases where you need access to `'domain_names'` labels. Domain labels are assigned following the convention that source gets non-negative integer (1,2,..) and target always gets negative (-1,-2,...). Labels are assigned in the order that datasets are provided, should make it easier to "reconstruct" labels even working with tuple output (without access to `Bunch` object). Absolute value of the label is always static for a given domain name, for example if "amazon" domain gets index 2 it will be included in `sample_domain` as 2 when included as source and -2 when included as target. Such convention is required to avoid fluctuations of domain labels (otherwise multi-estimator API won't be possible).

Considering different scenarios, the dataset provides the following helpers:

* `pack_train` masks labels for domains designated for being used as targets
* `pack_test` packs requested targets
`mask_target_labels` is a mandatory parameter of the `pack` method.
With mask_target_labels set to True, the labels y of the target domains are masked (set to -1), which enables unsupervised domain adaptation.
With mask_target_labels set to False, labels are returned for all domains, which is useful for supervised evaluation or analysis.

Working with an estimator with a new API would look like the following:

```python
office31 = fetch_office31_surf_all()
X_train, y_train, sample_domain = office31.pack_train(as_sources=['amazon', 'dslr'], as_targets=['webcam'])
X_train, y_train, sample_domain = office31.pack(as_sources=['amazon', 'dslr'], as_targets=['webcam'], mask_target_labels=True)

estimator = make_da_pipeline(CORALAdapter(),LogisticRegression())
estimator.fit(X_train, y_train, sample_domain=sample_domain)

# predict and score on target domain
X_test, y_test, sample_domain = office31.pack_test(as_targets=['webcam'])
X_test, y_test, sample_domain = office31.pack(as_targets=['webcam'], mask_target_labels=False)
webcam_idx = office31.select_domain(sample_domain, 'webcam')
y_target = estimator.predict(X_test,[webcam_idx], sample_domain=sample_domain[webcam_idx])
score = estimator.score(X_test[webcam_idx], y=y_test[webcam_idx], sample_domain=sample_domain[webcam_idx])
Expand All @@ -107,6 +96,25 @@ from skada.datasets import select_domain
source_idx = select_domain(office31.domain_names, sample_domain, ('amazon', 'dslr'))
```


The `pack` method is also compatible with fetchers, like:

```python
office31 = fetch_office31_surf_all()
X, y, sample_domain = office31.pack(as_sources=['amazon', 'dslr'], as_targets=['webcam'], mask_target_labels=False)
```

`pack` has an optional `return_X_y` argument (defaults to `True`). When this argument is set to `False`, the method returns `Bunch` object with the following set of keys:

```python
>>> office31 = fetch_all_office31_surf()
>>> data = office31.pack(as_sources=['amazon', 'dslr'], as_targets=['webcam'], return_X_y=False, mask_target_labels=False)
>>> data.keys()
dict_keys(['X', 'y', 'sample_domain', 'domain_names'])
```

This is mostly to cover use cases where you need access to `'domain_names'` labels. Since labels are assigned in the order that datasets are provided, it should make it easier to "reconstruct" labels even working with tuple output (without access to `Bunch` object). Absolute value of the label is always static for a given domain name, for example if "amazon" domain gets index 2 it will be included in `sample_domain` as 2 when included as source and -2 when included as target. Such convention is required to avoid fluctuations of domain labels (otherwise multi-estimator API won't be possible).

## Adapters and Estimators

### Adapter
Expand Down Expand Up @@ -193,13 +201,13 @@ See API usage examples in `skada/tests/test_scorer.py`.
The `SupervisedScorer` is a unique scorer that necessitates special consideration. Since it requires access to target labels, which are masked during the dataset packing process for training, this scorer mandates an additional key to be passed within the `params`. The usage is as follows:

```python
X, y, sample_domain = da_dataset.pack_train(as_sources=['s'], as_targets=['t'])
X, y, sample_domain = da_dataset.pack(as_sources=['s'], as_targets=['t'], mask_target_labels=True)
estimator = make_da_pipeline(
DensityReweightAdapter(),
LogisticRegression().set_score_request(sample_weight=True),
)
cv = ShuffleSplit(n_splits=3, test_size=0.3, random_state=0)
_, target_labels, _ = da_dataset.pack(as_sources=['s'], as_targets=['t'], train=False)
_, target_labels, _ = da_dataset.pack(as_sources=['s'], as_targets=['t'], mask_target_labels=False)
scoring = SupervisedScorer()
scores = cross_validate(
estimator,
Expand All @@ -220,7 +228,7 @@ The library includes a range of splitters designed specifically for domain adapt
`skada.model_selection.SourceTargetShuffleSplit`: This splitter functions similarly to the standard `ShuffleSplit` but takes into account the distinct separation between source and target domains. It follows the standard API structure:

```python
X, y, sample_domain = da_dataset.pack_train(as_sources=['s', 's2'], as_targets=['t', 't2'])
X, y, sample_domain = da_dataset.pack(as_sources=['s', 's2'], as_targets=['t', 't2'], mask_target_labels=True)
pipe = make_da_pipeline(
SubspaceAlignmentAdapter(n_components=2),
LogisticRegression(),
Expand Down
8 changes: 6 additions & 2 deletions examples/deep/plot_adversarial.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,12 @@
# ----------------------------------------------------------------------------

dataset = load_mnist_usps(n_classes=2, n_samples=0.5, return_dataset=True)
X, y, sample_domain = dataset.pack_train(as_sources=["mnist"], as_targets=["usps"])
X_test, y_test, sample_domain_test = dataset.pack_test(as_targets=["usps"])
X, y, sample_domain = dataset.pack(
as_sources=["mnist"], as_targets=["usps"], mask_target_labels=True
)
X_test, y_test, sample_domain_test = dataset.pack(
as_targets=["usps"], mask_target_labels=False
)

# %%
# Train a classic model
Expand Down
8 changes: 6 additions & 2 deletions examples/deep/plot_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,12 @@
# ----------------------------------------------------------------------------

dataset = load_mnist_usps(n_classes=2, n_samples=0.5, return_dataset=True)
X, y, sample_domain = dataset.pack_train(as_sources=["mnist"], as_targets=["usps"])
X_test, y_test, sample_domain_test = dataset.pack_test(as_targets=["usps"])
X, y, sample_domain = dataset.pack(
as_sources=["mnist"], as_targets=["usps"], mask_target_labels=True
)
X_test, y_test, sample_domain_test = dataset.pack(
as_targets=["usps"], mask_target_labels=False
)

# %%
# Train a classic model
Expand Down
8 changes: 6 additions & 2 deletions examples/deep/plot_optimal_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,12 @@
# ----------------------------------------------------------------------------

dataset = load_mnist_usps(n_classes=2, n_samples=0.5, return_dataset=True)
X, y, sample_domain = dataset.pack_train(as_sources=["mnist"], as_targets=["usps"])
X_test, y_test, sample_domain_test = dataset.pack_test(as_targets=["usps"])
X, y, sample_domain = dataset.pack(
as_sources=["mnist"], as_targets=["usps"], mask_target_labels=True
)
X_test, y_test, sample_domain_test = dataset.pack(
as_targets=["usps"], mask_target_labels=False
)

# %%
# Train a classic model
Expand Down
8 changes: 6 additions & 2 deletions examples/deep/plot_training_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,12 @@
# ----------------------------------------------------------------------------

dataset = load_mnist_usps(n_classes=2, n_samples=0.5, return_dataset=True)
X, y, sample_domain = dataset.pack_train(as_sources=["mnist"], as_targets=["usps"])
X_test, y_test, sample_domain_test = dataset.pack_test(as_targets=["usps"])
X, y, sample_domain = dataset.pack(
as_sources=["mnist"], as_targets=["usps"], mask_target_labels=True
)
X_test, y_test, sample_domain_test = dataset.pack(
as_targets=["usps"], mask_target_labels=False
)

# %%
# Training parameters
Expand Down
7 changes: 3 additions & 4 deletions examples/methods/plot_subspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,13 @@
return_dataset=True,
)

X_train, y_train, sample_domain_train = dataset.pack_train(
as_sources=["s"],
as_targets=["t"],
X_train, y_train, sample_domain_train = dataset.pack(
as_sources=["s"], as_targets=["t"], mask_target_labels=True
)
X, y, sample_domain = dataset.pack(
as_sources=["s"],
as_targets=["t"],
train=False,
mask_target_labels=False,
)
Xs, Xt, ys, yt = source_target_split(X, y, sample_domain=sample_domain)

Expand Down
4 changes: 3 additions & 1 deletion examples/plot_method_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,9 @@
# iterate over datasets
for ds_cnt, ds in enumerate(datasets):
# preprocess dataset, split into training and test part
X, y, sample_domain = ds.pack_train(as_sources=["s"], as_targets=["t"])
X, y, sample_domain = ds.pack(
as_sources=["s"], as_targets=["t"], mask_target_labels=True
)
Xs, ys = ds.get_domain("s")
Xt, yt = ds.get_domain("t")

Expand Down
8 changes: 6 additions & 2 deletions examples/validation/plot_cross_val_score_for_da.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
base_estimator = SVC()
estimator = EntropicOTMapping(base_estimator=base_estimator, reg_e=0.5, tol=1e-3)

X, y, sample_domain = dataset.pack_train(as_sources=["s"], as_targets=["t"])
X, y, sample_domain = dataset.pack(
as_sources=["s"], as_targets=["t"], mask_target_labels=True
)
X_source, X_target, y_source, y_target = source_target_split(
X, y, sample_domain=sample_domain
)
Expand All @@ -48,7 +50,9 @@
# by the DA pipeline thanks to :code:`sample_domain`. The :code:`target_labels`
# are only used by the :code:`SupervisedScorer`.

_, target_labels, _ = dataset.pack(as_sources=["s"], as_targets=["t"], train=False)
_, target_labels, _ = dataset.pack(
as_sources=["s"], as_targets=["t"], mask_target_labels=False
)
scores_sup = cross_val_score(
estimator,
X,
Expand Down
11 changes: 7 additions & 4 deletions examples/validation/plot_cross_validation_for_da.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,11 @@
)
dataset.merge(dataset2, names_mapping={"s": "s2", "t": "t2"})

X, y, sample_domain = dataset.pack_train(as_sources=["s", "s2"], as_targets=["t", "t2"])
X, y, sample_domain = dataset.pack(
as_sources=["s", "s2"], as_targets=["t", "t2"], mask_target_labels=True
)
_, target_labels, _ = dataset.pack(
as_sources=["s", "s2"], as_targets=["t", "t2"], train=False
as_sources=["s", "s2"], as_targets=["t", "t2"], mask_target_labels=False
)

# Sort by sample_domain first then by target_labels
Expand Down Expand Up @@ -277,11 +279,12 @@ def plot_st_shuffle_indices(cv, X, y, target_labels, sample_domain, ax, n_splits
# :class:`~skada.model_selection.SourceTargetShuffleSplit`.
# The left plot shows the indices of the training and
# testing sets for each split and with the datased packed with
# :func:`~skada.datasets._base.DomainAwareDataset.pack_train`
# :func:`~skada.datasets._base.DomainAwareDataset.pack`
# (the target domains labels are masked (=-1)).
# While the right plot shows the indices of the training and
# testing sets for each split and with the datased packed with
# :func:`~skada.datasets._base.DomainAwareDataset.pack_test`.
# :func:`~skada.datasets._base.DomainAwareDataset.pack` and
# argument mask_target_labels=False


cvs = [SourceTargetShuffleSplit]
Expand Down
6 changes: 4 additions & 2 deletions examples/validation/plot_gridsearch_for_da.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@
random_state=RANDOM_SEED,
return_dataset=True,
)
X, y, sample_domain = dataset.pack_train(as_sources=["s"], as_targets=["t"])
X_target, y_target, _ = dataset.pack_test(as_targets=["t"])
X, y, sample_domain = dataset.pack(
as_sources=["s"], as_targets=["t"], mask_target_labels=True
)
X_target, y_target, _ = dataset.pack(as_targets=["t"], mask_target_labels=False)

estimator = EntropicOTMapping(base_estimator=SVC(probability=True))
cv = ShuffleSplit(n_splits=5, test_size=0.3, random_state=RANDOM_SEED)
Expand Down
33 changes: 26 additions & 7 deletions skada/datasets/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
# License: BSD 3-Clause

import os
import warnings
from functools import reduce
from typing import Dict, Iterable, List, Mapping, Optional, Tuple, Union

import numpy as np
from sklearn.utils import Bunch
from sklearn.utils import Bunch, deprecated

_DEFAULT_HOME_FOLDER_KEY = "SKADA_DATA_FOLDER"
_DEFAULT_HOME_FOLDER = "~/skada_datasets"
Expand Down Expand Up @@ -202,7 +203,8 @@
as_sources: List[str] = None,
as_targets: List[str] = None,
return_X_y: bool = True,
train: bool = False,
mask_target_labels: bool = None,
train: Optional[bool] = None,
mask: Union[None, int, float] = None,
) -> PackedDatasetType:
"""Aggregates datasets from all domains into a unified domain-aware
Expand All @@ -219,9 +221,12 @@
When set to True, returns a tuple (X, y, sample_domain). Otherwise
returns :class:`~sklearn.utils.Bunch` object with the structure
described below.
train: bool, default=False
mask_target_labels : bool, default=None
This parameter should be set to True for training and False for testing.
When set to True, masks labels for target domains with -1
(or a `mask` given), so they are not available at train time.
train: Optional[bool], default=None
[DEPRECATED] Use `mask_target_labels`instead.
mask: int | float (optional), default=None
Value to mask labels at training time.

Expand All @@ -246,6 +251,18 @@
"""
Xs, ys, sample_domains = [], [], []
domain_labels = {}
if train is not None:
warnings.warn(

Check warning on line 255 in skada/datasets/_base.py

View check run for this annotation

Codecov / codecov/patch

skada/datasets/_base.py#L255

Added line #L255 was not covered by tests
"The `train` parameter is deprecated and will be removed in"
"future versions. Use `mask_target_labels` instead.",
DeprecationWarning,
)
mask_target_labels = train

Check warning on line 260 in skada/datasets/_base.py

View check run for this annotation

Codecov / codecov/patch

skada/datasets/_base.py#L260

Added line #L260 was not covered by tests
if mask_target_labels is None:
raise ValueError(

Check warning on line 262 in skada/datasets/_base.py

View check run for this annotation

Codecov / codecov/patch

skada/datasets/_base.py#L262

Added line #L262 was not covered by tests
"The `mask_target_labels` parameter must be set to True for"
"training or False for testing."
)
if as_sources is None:
as_sources = []
if as_targets is None:
Expand Down Expand Up @@ -279,7 +296,7 @@
X, y = target
else:
raise ValueError("Invalid definition for domain data")
if train:
if mask_target_labels:
if mask is not None:
y = np.array([mask] * X.shape[0], dtype=dtype)
elif y.dtype in (np.int32, np.int64):
Expand Down Expand Up @@ -311,6 +328,7 @@
)
)

@deprecated()
def pack_train(
self,
as_sources: List[str],
Expand Down Expand Up @@ -350,10 +368,11 @@
as_sources=as_sources,
as_targets=as_targets,
return_X_y=return_X_y,
train=True,
mask_target_labels=True,
mask=mask,
)

@deprecated()
def pack_test(
self,
as_targets: List[str],
Expand Down Expand Up @@ -384,7 +403,7 @@
as_sources=[],
as_targets=as_targets,
return_X_y=return_X_y,
train=False,
mask_target_labels=False,
)

def pack_lodo(self, return_X_y: bool = True) -> PackedDatasetType:
Expand Down Expand Up @@ -427,7 +446,7 @@
as_sources=list(self.domain_names_.keys()),
as_targets=list(self.domain_names_.keys()),
return_X_y=return_X_y,
train=True,
mask_target_labels=True,
)

def __str__(self) -> str:
Expand Down
5 changes: 4 additions & 1 deletion skada/datasets/_mnist_usps.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,5 +118,8 @@ def load_mnist_usps(
return dataset
else:
return dataset.pack(
as_sources=["mnist"], as_targets=["usps"], return_X_y=return_X_y
as_sources=["mnist"],
as_targets=["usps"],
return_X_y=return_X_y,
mask_target_labels=False,
)
Loading