Skip to content

Commit f7e42ac

Browse files
committed
Merge remote-tracking branch 'downstream/dev' into dev
2 parents 7b096f6 + 7a4257e commit f7e42ac

File tree

1 file changed

+20
-19
lines changed

1 file changed

+20
-19
lines changed

experiments/train_with_era5land.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,17 @@
3030
logger = logging.getLogger(logger_name)
3131
logger.setLevel(logging.INFO)
3232

33-
show = pd.read_csv(
34-
os.path.join(pathlib.Path(__file__).parent.parent, "data/basin_us.csv"),
35-
dtype={"id": str},
36-
)
37-
gage_id = show["id"].values.tolist()
38-
# gage_id = ["songliao_21401550"]
33+
# show = pd.read_csv(
34+
# os.path.join(pathlib.Path(__file__).parent.parent, "data/basin_us.csv"),
35+
# dtype={"id": str},
36+
# )
37+
# gage_id = show["id"].values.tolist()
38+
gage_id = ["songliao_21401550", "songliao_21401050"]
3939

4040

4141
def config():
4242
# 设置测试所需的项目名称和默认配置文件
43-
project_name = os.path.join("train_with_era5land", "ex7_us_basins_new_torchhydro")
43+
project_name = os.path.join("train_with_era5land", "ex_test")
4444
config_data = default_config_file()
4545

4646
# 填充测试所需的命令行参数
@@ -50,15 +50,15 @@ def config():
5050
"source_name": "selfmadehydrodataset",
5151
"source_path": SETTING["local_data_path"]["datasets-interim"],
5252
"other_settings": {
53-
"time_unit": ["3h"],
53+
"time_unit": ["1h"],
5454
},
5555
},
5656
ctx=[2],
5757
model_name="Seq2Seq",
5858
model_hyperparam={
59-
"en_input_size": 17,
60-
"de_input_size": 18,
61-
"output_size": 2,
59+
"en_input_size": 16,
60+
"de_input_size": 17,
61+
"output_size": 1,
6262
"hidden_size": 256,
6363
"forecast_length": 56,
6464
"hindcast_output_window": 1,
@@ -71,11 +71,11 @@ def config():
7171
hindcast_length=240,
7272
forecast_length=56,
7373
min_time_unit="h",
74-
min_time_interval=3,
74+
min_time_interval=1,
7575
var_t=[
7676
# "precipitationCal",
7777
"total_precipitation_hourly",
78-
"sm_surface",
78+
# "sm_surface",
7979
],
8080
var_c=[
8181
"area", # 面积
@@ -94,21 +94,21 @@ def config():
9494
"cly_pc_sav", # 土壤中的黏土、粉砂、砂粒含量
9595
"dor_pc_pva", # 调节程度
9696
],
97-
var_out=["streamflow", "sm_surface"],
97+
var_out=["streamflow"],
9898
dataset="Seq2SeqDataset",
99-
sampler="BasinBatchSampler",
99+
# sampler="BasinBatchSampler",
100100
scaler="DapengScaler",
101-
train_epoch=100,
101+
train_epoch=3,
102102
save_epoch=1,
103-
train_period=["2015-06-01-01", "2022-11-01-01"],
103+
train_period=["2020-06-01-01", "2022-11-01-01"],
104104
test_period=["2022-11-01-01", "2023-12-01-01"],
105105
valid_period=["2023-11-01-01", "2023-12-01-01"],
106106
loss_func="MultiOutLoss",
107107
loss_param={
108108
"loss_funcs": "RMSESum",
109-
"data_gap": [0, 0],
109+
"data_gap": [0],
110110
"device": [2],
111-
"item_weight": [0.8, 0.2],
111+
"item_weight": [1],
112112
},
113113
opt="Adam",
114114
lr_scheduler={
@@ -118,6 +118,7 @@ def config():
118118
which_first_tensor="batch",
119119
calc_metrics=False,
120120
early_stopping=True,
121+
rolling=56,
121122
# ensemble=True,
122123
# ensemble_items={
123124
# "batch_sizes": [256, 512],

0 commit comments

Comments
 (0)