@@ -600,6 +600,219 @@ def load_checkpoint(self, net, path):
600600 print (f"...loaded checkpoint with acc { checkpoint ['acc' ]} " )
601601 return net
602602
603+ from clip_utils import get_domain_text_embs
604+
605+ class DirectionLoss (torch .nn .Module ):
606+
607+ def __init__ (self , loss_type = 'mse' ):
608+ super (DirectionLoss , self ).__init__ ()
609+
610+ self .loss_type = loss_type
611+
612+ self .loss_func = {
613+ 'mse' : torch .nn .MSELoss ,
614+ 'cosine' : torch .nn .CosineSimilarity ,
615+ 'mae' : torch .nn .L1Loss
616+ }[loss_type ]()
617+
618+ def forward (self , x , y ):
619+ if self .loss_type == "cosine" :
620+ return 1. - self .loss_func (x , y )
621+
622+ return self .loss_func (x , y )
623+
624+ class DirectionalMulti (Augment ):
625+
626+ def __init__ (self , cfg , image_features , labels , group_labels , domain_labels , filenames , text_features , val_image_features , val_labels , val_group_labels ,val_domain_labels , val_filenames ):
627+ super ().__init__ (cfg , image_features , labels , group_labels , domain_labels , filenames , text_features )
628+ source_embeddings , target_embeddings = get_domain_text_embs (self .model , cfg , self .neutral_prompts , self .prompts , self .class_names )
629+ # target_embeddings is size (num_domains, num_classes, emb_size)
630+ # source_embeddings is size (num_source_domain_descriptions, num_classes, emb_size)
631+ source_embeddings /= source_embeddings .norm (dim = - 1 , keepdim = True )
632+ target_embeddings /= target_embeddings .norm (dim = - 1 , keepdim = True )
633+ self .source_embeddings = source_embeddings .cuda ().float ()
634+ self .target_embeddings = target_embeddings .cuda ().float ()
635+ dataset = EmbeddingDataset (self .cfg , self .image_features , self .labels , self .group_labels , self .domain_labels )
636+ self .dataset = dataset
637+ self .train_loader = torch .utils .data .DataLoader (dataset , batch_size = cfg .DATA .BATCH_SIZE , shuffle = True )
638+
639+ val_dataset = EmbeddingDataset (self .cfg , val_image_features , val_labels , val_group_labels , val_domain_labels )
640+ self .val_dataset = val_dataset
641+ self .val_loader = torch .utils .data .DataLoader (val_dataset , batch_size = cfg .DATA .BATCH_SIZE , shuffle = True )
642+
643+ self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
644+ self .nets = []
645+ self .net_checkpoints = []
646+ self .uid = uuid .uuid4 ()
647+ if self .cfg .DATA .DATASET == 'ColoredMNISTBinary' :
648+ text_embs = zeroshot_classifier ([[f'a photo of the number "{ c } "' ] for c in self .class_names ], self .model , model_type = self .cfg .EXP .IMAGE_FEATURES )
649+ else :
650+ text_embs = zeroshot_classifier ([[f"a photo of a { c } " ] for c in self .class_names ], self .model , model_type = self .cfg .EXP .IMAGE_FEATURES )
651+
652+ self .class_text_embs = text_embs .float ().cuda ()
653+ print ("text emb shape " , self .class_text_embs .shape )
654+
655+ for i in range (len (self .prompts )):
656+ print (f"Training network for { self .prompts [i ]} " )
657+ self .train_network (i )
658+
659+ def directional_loss_builder (self ):
660+ """
661+ CLIP directional loss from gan NADA paper. Ensures that the difference in
662+ image embeddings is similar to the difference in text embeddings of the
663+ source and target domain.
664+ """
665+ def custom_loss (predictions , labels , targets ):
666+ total_sum = None
667+ delta_i = predictions - labels
668+ ctr = 0
669+ for i , delta_tt in zip (delta_i , targets ):
670+ ctr += 1
671+ if total_sum == None :
672+ numerator = torch .dot (i , delta_tt )
673+ denominator = torch .norm (i ) * torch .norm (delta_tt )
674+ total_sum = 1 - (numerator / denominator )
675+ else :
676+ total_sum += 1 - (torch .dot (i , delta_tt )/ (torch .norm (i ) * torch .norm (delta_tt )))
677+ return total_sum / ctr
678+ return custom_loss
679+
680+ @staticmethod
681+ def get_class_logits (outputs , class_embs ):
682+ outputs_norm = outputs / outputs .norm (dim = - 1 , keepdim = True )
683+ return torch .matmul (outputs_norm , class_embs )
684+
685+ def train_network (self , num_net ):
686+ net = MLP (hidden_dim = self .cfg .AUGMENTATION .MODEL .HIDDEN_DIM , input_dim = self .dataset .embedding_dim )
687+ self .nets .append (net .cuda ())
688+ self .net_checkpoints .append ("" )
689+
690+ self .optimizer = AdamW (self .nets [num_net ].parameters (), lr = self .cfg .AUGMENTATION .MODEL .LR , weight_decay = self .cfg .AUGMENTATION .MODEL .WEIGHT_DECAY )
691+ self .directional_loss = DirectionLoss (self .cfg .AUGMENTATION .LOSS_TYPE )
692+ # self.directional_loss = self.directional_loss_builder()
693+ self .class_consistency_loss = nn .CrossEntropyLoss (weight = self .dataset .class_weights .cuda ())
694+
695+ if self .cfg .AUGMENTATION .CLIP_NN_LOSS :
696+ self .clip_nn_loss = nn .CrossEntropyLoss ()
697+
698+ self .nets [num_net ].train ()
699+
700+ best_train_loss , best_epoch = 10000 , 0
701+ for epoch in range (self .cfg .AUGMENTATION .EPOCHS ):
702+ train_metrics = self .training_loop (self .train_loader , num_net , epoch , phase = 'train' )
703+ val_metrics = self .training_loop (self .val_loader , num_net , epoch , phase = 'val' )
704+ # val_metrics = self.eval_loop(num_net, epoch)
705+ if val_metrics ['val loss' ] < best_train_loss :
706+ best_train_loss = val_metrics ['val loss' ]
707+ best_epoch = epoch
708+ self .net_checkpoints [num_net ] = self .save_checkpoint (best_train_loss , epoch , num_net )
709+
710+ wandb .summary [f"{ self .prompts [num_net ]} best epoch" ] = best_epoch
711+ wandb .summary [f"{ self .prompts [num_net ]} best train_loss" ] = best_train_loss
712+ print (f"==> loading checkpoint { self .net_checkpoints [num_net ]} at epoch { best_epoch } with loss { best_train_loss } " )
713+ self .nets [num_net ] = self .load_checkpoint (self .nets [num_net ], self .net_checkpoints [num_net ])
714+
715+ def get_direction_vectors (self , img_embs , labels , num_net ):
716+ """
717+ Returns the direction vectors for the image embeddings by taking the source
718+ embedding that is most similar to each image embedding and subtracting if from the target.
719+ """
720+ dir_vectors = []
721+ for (im , l ) in zip (img_embs , labels ):
722+ prod = im @ self .source_embeddings [:,l ,:].T
723+ _ , source_idx = torch .max (prod , dim = 0 )
724+ diff = self .target_embeddings [num_net ][l ] - self .source_embeddings [source_idx ][l ]
725+ if diff .norm () == 0 :
726+ print (diff )
727+ dir_vectors .append (diff )
728+ diffs = torch .stack (dir_vectors )
729+ diffs /= diffs .norm (dim = - 1 , keepdim = True )
730+ return diffs
731+
732+ def training_loop (self , loader , num_net , epoch , phase = 'train' ):
733+ if phase == 'train' :
734+ self .nets [num_net ].train ()
735+ else :
736+ self .nets [num_net ].eval ()
737+ train_directional_loss , train_class_loss , train_loss , total = 0 , 0 , 0 , 0
738+ with torch .set_grad_enabled (phase == 'train' ):
739+ for i , (inp , cls_target , cls_group , dom_target ) in enumerate (loader ):
740+ inp , cls_target = inp .cuda ().float (), cls_target .cuda ().long ()
741+ cls_outputs = self .nets [num_net ](inp )
742+ text_diffs = self .get_direction_vectors (inp , cls_target , num_net )
743+ im_diffs = cls_outputs - inp
744+ # print(text_diffs.shape, im_diffs.shape)
745+ # print(torch.min(text_diffs, dim=0), torch.max(text_diffs, dim=0))
746+ # text_diffs -= text_diffs.min(dim=-1, keepdim=True)
747+ # im_diffs -= im_diffs.min(dim=-1, keepdim=True)
748+ # text_diffs /= text_diffs.norm(dim=-1, keepdim=True)
749+ # im_diffs /= im_diffs.norm(dim=-1, keepdim=True)
750+ # compute directional loss
751+ # directional_loss = self.directional_loss(inp, cls_outputs/cls_outputs.norm(dim=-1, keepdim=True), text_diffs)
752+ directional_loss = self .directional_loss (im_diffs / im_diffs .norm (dim = - 1 , keepdim = True ), text_diffs ).mean ()
753+ # print(directional_loss)
754+ cls_logits = self .get_class_logits (cls_outputs , self .class_text_embs )
755+ cls_consist = self .class_consistency_loss (cls_logits , cls_target )
756+ loss = self .alpha * directional_loss + (1 - self .alpha ) * cls_consist
757+ train_class_loss += (1 - self .alpha ) * cls_consist .item ()
758+ train_directional_loss += self .alpha * directional_loss .item ()
759+
760+ if phase == 'train' :
761+ self .optimizer .zero_grad ()
762+ loss .backward (retain_graph = True )
763+ self .optimizer .step ()
764+
765+ train_loss += loss .item ()
766+
767+ total += cls_target .size (0 )
768+ progress_bar (i , len (loader ), 'Loss: %.3f' % (train_loss / (i + 1 )))
769+
770+ metrics = {f"{ phase } class loss" : train_class_loss / (i + 1 ), f"{ phase } directional loss" : train_directional_loss / (i + 1 ), f"{ phase } loss" : train_loss / (i + 1 ), "epoch" : epoch }
771+ wandb .log (metrics )
772+ return metrics
773+
774+ def augment_single (self , img_embedding , label ):
775+ keep = img_embedding
776+ if self .cfg .AUGMENTATION .INCLUDE_ORIG_TRAINING :
777+ output = [keep ]
778+ else :
779+ output = []
780+ img_embedding = torch .tensor (img_embedding )
781+ img_embedding = img_embedding .type (torch .float32 )
782+ img_embedding = img_embedding .cuda ()
783+ img_embedding /= img_embedding .norm (dim = - 1 , keepdim = True )
784+ for net in self .nets :
785+ o = net (img_embedding )
786+ o /= o .norm (dim = - 1 , keepdim = True )
787+
788+ o = o .detach ().cpu ().numpy ()
789+ output .append (o )
790+ wandb .log ({"cos sim:" : distance .cosine (o , img_embedding .cpu ())})
791+ # output = self.net(img_embedding)
792+ # val = torch.tensor(output)
793+ return list (np .array (output ))
794+
795+ def save_checkpoint (self , acc , epoch , num_net ):
796+ checkpoint_dir = os .path .join ("./checkpoint" , self .cfg .DATA .DATASET )
797+ if not os .path .exists (checkpoint_dir ):
798+ os .makedirs (checkpoint_dir )
799+ path = f'./checkpoint/{ self .cfg .DATA .DATASET } /{ self .prompts [num_net ]} -{ self .cfg .EXP .SEED } -{ self .uid } .pth'
800+ print (f'Saving checkpoint with acc { acc } to { path } ...' )
801+ state = {
802+ "acc" : acc ,
803+ "epoch" : epoch ,
804+ "net" : self .nets [num_net ].state_dict ()
805+ }
806+ torch .save (state , path )
807+ # wandb.save(path)
808+ return path
809+
810+ def load_checkpoint (self , net , path ):
811+ checkpoint = torch .load (path )
812+ net .load_state_dict (checkpoint ['net' ])
813+ print (f"...loaded checkpoint with acc { checkpoint ['acc' ]} " )
814+ return net
815+
603816class BiasDirectional (Directional ):
604817 """
605818 This implements the similar directional loss as the directional class, but routes examples
0 commit comments