@@ -177,8 +177,8 @@ def get_ensamble_preds(val_features, probs, zeroshot_weights, dataset_domains=No
177177 except :
178178 outputs = probs
179179 print (outputs .shape )
180- salem_preds = np .argmax (outputs , axis = 1 )
181- print (salem_preds .shape )
180+ lads_preds = np .argmax (outputs , axis = 1 )
181+ print (lads_preds .shape )
182182 # CLIP ZS
183183 zeroshot_weights = zeroshot_weights .cuda ()
184184 images = torch .tensor (val_features ).cuda ()
@@ -196,27 +196,27 @@ def get_ensamble_preds(val_features, probs, zeroshot_weights, dataset_domains=No
196196 dom_preds = []
197197 for i in range (len (ensambled_preds )):
198198 if soft_dom_label [i ] == 0 :
199- dom_preds .append (salem_preds [i ])
199+ dom_preds .append (lads_preds [i ])
200200 else :
201201 dom_preds .append (ensambled_preds [i ])
202202 ret_preds = np .array (dom_preds )
203203 else :
204204 ret_preds = ensambled_preds
205205
206- return salem_preds , zs_preds , ensambled_preds , ret_preds
206+ return lads_preds , zs_preds , ensambled_preds , ret_preds
207207
208- def get_pred_overlap (salem_preds , zs_preds , labels ):
208+ def get_pred_overlap (lads_preds , zs_preds , labels ):
209209 """
210- Get the overlap in correct predictions for salem and zeroshot.
210+ Get the overlap in correct predictions for lads and zeroshot.
211211 """
212- salem_correct = np .where (salem_preds == labels )[0 ]
212+ lads_correct = np .where (lads_preds == labels )[0 ]
213213 zs_correct = np .where (zs_preds == labels )[0 ]
214- print (len (salem_correct ), len (zs_correct ))
215- print ("salem correct " , salem_correct [:10 ])
214+ print (len (lads_correct ), len (zs_correct ))
215+ print ("lads correct " , lads_correct [:10 ])
216216 print ("zs correct " , zs_correct [:10 ])
217- salem_overlap = [i for i in salem_correct if i in zs_correct ]
218- salem_nonverlap = [i for i in salem_correct if not (i in zs_correct )]
219- zs_nonverlap = [i for i in zs_correct if not (i in salem_correct )]
220- num_zs_correct_nonoverlap = len (zs_correct ) - len (salem_overlap )
221- num_salem_correct_nonverlap = len (salem_correct ) - len (salem_overlap )
222- return num_salem_correct_nonverlap , num_salem_correct_nonverlap / len (labels ), num_salem_correct_nonverlap / len (salem_correct )
217+ lads_overlap = [i for i in lads_correct if i in zs_correct ]
218+ lads_nonverlap = [i for i in lads_correct if not (i in zs_correct )]
219+ zs_nonverlap = [i for i in zs_correct if not (i in lads_correct )]
220+ num_zs_correct_nonoverlap = len (zs_correct ) - len (lads_overlap )
221+ num_lads_correct_nonverlap = len (lads_correct ) - len (lads_overlap )
222+ return num_lads_correct_nonverlap , num_lads_correct_nonverlap / len (labels ), num_lads_correct_nonverlap / len (lads_correct )
0 commit comments