Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,5 @@ docs/source/gen_modules/
.vscode

datasets

docs/source/sg_execution_times.rst
examples/skada_logo*
6 changes: 6 additions & 0 deletions docs/source/all.rst
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,12 @@ Datasets :py:mod:`skada.datasets`
:no-members:
:no-inherited-members:

.. autosummary::
:toctree: gen_modules/
:template: class.rst

DomainAwareDataset

.. autosummary::
:toctree: gen_modules/
:template: function.rst
Expand Down
137 changes: 134 additions & 3 deletions skada/datasets/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,26 @@ def get_data_home(data_home: Union[str, os.PathLike, None]) -> str:


class DomainAwareDataset:
"""
Container carrying all dataset domains.

This class allows to store and manipulate datasets from multiple domains,
keeping track of the domain information for each sample.

Parameters
----------
domains : list of tuple or dict of tuple or None, optional
List or dictionary of domains to add at initialization.
Each domain can be a tuple (X, y) or (X, y, name).

Attributes
----------
domains_ : list
List of domains added, each as a tuple (X, y) or (X,).
domain_names_ : dict
Dictionary mapping each domain name to its internal identifier.
"""

def __init__(
self,
# xxx(okachaiev): not sure if dictionary is a good format :thinking:
Expand All @@ -82,6 +102,23 @@ def __init__(
def add_domain(
self, X, y=None, domain_name: Optional[str] = None
) -> "DomainAwareDataset":
"""
Add a new domain to the dataset.

Parameters
----------
X : np.ndarray
Feature matrix for the domain.
y : np.ndarray or None, optional
Labels for the domain. If None, labels are not provided.
domain_name : str, optional
Name of the domain. If None, a unique name is autogenerated.

Returns
-------
self : DomainAwareDataset
The updated dataset.
"""
if domain_name is not None:
# check the name is unique
# xxx(okachaiev): ValueError would be more appropriate
Expand All @@ -96,6 +133,21 @@ def add_domain(
def merge(
self, dataset: "DomainAwareDataset", names_mapping: Optional[Mapping] = None
) -> "DomainAwareDataset":
"""
Merge another DomainAwareDataset into this one.

Parameters
----------
dataset : DomainAwareDataset
The dataset to merge.
names_mapping : mapping, optional
Mapping from old domain names to new domain names.

Returns
-------
self : DomainAwareDataset
The updated dataset.
"""
for domain_name in dataset.domain_names_:
# xxx(okachaiev): this needs to be more flexible
# as it should be possible to pass only X with y=None
Expand All @@ -107,12 +159,40 @@ def merge(
return self

def get_domain(self, domain_name: str) -> Tuple[np.ndarray, Optional[np.ndarray]]:
"""
Retrieve the data and labels for a given domain.

Parameters
----------
domain_name : str
Name of the domain to retrieve.

Returns
-------
domain : tuple
Tuple containing (X, y) or (X,) for the specified domain.
"""
domain_id = self.domain_names_[domain_name]
return self.domains_[domain_id - 1]

def select_domain(
self, sample_domain: np.ndarray, domains: Union[str, Iterable[str]]
) -> np.ndarray:
"""
Select samples belonging to one or more domains.

Parameters
----------
sample_domain : np.ndarray
Array of domain labels for each sample.
domains : str or iterable of str
Domain name(s) to select.

Returns
-------
mask : np.ndarray
Boolean mask indicating selected samples.
"""
return select_domain(self.domain_names_, sample_domain, domains)

# xxx(okachaiev): i guess, if we are using names to pack domains into array,
Expand Down Expand Up @@ -238,10 +318,33 @@ def pack_train(
return_X_y: bool = True,
mask: Union[None, int, float] = None,
) -> PackedDatasetType:
"""Same as `pack`.
"""
Aggregate source and target domains for training.

This method is equivalent to :meth:`pack` with ``train=True``.
It masks the labels for target domains (with -1 or a custom mask value)
so that they are not available during training, as required for
domain adaptation scenarios.

Parameters
----------
as_sources : list of str
List of domain names to be used as sources.
as_targets : list of str
List of domain names to be used as targets.
return_X_y : bool, default=True
If True, returns a tuple (X, y, sample_domain). Otherwise,
returns a :class:`sklearn.utils.Bunch` object.
mask : int or float, optional
Value to mask labels at training time. If None, uses -1 for integers
and np.nan for floats.

Masks labels for target domains with -1 so they are not available
at training time.
Returns
-------
data : :class:`sklearn.utils.Bunch`
Dictionary-like object with attributes X, y, sample_domain, domain_names.
(X, y, sample_domain) : tuple if `return_X_y=True`
Tuple of (data, target, sample_domain).
"""
return self.pack(
as_sources=as_sources,
Expand All @@ -256,6 +359,34 @@ def pack_test(
as_targets: List[str],
return_X_y: bool = True,
) -> PackedDatasetType:
"""
Aggregate source and target domains for training.

This method is equivalent to :meth:`pack` with ``train=True``.
It masks the labels for target domains (with -1 or a custom mask value)
so that they are not available during training, as required for
domain adaptation scenarios.

Parameters
----------
as_sources : list of str
List of domain names to be used as sources.
as_targets : list of str
List of domain names to be used as targets.
return_X_y : bool, default=True
If True, returns a tuple (X, y, sample_domain). Otherwise,
returns a :class:`sklearn.utils.Bunch` object.
mask : int or float, optional
Value to mask labels at training time. If None, uses -1 for integers
and np.nan for floats.

Returns
-------
data : :class:`sklearn.utils.Bunch`
Dictionary-like object with attributes X, y, sample_domain, domain_names.
(X, y, sample_domain) : tuple if `return_X_y=True`
Tuple of (data, target, sample_domain).
"""
return self.pack(
as_sources=[],
as_targets=as_targets,
Expand Down