Skip to content

Commit 9790e44

Browse files
author
Lisa Dunlap
committed
updated embeddings/configs
1 parent 0cf64d6 commit 9790e44

File tree

19 files changed

+108
-172
lines changed

19 files changed

+108
-172
lines changed

clip_utils.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,8 @@ def get_ensamble_preds(val_features, probs, zeroshot_weights, dataset_domains=No
177177
except:
178178
outputs = probs
179179
print(outputs.shape)
180-
salem_preds = np.argmax(outputs, axis=1)
181-
print(salem_preds.shape)
180+
lads_preds = np.argmax(outputs, axis=1)
181+
print(lads_preds.shape)
182182
# CLIP ZS
183183
zeroshot_weights = zeroshot_weights.cuda()
184184
images = torch.tensor(val_features).cuda()
@@ -196,27 +196,27 @@ def get_ensamble_preds(val_features, probs, zeroshot_weights, dataset_domains=No
196196
dom_preds = []
197197
for i in range(len(ensambled_preds)):
198198
if soft_dom_label[i] == 0:
199-
dom_preds.append(salem_preds[i])
199+
dom_preds.append(lads_preds[i])
200200
else:
201201
dom_preds.append(ensambled_preds[i])
202202
ret_preds = np.array(dom_preds)
203203
else:
204204
ret_preds = ensambled_preds
205205

206-
return salem_preds, zs_preds, ensambled_preds, ret_preds
206+
return lads_preds, zs_preds, ensambled_preds, ret_preds
207207

208-
def get_pred_overlap(salem_preds, zs_preds, labels):
208+
def get_pred_overlap(lads_preds, zs_preds, labels):
209209
"""
210-
Get the overlap in correct predictions for salem and zeroshot.
210+
Get the overlap in correct predictions for lads and zeroshot.
211211
"""
212-
salem_correct = np.where(salem_preds == labels)[0]
212+
lads_correct = np.where(lads_preds == labels)[0]
213213
zs_correct = np.where(zs_preds == labels)[0]
214-
print(len(salem_correct), len(zs_correct))
215-
print("salem correct ", salem_correct[:10])
214+
print(len(lads_correct), len(zs_correct))
215+
print("lads correct ", lads_correct[:10])
216216
print("zs correct ", zs_correct[:10])
217-
salem_overlap = [i for i in salem_correct if i in zs_correct]
218-
salem_nonverlap = [i for i in salem_correct if not (i in zs_correct)]
219-
zs_nonverlap = [i for i in zs_correct if not (i in salem_correct)]
220-
num_zs_correct_nonoverlap = len(zs_correct) - len(salem_overlap)
221-
num_salem_correct_nonverlap = len(salem_correct) - len(salem_overlap)
222-
return num_salem_correct_nonverlap, num_salem_correct_nonverlap/len(labels), num_salem_correct_nonverlap/len(salem_correct)
217+
lads_overlap = [i for i in lads_correct if i in zs_correct]
218+
lads_nonverlap = [i for i in lads_correct if not (i in zs_correct)]
219+
zs_nonverlap = [i for i in zs_correct if not (i in lads_correct)]
220+
num_zs_correct_nonoverlap = len(zs_correct) - len(lads_overlap)
221+
num_lads_correct_nonverlap = len(lads_correct) - len(lads_overlap)
222+
return num_lads_correct_nonverlap, num_lads_correct_nonverlap/len(labels), num_lads_correct_nonverlap/len(lads_correct)

configs/ColoredMNIST/ZS.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
EXP:
22
ADVICE_METHOD: "CLIPZS"
33
WANDB_SILENT: False
4-
PROJ: "ColoredMNIST_ViT_Final"
4+
PROJ: "ColoredMNIST"
55
SEED: 0
66
TEXT_PROMPTS: ['0','1','2','3','4','5','6','7','8','9']
77
TEMPLATES: 'color_mnist_templates'
8-
# TEMPLATES: 'no_template'
98

109
DATA:
1110
DATASET: "ColoredMNISTBinary"
1211
LOAD_CACHED: True
1312
BATCH_SIZE: 256
13+
ROOT: './data'
1414

1515
METHOD:
1616
MODEL:

configs/ColoredMNIST/lads.yaml

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,20 @@
11
EXP:
22
ADVICE_METHOD: "ClipMLP"
33
WANDB_SILENT: False
4-
PROJ: "ColoredMNIST_ViT_Final"
4+
PROJ: "ColoredMNIST"
55
SEED: 0
6+
AUGMENTATION: 'LADSBias'
67
TEXT_PROMPTS: [['a photo of a red number {}.'], ['a photo of a blue number {}.']]
7-
AUGMENTATION: 'BiasDirectional'
8+
NEUTRAL_TEXT_PROMPTS: ['a photo of a red number {}.', 'a photo of a blue number {}.']
89
EPOCHS: 200
910
CHECKPOINT_VAL: True
1011
ENSAMBLE: False
1112

12-
# CLIP_MODEL: 'ViT-H-14'
13-
# CLIP_PRETRAINED_DATASET: 'laion2b_s32b_b79k'
14-
# IMAGE_FEATURES: 'openclip'
15-
1613
DATA:
1714
DATASET: "ColoredMNISTBinary"
1815
LOAD_CACHED: True
1916
BATCH_SIZE: 256
17+
ROOT: './data'
2018

2119
METHOD:
2220
MODEL:

configs/ColoredMNIST/mlp.yaml

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
EXP:
22
ADVICE_METHOD: "ClipMLP"
33
WANDB_SILENT: False
4-
PROJ: "ColoredMNIST_ViT_Final"
4+
PROJ: "ColoredMNIST"
55
SEED: 0
66
TEXT_PROMPTS: ['a photo of a red number "{}".', 'a photo of a blue number "{}".']
77
NEUTRAL_TEXT_PROMPTS: ['a photo of a white number "{}".']
@@ -10,14 +10,11 @@ EXP:
1010
GENERIC: False
1111
LOG_NN: True
1212

13-
# CLIP_MODEL: 'ViT-H-14'
14-
# CLIP_PRETRAINED_DATASET: 'laion2b_s32b_b79k'
15-
# IMAGE_FEATURES: 'openclip'
16-
1713
DATA:
1814
DATASET: "ColoredMNISTBinary"
1915
LOAD_CACHED: True
2016
BATCH_SIZE: 256
17+
ROOT: './data'
2118

2219
METHOD:
2320
MODEL:

configs/Waterbirds/ZS.yaml

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,4 @@ DATA:
1212
BATCH_SIZE: 256
1313
LOAD_CACHED: True
1414
UPWEIGHT_CLASSES: True
15-
16-
METHOD:
17-
MODEL:
18-
NUM_LAYERS: 1
19-
SEPERATE_CLASSES: True
20-
RESUME: True
21-
NORMALIZE: False
15+
ROOT: /shared/lisabdunlap/vl-attention/data

configs/Waterbirds/lads.yaml

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ EXP:
44
PROJ: "LADS_Waterbirds_Rebuttal"
55
SEED: 0
66
TEXT_PROMPTS: [['a photo of a {} on forest.'], ['a photo of a {} on water.']]
7-
NEUTRAL_TEXT_PROMPTS: []
8-
AUGMENTATION: 'BiasDirectional'
7+
NEUTRAL_TEXT_PROMPTS: ['a photo of a {} on forest.', 'a photo of a {} on water.']
8+
AUGMENTATION: 'LADSBias'
99
EPOCHS: 200
1010
CHECKPOINT_VAL: True
1111
ENSAMBLE: False
@@ -15,31 +15,27 @@ DATA:
1515
DATASET: "Waterbirds"
1616
LOAD_CACHED: True
1717
BATCH_SIZE: 256
18+
ROOT: /shared/lisabdunlap/vl-attention/data
1819

1920
METHOD:
2021
MODEL:
2122
NUM_LAYERS: 1
2223
DOM_WEIGHT: 1.0
2324
LR: 0.001
24-
WEIGHT_DECAY: 0.05
25-
CHECKPOINT: 'checkpoint/mlp_simple.pth'
26-
CHECKPOINT_NAME: 'mlp-directional'
25+
CHECKPOINT: 'checkpoint/mlp.pth'
26+
CHECKPOINT_NAME: "mlp"
2727
RESUME: False
28-
# USE_DOM_GT: True
2928

3029
AUGMENTATION:
3130
MODEL:
3231
LR: 0.005
33-
WEIGHT_DECAY: 0.05
32+
WEIGHT_DECAY: 0.005
3433
NUM_LAYERS: 1
3534
HIDDEN_DIM: 512
3635
EPOCHS: 50
3736
GENERIC: False
3837
DOM_LABELS: ["forest", "water"]
39-
DOM_SPECIFIC_XE: true
38+
DOM_SPECIFIC_XE: True
4039
ALPHA: 0.75
41-
# CLIP_NN_LOSS: True
42-
# COMPARE_BEFORE_AUG: True
43-
# NN_INCLUDE_SAMPLE: True
44-
# DOM_WEIGHT: 1.0
45-
# NN_WEIGHT: 1.0
40+
COMPARE_BEFORE_AUG: True
41+
NN_INCLUDE_SAMPLE: True

configs/Waterbirds/mlp.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@ EXP:
1010
DATA:
1111
DATASET: "Waterbirds"
1212
BATCH_SIZE: 256
13+
ROOT: /shared/lisabdunlap/vl-attention/data
1314
LOAD_CACHED: True
1415

1516
METHOD:
1617
MODEL:
1718
NUM_LAYERS: 1
1819
DOM_WEIGHT: 1.0
1920
LR: 0.001
20-
# WEIGHT_DECAY: 0.005
2121
CHECKPOINT: 'checkpoint/mlp.pth'
2222
CHECKPOINT_NAME: "mlp"
2323
RESUME: False

configs/Waterbirds/mlpzs.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ DATA:
1212
DATASET: "Waterbirds"
1313
BATCH_SIZE: 256
1414
LOAD_CACHED: True
15+
ROOT: /shared/lisabdunlap/vl-attention/data
1516

1617
METHOD:
1718
MODEL:

configs/Waterbirds/new_lads.yaml

Lines changed: 0 additions & 45 deletions
This file was deleted.

configs/base.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ DATA:
2323
UPWEIGHT_DOMAINS: False
2424
UPWEIGHT_CLASSES: True
2525
MODEL_DIM: 1024
26+
ROOT: '/shared/lisabdunlap'
2627

2728
METHOD:
2829
MODEL:

0 commit comments

Comments
 (0)