Fix mismatch issue of activations vslabels in layers bcs of slicing o…#57
Fix mismatch issue of activations vslabels in layers bcs of slicing o…#57anandawolz wants to merge 3 commits into
Conversation
…f padding & fix some downstream errors due to that change
| "Consider using discrete binning instead.", UserWarning) | ||
| num_bins = int( | ||
| 0.005 * len(data) | ||
| 0.5 * len(data) |
There was a problem hiding this comment.
temporary change to make calling of compute_metric(rdm) work with parameter is_discrete_labels = False
There was a problem hiding this comment.
solves the IndexError: index 1300 is out of bounds for axis 0 with size 982
problem was that RDM’s bin-indices (self.idxs) are being computed once at initialization over the original label length and then compute RDM tried to use self.idx instead of cutted labels which resulted in a mismatch with the padding-cutted activations
There is still a problem when parameter is_discrete_labels = False. Namly, we pass the entire labels_dict into every single call of RDM.compute, so that inside RDM.compute() we end up having an AttributeError. I tried to only pass the per-sample labels list to RDM.compute, but the not succeed . @CeliaBenquet maybe you can have a look
|
This PR is obsolete; the core issue has been addressed in PR #58. |
Fix label–activation shape mismatch in compute_metric (output_only=False)
Solving a shape mismatch issue in the compute_metric function when output_only=False. The problem arose because activations were padded and then sliced, but labels were not sliced accordingly, resulting in incompatible shapes. Additionally, there were transposition inconsistencies between labels and activations.
The fix ensures that labels are sliced in the same way as activations (without using _cut_array() since it expects 2D inputs and label shapes may vary). I also addressed some downstream issues caused by this change — notably in process_activations() — to ensure compatibility throughout.
Tests now pass for both output_only=False and output_only=True.