Skip to content

Commit d7cef20

Browse files
authored
[MERGE] Add new scorer: MixValScorer (#221)
* Add new scorer: MixValScorer * Modif following appendix A of paper * Add arg to compute intra - Inter - Both Ice scores * Change impl to the one in the paper git
1 parent 28aaf82 commit d7cef20

File tree

5 files changed

+179
-0
lines changed

5 files changed

+179
-0
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,3 +235,6 @@ The library is distributed under the 3-Clause BSD license.
235235

236236
[31] Redko, Ievgen, Nicolas Courty, Rémi Flamary, and Devis Tuia.[ "Optimal transport for multi-source domain adaptation under target shift."](https://proceedings.mlr.press/v89/redko19a/redko19a.pdf) In The 22nd International Conference on artificial intelligence and statistics, pp. 849-858. PMLR, 2019.
237237

238+
[32] Hu, D., Liang, J., Liew, J. H., Xue, C., Bai, S., & Wang, X. (2023). [Mixed Samples as Probes for Unsupervised Model Selection in Domain Adaptation](https://proceedings.neurips.cc/paper_files/paper/2023/file/7721f1fea280e9ffae528dc78c732576-Paper-Conference.pdf). Advances in Neural Information Processing Systems 36 (2024).
239+
240+

docs/source/all.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ DA metrics :py:mod:`skada.metrics`
180180
DeepEmbeddedValidation
181181
SoftNeighborhoodDensity
182182
CircularValidation
183+
MixValScorer
183184

184185

185186
Model Selection :py:mod:`skada.model_selection`

skada/deep/tests/test_deep_scorer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from skada.metrics import (
1414
CircularValidation,
1515
DeepEmbeddedValidation,
16+
MixValScorer,
1617
PredictionEntropyScorer,
1718
SoftNeighborhoodDensity,
1819
)
@@ -25,6 +26,7 @@
2526
PredictionEntropyScorer(),
2627
SoftNeighborhoodDensity(),
2728
CircularValidation(),
29+
MixValScorer(),
2830
],
2931
)
3032
def test_generic_scorer_on_deepmodel(scorer, da_dataset):

skada/metrics.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,3 +619,125 @@ def _score(self, estimator, X, y, sample_domain=None):
619619
score = self.source_scorer(y[source_idx], y_pred_source)
620620

621621
return self._sign * score
622+
623+
624+
class MixValScorer(_BaseDomainAwareScorer):
625+
"""
626+
MixVal scorer for unsupervised domain adaptation.
627+
628+
This scorer uses mixup to create mixed samples from the target domain,
629+
and evaluates the model's consistency on these mixed samples.
630+
631+
See [32]_ for details.
632+
633+
Parameters
634+
----------
635+
alpha : float, default=0.55
636+
Mixing parameter for mixup.
637+
random_state : int, RandomState instance or None, default=None
638+
Controls the randomness of the mixing process.
639+
greater_is_better : bool, default=True
640+
Whether higher scores are better.
641+
ice_type : {'both', 'intra', 'inter'}, default='both'
642+
Type of ICE score to compute:
643+
- 'both': Compute both intra-cluster and inter-cluster ICE scores (average).
644+
- 'intra': Compute only intra-cluster ICE score.
645+
- 'inter': Compute only inter-cluster ICE score.
646+
647+
Attributes
648+
----------
649+
alpha : float
650+
Mixing parameter.
651+
random_state : RandomState
652+
Random number generator.
653+
_sign : int
654+
1 if greater_is_better is True, -1 otherwise.
655+
ice_type : str
656+
Type of ICE score to compute.
657+
658+
References
659+
----------
660+
.. [32] Dapeng Hu et al. Mixed Samples as Probes for Unsupervised Model
661+
Selection in Domain Adaptation.
662+
NeurIPS, 2023.
663+
"""
664+
665+
def __init__(
666+
self,
667+
alpha=0.55,
668+
random_state=None,
669+
greater_is_better=True,
670+
ice_type="both",
671+
):
672+
super().__init__()
673+
self.alpha = alpha
674+
self.random_state = random_state
675+
self._sign = 1 if greater_is_better else -1
676+
self.ice_type = ice_type
677+
678+
if self.ice_type not in ["both", "intra", "inter"]:
679+
raise ValueError("ice_type must be 'both', 'intra', or 'inter'")
680+
681+
def _score(self, estimator, X, y=None, sample_domain=None, **params):
682+
"""
683+
Compute the Interpolation Consistency Evaluation (ICE) score.
684+
685+
Parameters
686+
----------
687+
estimator : object
688+
The fitted estimator to evaluate.
689+
X : array-like of shape (n_samples, n_features)
690+
The input samples.
691+
y : Ignored
692+
Not used, present for API consistency by convention.
693+
sample_domain : array-like, default=None
694+
Domain labels for each sample.
695+
696+
Returns
697+
-------
698+
score : float
699+
The ICE score.
700+
"""
701+
X, _, sample_domain = check_X_y_domain(X, y, sample_domain)
702+
source_idx = extract_source_indices(sample_domain)
703+
X_target = X[~source_idx]
704+
705+
rng = check_random_state(self.random_state)
706+
rand_idx = rng.permutation(X_target.shape[0])
707+
708+
# Get predictions for target samples
709+
labels_a = estimator.predict(X_target, sample_domain=sample_domain[~source_idx])
710+
labels_b = labels_a[rand_idx]
711+
712+
# Intra-cluster and inter-cluster mixup
713+
same_idx = (labels_a == labels_b).nonzero()[0]
714+
diff_idx = (labels_a != labels_b).nonzero()[0]
715+
716+
# Mixup with images and hard pseudo labels
717+
mix_inputs = self.alpha * X_target + (1 - self.alpha) * X_target[rand_idx]
718+
mix_labels = self.alpha * labels_a + (1 - self.alpha) * labels_b
719+
720+
# Obtain predictions for the mixed samples
721+
mix_pred = estimator.predict(
722+
mix_inputs, sample_domain=np.full(mix_inputs.shape[0], -1)
723+
)
724+
725+
# Calculate ICE scores based on ice_type
726+
if self.ice_type in ["both", "intra"]:
727+
ice_same = (
728+
np.sum(mix_pred[same_idx] == mix_labels[same_idx]) / same_idx.shape[0]
729+
)
730+
731+
if self.ice_type in ["both", "inter"]:
732+
ice_diff = (
733+
np.sum(mix_pred[diff_idx] == mix_labels[diff_idx]) / diff_idx.shape[0]
734+
)
735+
736+
if self.ice_type == "both":
737+
ice_score = (ice_same + ice_diff) / 2
738+
elif self.ice_type == "intra":
739+
ice_score = ice_same
740+
else: # self.ice_type == 'inter'
741+
ice_score = ice_diff
742+
743+
return self._sign * ice_score

skada/tests/test_scorer.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
CircularValidation,
2424
DeepEmbeddedValidation,
2525
ImportanceWeightedScorer,
26+
MixValScorer,
2627
PredictionEntropyScorer,
2728
SoftNeighborhoodDensity,
2829
SupervisedScorer,
@@ -246,3 +247,53 @@ def test_deep_embedding_validation_no_transform(da_dataset):
246247
)["test_score"]
247248
assert scores.shape[0] == 3, "evaluate 3 splits"
248249
assert np.all(~np.isnan(scores)), "all scores are computed"
250+
251+
252+
def test_mixval_scorer(da_dataset):
253+
X, y, sample_domain = da_dataset.pack_train(as_sources=["s"], as_targets=["t"])
254+
estimator = make_da_pipeline(
255+
DensityReweightAdapter(),
256+
LogisticRegression()
257+
.set_fit_request(sample_weight=True)
258+
.set_score_request(sample_weight=True),
259+
)
260+
cv = ShuffleSplit(n_splits=3, test_size=0.3, random_state=0)
261+
262+
# Test with default parameters
263+
scorer = MixValScorer(alpha=0.55, random_state=42)
264+
scores = cross_validate(
265+
estimator,
266+
X,
267+
y,
268+
cv=cv,
269+
params={"sample_domain": sample_domain},
270+
scoring=scorer,
271+
)["test_score"]
272+
273+
assert scores.shape[0] == 3, "evaluate 3 splits"
274+
assert np.all(~np.isnan(scores)), "all scores are computed"
275+
assert np.all(scores >= 0) and np.all(scores <= 1), "scores are between 0 and 1"
276+
277+
# Test different ice_type options
278+
for ice_type in ["both", "intra", "inter"]:
279+
scorer = MixValScorer(alpha=0.55, random_state=42, ice_type=ice_type)
280+
scores = cross_validate(
281+
estimator,
282+
X,
283+
y,
284+
cv=cv,
285+
params={"sample_domain": sample_domain},
286+
scoring=scorer,
287+
)["test_score"]
288+
289+
assert scores.shape[0] == 3, f"evaluate 3 splits for ice_type={ice_type}"
290+
assert np.all(
291+
~np.isnan(scores)
292+
), f"all scores are computed for ice_type={ice_type}"
293+
assert np.all(scores >= 0) and np.all(
294+
scores <= 1
295+
), f"scores are between 0 and 1 for ice_type={ice_type}"
296+
297+
# Test invalid ice_type
298+
with pytest.raises(ValueError):
299+
MixValScorer(ice_type="invalid")

0 commit comments

Comments
 (0)