Skip to content

Commit 701b18a

Browse files
SunsetWolfLinlang Lv (iSoftStone)
andauthored
fix_issue_715 (microsoft#1070)
* fix_issue_715 * fix_issue_1065 Co-authored-by: Linlang Lv (iSoftStone) <v-linlanglv@microsoft.com>
1 parent 84ff662 commit 701b18a

File tree

2 files changed

+31
-39
lines changed

2 files changed

+31
-39
lines changed

qlib/contrib/model/pytorch_adarnn.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def __init__(
144144
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
145145

146146
self.fitted = False
147-
self.model.cuda()
147+
self.model.to(self.device)
148148

149149
@property
150150
def use_gpu(self):
@@ -153,7 +153,7 @@ def use_gpu(self):
153153
def train_AdaRNN(self, train_loader_list, epoch, dist_old=None, weight_mat=None):
154154
self.model.train()
155155
criterion = nn.MSELoss()
156-
dist_mat = torch.zeros(self.num_layers, self.len_seq).cuda()
156+
dist_mat = torch.zeros(self.num_layers, self.len_seq).to(self.device)
157157
len_loader = np.inf
158158
for loader in train_loader_list:
159159
if len(loader) < len_loader:
@@ -165,7 +165,7 @@ def train_AdaRNN(self, train_loader_list, epoch, dist_old=None, weight_mat=None)
165165
list_label = []
166166
for data in data_all:
167167
# feature :[36, 24, 6]
168-
feature, label_reg = data[0].cuda().float(), data[1].cuda().float()
168+
feature, label_reg = data[0].to(self.device).float(), data[1].to(self.device).float()
169169
list_feat.append(feature)
170170
list_label.append(label_reg)
171171
flag = False
@@ -179,7 +179,7 @@ def train_AdaRNN(self, train_loader_list, epoch, dist_old=None, weight_mat=None)
179179
if flag:
180180
continue
181181

182-
total_loss = torch.zeros(1).cuda()
182+
total_loss = torch.zeros(1).to(self.device)
183183
for i, n in enumerate(index):
184184
feature_s = list_feat[n[0]]
185185
feature_t = list_feat[n[1]]
@@ -325,7 +325,7 @@ def infer(self, x_test):
325325
else:
326326
end = begin + self.batch_size
327327

328-
x_batch = torch.from_numpy(x_values[begin:end]).float().cuda()
328+
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
329329

330330
with torch.no_grad():
331331
pred = self.model.predict(x_batch).detach().cpu().numpy()
@@ -335,7 +335,7 @@ def infer(self, x_test):
335335
return pd.Series(np.concatenate(preds), index=index)
336336

337337
def transform_type(self, init_weight):
338-
weight = torch.ones(self.num_layers, self.len_seq).cuda()
338+
weight = torch.ones(self.num_layers, self.len_seq).to(self.device)
339339
for i in range(self.num_layers):
340340
for j in range(self.len_seq):
341341
weight[i, j] = init_weight[i][j].item()
@@ -389,6 +389,7 @@ def __init__(
389389
len_seq=9,
390390
model_type="AdaRNN",
391391
trans_loss="mmd",
392+
GPU=0,
392393
):
393394
super(AdaRNN, self).__init__()
394395
self.use_bottleneck = use_bottleneck
@@ -399,6 +400,7 @@ def __init__(
399400
self.model_type = model_type
400401
self.trans_loss = trans_loss
401402
self.len_seq = len_seq
403+
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
402404
in_size = self.n_input
403405

404406
features = nn.ModuleList()
@@ -455,7 +457,7 @@ def forward_pre_train(self, x, len_win=0):
455457

456458
out_list_all, out_weight_list = out[1], out[2]
457459
out_list_s, out_list_t = self.get_features(out_list_all)
458-
loss_transfer = torch.zeros((1,)).cuda()
460+
loss_transfer = torch.zeros((1,)).to(self.device)
459461
for i, n in enumerate(out_list_s):
460462
criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=n.shape[2])
461463
h_start = 0
@@ -516,12 +518,12 @@ def forward_Boosting(self, x, weight_mat=None):
516518

517519
out_list_all = out[1]
518520
out_list_s, out_list_t = self.get_features(out_list_all)
519-
loss_transfer = torch.zeros((1,)).cuda()
521+
loss_transfer = torch.zeros((1,)).to(self.device)
520522
if weight_mat is None:
521-
weight = (1.0 / self.len_seq * torch.ones(self.num_layers, self.len_seq)).cuda()
523+
weight = (1.0 / self.len_seq * torch.ones(self.num_layers, self.len_seq)).to(self.device)
522524
else:
523525
weight = weight_mat
524-
dist_mat = torch.zeros(self.num_layers, self.len_seq).cuda()
526+
dist_mat = torch.zeros(self.num_layers, self.len_seq).to(self.device)
525527
for i, n in enumerate(out_list_s):
526528
criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=n.shape[2])
527529
for j in range(self.len_seq):
@@ -553,12 +555,13 @@ def predict(self, x):
553555

554556

555557
class TransferLoss:
556-
def __init__(self, loss_type="cosine", input_dim=512):
558+
def __init__(self, loss_type="cosine", input_dim=512, GPU=0):
557559
"""
558560
Supported loss_type: mmd(mmd_lin), mmd_rbf, coral, cosine, kl, js, mine, adv
559561
"""
560562
self.loss_type = loss_type
561563
self.input_dim = input_dim
564+
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
562565

563566
def compute(self, X, Y):
564567
"""Compute adaptation loss
@@ -574,18 +577,18 @@ def compute(self, X, Y):
574577
mmdloss = MMD_loss(kernel_type="linear")
575578
loss = mmdloss(X, Y)
576579
elif self.loss_type == "coral":
577-
loss = CORAL(X, Y)
580+
loss = CORAL(X, Y, self.device)
578581
elif self.loss_type in ("cosine", "cos"):
579582
loss = 1 - cosine(X, Y)
580583
elif self.loss_type == "kl":
581584
loss = kl_div(X, Y)
582585
elif self.loss_type == "js":
583586
loss = js(X, Y)
584587
elif self.loss_type == "mine":
585-
mine_model = Mine_estimator(input_dim=self.input_dim, hidden_dim=60).cuda()
588+
mine_model = Mine_estimator(input_dim=self.input_dim, hidden_dim=60).to(self.device)
586589
loss = mine_model(X, Y)
587590
elif self.loss_type == "adv":
588-
loss = adv(X, Y, input_dim=self.input_dim, hidden_dim=32)
591+
loss = adv(X, Y, self.device, input_dim=self.input_dim, hidden_dim=32)
589592
elif self.loss_type == "mmd_rbf":
590593
mmdloss = MMD_loss(kernel_type="rbf")
591594
loss = mmdloss(X, Y)
@@ -630,12 +633,12 @@ def forward(self, x):
630633
return x
631634

632635

633-
def adv(source, target, input_dim=256, hidden_dim=512):
636+
def adv(source, target, device, input_dim=256, hidden_dim=512):
634637
domain_loss = nn.BCELoss()
635638
# !!! Pay attention to .cuda !!!
636-
adv_net = Discriminator(input_dim, hidden_dim).cuda()
637-
domain_src = torch.ones(len(source)).cuda()
638-
domain_tar = torch.zeros(len(target)).cuda()
639+
adv_net = Discriminator(input_dim, hidden_dim).to(device)
640+
domain_src = torch.ones(len(source)).to(device)
641+
domain_tar = torch.zeros(len(target)).to(device)
639642
domain_src, domain_tar = domain_src.view(domain_src.shape[0], 1), domain_tar.view(domain_tar.shape[0], 1)
640643
reverse_src = ReverseLayerF.apply(source, 1)
641644
reverse_tar = ReverseLayerF.apply(target, 1)
@@ -646,16 +649,16 @@ def adv(source, target, input_dim=256, hidden_dim=512):
646649
return loss
647650

648651

649-
def CORAL(source, target):
652+
def CORAL(source, target, device):
650653
d = source.size(1)
651654
ns, nt = source.size(0), target.size(0)
652655

653656
# source covariance
654-
tmp_s = torch.ones((1, ns)).cuda() @ source
657+
tmp_s = torch.ones((1, ns)).to(device) @ source
655658
cs = (source.t() @ source - (tmp_s.t() @ tmp_s) / ns) / (ns - 1)
656659

657660
# target covariance
658-
tmp_t = torch.ones((1, nt)).cuda() @ target
661+
tmp_t = torch.ones((1, nt)).to(device) @ target
659662
ct = (target.t() @ target - (tmp_t.t() @ tmp_t) / nt) / (nt - 1)
660663

661664
# frobenius norm

scripts/data_collector/cn_index/collector.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,14 @@ def index_code(self) -> str:
9090
raise NotImplementedError("rewrite index_code")
9191

9292
@property
93-
@abc.abstractmethod
9493
def html_table_index(self) -> int:
9594
"""Which table of changes in html
9695
9796
CSI300: 0
9897
CSI100: 1
9998
:return:
10099
"""
101-
raise NotImplementedError()
100+
raise NotImplementedError("rewrite html_table_index")
102101

103102
def format_datetime(self, inst_df: pd.DataFrame) -> pd.DataFrame:
104103
"""formatting the datetime in an instrument
@@ -184,12 +183,7 @@ def _parse_table(self, content: str, add_date: pd.DataFrame, remove_date: pd.Dat
184183
df = pd.DataFrame()
185184
_tmp_count = 0
186185
for _df in pd.read_html(content):
187-
if (
188-
_df.shape[-1] != 4
189-
or _df.iloc[2:,][0].str.contains(
190-
"."
191-
)[2]
192-
):
186+
if _df.shape[-1] != 4 or _df.isnull().loc(0)[0][0]:
193187
continue
194188
_tmp_count += 1
195189
if self.html_table_index + 1 > _tmp_count:
@@ -341,8 +335,8 @@ def bench_start_date(self) -> pd.Timestamp:
341335
return pd.Timestamp("2005-01-01")
342336

343337
@property
344-
def html_table_index(self):
345-
return 1
338+
def html_table_index(self) -> int:
339+
return 0
346340

347341

348342
class CSI100Index(CSIIndex):
@@ -355,8 +349,8 @@ def bench_start_date(self) -> pd.Timestamp:
355349
return pd.Timestamp("2006-05-29")
356350

357351
@property
358-
def html_table_index(self):
359-
return 2
352+
def html_table_index(self) -> int:
353+
return 1
360354

361355

362356
class CSI500Index(CSIIndex):
@@ -368,10 +362,6 @@ def index_code(self) -> str:
368362
def bench_start_date(self) -> pd.Timestamp:
369363
return pd.Timestamp("2007-01-15")
370364

371-
@property
372-
def html_table_index(self) -> int:
373-
return 0
374-
375365
def get_changes(self) -> pd.DataFrame:
376366
"""get companies changes
377367
@@ -475,5 +465,4 @@ def get_new_companies(self) -> pd.DataFrame:
475465

476466

477467
if __name__ == "__main__":
478-
get_instruments(index_name="CSI300", qlib_dir="~/.qlib/qlib_data/cn_data", method="parse_instruments")
479-
# fire.Fire(get_instruments)
468+
fire.Fire(get_instruments)

0 commit comments

Comments
 (0)