Skip to content

Commit cefa58b

Browse files
author
Lisa Dunlap
committed
added class_weighting, data loading and saving, CLIPZS, getting worse results
1 parent e1f4393 commit cefa58b

File tree

7 files changed

+376
-155
lines changed

7 files changed

+376
-155
lines changed

README.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@ Official Implementation of [LADS (Latent Augmentation using Domain descriptionS)
66
*WARNING: this is still WIP, please raise an issue if you run into any bugs.*
77

88
## TODOs
9-
[] add e2e method for DA
10-
[] add in selective augmentation (run lp, check per class acc, augment poor performing finetuned classes more towards the text emb)
9+
[X] clean up emb saving/loading
10+
[] fix the Directional vs LADS acc diff
11+
[X] add e2e method for DA
12+
[] get E2E to work well
13+
[X] add in selective augmentation (run lp, check per class acc, augment poor performing finetuned classes more towards the text emb)
1114
[] run 2 layer mlp baselines
1215

1316
## Getting started

class_weighting.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
import os
2+
import clip
3+
import open_clip
4+
import torch
5+
import numpy as np
6+
import torchvision
7+
import wandb
8+
import argparse
9+
from PIL import Image
10+
import matplotlib.pyplot as plt
11+
import random
12+
import omegaconf
13+
from omegaconf import OmegaConf
14+
15+
import helpers.data_helpers as dh
16+
import methods.clip_transformations as CLIPTransformations
17+
from utils import read_unknowns, nest_dict
18+
from clip_utils import get_features, evaluate, zeroshot_classifier, get_ensamble_preds, get_pred_overlap, get_nn_metrics
19+
import methods.augmentations
20+
21+
parser = argparse.ArgumentParser(description='CLIP Advice')
22+
parser.add_argument('--config', default='configs/base.yaml', help="config file")
23+
parser.add_argument('overrides', nargs='*', help="Any key=value arguments to override config values "
24+
"(use dots for.nested=overrides)")
25+
# flags = parser.parse_args()
26+
flags, unknown = parser.parse_known_args()
27+
28+
overrides = OmegaConf.from_cli(flags.overrides)
29+
cfg = OmegaConf.load(flags.config)
30+
base = OmegaConf.load('configs/base.yaml')
31+
args = OmegaConf.merge(base, cfg, overrides)
32+
if len(unknown) > 0:
33+
print(unknown)
34+
config = nest_dict(read_unknowns(unknown))
35+
to_merge = OmegaConf.create(config)
36+
args = OmegaConf.merge(args, to_merge)
37+
args.yaml = flags.config
38+
39+
assert args.EXP.ADVICE_METHOD != 'CNN', "main.py not for CNN baseline, use train.py"
40+
assert args.EXP.ADVICE_METHOD != 'CLIPZS', "main.py not for CLIP zero-shot, use clip_zs.py"
41+
42+
if args.EXP.WANDB_SILENT:
43+
os.environ['WANDB_SILENT']="true"
44+
45+
def flatten_config(dic, running_key=None, flattened_dict={}):
46+
for key, value in dic.items():
47+
if running_key is None:
48+
running_key_temp = key
49+
else:
50+
running_key_temp = '{}.{}'.format(running_key, key)
51+
if isinstance(value, omegaconf.dictconfig.DictConfig):
52+
flatten_config(value, running_key_temp)
53+
else:
54+
#print(running_key_temp, value)
55+
flattened_dict[running_key_temp] = value
56+
return flattened_dict
57+
58+
run = wandb.init(project='debug', group=args.EXP.ADVICE_METHOD, config=flatten_config(args), allow_val_change=False)
59+
# wandb.save(flags.config)
60+
# wandb.run.log_code(".")
61+
62+
torch.manual_seed(args.EXP.SEED)
63+
np.random.seed(args.EXP.SEED)
64+
random.seed(args.EXP.SEED)
65+
66+
DATASET_NAME = args.DATA.DATASET
67+
68+
# load data
69+
if args.DATA.LOAD_CACHED:
70+
print(args.DATA.LOAD_CACHED)
71+
if args.EXP.IMAGE_FEATURES == 'clip' or args.EXP.IMAGE_FEATURES == 'openclip':
72+
model_name = args.EXP.CLIP_MODEL
73+
else:
74+
model_name = args.EXP.IMAGE_FEATURES
75+
cache_file, dataset_classes, dataset_domains = dh.get_cache_file(DATASET_NAME, model_name, args.EXP.BIASED_VAL, args.EXP.IMAGE_FEATURES)
76+
assert os.path.exists(cache_file), f"{cache_file} does not exist. To compute embeddings, set DATA.LOAD_CACHED=False"
77+
data = torch.load(cache_file)
78+
train_features, train_labels, train_groups, train_domains, train_filenames = data['train_features'], data['train_labels'], data['train_groups'], data['train_domains'], data['train_filenames']
79+
val_features, val_labels, val_groups, val_domains, val_filenames = data['val_features'], data['val_labels'], data['val_groups'], data['val_domains'], data['val_filenames']
80+
test_features, test_labels, test_groups, test_domains, test_filenames = data['test_features'], data['test_labels'], data['test_groups'], data['test_domains'], data['test_filenames']
81+
# move some val data to test
82+
if args.DATA.DATASET != 'ColoredMNISTBinary':
83+
val_features, val_labels, val_groups, val_domains, val_filenames = data['val_features'][::2], data['val_labels'][::2], data['val_groups'][::2], data['val_domains'][::2], data['val_filenames'][::2]
84+
test_features, test_labels, test_groups, test_domains, test_filenames = np.concatenate((data['test_features'], data['val_features'][1::2])), np.concatenate((data['test_labels'], data['val_labels'][1::2])), np.concatenate((data['test_groups'], data['val_groups'][1::2])), np.concatenate((data['test_domains'], data['val_domains'][1::2])), np.concatenate((data['test_filenames'], data['val_filenames'][1::2]))
85+
if args.METHOD.NORMALIZE:
86+
train_features /= np.linalg.norm(train_features, axis=-1, keepdims=True)
87+
val_features /= np.linalg.norm(val_features, axis=-1, keepdims=True)
88+
test_features /= np.linalg.norm(test_features, axis=-1, keepdims=True)
89+
# Load the model
90+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
91+
print(args.EXP.IMAGE_FEATURES)
92+
# clip_model, preprocess = clip.load(args.EXP.CLIP_MODEL, device)
93+
if args.EXP.IMAGE_FEATURES == 'clip':
94+
clip_model, preprocess = clip.load(args.EXP.CLIP_MODEL, device)
95+
model, preprocess = clip.load(args.EXP.CLIP_MODEL, device)
96+
elif args.EXP.IMAGE_FEATURES == 'openclip':
97+
model, _, preprocess = open_clip.create_model_and_transforms(args.EXP.CLIP_MODEL, pretrained=args.EXP.CLIP_PRETRAINED_DATASET)
98+
model = model.to(torch.device('cuda:0'))
99+
clip_model = model
100+
else:
101+
model = getattr(torchvision.models, args.EXP.IMAGE_FEATURES)(pretrained=True)
102+
model = model.to(device)
103+
104+
# Calculate the image features
105+
prompts = list(args.EXP.TEXT_PROMPTS)
106+
if len(prompts) >0 and type(prompts[0]) == omegaconf.listconfig.ListConfig:
107+
prompts = [list(p) for p in prompts]
108+
109+
neutral_prompts = list(args.EXP.NEUTRAL_TEXT_PROMPTS)
110+
if len(neutral_prompts) >0 and type(neutral_prompts[0]) == omegaconf.listconfig.ListConfig:
111+
neutral_prompts = [list(p) for p in neutral_prompts]
112+
113+
lp = CLIPTransformations.ClipMLP(prompts, clip_model, args, neutral_prompts)
114+
zs = CLIPTransformations.CLIPZS(prompts, clip_model, args, neutral_prompts)
115+
116+
# train MLP with domain adaptation loss
117+
lp.train_debias(train_features, train_labels, train_groups, train_domains, val_features, val_labels, np.squeeze(val_groups), val_domains)
118+
119+
lp_val_predictions, lp_val_probs = lp.eval(val_features)
120+
zs_val_predictions, zs_val_probs = zs.eval(val_features)
121+
122+
def log_wandb(acc, balanced_acc, class_acc, group_acc, tag='val'):
123+
wandb.summary[f"{tag}_accuracy"] = acc
124+
wandb.summary[f"{tag}_balanced_accuracy"] = balanced_acc
125+
wandb.summary[f"{tag}_group_accuracy"] = group_acc
126+
for d, d_acc in zip(dataset_domains, group_acc):
127+
wandb.summary[f"{tag}_{d}_acc"] = d_acc
128+
129+
lp_val_accuracy, lp_val_balanced_acc, lp_val_class_accuracy, lp_val_group_accuracy = evaluate(lp_val_predictions, val_labels, np.squeeze(val_groups))
130+
zs_val_accuracy, zs_val_balanced_acc, zs_val_class_accuracy, zs_val_group_accuracy = evaluate(zs_val_predictions, val_labels, np.squeeze(val_groups))
131+
log_wandb(lp_val_accuracy, lp_val_balanced_acc, lp_val_class_accuracy, lp_val_group_accuracy, tag='lp_val')
132+
log_wandb(zs_val_accuracy, zs_val_balanced_acc, zs_val_class_accuracy, zs_val_group_accuracy, tag='zs_val')
133+
134+
print('..........................................')
135+
print(f"LP val accuracy: {lp_val_accuracy} \t ZS val accuracy: {zs_val_accuracy}")
136+
# acc_diff = lp_val_class_accuracy - zs_val_class_accuracy
137+
# print(lp_val_class_accuracy + zs_val_class_accuracy)
138+
acc_prop = np.nan_to_num(zs_val_class_accuracy / (lp_val_class_accuracy + zs_val_class_accuracy), nan=0.5)
139+
print(f"--------------- acc prop {acc_prop} {type(acc_prop)} {acc_prop.shape} np sum {np.sum(acc_prop)}")
140+
class_weights = acc_prop / np.sum(acc_prop)
141+
print(f"Accuracy difference: {class_weights[:10]} \t Accuracy proportion: {acc_prop[:10]} {np.sum(acc_prop)}")
142+
print('..........................................')
143+
144+
old_val_features, old_val_labels, old_val_groups, old_val_domains, old_val_filenames = val_features, val_labels, val_groups, val_domains, val_filenames
145+
old_test_features, old_test_labels, old_test_groups, old_test_domains, old_test_filenames = test_features, test_labels, test_groups, test_domains, test_filenames
146+
147+
if args.EXP.AUGMENTATION != None and args.EXP.AUGMENTATION != 'None':
148+
print("Augmenting training set...")
149+
if "LADS" in args.EXP.AUGMENTATION or 'Directional' in args.EXP.AUGMENTATION:
150+
augment = getattr(methods.augmentations, args.EXP.AUGMENTATION)(args, train_features, train_labels, train_groups, train_domains, train_filenames, lp.text_embeddings, val_features, val_labels, val_groups, val_domains, class_weights)
151+
else:
152+
augment = getattr(methods.augmentations, args.EXP.AUGMENTATION)(args, train_features, train_labels, train_groups, train_domains, train_filenames, lp.text_embeddings)
153+
train_features, train_labels, train_domains, train_groups, train_filenames = augment.augment_dataset()
154+
print("Training set augmented!")
155+
156+
# if args.EXP.LOG_NN:
157+
# features, labels, groups, domains, filenames = np.concatenate([old_val_features, old_test_features]), np.concatenate([old_val_labels, old_test_labels]), np.concatenate([old_val_groups, old_test_groups]), np.concatenate([old_val_domains, old_test_domains]), np.concatenate([old_val_filenames, old_test_filenames])
158+
# # features, labels, groups, domains, filenames = old_test_features, old_test_labels, old_test_groups, old_test_domains, old_test_filenames
159+
# if len(np.unique(train_domains)) > 1:
160+
# filtered_idxs = np.where(train_domains != train_domains[0])
161+
# sample_features, sample_domains, sample_labels, sample_filenames = np.array(train_features[filtered_idxs]), train_domains[filtered_idxs], train_labels[filtered_idxs], train_filenames[filtered_idxs]
162+
# sample_idxs = random.sample(list(range(len(sample_filenames))), min((len(train_filenames), 1000)))
163+
# sample_features, sample_domains, sample_labels, sample_filenames = sample_features[sample_idxs], sample_domains[sample_idxs], sample_labels[sample_idxs], sample_filenames[sample_idxs]
164+
# else:
165+
# sample_idxs = random.sample(list(range(len(train_filenames))), min((len(train_filenames), 1000)))
166+
# sample_features, sample_domains, sample_labels, sample_filenames = train_features[sample_idxs], train_domains[sample_idxs], train_labels[sample_idxs], train_filenames[sample_idxs]
167+
# neighbor_domains, neighbor_labels, domain_acc, class_acc, neighbor_samples, prop_unique, mean_cs = get_nn_metrics(sample_features, sample_domains, sample_labels, features, domains, labels)
168+
# plt.rcParams["figure.figsize"] = (20,5)
169+
# f, (axs_orig, axs_new) = plt.subplots(2, 10, sharey=True)
170+
# for i, (original_idx, sample_idx) in enumerate(neighbor_samples):
171+
# try:
172+
# axs_orig[i].imshow(Image.open(sample_filenames[original_idx]).resize((224, 224)))
173+
# axs_orig[i].set_title(f"{dataset_domains[int(sample_domains[int(original_idx)])]} - {sample_labels[int(original_idx)]}")
174+
# axs_orig[i].axis('off')
175+
# axs_new[i].imshow(Image.open(filenames[sample_idx]).resize((224, 224)))
176+
# axs_new[i].set_title(f"{dataset_domains[int(domains[int(sample_idx)])]} - {labels[int(sample_idx)]}")
177+
# axs_new[i].axis('off')
178+
# except:
179+
# print(f"sample idx {sample_idx} is not a valid index")
180+
# wandb.log({"train features NN": wandb.Image(f), "domain consistency acc": domain_acc, "class consistency acc": class_acc, "unique nn": prop_unique})
181+
# # wandb.sklearn.plot_confusion_matrix(sample_domains, neighbor_domains, dataset_domains)
182+
# print("Plotted Nearest Neighbors")
183+
184+
# retrain the model on the augmented dataset
185+
lp.train_debias(train_features, train_labels, train_groups, train_domains, val_features, val_labels, np.squeeze(val_groups), val_domains)
186+
lp_val_predictions, lp_val_probs = lp.eval(val_features)
187+
lp_val_accuracy, lp_val_balanced_acc, lp_val_class_accuracy, lp_val_group_accuracy = evaluate(lp_val_predictions, val_labels, np.squeeze(val_groups))
188+
log_wandb(lp_val_accuracy, lp_val_balanced_acc, lp_val_class_accuracy, lp_val_group_accuracy, tag='aug_lp_val')
189+
lp_test_predictions, lp_test_probs = lp.eval(test_features)
190+
lp_test_accuracy, lp_test_balanced_acc, lp_test_class_accuracy, lp_test_group_accuracy = evaluate(lp_test_predictions, test_labels, np.squeeze(test_groups))
191+
log_wandb(lp_test_accuracy, lp_test_balanced_acc, lp_test_class_accuracy, lp_test_group_accuracy, tag='aug_lp_test')

clip_utils.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,13 @@ def get_features(dataset, model, device, model_type):
4646
all_domains.append(domains)
4747
all_filenames.extend(filenames)
4848

49-
return torch.cat(all_features).cpu().numpy(), torch.cat(all_labels).cpu().numpy(), torch.cat(all_groups).cpu().numpy(), torch.cat(all_domains).cpu().numpy(), np.array(all_filenames)
49+
features = torch.cat(all_features).cpu().numpy()
50+
labels = torch.cat(all_labels).cpu().numpy()
51+
groups = torch.cat(all_groups).cpu().numpy()
52+
domains = torch.cat(all_domains).cpu().numpy()
53+
filenames = np.array(all_filenames)
54+
55+
return features, labels, groups, domains, filenames
5056

5157
def get_resnet_features(dataset, model, device):
5258
"""
@@ -77,6 +83,20 @@ def hook(model, input, output):
7783

7884
return torch.cat(all_features).cpu().numpy(), torch.cat(all_labels).cpu().numpy(), torch.cat(all_groups).cpu().numpy(), torch.cat(all_domains).cpu().numpy()
7985

86+
def load_embeddings(cache_file, dataset):
87+
"""
88+
Loads the embeddings from a file
89+
"""
90+
save_dict = torch.load(cache_file)
91+
train_features, train_labels, train_groups, train_domains, train_filenames = save_dict['train_features'], save_dict['train_labels'], save_dict['train_groups'], save_dict['train_domains'], save_dict['train_filenames']
92+
val_features, val_labels, val_groups, val_domains, val_filenames = save_dict['val_features'], save_dict['val_labels'], save_dict['val_groups'], save_dict['val_domains'], save_dict['val_filenames']
93+
test_features, test_labels, test_groups, test_domains, test_filenames = save_dict['test_features'], save_dict['test_labels'], save_dict['test_groups'], save_dict['test_domains'], save_dict['test_filenames']
94+
if dataset != 'ColoredMNISTBinary':
95+
old_val_features, old_val_labels, old_val_groups, old_val_domains, old_val_filenames = val_features, val_labels, val_groups, val_domains, val_filenames
96+
val_features, val_labels, val_groups, val_domains, val_filenames = val_features[::2], val_labels[::2], val_groups[::2], val_domains[::2], val_filenames[::2]
97+
test_features, test_labels, test_groups, test_domains, test_filenames = np.concatenate((test_features, old_val_features[1::2])), np.concatenate((test_labels, old_val_labels[1::2])), np.concatenate((test_groups, old_val_groups[1::2])), np.concatenate((test_domains, old_val_domains[1::2])), np.concatenate((test_filenames, old_val_filenames[1::2]))
98+
return train_features, train_labels, train_groups, train_domains, train_filenames, val_features, val_labels, val_groups, val_domains, val_filenames, test_features, test_labels, test_groups, test_domains, test_filenames
99+
80100
def projection(u, v):
81101
return (v * u).sum() / (u * u).sum() * u
82102

0 commit comments

Comments
 (0)