@@ -78,7 +78,6 @@ def __init__(
7878 self .optimizer = optimizer .lower ()
7979 self .loss = loss
8080 self .device = torch .device ("cuda:%d" % (GPU ) if torch .cuda .is_available () and GPU >= 0 else "cpu" )
81- self .use_gpu = torch .cuda .is_available ()
8281 self .seed = seed
8382
8483 self .logger .info (
@@ -94,7 +93,7 @@ def __init__(
9493 "\n early_stop : {}"
9594 "\n optimizer : {}"
9695 "\n loss_type : {}"
97- "\n visible_GPU : {}"
96+ "\n device : {}"
9897 "\n use_GPU : {}"
9998 "\n seed : {}" .format (
10099 d_feat ,
@@ -108,7 +107,7 @@ def __init__(
108107 early_stop ,
109108 optimizer .lower (),
110109 loss ,
111- GPU ,
110+ self . device ,
112111 self .use_gpu ,
113112 seed ,
114113 )
@@ -137,6 +136,10 @@ def __init__(
137136 self .fitted = False
138137 self .ALSTM_model .to (self .device )
139138
139+ @property
140+ def use_gpu (self ):
141+ return self .device != torch .device ("cpu" )
142+
140143 def mse (self , pred , label ):
141144 loss = (pred - label ) ** 2
142145 return torch .mean (loss )
@@ -205,12 +208,13 @@ def test_epoch(self, data_x, data_y):
205208 feature = torch .from_numpy (x_values [indices [i : i + self .batch_size ]]).float ().to (self .device )
206209 label = torch .from_numpy (y_values [indices [i : i + self .batch_size ]]).float ().to (self .device )
207210
208- pred = self .ALSTM_model (feature )
209- loss = self .loss_fn (pred , label )
210- losses .append (loss .item ())
211+ with torch .no_grad ():
212+ pred = self .ALSTM_model (feature )
213+ loss = self .loss_fn (pred , label )
214+ losses .append (loss .item ())
211215
212- score = self .metric_fn (pred , label )
213- scores .append (score .item ())
216+ score = self .metric_fn (pred , label )
217+ scores .append (score .item ())
214218
215219 return np .mean (losses ), np .mean (scores )
216220
@@ -292,10 +296,7 @@ def predict(self, dataset):
292296 x_batch = torch .from_numpy (x_values [begin :end ]).float ().to (self .device )
293297
294298 with torch .no_grad ():
295- if self .use_gpu :
296- pred = self .ALSTM_model (x_batch ).detach ().cpu ().numpy ()
297- else :
298- pred = self .ALSTM_model (x_batch ).detach ().numpy ()
299+ pred = self .ALSTM_model (x_batch ).detach ().cpu ().numpy ()
299300
300301 preds .append (pred )
301302
0 commit comments