Skip to content

Commit 20b8c4e

Browse files
author
Lisa Dunlap
committed
fixed domainnet dataset config
1 parent 36c9854 commit 20b8c4e

File tree

3 files changed

+5
-2
lines changed

3 files changed

+5
-2
lines changed

data_configs/domain_net.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
DATA:
2+
SOURCE_DOMAIN: clipart # options are clipart, infograph, painting, quickdraw, real, sketch
3+
TARGET_DOMAIN: real # options are clipart, infograph, painting, quickdraw, real, sketch

helpers/data_helpers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ def get_config(name="Waterbirds"):
2626
cfg = OmegaConf.load('data_configs/waterbirds.yaml')
2727
elif "ColoredMNIST" in name:
2828
cfg = OmegaConf.load('data_configs/colored_mnist.yaml')
29+
elif "DomainNet" in name:
30+
cfg = OmegaConf.load('data_configs/domain_net.yaml')
2931
else:
3032
raise ValueError(f"{name} Dataset config not found")
3133
args = OmegaConf.merge(base_cfg, cfg)

methods/clip_transformations.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,6 @@ def train_debias(self, inputs, labels, groups, dom_gt, test_inputs, test_labels,
186186
if self.cfg.EXP.CHECKPOINT_VAL:
187187
self.train_val_loop(self.test_loader, epoch, phase="val")
188188
self.save_checkpoint(0.0, epoch, last=True)
189-
# wandb.ss['best val balanced acc'] = self.best_acc
190-
# wandb.ss['best epoch'] = self.best_epoch
191189

192190
def load_checkpoint(self, path):
193191
checkpoint = torch.load(path)

0 commit comments

Comments
 (0)