Skip to content

Commit fa95e7b

Browse files
mbarnechetgnassou
andauthored
[MRG] Add DeepDADataset (#302)
* DeepDADataset class without containers creation of DeepDADataset class and main functionalities this version is without containers. * added dictionary handling and getitem method * finished main DDAD logic, added methods * fixed merging and empty initialisation * added methods to modify and manipulate the dataset * relocated DDAD to deep\base.py * added docstrings to most methods * added missing docstrings to DDAD methods * added methods and most of the tests * added tests and removed unecessary methods * DDAD tests completed * improved DDAD, implemented in prepare_input method * update DDAD * adapted to tests * updated DDAD string representation * updated after comments * updated DDAD after discussion, connected to skada * added simple DDAD guide * updated for the last tests * changed according to comments * Update base.py * Updated for a test * updated guide for DDAD because of an error * added pandas dataframe, updated doc and example * Update plot_deepdadataset.py * changed requirements --------- Co-authored-by: Théo Gnassounou <66993815+tgnassou@users.noreply.github.com>
1 parent cf09ebd commit fa95e7b

File tree

6 files changed

+906
-88
lines changed

6 files changed

+906
-88
lines changed
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
"""
2+
Deep Domain Aware Datasets
3+
==========================
4+
5+
This example illustrate some uses of DeepDADatasets.
6+
"""
7+
# Author: Maxence Barneche
8+
#
9+
# License: BSD 3-Clause
10+
# sphinx_gallery_thumbnail_number = 4
11+
# %%
12+
13+
import numpy as np
14+
import pandas as pd
15+
import torch
16+
17+
from skada.datasets import make_shifted_datasets
18+
from skada.deep.base import DeepDADataset
19+
20+
# %%
21+
# Creation
22+
# --------
23+
# Deep domain aware datasets are a unified representation of data for deep
24+
# methods in skada.
25+
#
26+
# In those datasets, a data sample has four (optionally, five) attributes:
27+
# * the data point :code:`X`
28+
# * the label :code:`y`
29+
# * the domain :code:`sample_domain`
30+
# * optionally, the weight :code:`sample_weight`
31+
# * the sample index :code:`sample_idx` (automatically generated), which is
32+
# the index of the sample in the dataset, relative to its domain.
33+
#
34+
# Note that the data is not shuffled, so the order of the samples is preserved.
35+
#
36+
# .. WARNING::
37+
# In a dataset, either all data samples have a weight, or none of them.
38+
# On the other hand, it is possible that a sample has no associated label or domain.
39+
# In that case, it will be associated to label :code:`-1` and domain :code:`0`.
40+
#
41+
# DeepDADatasets can be created from numpy arrays, torch tensors, lists,
42+
# tuples, or dictionary of one of the former.
43+
#
44+
# If a dictionary is provided, it must contain the keys :code:`X`, :code:`y`(optional),
45+
# :code:`sample_domain`(optional) and :code:`sample_weight`(optional).
46+
#
47+
# If both dictionary and positional arguments are provided, the dictionary
48+
# arguments will take precedence over the positional ones.
49+
50+
# practice dataset as numpy arrays
51+
raw_data = make_shifted_datasets(20, 20, random_state=42)
52+
X, y, sample_domain = raw_data
53+
# though these are not technically weights, they will act as such throughout the guide.
54+
weights = np.ones_like(y)
55+
dict_raw_data = {"X": X, "sample_domain": sample_domain, "y": y}
56+
weighted_dict_raw_data = {
57+
"X": X,
58+
"sample_domain": sample_domain,
59+
"y": y,
60+
"sample_weight": weights,
61+
}
62+
63+
dataset = DeepDADataset(X, y, sample_domain)
64+
dataset_from_dict = DeepDADataset(dict_raw_data)
65+
# it is possible to add weights to the dataset, either at creation or later
66+
dataset_with_weights = DeepDADataset(X, y, sample_domain, weights)
67+
dataset_with_weights_from_dict = DeepDADataset(weighted_dict_raw_data)
68+
69+
# these methods change the dataset in place and return the dataset itself
70+
dataset = dataset.add_weights(weights)
71+
dataset = dataset.remove_weights()
72+
73+
# %%
74+
# It is also possible to create a DeepDADataset from lists, tuples, tensors,
75+
# pandas dataframes or any combination of those.
76+
#
77+
# .. note::
78+
# Just like for the dictionary, if a pandas dataframe is provided it must
79+
# contain the keys :code:`X`, :code:`y` (optional), :code:`sample_domain`(optional)
80+
# and :code:`sample_weight` (optional).
81+
# Also, the data in the dataframe will take precedence over the positional arguments.
82+
83+
# from lists
84+
dataset_from_list = DeepDADataset(X.tolist(), y.tolist(), sample_domain.tolist())
85+
# from tuples
86+
dataset_from_tuple = DeepDADataset(
87+
tuple(X.tolist()), tuple(y.tolist()), tuple(sample_domain.tolist())
88+
)
89+
90+
# from torch tensors
91+
dataset_from_tensor = DeepDADataset(
92+
torch.tensor(X), torch.tensor(y), torch.tensor(sample_domain)
93+
)
94+
95+
# from pandas dataframe of same structure as the dictionary
96+
df = pd.DataFrame(
97+
{
98+
"X": list(X),
99+
"y": y,
100+
"sample_domain": sample_domain,
101+
"sample_weight": weights,
102+
}
103+
)
104+
dataset_from_df = DeepDADataset(df)
105+
106+
# %%
107+
# It is also possible to merge two datasets, which will concatenate the data
108+
# samples, the labels and the domains.
109+
dataset2 = dataset.merge(dataset)
110+
111+
# %%
112+
# Accessing data
113+
# ----------------
114+
#
115+
# The data can be accessed with the same indexing methods as for a torch tensor.
116+
# The returned data is a tuple with a dictionary with the keys :code:`X`,
117+
# :code:`sample_domain`, :code:`sample_idx`, and optionally :code:`sample_weight`
118+
# as first element and the corresponding label :code:`y` as second element.
119+
#
120+
# ..note::
121+
# The data is stored in torch tensors, with dimension 1 for :code:`sample_domain`,
122+
# :code:`y` and :code:`sample_weight`.
123+
#
124+
# It is also possible to access the data through the various selection methods,
125+
# all of which return DeepDADatasets instances.
126+
127+
# indexing methods return a tuple with the data as dict and the label
128+
first_sample = dataset[0] # first sample
129+
first_five_samples = dataset[0:5] # first five samples
130+
131+
# selecting methods return a DeepDADataset with the selected samples
132+
domain_1_samples = dataset.select_domain(1) # all samples from domain 1
133+
label_1_samples = dataset.select(
134+
lambda label: label == 1, on="y"
135+
) # all samples with label 1

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ numpy>=1.24
22
scipy>=1.10
33
scikit-learn>=1.5.0
44
pot>=0.9.0
5+
pandas>=2.3.0

requirements_full.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
-r requirements.txt
22
torch
33
torchvision
4-
skorch
4+
skorch

0 commit comments

Comments
 (0)