Skip to content

Commit f5cf028

Browse files
flefebvjuAlbergetgnassou
authored
[MRG] Gradual domain adaptation (#354)
* gradual da * gradual da * implement gradual domain adaptation * fix docstring * fix docstrings * add more comprehensive tests * Improves clarity of variable names (`X_t` -> `X_step`), and improves testing of estimator-fit condition --------- Co-authored-by: Julie Alberge <julie.alberge@gmail.com> Co-authored-by: Théo Gnassounou <66993815+tgnassou@users.noreply.github.com>
1 parent c750dde commit f5cf028

File tree

4 files changed

+610
-0
lines changed

4 files changed

+610
-0
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,3 +249,5 @@ The library is distributed under the 3-Clause BSD license.
249249
[36] Xiao, Zhiqing, Wang, Haobo, Jin, Ying, Feng, Lei, Chen, Gang, Huang, Fei, Zhao, Junbo.[SPA: A Graph Spectral Alignment Perspective for Domain Adaptation](https://arxiv.org/pdf/2310.17594). In Neurips, 2023.
250250

251251
[37] Xie, Renchunzi, Odonnat, Ambroise, Feofanov, Vasilii, Deng, Weijian, Zhang, Jianfeng and An, Bo. [MaNo: Exploiting Matrix Norm for Unsupervised Accuracy Estimation Under Distribution Shifts](https://arxiv.org/pdf/2405.18979). In NeurIPS, 2024.
252+
253+
[38] Y. He, H. Wang, B. Li, H. Zhao. [Gradual Domain Adaptation: Theory and Algorithms](https://arxiv.org/pdf/2310.13852). In Journal of Machine Learning Research, 2024.
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
# %%
2+
"""
3+
Gradual Domain Adaptation Using Optimal Transport
4+
=================================================
5+
6+
This example illustrates the GOAT method from [38] on a simple classification task.
7+
However, the CNN is replaced with a MLP.
8+
9+
.. [38] Y. He, H. Wang, B. Li, H. Zhao
10+
Gradual Domain Adaptation: Theory and Algorithms in
11+
Journal of Machine Learning Research, 2024.
12+
13+
"""
14+
15+
# Authors: Félix Lefebvre and Julie Alberge
16+
#
17+
# License: BSD 3-Clause
18+
19+
# %% Imports
20+
import matplotlib.pyplot as plt
21+
from sklearn.inspection import DecisionBoundaryDisplay
22+
from sklearn.neural_network import MLPClassifier
23+
24+
from skada import source_target_split
25+
from skada._gradual_da import GradualEstimator
26+
from skada.datasets import make_shifted_datasets
27+
28+
# %%
29+
# Generate conditional shift dataset
30+
# ----------------------------------
31+
32+
n, m = 20, 25 # number of source and target samples
33+
X, y, sample_domain = make_shifted_datasets(
34+
n_samples_source=n,
35+
n_samples_target=m,
36+
shift="conditional_shift",
37+
noise=0.1,
38+
random_state=42,
39+
)
40+
41+
# %%
42+
# Plot source and target datasets
43+
# -------------------------------
44+
45+
X_source, X_target, y_source, y_target = source_target_split(
46+
X, y, sample_domain=sample_domain
47+
)
48+
lims = (min(X[:, 0]) - 0.5, max(X[:, 0]) + 0.5, min(X[:, 1]) - 0.5, max(X[:, 1]) + 0.5)
49+
50+
n_tot_source = X_source.shape[0]
51+
n_tot_target = X_target.shape[0]
52+
53+
plt.figure(1, figsize=(8, 3.5))
54+
plt.subplot(121)
55+
56+
plt.scatter(X_source[:, 0], X_source[:, 1], c=y_source, vmax=9, cmap="tab10", alpha=0.7)
57+
plt.title("Source domain")
58+
plt.axis(lims)
59+
60+
plt.subplot(122)
61+
plt.scatter(X_target[:, 0], X_target[:, 1], c=y_target, vmax=9, cmap="tab10", alpha=0.7)
62+
plt.title("Target domain")
63+
plt.axis(lims)
64+
65+
# %%
66+
# Fit Gradual Domain Adaptation
67+
# -----------------------------
68+
#
69+
# We use a MLP classifier as the base estimator (default parameters).
70+
71+
base_estimator = MLPClassifier(hidden_layer_sizes=(50, 50))
72+
73+
gradual_adapter = GradualEstimator(
74+
n_steps=40, # number of adaptation steps
75+
base_estimator=base_estimator,
76+
advanced_ot_plan_sampling=True,
77+
save_estimators=True,
78+
save_intermediate_data=True,
79+
)
80+
81+
gradual_adapter.fit(
82+
X,
83+
y,
84+
sample_domain=sample_domain,
85+
)
86+
87+
# %%
88+
# Check results
89+
# -------------
90+
# Compute accuracy on source and target with the initial
91+
# estimator and the final estimator.
92+
93+
94+
clfs = gradual_adapter.get_intermediate_estimators()
95+
96+
ACC_source_init = clfs[0].score(X_source, y_source)
97+
ACC_target_init = clfs[0].score(X_target, y_target)
98+
99+
print(f"Initial accuracy on source domain: {ACC_source_init:.3f}")
100+
print(f"Initial accuracy on target domain: {ACC_target_init:.3f}")
101+
print("")
102+
103+
ACC_source = gradual_adapter.score(X_source, y_source)
104+
ACC_target = gradual_adapter.score(X_target, y_target)
105+
106+
print(f"Final accuracy on source domain: {ACC_source:.3f}")
107+
print(f"Final accuracy on target domain: {ACC_target:.3f}")
108+
109+
110+
# %%
111+
# Inspect intermediate states
112+
# ---------------------------
113+
#
114+
# We can plot the intermediate datasets and decision boundaries.
115+
116+
intermediate_data = gradual_adapter.intermediate_data_
117+
118+
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
119+
axes = axes.ravel()
120+
121+
# Define which steps to plot
122+
steps_to_plot = [5, 10, 15, 20, 25, 30, 35, 40]
123+
124+
for i, step in enumerate(steps_to_plot):
125+
ax = axes[i]
126+
X_step, y_step = intermediate_data[step - 1]
127+
clf = clfs[step - 1]
128+
129+
ax.scatter(X_step[:, 0], X_step[:, 1], c=y_step, vmax=9, cmap="tab10", alpha=0.7)
130+
DecisionBoundaryDisplay.from_estimator(
131+
clf,
132+
X,
133+
response_method="predict",
134+
cmap="gray_r",
135+
alpha=0.15,
136+
ax=ax,
137+
grid_resolution=200,
138+
)
139+
ax.set_title(f"t = {step}")
140+
ax.axis(lims)
141+
142+
plt.tight_layout()
143+
144+
145+
# %%
146+
# Plot decision boundaries on source and target datasets
147+
# ------------------------------------------------------
148+
#
149+
# Now we can see how this gradual domain adaptation has changed
150+
# the decision boundary between the source and target domains.
151+
152+
figure, axis = plt.subplots(1, 2, figsize=(9, 4))
153+
cm = "gray_r"
154+
DecisionBoundaryDisplay.from_estimator(
155+
clfs[0],
156+
X,
157+
response_method="predict",
158+
cmap=cm,
159+
alpha=0.15,
160+
ax=axis[0],
161+
grid_resolution=200,
162+
)
163+
axis[0].scatter(
164+
X_source[:, 0],
165+
X_source[:, 1],
166+
c=y_source,
167+
vmax=9,
168+
cmap="tab10",
169+
alpha=0.7,
170+
)
171+
axis[0].set_title("Source domain")
172+
DecisionBoundaryDisplay.from_estimator(
173+
clfs[-1],
174+
X,
175+
response_method="predict",
176+
cmap=cm,
177+
alpha=0.15,
178+
ax=axis[1],
179+
grid_resolution=200,
180+
)
181+
axis[1].scatter(
182+
X_target[:, 0],
183+
X_target[:, 1],
184+
c=y_target,
185+
vmax=9,
186+
cmap="tab10",
187+
alpha=0.7,
188+
)
189+
axis[1].set_title("Target domain")
190+
191+
axis[0].text(
192+
0.05,
193+
0.1,
194+
f"Accuracy: {clfs[0].score(X_source, y_source):.1%}",
195+
transform=axis[0].transAxes,
196+
ha="left",
197+
bbox={"boxstyle": "round", "facecolor": "white", "alpha": 0.5},
198+
)
199+
axis[1].text(
200+
0.05,
201+
0.1,
202+
f"Accuracy: {gradual_adapter.score(X_target, y_target):.1%}",
203+
transform=axis[1].transAxes,
204+
ha="left",
205+
bbox={"boxstyle": "round", "facecolor": "white", "alpha": 0.5},
206+
)
207+
208+
plt.show()
209+
# %%

0 commit comments

Comments
 (0)