forked from scikit-adaptation/skada
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot_cross_validation_for_da.py
More file actions
376 lines (322 loc) · 11.3 KB
/
plot_cross_validation_for_da.py
File metadata and controls
376 lines (322 loc) · 11.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
"""
Visualizing cross-validation behavior in skada
==============================================
This example illustrates the use of DA cross-validation object such as
:class:`~skada.model_selection.DomainShuffleSplit`.
""" # noqa
# %%
# Let's prepare the imports:
# Author: Yanis Lalou
#
# License: BSD 3-Clause
# sphinx_gallery_thumbnail_number = 1
# %% imports
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import Patch
from skada.datasets import make_shifted_datasets
from skada.model_selection import (
DomainShuffleSplit,
LeaveOneDomainOut,
SourceTargetShuffleSplit,
StratifiedDomainShuffleSplit,
)
RANDOM_SEED = 0
cmap_data = plt.cm.PRGn
cmap_domain = plt.cm.RdBu
cmap_cv = plt.cm.coolwarm
n_splits = 4
# Since we'll be using a dataset with 2 source and 2 target domains,
# the lodo splitter will generate only at most 4 splits
n_splits_lodo = 4
# %%
# First we generate a dataset with 4 different domains.
# The domains are drawn from 4 different distributions: 2 source
# and 2 target distributions. The target distributions are shifted
# versions of the source distributions. Thus we will have a domain
# adaptation problem with 2 source domains and 2 target domains.
dataset = make_shifted_datasets(
n_samples_source=3,
n_samples_target=2,
shift="conditional_shift",
label="binary",
noise=0.4,
random_state=RANDOM_SEED,
return_dataset=True,
)
dataset2 = make_shifted_datasets(
n_samples_source=3,
n_samples_target=2,
shift="conditional_shift",
label="binary",
noise=0.4,
random_state=RANDOM_SEED + 1,
return_dataset=True,
)
dataset.merge(dataset2, names_mapping={"s": "s2", "t": "t2"})
X, y, sample_domain = dataset.pack(
as_sources=["s", "s2"], as_targets=["t", "t2"], mask_target_labels=True
)
_, target_labels, _ = dataset.pack(
as_sources=["s", "s2"], as_targets=["t", "t2"], mask_target_labels=False
)
# Sort by sample_domain first then by target_labels
indx_sort = np.lexsort((target_labels, sample_domain))
X = X[indx_sort]
y = y[indx_sort]
target_labels = target_labels[indx_sort]
sample_domain = sample_domain[indx_sort]
# For Lodo methods
X_lodo, y_lodo, sample_domain_lodo = dataset.pack_lodo()
indx_sort = np.lexsort((y_lodo, sample_domain_lodo))
X_lodo = X_lodo[indx_sort]
y_lodo = y_lodo[indx_sort]
sample_domain_lodo = sample_domain_lodo[indx_sort]
# %%
# We define functions to visualize the behavior of each
# cross-validation object. The number of splits is set to 4
# (2 for the lodo method). For each split, we visualize the
# indices selected for the training set (in blue) and the
# test set (in orange).
# Code source: scikit-learn documentation
# Modified for documentation by Yanis Lalou
# License: BSD 3 clause
def plot_cv_indices(cv, X, y, sample_domain, ax, n_splits, lw=10):
"""Create a sample plot for indices of a cross-validation object."""
# Generate the training/testing visualizations for each CV split
cv_args = {"X": X, "y": y, "sample_domain": sample_domain}
for ii, (tr, tt) in enumerate(cv.split(**cv_args)):
# Fill in indices with the training/test sample_domain
indices = np.array([np.nan] * len(X))
indices[tt] = 1
indices[tr] = 0
# Visualize the results
ax.scatter(
[i / 2 for i in range(1, len(indices) * 2 + 1, 2)],
[ii + 0.5] * len(indices),
c=indices,
marker="_",
lw=lw,
cmap=cmap_cv,
vmin=-0.2,
vmax=1.2,
)
# Plot the data classes and sample_domain at the end
ax.scatter(
[i / 2 for i in range(1, len(indices) * 2 + 1, 2)],
[ii + 1.5] * len(X),
c=y,
marker="_",
lw=lw,
cmap=cmap_data,
vmin=-1.2,
vmax=0.2,
)
ax.scatter(
[i / 2 for i in range(1, len(indices) * 2 + 1, 2)],
[ii + 2.5] * len(X),
c=sample_domain,
marker="_",
lw=lw,
cmap=cmap_domain,
vmin=-3.2,
vmax=3.2,
)
# Formatting
yticklabels = list(range(n_splits)) + ["class", "sample_domain"]
ax.set(
yticks=np.arange(n_splits + 2) + 0.5,
yticklabels=yticklabels,
ylim=[n_splits + 2.2, -0.2],
xlim=[0, len(X)],
)
ax.set_title(f"{type(cv).__name__}", fontsize=15)
return ax
def plot_lodo_indices(cv, X, y, sample_domain, ax, lw=10):
"""Create a sample plot for indices of a cross-validation object."""
# Generate the training/testing visualizations for each CV split
cv_args = {"X": X, "y": y, "sample_domain": sample_domain}
for ii, (tr, tt) in enumerate(cv.split(**cv_args)):
# Fill in indices with the training/test sample_domain
indices = np.array([np.nan] * len(X))
indices[tt] = 1
indices[tr] = 0
# Visualize the results
ax.scatter(
[i / 2 for i in range(1, len(indices) * 2 + 1, 2)],
[ii + 0.5] * len(indices),
c=indices,
marker="_",
lw=lw,
cmap=cmap_cv,
vmin=-0.2,
vmax=1.2,
s=1.8,
)
# Plot the data classes and sample_domain at the end
ax.scatter(
[i / 2 for i in range(1, len(indices) * 2 + 1, 2)],
[ii + 1.5] * len(X),
c=y,
marker="_",
lw=lw,
cmap=cmap_data,
vmin=-1.2,
vmax=0.2,
)
ax.scatter(
[i / 2 for i in range(1, len(indices) * 2 + 1, 2)],
[ii + 2.5] * len(X),
c=sample_domain,
marker="_",
lw=lw,
cmap=cmap_domain,
vmin=-3.2,
vmax=3.2,
)
# Formatting
yticklabels = list(range(n_splits)) + ["class", "sample_domain"]
ax.set(
yticks=np.arange(n_splits + 2) + 0.5,
yticklabels=yticklabels,
ylim=[n_splits + 2.2, -0.2],
xlim=[0, len(X)],
)
ax.set_title(f"{type(cv).__name__}", fontsize=15)
return ax
def plot_st_shuffle_indices(cv, X, y, target_labels, sample_domain, ax, n_splits, lw):
"""Create a sample plot for indices of a cross-validation object."""
for n, labels in enumerate([y, target_labels]):
# Generate the training/testing visualizations for each CV split
cv_args = {"X": X, "y": labels, "sample_domain": sample_domain}
for ii, (tr, tt) in enumerate(cv.split(**cv_args)):
# Fill in indices with the training/test sample_domain
indices = np.array([np.nan] * len(X))
indices[tt] = 1
indices[tr] = 0
# Visualize the results
ax[n].scatter(
[i / 2 for i in range(1, len(indices) * 2 + 1, 2)],
[ii + 0.5] * len(indices),
c=indices,
marker="_",
lw=lw,
cmap=cmap_cv,
vmin=-0.2,
vmax=1.2,
)
# Plot the data classes and sample_domain at the end
ax[n].scatter(
[i / 2 for i in range(1, len(indices) * 2 + 1, 2)],
[ii + 1.5] * len(X),
c=labels,
marker="_",
lw=lw,
cmap=cmap_data,
vmin=-1.2,
vmax=0.2,
)
ax[n].scatter(
[i / 2 for i in range(1, len(indices) * 2 + 1, 2)],
[ii + 2.5] * len(X),
c=sample_domain,
marker="_",
lw=lw,
cmap=cmap_domain,
vmin=-3.2,
vmax=3.2,
)
# Formatting
yticklabels = list(range(n_splits)) + ["class", "sample_domain"]
ax[n].set(
yticks=np.arange(n_splits + 2) + 0.5,
yticklabels=yticklabels,
ylim=[n_splits + 2.2, -0.2],
xlim=[0, len(X)],
)
return ax
# %%
# The following plot illustrates the behavior of
# :class:`~skada.model_selection.SourceTargetShuffleSplit`.
# The left plot shows the indices of the training and
# testing sets for each split and with the datased packed with
# :func:`~skada.datasets._base.DomainAwareDataset.pack`
# (the target domains labels are masked (=-1)).
# While the right plot shows the indices of the training and
# testing sets for each split and with the datased packed with
# :func:`~skada.datasets._base.DomainAwareDataset.pack` and
# argument mask_target_labels=False
cvs = [SourceTargetShuffleSplit]
for cv in cvs:
fig, ax = plt.subplots(1, 2, figsize=(7, 3), sharey=True)
fig.suptitle(f"{cv.__name__}", fontsize=15)
plot_st_shuffle_indices(
cv(n_splits), X, y, target_labels, sample_domain, ax, n_splits, 10
)
fig.legend(
[Patch(color=cmap_cv(0.8)), Patch(color=cmap_cv(0.02))],
["Testing set", "Training set"],
loc="center right",
)
fig.text(0.48, 0.01, "Sample index", ha="center")
fig.text(0.001, 0.5, "CV iteration", va="center", rotation="vertical")
# Make the legend fit
plt.tight_layout()
fig.subplots_adjust(right=0.7)
# %%
# The following plot illustrates the behavior of
# :class:`~skada.model_selection.LeaveOneDomainOut`.
# The plot shows the indices of the training and testing sets
# for each split and which domain is used as the target domain
# for each split.
cvs = [LeaveOneDomainOut]
for cv in cvs:
fig, ax = plt.subplots(figsize=(6, 3))
plot_lodo_indices(cv(n_splits_lodo), X_lodo, y_lodo, sample_domain_lodo, ax)
fig.legend(
[Patch(color=cmap_cv(0.8)), Patch(color=cmap_cv(0.02))],
["Testing set", "Training set"],
loc="center right",
)
fig.text(0.48, 0.01, "Sample index", ha="center")
fig.text(0.001, 0.5, "CV iteration", va="center", rotation="vertical")
# Make the legend fit
plt.tight_layout()
fig.subplots_adjust(right=0.7)
# %%
# Now let's see how the other
# cross-validation objects behave on our dataset.
cvs = [
DomainShuffleSplit,
StratifiedDomainShuffleSplit,
]
for cv in cvs:
fig, ax = plt.subplots(figsize=(6, 3))
plot_cv_indices(cv(n_splits), X, y, sample_domain, ax, n_splits)
fig.legend(
[Patch(color=cmap_cv(0.8)), Patch(color=cmap_cv(0.02))],
["Testing set", "Training set"],
loc="center right",
)
fig.text(0.48, 0.01, "Sample index", ha="center")
fig.text(0.001, 0.5, "CV iteration", va="center", rotation="vertical")
# Make the legend fit
plt.tight_layout()
fig.subplots_adjust(right=0.7)
# %%
# As we can see each splitter has a very different behavior:
# - :class:`~skada.model_selection.SourceTargetShuffleSplit`: Each sample
# is used once as a test set while the remaining samples
# form the training set.
# - :class:`~skada.model_selection.DomainShuffleSplit`:
# Randomly split the data depending on their sample_domain.
# Each fold is composed of samples coming from all
# source and target domains.
# - :class:`~skada.model_selection.StratifiedDomainShuffleSplit`: Same as
# :class:`~skada.model_selection.DomainShuffleSplit` but by also
# preserving the percentage of samples for each class and for each sample domain.
# Split depends not only on the samples sample_domain but also their label.
# - :class:`~skada.model_selection.LeaveOneDomainOut`: Each sample with the same
# sample_domain is used once as the target domain, while the remaining samples
# from the others sample_domain for the source domain (Can be used only with
# :func:`~skada.datasets._base.DomainAwareDataset.pack_lodo`)