Skip to content

Commit 4028dd5

Browse files
author
Lisa Dunlap
committed
added in multiple source descriptions
1 parent 08e812e commit 4028dd5

File tree

3 files changed

+349
-9
lines changed

3 files changed

+349
-9
lines changed

configs/DomainNet/test_aug.yaml

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
EXP:
2+
ADVICE_METHOD: "ClipMLP"
3+
WANDB_SILENT: False
4+
PROJ: "DomainNetMini_LADS_Replication"
5+
SEED: 0
6+
TEXT_PROMPTS: [['a realistic photo of a {}.'], ['a painting of a {}.'], ['clipart of a {}.']]
7+
NEUTRAL_TEXT_PROMPTS: ['a sketch of a {}', 'a pencil drawing of a {}.', 'a drawing of a {}.']
8+
AUGMENTATION: 'DirectionalMulti'
9+
EPOCHS: 400
10+
LOG_NN: True
11+
ENSAMBLE: False
12+
13+
14+
DATA:
15+
DATASET: "DomainNetMini"
16+
LOAD_CACHED: True
17+
SAVE_PATH: "vit14_clip.pth"
18+
BATCH_SIZE: 256
19+
20+
METHOD:
21+
MODEL:
22+
NUM_LAYERS: 1
23+
DOM_WEIGHT: 1.0
24+
LR: 0.0001
25+
CHECKPOINT: 'checkpoint/mlp_simple.pth'
26+
CHECKPOINT_NAME: 'DomainNetMini-mlp-directional'
27+
RESUME: False
28+
USE_DOM_GT: True
29+
APPLY_TRANSFORMATION: False
30+
31+
AUGMENTATION:
32+
MODEL:
33+
LR: 0.0001
34+
WEIGHT_DECAY: 0.005
35+
NUM_LAYERS: 1
36+
EPOCHS: 50
37+
GENERIC: False
38+
ALPHA: 0.5

methods/augmentations.py

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,219 @@ def load_checkpoint(self, net, path):
600600
print(f"...loaded checkpoint with acc {checkpoint['acc']}")
601601
return net
602602

