Skip to content

Commit 3e16edd

Browse files
author
Lisa Dunlap
committed
improving results
1 parent f09cf81 commit 3e16edd

File tree

7 files changed

+815
-164
lines changed

7 files changed

+815
-164
lines changed

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@ Official Implementation of [LADS (Latent Augmentation using Domain descriptionS)
55

66
*WARNING: this is still WIP, please raise an issue if you run into any bugs.*
77

8+
## TODOs
9+
[] add e2e method for DA
10+
[] add in selective augmentation
11+
[] run 2 layer mlp baselines
12+
813
## Getting started
914

1015
1. Install the dependencies for our code using Conda. You may need to adjust the environment YAML file depending on your setup.

clip_utils.py

Lines changed: 0 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -26,67 +26,6 @@
2626

2727
import omegaconf
2828

29-
def get_domain_text_embs(model, cfg, source_text_prompts, target_text_prompts, class_names):
30-
"""
31-
Gets the text embeddings of the prompts describing the source and target domains.
32-
If generic is True, source_text_prompts and target_text_prompts are strings instead of
33-
templates to put the class name in.
34-
"""
35-
if cfg.AUGMENTATION.GENERIC:
36-
text_embeddings = zeroshot_classifier(target_text_prompts, model, normalize=cfg.METHOD.NORMALIZE, model_type=cfg.EXP.IMAGE_FEATURES)
37-
text_embeddings = np.transpose(text_embeddings, (1,0))
38-
orig_prompts = text_embeddings
39-
if len(source_text_prompts) > 0:
40-
source_embeddings = zeroshot_classifier(source_text_prompts, model, normalize=cfg.METHOD.NORMALIZE, model_type=cfg.EXP.IMAGE_FEATURES)
41-
print("source emb before averaging", source_embeddings.shape)
42-
source_embeddings = source_embeddings.mean(dim=0)
43-
print("source emb after averaging", source_embeddings.shape)
44-
diffs = torch.stack([emb-source_embeddings[0] for emb in text_embeddings])
45-
diffs /= text_embeddings.norm(dim=-1, keepdim=True)
46-
else:
47-
print(target_text_prompts)
48-
# print("yo", len(source_text_prompts), len(source_text_prompts[0]))
49-
# go on a per class basis
50-
templates = target_text_prompts
51-
all_texts = []
52-
for t in source_text_prompts:
53-
texts = [[t.format(c)] for c in class_names]
54-
text_emb = zeroshot_classifier(texts, model, normalize=cfg.METHOD.NORMALIZE, model_type=cfg.EXP.IMAGE_FEATURES).T
55-
print(texts, "text_emb", text_emb.shape)
56-
all_texts.append(text_emb)
57-
if type(target_text_prompts[0]) == str:
58-
target_text_prompts = [target_text_prompts]
59-
print(target_text_prompts)
60-
for p in target_text_prompts:
61-
print(p)
62-
texts = [[t.format(c) for t in p] for c in class_names]
63-
text_emb = zeroshot_classifier(texts, model, normalize=cfg.METHOD.NORMALIZE, model_type=cfg.EXP.IMAGE_FEATURES).T
64-
all_texts.append(text_emb)
65-
# this subtracts the neutral embedding from the domain embeddings and normalizes.
66-
text_pairs = torch.stack(all_texts)
67-
print("text pairs", text_pairs.shape)
68-
target_embeddings, source_embeddings = text_pairs, []
69-
if len(source_text_prompts) > 0:
70-
source_embeddings = text_pairs[:len(source_text_prompts)]
71-
target_embeddings = text_pairs[len(source_text_prompts):]
72-
else:
73-
source_embeddings = torch.zeros_like(target_embeddings)
74-
# text_diffs = []
75-
# source_domain = text_pairs[0]
76-
# for target_domain in text_pairs[1:]:
77-
# diff = target_domain - source_domain
78-
# diff /= np.linalg.norm(diff, axis=-1, keepdims=True)
79-
# # diff = np.expand_dims(diff, axis=0)
80-
# text_diffs.append(diff)
81-
# else:
82-
# target_embeddings = text_pairs
83-
# text_diffs = text_pairs
84-
# diffs = torch.stack(text_diffs).permute(1,0,2) # should be (num_classes, num_domains, emb_size)
85-
# print("diffs shape", diffs.shape)
86-
# print("source embeddings", source_embeddings.shape)
87-
print("target embeddings", target_embeddings.shape)
88-
return source_embeddings, target_embeddings
89-
9029
def get_features(dataset, model, device, model_type):
9130
if model_type != 'clip' and model_type != 'openclip':
9231
return get_resnet_features(dataset, model, device)
Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,8 @@ EXP:
33
WANDB_SILENT: False
44
PROJ: "CUB_Painting_E2E"
55
SEED: 0
6-
# TEXT_PROMPTS: [["clipart"], ["painting"], ["real photo"]]
7-
# TEXT_PROMPTS: [['a painting of a bird']]
8-
TEXT_PROMPTS: [['a painting of a {} bird.']]
9-
# TEXT_PROMPTS: ['an anime drawing of a {} bird.']
6+
TEXT_PROMPTS: [['a painting of a {} bird.'], ['art of a {} bird.'], ['a drawing of a {} bird.']]
107
NEUTRAL_TEXT_PROMPTS: ['a photo of a {} bird.']
11-
# NEUTRAL_TEXT_PROMPTS: [['a photo of a bird']]
12-
# AUGMENTATION: 'None'
138
EPOCHS: 200
149
LOG_HIST: False
1510
ENSAMBLE: True
@@ -41,4 +36,6 @@ AUGMENTATION:
4136
DOM_WEIGHT: 1
4237
ALPHA: 0.5
4338
GENERIC: False
44-
DOM_LABELS: ['painting']
39+
DOM_LABELS: ['painting']
40+
RANDOMIZE_PROB: 0.0
41+
RANDOMIZE: False

0 commit comments

Comments
 (0)