1- from yahoo_finance import Share
21from matplotlib import pyplot as plt
32import numpy as np
43import random
54import tensorflow as tf
65import random
7-
6+ import pandas as pd
7+ pd .core .common .is_list_like = pd .api .types .is_list_like
8+ from pandas_datareader import data
9+ import datetime
10+ import requests_cache
811
912class DecisionPolicy :
1013 def select_action (self , current_state , step ):
@@ -43,7 +46,7 @@ def __init__(self, actions, input_dim):
4346 loss = tf .square (self .y - self .q )
4447 self .train_op = tf .train .AdagradOptimizer (0.01 ).minimize (loss )
4548 self .sess = tf .Session ()
46- self .sess .run (tf .initialize_all_variables ())
49+ self .sess .run (tf .global_variables_initializer ())
4750
4851 def select_action (self , current_state , step ):
4952 threshold = min (self .epsilon , step / 1000. )
@@ -108,17 +111,12 @@ def run_simulations(policy, budget, num_stocks, prices, hist):
108111 return avg , std
109112
110113
111- def get_prices (share_symbol , start_date , end_date , cache_filename = 'stock_prices.npy' ):
112- try :
113- stock_prices = np .load (cache_filename )
114- except IOError :
115- share = Share (share_symbol )
116- stock_hist = share .get_historical (start_date , end_date )
117- stock_prices = [stock_price ['Open' ] for stock_price in stock_hist ]
118- np .save (cache_filename , stock_prices )
119-
120- return stock_prices
121-
114+ def get_prices (share_symbol , start_date , end_date ):
115+ expire_after = datetime .timedelta (days = 3 )
116+ session = requests_cache .CachedSession (cache_name = 'cache' , backend = 'sqlite' , expire_after = expire_after )
117+ stock_hist = data .DataReader (share_symbol , 'iex' , start_date , end_date , session = session )
118+ open_prices = stock_hist ['open' ]
119+ return open_prices .values .tolist ()
122120
123121def plot_prices (prices ):
124122 plt .title ('Opening stock prices' )
@@ -129,7 +127,7 @@ def plot_prices(prices):
129127
130128
131129if __name__ == '__main__' :
132- prices = get_prices ('MSFT' , '1992 -07-22' , '2016 -07-22' )
130+ prices = get_prices ('MSFT' , '2013 -07-22' , '2018 -07-22' )
133131 plot_prices (prices )
134132 actions = ['Buy' , 'Sell' , 'Hold' ]
135133 hist = 200
0 commit comments