603+
from clip_utils import get_domain_text_embs
604+
605+
class DirectionLoss(torch.nn.Module):
606+
607+
def __init__(self, loss_type='mse'):
608+
super(DirectionLoss, self).__init__()
609+
610+
self.loss_type = loss_type
611+
612+
self.loss_func = {
613+
'mse': torch.nn.MSELoss,
614+
'cosine': torch.nn.CosineSimilarity,
615+
'mae': torch.nn.L1Loss
616+
}[loss_type]()
617+
618+
def forward(self, x, y):
619+
if self.loss_type == "cosine":
620+
return 1. - self.loss_func(x, y)
621+
622+
return self.loss_func(x, y)
623+
624+
class DirectionalMulti(Augment):
625+
626+
def __init__(self, cfg, image_features, labels, group_labels, domain_labels, filenames, text_features, val_image_features, val_labels, val_group_labels,val_domain_labels, val_filenames):
627+
super().__init__(cfg, image_features, labels, group_labels, domain_labels, filenames, text_features)
628+
source_embeddings, target_embeddings = get_domain_text_embs(self.model, cfg, self.neutral_prompts, self.prompts, self.class_names)
629+
# target_embeddings is size (num_domains, num_classes, emb_size)
630+
# source_embeddings is size (num_source_domain_descriptions, num_classes, emb_size)
631+
source_embeddings /= source_embeddings.norm(dim=-1, keepdim=True)
632+
target_embeddings /= target_embeddings.norm(dim=-1, keepdim=True)
633+
self.source_embeddings = source_embeddings.cuda().float()
634+
self.target_embeddings = target_embeddings.cuda().float()
635+
dataset = EmbeddingDataset(self.cfg, self.image_features, self.labels, self.group_labels, self.domain_labels)
636+
self.dataset = dataset
637+
self.train_loader = torch.utils.data.DataLoader(dataset, batch_size=cfg.DATA.BATCH_SIZE, shuffle=True)
638+
639+
val_dataset = EmbeddingDataset(self.cfg, val_image_features, val_labels, val_group_labels, val_domain_labels)
640+
self.val_dataset = val_dataset
641+
self.val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=cfg.DATA.BATCH_SIZE, shuffle=True)
642+
643+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
644+
self.nets = []
645+
self.net_checkpoints = []
646+
self.uid = uuid.uuid4()
647+
if self.cfg.DATA.DATASET == 'ColoredMNISTBinary':
648+
text_embs = zeroshot_classifier([[f'a photo of the number "{c}"'] for c in self.class_names], self.model, model_type=self.cfg.EXP.IMAGE_FEATURES)
649+
else:
650+
text_embs = zeroshot_classifier([[f"a photo of a {c}"] for c in self.class_names], self.model, model_type=self.cfg.EXP.IMAGE_FEATURES)
651+
652+
self.class_text_embs = text_embs.float().cuda()
653+
print("text emb shape ", self.class_text_embs.shape)
654+
655+
for i in range(len(self.prompts)):
656+
print(f"Training network for {self.prompts[i]}")
657+
self.train_network(i)
658+
659+
def directional_loss_builder(self):
660+
"""
661+
CLIP directional loss from gan NADA paper. Ensures that the difference in
662+
image embeddings is similar to the difference in text embeddings of the
663+
source and target domain.
664+
"""
665+
def custom_loss(predictions, labels, targets):
666+
total_sum = None
667+
delta_i = predictions - labels
668+
ctr = 0
669+
for i, delta_tt in zip(delta_i, targets):
670+
ctr += 1
671+
if total_sum == None:
672+
numerator = torch.dot(i, delta_tt)
673+
denominator = torch.norm(i) * torch.norm(delta_tt)
674+
total_sum = 1 - (numerator/denominator)
675+
else:
676+
total_sum += 1 - (torch.dot(i, delta_tt)/ (torch.norm(i) * torch.norm(delta_tt)))
677+
return total_sum / ctr
678+
return custom_loss
679+
680+
@staticmethod
681+
def get_class_logits(outputs, class_embs):
682+
outputs_norm = outputs / outputs.norm(dim=-1, keepdim=True)
683+
return torch.matmul(outputs_norm, class_embs)
684+
685+
def train_network(self, num_net):
686+
net = MLP(hidden_dim=self.cfg.AUGMENTATION.MODEL.HIDDEN_DIM, input_dim=self.dataset.embedding_dim)
687+
self.nets.append(net.cuda())
688+
self.net_checkpoints.append("")
689+
690+
self.optimizer = AdamW(self.nets[num_net].parameters(), lr=self.cfg.AUGMENTATION.MODEL.LR, weight_decay=self.cfg.AUGMENTATION.MODEL.WEIGHT_DECAY)
691+
self.directional_loss = DirectionLoss(self.cfg.AUGMENTATION.LOSS_TYPE)
692+
# self.directional_loss = self.directional_loss_builder()
693+
self.class_consistency_loss = nn.CrossEntropyLoss(weight=self.dataset.class_weights.cuda())
694+
695+
if self.cfg.AUGMENTATION.CLIP_NN_LOSS:
696+
self.clip_nn_loss = nn.CrossEntropyLoss()
697+
698+
self.nets[num_net].train()
699+
700+
best_train_loss, best_epoch = 10000, 0
701+
for epoch in range(self.cfg.AUGMENTATION.EPOCHS):
702+
train_metrics = self.training_loop(self.train_loader, num_net, epoch, phase='train')
703+
val_metrics = self.training_loop(self.val_loader, num_net, epoch, phase='val')
704+
# val_metrics = self.eval_loop(num_net, epoch)
705+
if val_metrics['val loss'] < best_train_loss:
706+
best_train_loss = val_metrics['val loss']
707+
best_epoch = epoch
708+
self.net_checkpoints[num_net] = self.save_checkpoint(best_train_loss, epoch, num_net)
709+
710+
wandb.summary[f"{self.prompts[num_net]} best epoch"] = best_epoch
711+
wandb.summary[f"{self.prompts[num_net]} best train_loss"] = best_train_loss
712+
print(f"==> loading checkpoint {self.net_checkpoints[num_net]} at epoch {best_epoch} with loss {best_train_loss}")
713+
self.nets[num_net] = self.load_checkpoint(self.nets[num_net], self.net_checkpoints[num_net])
714+
715+
def get_direction_vectors(self, img_embs, labels, num_net):
716+
"""
717+
Returns the direction vectors for the image embeddings by taking the source
718+
embedding that is most similar to each image embedding and subtracting if from the target.
719+
"""
720+
dir_vectors = []
721+
for (im, l) in zip(img_embs, labels):
722+
prod = im @ self.source_embeddings[:,l,:].T
723+
_, source_idx = torch.max(prod, dim=0)
724+
diff = self.target_embeddings[num_net][l] - self.source_embeddings[source_idx][l]
725+
if diff.norm() == 0:
726+
print(diff)
727+
dir_vectors.append(diff)
728+
diffs = torch.stack(dir_vectors)
729+
diffs /= diffs.norm(dim=-1, keepdim=True)
730+
return diffs
731+
732+
def training_loop(self, loader, num_net, epoch, phase='train'):
733+
if phase == 'train':
734+
self.nets[num_net].train()
735+
else:
736+
self.nets[num_net].eval()
737+
train_directional_loss, train_class_loss, train_loss, total = 0, 0, 0, 0
738+
with torch.set_grad_enabled(phase == 'train'):
739+
for i, (inp, cls_target, cls_group, dom_target) in enumerate(loader):
740+
inp, cls_target= inp.cuda().float(), cls_target.cuda().long()
741+
cls_outputs = self.nets[num_net](inp)
742+
text_diffs = self.get_direction_vectors(inp, cls_target, num_net)
743+
im_diffs = cls_outputs - inp
744+
# print(text_diffs.shape, im_diffs.shape)
745+
# print(torch.min(text_diffs, dim=0), torch.max(text_diffs, dim=0))
746+
# text_diffs -= text_diffs.min(dim=-1, keepdim=True)
747+
# im_diffs -= im_diffs.min(dim=-1, keepdim=True)
748+
# text_diffs /= text_diffs.norm(dim=-1, keepdim=True)
749+
# im_diffs /= im_diffs.norm(dim=-1, keepdim=True)
750+
# compute directional loss
751+
# directional_loss = self.directional_loss(inp, cls_outputs/cls_outputs.norm(dim=-1, keepdim=True), text_diffs)
752+
directional_loss = self.directional_loss(im_diffs / im_diffs.norm(dim=-1, keepdim=True), text_diffs).mean()
753+
# print(directional_loss)
754+
cls_logits = self.get_class_logits(cls_outputs, self.class_text_embs)
755+
cls_consist = self.class_consistency_loss(cls_logits, cls_target)
756+
loss = self.alpha * directional_loss + (1 - self.alpha) * cls_consist
757+
train_class_loss += (1 - self.alpha) * cls_consist.item()
758+
train_directional_loss += self.alpha * directional_loss.item()
759+
760+
if phase == 'train':
761+
self.optimizer.zero_grad()
762+
loss.backward(retain_graph=True)
763+
self.optimizer.step()
764+
765+
train_loss += loss.item()
766+
767+
total += cls_target.size(0)
768+
progress_bar(i, len(loader), 'Loss: %.3f'% (train_loss/(i+1)))
769+
770+
metrics = {f"{phase} class loss": train_class_loss/(i+1), f"{phase} directional loss": train_directional_loss/(i+1), f"{phase} loss": train_loss/(i+1), "epoch": epoch}
771+
wandb.log(metrics)
772+
return metrics
773+
774+
def augment_single(self, img_embedding, label):
775+
keep = img_embedding
776+
if self.cfg.AUGMENTATION.INCLUDE_ORIG_TRAINING:
777+
output = [keep]
778+
else:
779+
output = []
780+
img_embedding = torch.tensor(img_embedding)
781+
img_embedding = img_embedding.type(torch.float32)
782+
img_embedding = img_embedding.cuda()
783+
img_embedding /= img_embedding.norm(dim=-1, keepdim=True)
784+
for net in self.nets:
785+
o = net(img_embedding)
786+
o /= o.norm(dim=-1, keepdim=True)
787+
788+
o = o.detach().cpu().numpy()
789+
output.append(o)
790+
wandb.log({"cos sim:": distance.cosine(o, img_embedding.cpu())})
791+
# output = self.net(img_embedding)
792+
# val = torch.tensor(output)
793+
return list(np.array(output))
794+
795+
def save_checkpoint(self, acc, epoch, num_net):
796+
checkpoint_dir = os.path.join("./checkpoint", self.cfg.DATA.DATASET)
797+
if not os.path.exists(checkpoint_dir):
798+
os.makedirs(checkpoint_dir)
799+
path = f'./checkpoint/{self.cfg.DATA.DATASET}/{self.prompts[num_net]}-{self.cfg.EXP.SEED}-{self.uid}.pth'
800+
print(f'Saving checkpoint with acc {acc} to {path}...')
801+
state = {
802+
"acc": acc,
803+
"epoch": epoch,
804+
"net": self.nets[num_net].state_dict()
805+
}
806+
torch.save(state, path)
807+
# wandb.save(path)
808+
return path
809+
810+
def load_checkpoint(self, net, path):
811+
checkpoint = torch.load(path)
812+
net.load_state_dict(checkpoint['net'])
813+
print(f"...loaded checkpoint with acc {checkpoint['acc']}")
814+
return net
815+
603816
class BiasDirectional(Directional):
604817
"""
605818
This implements the similar directional loss as the directional class, but routes examples

0 commit comments

Comments
 (0)