Skip to content

Commit 85ef229

Browse files
author
Lisa Dunlap
committed
verified waterbirds results
1 parent b7941c6 commit 85ef229

File tree

9 files changed

+104
-981
lines changed

9 files changed

+104
-981
lines changed

README.md

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ Official Implementation of [LADS (Latent Augmentation using Domain descriptionS)
33

44
![LADS method overview.](figs/lads-method-2-1.png "LADS method overview")
55

6-
*WARNING: this is still WIP, please raise an issue or email me if you run into any bugs.*
6+
*WARNING: this is still WIP, please raise an issue if you run into any bugs.*
77

88
```
99
@article{dunlap2023lads,
@@ -14,14 +14,6 @@ Official Implementation of [LADS (Latent Augmentation using Domain descriptionS)
1414
}
1515
```
1616

17-
## TODOs
18-
[X] clean up emb saving/loading
19-
[] fix the Directional vs LADS acc diff
20-
[X] add e2e method for DA
21-
[] get E2E to work well
22-
[X] add in selective augmentation (run lp, check per class acc, augment poor performing finetuned classes more towards the text emb)
23-
[] run 2 layer mlp baselines
24-
2517
## Getting started
2618

2719
1. Install the dependencies for our code using Conda. You may need to adjust the environment YAML file depending on your setup.
@@ -38,21 +30,30 @@ Official Implementation of [LADS (Latent Augmentation using Domain descriptionS)
3830
## Code Structure
3931
The configurations for each method are in the `configs` folder. To try say the baseline of doing normal LR on the CLIP embeddings:
4032
```
41-
python main.py --config configs/Waterbirds/base.yaml
33+
python main.py --config configs/Waterbirds/mlp.yaml
34+
```
35+
36+
you can also override parameters like so
37+
```
38+
python main.py --config configs/Waterbirds/mlp.yaml METHOD.MODEL.LR=0.1 EXP.PROJ=new_project
4239
```
4340
41+
### Datasets
42+
4443
Datasets supported are in the [helpers folder](./helpers/data_helpers.py). Currently they are:
45-
* Waterbirds (100% and 95%)
46-
* ColoredMNIST (LNTL version and simplified version)
47-
* DomainNet
48-
* CUB Paintings
49-
* OfficeHome
44+
* Waterbirds (100% and 95%) [our specific split](https://drive.google.com/file/d/1zJpQYGEt1SuwitlNfE06TFyLaWX-st1k/view) [code to generate data](https://github.com/kohpangwei/group_DRO)
45+
* ColoredMNIST (LNTL version and simplified version) NOTEBOOK COMING SOON
46+
* DomainNet (the version used in the paper is `DATA.DATASET=DomainNetMini`) [full dataset](http://ai.bu.edu/DomainNet/)
47+
* CUB Paintings [photos dataset](https://www.vision.caltech.edu/datasets/cub_200_2011/) [paintings dataset](https://github.com/thuml/PAN)
48+
* OfficeHome COMING SOON
49+
50+
You can download the CLIP embeddings of these datasets [here](https://drive.google.com/drive/folders/1ItjhX7RPfQ6fQQk6_bEYJPewnkVdcfOC?usp=sharing). We also have the embeddings for CUB, Waterbirds, and DomainNetMini in the [embeddings](./embeddings/) folder.
5051
51-
You can download the CLIP embeddings of these datasets [here](https://drive.google.com/drive/folders/1ItjhX7RPfQ6fQQk6_bEYJPewnkVdcfOC?usp=sharing)
52+
Since computing the CLIP embeddings for each train/val/test set is time consuming, you can store the embeddings by setting `DATA.LOAD_CACHED=False`, then it should store the embeddings into a file `embeddings/{dataset}/clip_{openai,LAION}_{model_name}`
5253
53-
Since computing the CLIP embeddings for each train/val/test set is time consuming, you can store the embeddings by setting `DATA.LOAD_CACHED=False` and `DATA.SAVE_PATH=[path you want to save to]`
54+
### Methods
5455
55-
Then, add the path to the saved embeddings to DATASET_PATHS in [data_helpers](./helpers/data_helpers.py) and set `DATA.LOAD_CACHED=Tue` in your yaml file
56+
All the augmenation methods (i.e. LADS and BiasLADS) are in `methods/augmentations`, while the classifiers and baselines are in `methods/clip_transformations.py`
5657
5758
More description of each method and the config files in the config folder.
5859

configs/CUB/lads.yaml

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ EXP:
66
TEXT_PROMPTS: ['a painting of a {} bird.']
77
NEUTRAL_TEXT_PROMPTS: ['a photo of a {} bird.']
88
AUGMENTATION: 'LADS'
9-
EPOCHS: 200
9+
EPOCHS: 400
1010
ENSAMBLE: True
1111

1212

@@ -18,12 +18,9 @@ DATA:
1818
METHOD:
1919
MODEL:
2020
NUM_LAYERS: 1
21-
DOM_WEIGHT: 1.0
2221
LR: 0.001
2322
WEIGHT_DECAY: 0.05
24-
CHECKPOINT_NAME: 'lads'
25-
RESUME: False
26-
USE_DOM_GT: True
23+
CHECKPOINT_NAME: 'cub_lp'
2724

2825
AUGMENTATION:
2926
MODEL:

configs/DomainNet/test.yaml

Lines changed: 0 additions & 22 deletions
This file was deleted.

configs/DomainNet/test_aug.yaml

Lines changed: 0 additions & 37 deletions
This file was deleted.

configs/Waterbirds/lads.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,16 @@ METHOD:
3030
AUGMENTATION:
3131
MODEL:
3232
LR: 0.005
33-
WEIGHT_DECAY: 0.005
34-
NUM_LAYERS: 2
35-
HIDDEN_DIM: 384
33+
WEIGHT_DECAY: 0.05
34+
NUM_LAYERS: 1
35+
HIDDEN_DIM: 512
3636
EPOCHS: 50
3737
GENERIC: False
3838
DOM_LABELS: ["forest", "water"]
39-
DOM_SPECIFIC_XE: False
39+
DOM_SPECIFIC_XE: true
4040
ALPHA: 0.75
4141
# CLIP_NN_LOSS: True
4242
# COMPARE_BEFORE_AUG: True
4343
# NN_INCLUDE_SAMPLE: True
44-
DOM_WEIGHT: 1.0
45-
NN_WEIGHT: 1.0
44+
# DOM_WEIGHT: 1.0
45+
# NN_WEIGHT: 1.0

configs/base.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,4 @@ AUGMENTATION:
6161
DOM_SPECIFIC_XE: False
6262
DOM_LABELS: []
6363
NN_WEIGHT: 0.0
64-
REG_WEIGHT: 0.0
64+
REG_WEIGHT: 0.1

main.py

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -85,36 +85,13 @@ def flatten_config(dic, running_key=None, flattened_dict={}):
8585
model = model.to(device)
8686

8787
# # load data
88-
# if args.DATA.LOAD_CACHED:
89-
# cache_file = f"{args.DATA.SAVE_PATH}/{args.DATA.DATASET}/{args.EXP.IMAGE_FEATURES}_{args.EXP.CLIP_PRETRAINED_DATASET}_{args.EXP.CLIP_MODEL.replace('/','_')}.pt"
90-
# dataset_classes, dataset_domains = dh.DATASET_CLASSES[args.DATA.DATASET], dh.DATASET_DOMAINS[args.DATA.DATASET]
91-
# assert os.path.exists(cache_file), f"{cache_file} does not exist. To compute embeddings, set DATA.LOAD_CACHED=False"
92-
# print(f"Loading cached embeddings from {cache_file}")
93-
# train_features, train_labels, train_groups, train_domains, train_filenames, val_features, val_labels, val_groups, val_domains, val_filenames, test_features, test_labels, test_groups, test_domains, test_filenames = load_embeddings(cache_file, args.DATA.DATASET)
94-
# load data
95-
# if args.DATA.LOAD_CACHED:
96-
# print(args.DATA.LOAD_CACHED)
97-
# if args.EXP.IMAGE_FEATURES == 'clip' or args.EXP.IMAGE_FEATURES == 'openclip':
98-
# model_name = args.EXP.CLIP_MODEL
99-
# else:
100-
# model_name = args.EXP.IMAGE_FEATURES
101-
# cache_file, dataset_classes, dataset_domains = dh.get_cache_file(DATASET_NAME, model_name, args.EXP.IMAGE_FEATURES)
102-
# assert os.path.exists(cache_file), f"{cache_file} does not exist. To compute embeddings, set DATA.LOAD_CACHED=False"
103-
# data = torch.load(cache_file)
104-
# train_features, train_labels, train_groups, train_domains, train_filenames = data['train_features'], data['train_labels'], data['train_groups'], data['train_domains'], data['train_filenames']
105-
# val_features, val_labels, val_groups, val_domains, val_filenames = data['val_features'], data['val_labels'], data['val_groups'], data['val_domains'], data['val_filenames']
106-
# test_features, test_labels, test_groups, test_domains, test_filenames = data['test_features'], data['test_labels'], data['test_groups'], data['test_domains'], data['test_filenames']
107-
# # move some val data to test
108-
# if args.DATA.DATASET != 'ColoredMNISTBinary':
109-
# val_features, val_labels, val_groups, val_domains, val_filenames = data['val_features'][::2], data['val_labels'][::2], data['val_groups'][::2], data['val_domains'][::2], data['val_filenames'][::2]
110-
# test_features, test_labels, test_groups, test_domains, test_filenames = np.concatenate((data['test_features'], data['val_features'][1::2])), np.concatenate((data['test_labels'], data['val_labels'][1::2])), np.concatenate((data['test_groups'], data['val_groups'][1::2])), np.concatenate((data['test_domains'], data['val_domains'][1::2])), np.concatenate((data['test_filenames'], data['val_filenames'][1::2]))
11188
cache_file = f"{args.DATA.SAVE_PATH}/{args.DATA.DATASET}/{args.EXP.IMAGE_FEATURES}_{args.EXP.CLIP_PRETRAINED_DATASET}_{args.EXP.CLIP_MODEL.replace('/','_')}.pt"
11289
dataset_classes, dataset_domains = dh.DATASET_CLASSES[args.DATA.DATASET], dh.DATASET_DOMAINS[args.DATA.DATASET]
11390
if os.path.exists(cache_file):
11491
print(f"Loading cached embeddings from {cache_file}")
11592
train_features, train_labels, train_groups, train_domains, train_filenames, val_features, val_labels, val_groups, val_domains, val_filenames, test_features, test_labels, test_groups, test_domains, test_filenames = load_embeddings(cache_file, args.DATA.DATASET)
11693
else:
117-
# print(f"Computing embeddings and saving to {cache_file}")
94+
print(f"Computing embeddings and saving to {cache_file}")
11895
trainset, valset, testset = dh.get_dataset(DATASET_NAME, preprocess)
11996
dataset_classes, dataset_domains = dh.get_class(DATASET_NAME), dh.get_domain(DATASET_NAME)
12097
train_loader = torch.utils.data.DataLoader(trainset, batch_size=cfg.DATA.BATCH_SIZE, shuffle=True)
@@ -154,7 +131,7 @@ def flatten_config(dic, running_key=None, flattened_dict={}):
154131
print("Advice Method", args.EXP.ADVICE_METHOD)
155132
bias_correction = getattr(CLIPTransformations, args.EXP.ADVICE_METHOD)(prompts, clip_model, args, neutral_prompts)
156133

157-
# old_train_features, old_train_labels, old_train_groups, old_train_domains, old_train_filenames = train_features, train_labels, train_groups, train_domains, train_filenames
134+
158135
old_val_features, old_val_labels, old_val_groups, old_val_domains, old_val_filenames = val_features, val_labels, val_groups, val_domains, val_filenames
159136
old_test_features, old_test_labels, old_test_groups, old_test_domains, old_test_filenames = test_features, test_labels, test_groups, test_domains, test_filenames
160137

@@ -251,26 +228,21 @@ def flatten_config(dic, running_key=None, flattened_dict={}):
251228
print(f"Test accuracy: {group_accuracy} \n Test domain accuracy: {domain_accuracy}")
252229

253230
if 'E2E' in args.EXP.ADVICE_METHOD:
254-
# features, labels, groups, domains, filenames = np.concatenate([old_val_features, old_test_features]), np.concatenate([old_val_labels, old_test_labels]), np.concatenate([old_val_groups, old_test_groups]), np.concatenate([old_val_domains, old_test_domains]), np.concatenate([old_val_filenames, old_test_filenames])
255231
aug_features, aug_labels, aug_domains, aug_filenames = bias_correction.augment_dataset(train_features, train_labels, train_domains, train_filenames)
256232
sample_idxs = random.sample(list(range(len(aug_filenames))), 1000)
257-
# print("SAMPLE SHAPE: ", sample_filenames.shape, sample_domains.shape)
258233
sample_features, sample_domains, sample_labels, sample_filenames = aug_features[sample_idxs], aug_domains[sample_idxs], aug_labels[sample_idxs], aug_filenames[sample_idxs]
259234
neighbor_domains, neighbor_labels, domain_acc, class_acc, neighbor_samples, prop_unique, mean_cs = get_nn_metrics(sample_features, sample_domains, sample_labels, old_test_features, old_test_domains, old_test_labels)
260235
wandb.log({"mean CS for NN": mean_cs})
261236
print(neighbor_samples)
262237
plt.rcParams["figure.figsize"] = (20,5)
263238
f, (axs_orig, axs_new) = plt.subplots(2, 10, sharey=True)
264239
for i, (original_idx, sample_idx) in enumerate(neighbor_samples):
265-
# try:
266240
print(sample_filenames[original_idx])
267241
axs_orig[i].imshow(Image.open(sample_filenames[original_idx]).resize((224, 224)))
268242
axs_orig[i].set_title(f"{dataset_domains[int(sample_domains[int(original_idx)])]} - {sample_labels[int(original_idx)]}")
269243
axs_orig[i].axis('off')
270244
axs_new[i].imshow(Image.open(old_test_filenames[sample_idx]).resize((224, 224)))
271245
axs_new[i].set_title(f"{dataset_domains[int(old_test_domains[int(sample_idx)])]} - {old_test_labels[int(sample_idx)]}")
272246
axs_new[i].axis('off')
273-
# except:
274-
# print(f"sample idx {sample_idx} is not a valid index")
275247
wandb.log({"train features NN": wandb.Image(f), "domain consistency acc": domain_acc, "class consistency acc": class_acc, "unique nn": prop_unique})
276248
wandb.sklearn.plot_confusion_matrix(sample_domains, neighbor_domains, dataset_domains)

0 commit comments

Comments
 (0)