[MRG] Add CAN Method#251
Conversation
skada/deep/losses.py
Outdated
| if mask.sum() > 0: | ||
| class_features = features_s[mask] | ||
| normalized_features = F.normalize(class_features, p=2, dim=1) | ||
| centroid = normalized_features.mean(dim=0) |
There was a problem hiding this comment.
In the paper it seems to be only a sum no ?
| # Discard ambiguous classes | ||
| class_counts = torch.bincount(cluster_labels_t, minlength=n_classes) | ||
| valid_classes = class_counts >= class_threshold | ||
| mask_t = valid_classes[cluster_labels_t] |
There was a problem hiding this comment.
I don't see what this line is doing?
There was a problem hiding this comment.
class_counts = torch.bincount(cluster_labels_t, minlength=n_classes)counts how many samples are in each cluster.valid_classes = class_counts >= class_thresholdcreates a boolean tensor whereTrueindicates classes that have at leastclass_thresholdsamples.mask_t = valid_classes[cluster_labels_t]is using the cluster labels as indices into thevalid_classes tensor. This create a booleanmask_t, whereTrue` indicates samples that belong to classes with enough representation.
This part of the code corresponds to the Filter the ambiguous classes part of the paper pseudo algorithm.
| features_t = features_t[mask_t] | ||
| cluster_labels_t = cluster_labels_t[mask_t] | ||
|
|
||
| # Define sigmas |
There was a problem hiding this comment.
Do you cannot use the mmd distance from DAN?
There was a problem hiding this comment.
The formula is not exactly the same as for the mmd since before computing each mean we apply a specific mask
|
|
||
| for n_iter in range(self.max_iter): | ||
| # Assign samples to closest centroids | ||
| dissimilarities = self._compute_dissimilarities(X, centroids) |
There was a problem hiding this comment.
There is a difference here with the function cosine_similarities de torch?
There was a problem hiding this comment.
In paper: cosine_dissimilarity is 0.5*(1 − cosine_similarity)


Paper: https://arxiv.org/pdf/1901.00976
Mostly eq 3-4-5 + paragraph 3.4
New Features:
CANLossclass toskada/deep/_divergence.pyto implement the contrastive domain discrepancy (CDD) loss.CANfunction toskada/deep/_divergence.pyto implement the CAN domain adaptation method.New Utilities:
SphericalKMeansclass toskada/deep/utils.pyfor clustering using cosine similarity.Testing:
test_deep_divergence.pyto ensure the new method works as expected.Still needs to be done: