|
| 1 | +import os |
| 2 | +import clip |
| 3 | +import open_clip |
| 4 | +import torch |
| 5 | +import numpy as np |
| 6 | +import torchvision |
| 7 | +import wandb |
| 8 | +import argparse |
| 9 | +from PIL import Image |
| 10 | +import matplotlib.pyplot as plt |
| 11 | +import random |
| 12 | +import omegaconf |
| 13 | +from omegaconf import OmegaConf |
| 14 | + |
| 15 | +import helpers.data_helpers as dh |
| 16 | +import methods.clip_transformations as CLIPTransformations |
| 17 | +from utils import read_unknowns, nest_dict |
| 18 | +from clip_utils import get_features, evaluate, zeroshot_classifier, get_ensamble_preds, get_pred_overlap, get_nn_metrics |
| 19 | +import methods.augmentations |
| 20 | + |
| 21 | +parser = argparse.ArgumentParser(description='CLIP Advice') |
| 22 | +parser.add_argument('--config', default='configs/base.yaml', help="config file") |
| 23 | +parser.add_argument('overrides', nargs='*', help="Any key=value arguments to override config values " |
| 24 | + "(use dots for.nested=overrides)") |
| 25 | +# flags = parser.parse_args() |
| 26 | +flags, unknown = parser.parse_known_args() |
| 27 | + |
| 28 | +overrides = OmegaConf.from_cli(flags.overrides) |
| 29 | +cfg = OmegaConf.load(flags.config) |
| 30 | +base = OmegaConf.load('configs/base.yaml') |
| 31 | +args = OmegaConf.merge(base, cfg, overrides) |
| 32 | +if len(unknown) > 0: |
| 33 | + print(unknown) |
| 34 | + config = nest_dict(read_unknowns(unknown)) |
| 35 | + to_merge = OmegaConf.create(config) |
| 36 | + args = OmegaConf.merge(args, to_merge) |
| 37 | +args.yaml = flags.config |
| 38 | + |
| 39 | +assert args.EXP.ADVICE_METHOD != 'CNN', "main.py not for CNN baseline, use train.py" |
| 40 | +assert args.EXP.ADVICE_METHOD != 'CLIPZS', "main.py not for CLIP zero-shot, use clip_zs.py" |
| 41 | + |
| 42 | +if args.EXP.WANDB_SILENT: |
| 43 | + os.environ['WANDB_SILENT']="true" |
| 44 | + |
| 45 | +def flatten_config(dic, running_key=None, flattened_dict={}): |
| 46 | + for key, value in dic.items(): |
| 47 | + if running_key is None: |
| 48 | + running_key_temp = key |
| 49 | + else: |
| 50 | + running_key_temp = '{}.{}'.format(running_key, key) |
| 51 | + if isinstance(value, omegaconf.dictconfig.DictConfig): |
| 52 | + flatten_config(value, running_key_temp) |
| 53 | + else: |
| 54 | + #print(running_key_temp, value) |
| 55 | + flattened_dict[running_key_temp] = value |
| 56 | + return flattened_dict |
| 57 | + |
| 58 | +run = wandb.init(project='debug', group=args.EXP.ADVICE_METHOD, config=flatten_config(args), allow_val_change=False) |
| 59 | +# wandb.save(flags.config) |
| 60 | +# wandb.run.log_code(".") |
| 61 | + |
| 62 | +torch.manual_seed(args.EXP.SEED) |
| 63 | +np.random.seed(args.EXP.SEED) |
| 64 | +random.seed(args.EXP.SEED) |
| 65 | + |
| 66 | +DATASET_NAME = args.DATA.DATASET |
| 67 | + |
| 68 | +# load data |
| 69 | +if args.DATA.LOAD_CACHED: |
| 70 | + print(args.DATA.LOAD_CACHED) |
| 71 | + if args.EXP.IMAGE_FEATURES == 'clip' or args.EXP.IMAGE_FEATURES == 'openclip': |
| 72 | + model_name = args.EXP.CLIP_MODEL |
| 73 | + else: |
| 74 | + model_name = args.EXP.IMAGE_FEATURES |
| 75 | + cache_file, dataset_classes, dataset_domains = dh.get_cache_file(DATASET_NAME, model_name, args.EXP.BIASED_VAL, args.EXP.IMAGE_FEATURES) |
| 76 | + assert os.path.exists(cache_file), f"{cache_file} does not exist. To compute embeddings, set DATA.LOAD_CACHED=False" |
| 77 | + data = torch.load(cache_file) |
| 78 | + train_features, train_labels, train_groups, train_domains, train_filenames = data['train_features'], data['train_labels'], data['train_groups'], data['train_domains'], data['train_filenames'] |
| 79 | + val_features, val_labels, val_groups, val_domains, val_filenames = data['val_features'], data['val_labels'], data['val_groups'], data['val_domains'], data['val_filenames'] |
| 80 | + test_features, test_labels, test_groups, test_domains, test_filenames = data['test_features'], data['test_labels'], data['test_groups'], data['test_domains'], data['test_filenames'] |
| 81 | + # move some val data to test |
| 82 | + if args.DATA.DATASET != 'ColoredMNISTBinary': |
| 83 | + val_features, val_labels, val_groups, val_domains, val_filenames = data['val_features'][::2], data['val_labels'][::2], data['val_groups'][::2], data['val_domains'][::2], data['val_filenames'][::2] |
| 84 | + test_features, test_labels, test_groups, test_domains, test_filenames = np.concatenate((data['test_features'], data['val_features'][1::2])), np.concatenate((data['test_labels'], data['val_labels'][1::2])), np.concatenate((data['test_groups'], data['val_groups'][1::2])), np.concatenate((data['test_domains'], data['val_domains'][1::2])), np.concatenate((data['test_filenames'], data['val_filenames'][1::2])) |
| 85 | + if args.METHOD.NORMALIZE: |
| 86 | + train_features /= np.linalg.norm(train_features, axis=-1, keepdims=True) |
| 87 | + val_features /= np.linalg.norm(val_features, axis=-1, keepdims=True) |
| 88 | + test_features /= np.linalg.norm(test_features, axis=-1, keepdims=True) |
| 89 | +# Load the model |
| 90 | +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 91 | +print(args.EXP.IMAGE_FEATURES) |
| 92 | +# clip_model, preprocess = clip.load(args.EXP.CLIP_MODEL, device) |
| 93 | +if args.EXP.IMAGE_FEATURES == 'clip': |
| 94 | + clip_model, preprocess = clip.load(args.EXP.CLIP_MODEL, device) |
| 95 | + model, preprocess = clip.load(args.EXP.CLIP_MODEL, device) |
| 96 | +elif args.EXP.IMAGE_FEATURES == 'openclip': |
| 97 | + model, _, preprocess = open_clip.create_model_and_transforms(args.EXP.CLIP_MODEL, pretrained=args.EXP.CLIP_PRETRAINED_DATASET) |
| 98 | + model = model.to(torch.device('cuda:0')) |
| 99 | + clip_model = model |
| 100 | +else: |
| 101 | + model = getattr(torchvision.models, args.EXP.IMAGE_FEATURES)(pretrained=True) |
| 102 | + model = model.to(device) |
| 103 | + |
| 104 | +# Calculate the image features |
| 105 | +prompts = list(args.EXP.TEXT_PROMPTS) |
| 106 | +if len(prompts) >0 and type(prompts[0]) == omegaconf.listconfig.ListConfig: |
| 107 | + prompts = [list(p) for p in prompts] |
| 108 | + |
| 109 | +neutral_prompts = list(args.EXP.NEUTRAL_TEXT_PROMPTS) |
| 110 | +if len(neutral_prompts) >0 and type(neutral_prompts[0]) == omegaconf.listconfig.ListConfig: |
| 111 | + neutral_prompts = [list(p) for p in neutral_prompts] |
| 112 | + |
| 113 | +lp = CLIPTransformations.ClipMLP(prompts, clip_model, args, neutral_prompts) |
| 114 | +zs = CLIPTransformations.CLIPZS(prompts, clip_model, args, neutral_prompts) |
| 115 | + |
| 116 | +# train MLP with domain adaptation loss |
| 117 | +lp.train_debias(train_features, train_labels, train_groups, train_domains, val_features, val_labels, np.squeeze(val_groups), val_domains) |
| 118 | + |
| 119 | +lp_val_predictions, lp_val_probs = lp.eval(val_features) |
| 120 | +zs_val_predictions, zs_val_probs = zs.eval(val_features) |
| 121 | + |
| 122 | +def log_wandb(acc, balanced_acc, class_acc, group_acc, tag='val'): |
| 123 | + wandb.summary[f"{tag}_accuracy"] = acc |
| 124 | + wandb.summary[f"{tag}_balanced_accuracy"] = balanced_acc |
| 125 | + wandb.summary[f"{tag}_group_accuracy"] = group_acc |
| 126 | + for d, d_acc in zip(dataset_domains, group_acc): |
| 127 | + wandb.summary[f"{tag}_{d}_acc"] = d_acc |
| 128 | + |
| 129 | +lp_val_accuracy, lp_val_balanced_acc, lp_val_class_accuracy, lp_val_group_accuracy = evaluate(lp_val_predictions, val_labels, np.squeeze(val_groups)) |
| 130 | +zs_val_accuracy, zs_val_balanced_acc, zs_val_class_accuracy, zs_val_group_accuracy = evaluate(zs_val_predictions, val_labels, np.squeeze(val_groups)) |
| 131 | +log_wandb(lp_val_accuracy, lp_val_balanced_acc, lp_val_class_accuracy, lp_val_group_accuracy, tag='lp_val') |
| 132 | +log_wandb(zs_val_accuracy, zs_val_balanced_acc, zs_val_class_accuracy, zs_val_group_accuracy, tag='zs_val') |
| 133 | + |
| 134 | +print('..........................................') |
| 135 | +print(f"LP val accuracy: {lp_val_accuracy} \t ZS val accuracy: {zs_val_accuracy}") |
| 136 | +# acc_diff = lp_val_class_accuracy - zs_val_class_accuracy |
| 137 | +# print(lp_val_class_accuracy + zs_val_class_accuracy) |
| 138 | +acc_prop = np.nan_to_num(zs_val_class_accuracy / (lp_val_class_accuracy + zs_val_class_accuracy), nan=0.5) |
| 139 | +print(f"--------------- acc prop {acc_prop} {type(acc_prop)} {acc_prop.shape} np sum {np.sum(acc_prop)}") |
| 140 | +class_weights = acc_prop / np.sum(acc_prop) |
| 141 | +print(f"Accuracy difference: {class_weights[:10]} \t Accuracy proportion: {acc_prop[:10]} {np.sum(acc_prop)}") |
| 142 | +print('..........................................') |
| 143 | + |
| 144 | +old_val_features, old_val_labels, old_val_groups, old_val_domains, old_val_filenames = val_features, val_labels, val_groups, val_domains, val_filenames |
| 145 | +old_test_features, old_test_labels, old_test_groups, old_test_domains, old_test_filenames = test_features, test_labels, test_groups, test_domains, test_filenames |
| 146 | + |
| 147 | +if args.EXP.AUGMENTATION != None and args.EXP.AUGMENTATION != 'None': |
| 148 | + print("Augmenting training set...") |
| 149 | + if "LADS" in args.EXP.AUGMENTATION or 'Directional' in args.EXP.AUGMENTATION: |
| 150 | + augment = getattr(methods.augmentations, args.EXP.AUGMENTATION)(args, train_features, train_labels, train_groups, train_domains, train_filenames, lp.text_embeddings, val_features, val_labels, val_groups, val_domains, class_weights) |
| 151 | + else: |
| 152 | + augment = getattr(methods.augmentations, args.EXP.AUGMENTATION)(args, train_features, train_labels, train_groups, train_domains, train_filenames, lp.text_embeddings) |
| 153 | + train_features, train_labels, train_domains, train_groups, train_filenames = augment.augment_dataset() |
| 154 | + print("Training set augmented!") |
| 155 | + |
| 156 | +# if args.EXP.LOG_NN: |
| 157 | +# features, labels, groups, domains, filenames = np.concatenate([old_val_features, old_test_features]), np.concatenate([old_val_labels, old_test_labels]), np.concatenate([old_val_groups, old_test_groups]), np.concatenate([old_val_domains, old_test_domains]), np.concatenate([old_val_filenames, old_test_filenames]) |
| 158 | +# # features, labels, groups, domains, filenames = old_test_features, old_test_labels, old_test_groups, old_test_domains, old_test_filenames |
| 159 | +# if len(np.unique(train_domains)) > 1: |
| 160 | +# filtered_idxs = np.where(train_domains != train_domains[0]) |
| 161 | +# sample_features, sample_domains, sample_labels, sample_filenames = np.array(train_features[filtered_idxs]), train_domains[filtered_idxs], train_labels[filtered_idxs], train_filenames[filtered_idxs] |
| 162 | +# sample_idxs = random.sample(list(range(len(sample_filenames))), min((len(train_filenames), 1000))) |
| 163 | +# sample_features, sample_domains, sample_labels, sample_filenames = sample_features[sample_idxs], sample_domains[sample_idxs], sample_labels[sample_idxs], sample_filenames[sample_idxs] |
| 164 | +# else: |
| 165 | +# sample_idxs = random.sample(list(range(len(train_filenames))), min((len(train_filenames), 1000))) |
| 166 | +# sample_features, sample_domains, sample_labels, sample_filenames = train_features[sample_idxs], train_domains[sample_idxs], train_labels[sample_idxs], train_filenames[sample_idxs] |
| 167 | +# neighbor_domains, neighbor_labels, domain_acc, class_acc, neighbor_samples, prop_unique, mean_cs = get_nn_metrics(sample_features, sample_domains, sample_labels, features, domains, labels) |
| 168 | +# plt.rcParams["figure.figsize"] = (20,5) |
| 169 | +# f, (axs_orig, axs_new) = plt.subplots(2, 10, sharey=True) |
| 170 | +# for i, (original_idx, sample_idx) in enumerate(neighbor_samples): |
| 171 | +# try: |
| 172 | +# axs_orig[i].imshow(Image.open(sample_filenames[original_idx]).resize((224, 224))) |
| 173 | +# axs_orig[i].set_title(f"{dataset_domains[int(sample_domains[int(original_idx)])]} - {sample_labels[int(original_idx)]}") |
| 174 | +# axs_orig[i].axis('off') |
| 175 | +# axs_new[i].imshow(Image.open(filenames[sample_idx]).resize((224, 224))) |
| 176 | +# axs_new[i].set_title(f"{dataset_domains[int(domains[int(sample_idx)])]} - {labels[int(sample_idx)]}") |
| 177 | +# axs_new[i].axis('off') |
| 178 | +# except: |
| 179 | +# print(f"sample idx {sample_idx} is not a valid index") |
| 180 | +# wandb.log({"train features NN": wandb.Image(f), "domain consistency acc": domain_acc, "class consistency acc": class_acc, "unique nn": prop_unique}) |
| 181 | +# # wandb.sklearn.plot_confusion_matrix(sample_domains, neighbor_domains, dataset_domains) |
| 182 | +# print("Plotted Nearest Neighbors") |
| 183 | + |
| 184 | +# retrain the model on the augmented dataset |
| 185 | +lp.train_debias(train_features, train_labels, train_groups, train_domains, val_features, val_labels, np.squeeze(val_groups), val_domains) |
| 186 | +lp_val_predictions, lp_val_probs = lp.eval(val_features) |
| 187 | +lp_val_accuracy, lp_val_balanced_acc, lp_val_class_accuracy, lp_val_group_accuracy = evaluate(lp_val_predictions, val_labels, np.squeeze(val_groups)) |
| 188 | +log_wandb(lp_val_accuracy, lp_val_balanced_acc, lp_val_class_accuracy, lp_val_group_accuracy, tag='aug_lp_val') |
| 189 | +lp_test_predictions, lp_test_probs = lp.eval(test_features) |
| 190 | +lp_test_accuracy, lp_test_balanced_acc, lp_test_class_accuracy, lp_test_group_accuracy = evaluate(lp_test_predictions, test_labels, np.squeeze(test_groups)) |
| 191 | +log_wandb(lp_test_accuracy, lp_test_balanced_acc, lp_test_class_accuracy, lp_test_group_accuracy, tag='aug_lp_test') |
0 commit comments