-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patharguments.py
More file actions
120 lines (108 loc) · 7.94 KB
/
arguments.py
File metadata and controls
120 lines (108 loc) · 7.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import argparse
def args_parser():
parser = argparse.ArgumentParser(description='arguments')
# task setting
parser.add_argument('--task', type=str, default='readmission', help='train or eval for in-hospital-mortality, readmission, phenotyping')
parser.add_argument('--labels_set', type=str, default='readm', help='pheno, mortality, readm, radiology')
parser.add_argument('--num_classes', type=int, default=1, help='number of classes ihm:1, pheno:25')
parser.add_argument('--vision_num_classes', default=1, type=int, help='number of classes ihm:1, pheno:25')
# dataset setting
# cxr setting
parser.add_argument('--resize', default=256, type=int, help='number of epochs to train')
parser.add_argument('--crop', default=224, type=int, help='number of epochs to train')
# dataloader setting
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--num_workers', type=int, default=4)
# data pairs setting
parser.add_argument('--data_pairs', type=str, default='paired_ehr_note', help='paired, partial, ehr, cxr')
parser.add_argument('--data_ratio', type=float, default=1.0, help='percentage of uppaired data samples')
parser.add_argument('--missing_token', type=str, default=None, help='zeros, learnable')
# normalizer setting
parser.add_argument('--timestep', type=float, default=2.0, help="fixed timestep used in the dataset")
parser.add_argument('--imputation', type=str, default='previous')
parser.add_argument('--normalizer_state', type=str, default='/disk1/fwu/myProjects/MedFuse/data/ihm_ts.normalizer',
help='Path to a state file of a normalizer. Leave none if you want to use one of the provided ones.')
# training setting
parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='load imagenet pretrained model')
parser.add_argument('--mode', type=str, default="train", help='mode: train or test')
parser.add_argument('--eval', dest='eval', action='store_true', help='eval the pretrained models on val and test split')
parser.add_argument('--network', type=str)
parser.add_argument('--epochs', type=int, default=50, help='number of chunks to train')
parser.add_argument('--lr', type=float, default=0.0001, help='learning rate')
parser.add_argument('--beta_1', type=float, default=0.9, help='beta_1 param for Adam optimizer')
parser.add_argument('--patience', type=int, default=15, help='number of epoch to wait for best')
parser.add_argument('--load_state', type=str, default=None, help='state dir path')
parser.add_argument('--load_state_cxr', type=str, default=None, help='state dir path')
parser.add_argument('--load_state_ehr', type=str, default=None, help='state dir path')
parser.add_argument('--load_state_note', type=str, default=None, help='state dir path')
parser.add_argument('--resume', dest='resume', help='resume training from state to load', action='store_true')
parser.add_argument('--dropout', type=float, default=0.3)
parser.add_argument('--rec_dropout', type=float, default=0.0, help="dropout rate for recurrent connections")
# model setting
parser.add_argument('--fusion_type', type=str, default='dp', help='train or eval for [fused_ehr, fused_cxr, uni_cxr, uni_ehr, lstm, unified, mmtm, daft, dp, drfuse]')
# backbone setting
# vision backbone setting
parser.add_argument('--layers', default=2, type=int, help='number of lstm stacked modules')
parser.add_argument('--vision_backbone', default='resnet34', type=str, help='[densenet121, densenet169, densenet201]')
parser.add_argument('--dim', type=int, default=256, help='number of hidden units')
parser.add_argument('--depth', type=int, default=1, help='number of bi-LSTMs')
# text backbone setting
parser.add_argument('--bert_type', type=str, default="huawei-noah/TinyBERT_General_4L_312D", help='name of the bert pretrained model',
choices=['huawei-noah/TinyBERT_General_4L_312D', 'emilyalsentzer/Bio_ClinicalBERT', 'allenai/biomed_roberta_base', 'bert-base-uncased', 'yikuan8/Clinical-Longformer'])
parser.add_argument('--orig_d_txt', type=int, default=312, help='hidden_size of the bert model 768(max_length=1024)/312(max_length=512)')
parser.add_argument('--d_txt', type=int, default=512, help='hidden_size of the bert model 768')
# fusion setting
parser.add_argument('--fusion', type=str, default='joint', help='train or eval for [early late joint]')
parser.add_argument('--align', type=float, default=0, help='align weight')
# mmtm setting
parser.add_argument('--layer_after', default=4, type=int, help='apply mmtm module after fourth layer -1 indicates mmtm after every layer')
parser.add_argument('--mmtm_ratio', type=float, default=4, help='mmtm ratio hyperparameter')
# daft setting
parser.add_argument('--daft_activation', type=str, default='linear', help='daft activation ')
# dp setting
parser.add_argument('--dp', type=float, default=0.000001, help='dp weight')
parser.add_argument('--dp_fuse_type', type=str, default=None, help='dp_fuse_type: lstm, mha')
parser.add_argument('--dp_normalize_feats', action="store_true")
parser.add_argument('--dp_resample', type=bool, default=True)
parser.add_argument('--replace_w_align', type=str, default=None, help='replace_w_align: kl, cos, na')
# copula setting
parser.add_argument('--copula', type=float, default=0.000001, help='cupula weight')
parser.add_argument('--copula_fuse_type', type=str, default='lstm', help='cupula_fuse_type: lstm, mha')
parser.add_argument('--copula_normalize_feats', action="store_true")
parser.add_argument('--copula_resample', type=bool, default=True)
parser.add_argument('--copula_family', type=str, default='Frank', help='cupula_family: Frank, Gumbel, Gaussian')
# drfuse setting
parser.add_argument('--lambda_disentangle_shared', type=float, default=1)
parser.add_argument('--lambda_disentangle_ehr', type=float, default=1)
parser.add_argument('--lambda_disentangle_cxr', type=float, default=1)
parser.add_argument('--lambda_disentangle_note', type=float, default=1)
parser.add_argument('--lambda_pred_ehr', type=float, default=1)
parser.add_argument('--lambda_pred_cxr', type=float, default=1)
parser.add_argument('--lambda_pred_note', type=float, default=1)
parser.add_argument('--lambda_pred_shared', type=float, default=1)
parser.add_argument('--aug_missing_ratio', type=float, default=0.3)
parser.add_argument('--lambda_attn_aux', type=float, default=1)
parser.add_argument('--ehr_n_layers', type=int, default=1)
parser.add_argument('--ehr_n_head', type=int, default=4)
parser.add_argument('--hidden_size', type=int, default=256)
parser.add_argument('--wd', type=float, default=0)
parser.add_argument('--adaptive_adc_lambda', action="store_true")
parser.add_argument('--attn_fusion', action="store_true")
parser.add_argument('--gamma', type=float, default=0)
# path setting
parser.add_argument('--ehr_data_dir', type=str, help='Path to the data of phenotyping fusion_type',
default='/disk1/fwu/myProjects/MedFuse/data/')
parser.add_argument('--cxr_data_dir', type=str, help='Path to the data of phenotyping fusion_type',
default='/disk1/fwu/myProjects/MedFuse/data/mimic-cxr/')
parser.add_argument('--save_dir', type=str, help='Directory relative which all output files are stored',
default='checkpoints/debug')
# Copula parameters
parser.add_argument('--K', type=int, default=3)
parser.add_argument('--rho_scale', type=float, default=-3)
parser.add_argument('--eta', type=float, default=1)
# Temperature annealing
parser.add_argument('--temperature', type=float, default=0.0001)
# some default setting
parser.set_defaults(dp_normalize_feats=True)
parser.set_defaults(copula_normalize_feats=True)
return parser