@@ -127,7 +127,7 @@ def default_config_file():
127127 "t_range_test" : ["1993-01-01" , "1994-01-01" ],
128128 # the output
129129 "target_cols" : [Q_CAMELS_US_NAME ],
130- "target_rm_nan" : True ,
130+ "target_rm_nan" : False ,
131131 # only for cases in which target data will be used as input:
132132 # data assimilation -- use streamflow from period 0 to t-1 (TODO: not included now)
133133 # for physics-based model -- use streamflow to calibrate models
@@ -169,6 +169,7 @@ def default_config_file():
169169 # for each period, they have multiple forecast data with different lead time
170170 # hence we list them as a seperate type
171171 "forecast_cols" : None ,
172+ "forecast_rm_nan" : True ,
172173 # global variables such as ENSO indictors are used in some long term models
173174 "global_cols" : None ,
174175 # specify the data source of each variable
@@ -346,7 +347,7 @@ def cmd(
346347 lr_scheduler = None ,
347348 opt_param = None ,
348349 batch_size = None ,
349- warmup_length = 0 ,
350+ warmup_length = None ,
350351 # forecast_history will be deprecated in the future
351352 forecast_history = None ,
352353 hindcast_length = None ,
@@ -362,9 +363,9 @@ def cmd(
362363 weight_path = None ,
363364 continue_train = None ,
364365 var_c = None ,
365- c_rm_nan = 1 ,
366+ c_rm_nan = None ,
366367 var_t = None ,
367- t_rm_nan = 1 ,
368+ t_rm_nan = None ,
368369 n_output = None ,
369370 loss_func = None ,
370371 model_hyperparam = None ,
@@ -375,21 +376,22 @@ def cmd(
375376 var_g = None ,
376377 var_out = None ,
377378 var_to_source_map = None ,
378- out_rm_nan = 0 ,
379- target_as_input = 0 ,
380- constant_only = 0 ,
379+ out_rm_nan = None ,
380+ f_rm_nan = None ,
381+ target_as_input = None ,
382+ constant_only = None ,
381383 gage_id_screen = None ,
382384 loss_param = None ,
383385 metrics = None ,
384386 fill_nan = None ,
385387 explainer = None ,
386388 rolling = None ,
387389 calc_metrics = None ,
388- start_epoch = 1 ,
390+ start_epoch = None ,
389391 stat_dict_file = None ,
390392 num_workers = None ,
391393 which_first_tensor = None ,
392- ensemble = 0 ,
394+ ensemble = None ,
393395 ensemble_items = None ,
394396 early_stopping = None ,
395397 patience = None ,
@@ -746,6 +748,13 @@ def cmd(
746748 default = out_rm_nan ,
747749 type = int ,
748750 )
751+ parser .add_argument (
752+ "--f_rm_nan" ,
753+ dest = "f_rm_nan" ,
754+ help = "if true, we remove NaN value for var_f data when scaling" ,
755+ default = f_rm_nan ,
756+ type = int ,
757+ )
749758 parser .add_argument (
750759 "--target_as_input" ,
751760 dest = "target_as_input" ,
@@ -1013,7 +1022,8 @@ def update_cfg(cfg_file, new_args):
10131022 cfg_file ["data_cfgs" ]["constant_cols" ] = []
10141023 else :
10151024 cfg_file ["data_cfgs" ]["constant_cols" ] = new_args .var_c
1016- cfg_file ["data_cfgs" ]["constant_rm_nan" ] = bool (new_args .c_rm_nan != 0 )
1025+ if new_args .c_rm_nan is not None :
1026+ cfg_file ["data_cfgs" ]["constant_rm_nan" ] = bool (new_args .c_rm_nan > 0 )
10171027 if new_args .var_t is not None :
10181028 cfg_file ["data_cfgs" ]["relevant_cols" ] = new_args .var_t
10191029 print (
@@ -1022,7 +1032,8 @@ def update_cfg(cfg_file, new_args):
10221032 print ("If you have POTENTIAL_EVAPOTRANSPIRATION, please set it the 2nd!!!-" )
10231033 if new_args .var_t_type is not None :
10241034 cfg_file ["data_cfgs" ]["relevant_types" ] = new_args .var_t_type
1025- cfg_file ["data_cfgs" ]["relevant_rm_nan" ] = bool (new_args .t_rm_nan != 0 )
1035+ if new_args .t_rm_nan is not None :
1036+ cfg_file ["data_cfgs" ]["relevant_rm_nan" ] = bool (new_args .t_rm_nan > 0 )
10261037 if new_args .var_f is not None :
10271038 cfg_file ["data_cfgs" ]["forecast_cols" ] = new_args .var_f
10281039 if new_args .var_g is not None :
@@ -1034,10 +1045,14 @@ def update_cfg(cfg_file, new_args):
10341045 )
10351046 if new_args .var_to_source_map is not None :
10361047 cfg_file ["data_cfgs" ]["var_to_source_map" ] = new_args .var_to_source_map
1037- cfg_file ["data_cfgs" ]["target_rm_nan" ] = bool (new_args .out_rm_nan != 0 )
1038- if new_args .target_as_input == 0 :
1039- cfg_file ["data_cfgs" ]["target_as_input" ] = False
1040- cfg_file ["data_cfgs" ]["constant_only" ] = bool (new_args .constant_only != 0 )
1048+ if new_args .out_rm_nan is not None :
1049+ cfg_file ["data_cfgs" ]["target_rm_nan" ] = bool (new_args .out_rm_nan > 0 )
1050+ if new_args .f_rm_nan is not None :
1051+ cfg_file ["data_cfgs" ]["forecast_rm_nan" ] = bool (new_args .f_rm_nan > 0 )
1052+ if new_args .target_as_input is not None :
1053+ cfg_file ["data_cfgs" ]["target_as_input" ] = bool (new_args .target_as_input > 0 )
1054+ if new_args .constant_only is not None :
1055+ cfg_file ["data_cfgs" ]["constant_only" ] = bool (new_args .constant_only > 0 )
10411056 else :
10421057 cfg_file ["data_cfgs" ]["target_as_input" ] = True
10431058 if new_args .calc_metrics is not None :
@@ -1055,7 +1070,7 @@ def update_cfg(cfg_file, new_args):
10551070 if new_args .weight_path is not None :
10561071 cfg_file ["model_cfgs" ]["weight_path" ] = new_args .weight_path
10571072 continue_train = bool (
1058- new_args .continue_train is not None and new_args .continue_train != 0
1073+ new_args .continue_train is not None and new_args .continue_train > 0
10591074 )
10601075 cfg_file ["model_cfgs" ]["continue_train" ] = continue_train
10611076 if new_args .weight_path_add is not None :
@@ -1109,7 +1124,7 @@ def update_cfg(cfg_file, new_args):
11091124 cfg_file ["evaluation_cfgs" ]["rolling" ] = new_args .rolling
11101125 if new_args .model_loader is not None :
11111126 cfg_file ["evaluation_cfgs" ]["model_loader" ] = new_args .model_loader
1112- if new_args .warmup_length > 0 :
1127+ if new_args .warmup_length is not None :
11131128 cfg_file ["data_cfgs" ]["warmup_length" ] = new_args .warmup_length
11141129 if (
11151130 "warmup_length" in new_args .model_hyperparam .keys ()
@@ -1136,7 +1151,7 @@ def update_cfg(cfg_file, new_args):
11361151 if new_args .lead_time_start is None :
11371152 raise ValueError ("lead_time_start must be set when lead_time_type is set" )
11381153 cfg_file ["data_cfgs" ]["lead_time_start" ] = new_args .lead_time_start
1139- if new_args .start_epoch > 1 :
1154+ if new_args .start_epoch is not None :
11401155 cfg_file ["training_cfgs" ]["start_epoch" ] = new_args .start_epoch
11411156 if new_args .stat_dict_file is not None :
11421157 stat_dict_file = new_args .stat_dict_file
0 commit comments