Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/tf/cluster_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def run_task(v):
baseline = LinearFeatureBaseline(env_spec=env.spec)

algo = TRPO(
env=env,
env_spec=env.spec,
policy=policy,
baseline=baseline,
max_path_length=100,
Expand Down
2 changes: 1 addition & 1 deletion examples/tf/cluster_gym_mujoco_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def run_task(vv):
baseline = LinearFeatureBaseline(env_spec=env.spec)

algo = TRPO(
env=env,
env_spec=env.spec,
policy=policy,
baseline=baseline,
max_path_length=100,
Expand Down
2 changes: 1 addition & 1 deletion examples/tf/ddpg_pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def run_task(*_):
env_spec=env.spec, size_in_transitions=int(1e6), time_horizon=100)

ddpg = DDPG(
env,
env_spec=env.spec,
policy=policy,
policy_lr=1e-4,
qf_lr=1e-3,
Expand Down
2 changes: 1 addition & 1 deletion examples/tf/erwr_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def run_task(*_):
baseline = LinearFeatureBaseline(env_spec=env.spec)

algo = ERWR(
env=env,
env_spec=env.spec,
policy=policy,
baseline=baseline,
max_path_length=100,
Expand Down
4 changes: 2 additions & 2 deletions examples/tf/her_ddpg_fetchreach.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,14 @@ def run_task(*_):
reward_fun=env.compute_reward)

ddpg = DDPG(
env,
env_spec=env.spec,
policy=policy,
policy_lr=1e-3,
qf_lr=1e-3,
qf=qf,
replay_buffer=replay_buffer,
plot=False,
target_update_tau=0.05,
n_epoch_cycles=20,
max_path_length=100,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought this line is deleted in off policy algorithm?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually not. I keep n_epoch_cycles so that an algorithm can compute epoch by iteration/n_epoch_cycles. This is a little bit awkward but required in DDPG.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite understand. I meant is max_path_length=100 not necessary here?

n_train_steps=40,
discount=0.9,
Expand Down
2 changes: 1 addition & 1 deletion examples/tf/ppo_pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def run_task(*_):
)

algo = PPO(
env=env,
env_spec=env.spec,
policy=policy,
baseline=baseline,
max_path_length=100,
Expand Down
2 changes: 1 addition & 1 deletion examples/tf/reps_gym_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def run_task(*_):
baseline = LinearFeatureBaseline(env_spec=env.spec)

algo = REPS(
env=env,
env_spec=env.spec,
policy=policy,
baseline=baseline,
max_path_length=100,
Expand Down
2 changes: 1 addition & 1 deletion examples/tf/trpo_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def run_task(*_):
baseline = LinearFeatureBaseline(env_spec=env.spec)

algo = TRPO(
env=env,
env_spec=env.spec,
policy=policy,
baseline=baseline,
max_path_length=100,
Expand Down
2 changes: 1 addition & 1 deletion examples/tf/trpo_cartpole_batch_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def run_task(*_):
baseline = LinearFeatureBaseline(env_spec=env.spec)

algo = TRPO(
env=env,
env_spec=env.spec,
policy=policy,
baseline=baseline,
max_path_length=max_path_length,
Expand Down
2 changes: 1 addition & 1 deletion examples/tf/trpo_cartpole_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def run_task(*_):
baseline = LinearFeatureBaseline(env_spec=env.spec)

algo = TRPO(
env=env,
env_spec=env.spec,
policy=policy,
baseline=baseline,
max_path_length=100,
Expand Down
2 changes: 1 addition & 1 deletion examples/tf/trpo_gym_tf_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def run_task(*_):
baseline = LinearFeatureBaseline(env_spec=env.spec)

algo = TRPO(
env=env,
env_spec=env.spec,
policy=policy,
baseline=baseline,
max_path_length=200,
Expand Down
2 changes: 1 addition & 1 deletion examples/tf/trpo_swimmer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def run_task(*_):
baseline = LinearFeatureBaseline(env_spec=env.spec)

algo = TRPO(
env=env,
env_spec=env.spec,
policy=policy,
baseline=baseline,
max_path_length=500,
Expand Down
2 changes: 1 addition & 1 deletion examples/tf/vpg_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def run_task(*_):
baseline = LinearFeatureBaseline(env_spec=env.spec)

algo = VPG(
env=env,
env_spec=env.spec,
policy=policy,
baseline=baseline,
max_path_length=100,
Expand Down
20 changes: 12 additions & 8 deletions garage/experiment/local_tf_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,19 +121,20 @@ def setup(self, algo, env, sampler_cls=None, sampler_args=None):
from garage.tf.samplers import OffPolicyVectorizedSampler
sampler_cls = OffPolicyVectorizedSampler

self.sampler = sampler_cls(algo, **sampler_args)
self.sampler = sampler_cls(algo, env, **sampler_args)

self.initialize_tf_vars()
self.has_setup = True

def initialize_tf_vars(self):
"""Initialize all uninitialized variables in session."""
self.sess.run(
tf.variables_initializer([
v for v in tf.global_variables()
if v.name.split(':')[0] in str(
self.sess.run(tf.report_uninitialized_variables()))
]))
with tf.name_scope("initialize_tf_vars"):
self.sess.run(
tf.variables_initializer([
v for v in tf.global_variables()
if v.name.split(':')[0] in str(
self.sess.run(tf.report_uninitialized_variables()))
]))

def start_worker(self):
"""Start Plotter and Sampler workers."""
Expand Down Expand Up @@ -173,8 +174,11 @@ def save_snapshot(self, itr, paths=None):
paths: Batch of samples after preprocessed.

"""
assert self.has_setup

logger.log("Saving snapshot...")
params = self.algo.get_itr_snapshot(itr, paths)
params = self.algo.get_itr_snapshot(itr)
params['env'] = self.env
Comment thread
naeioi marked this conversation as resolved.
if paths:
params["paths"] = paths
Comment thread
ryanjulian marked this conversation as resolved.
logger.save_itr_params(itr, params)
Expand Down
3 changes: 2 additions & 1 deletion garage/sampler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@ def shutdown_worker(self):


class BaseSampler(Sampler):
def __init__(self, algo):
def __init__(self, algo, env):
"""
:type algo: BatchPolopt
"""
self.algo = algo
self.env = env

def process_samples(self, itr, paths):
baselines = []
Expand Down
7 changes: 4 additions & 3 deletions garage/tf/algos/batch_polopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class BatchPolopt(RLAlgorithm):
"""

def __init__(self,
env,
env_spec,
policy,
baseline,
scope=None,
Expand All @@ -24,7 +24,8 @@ def __init__(self,
fixed_horizon=False,
**kwargs):
"""
:param env: Environment
:param env_spec: Environment specification.
:type env_spec: EnvSpec
:param policy: Policy
:type policy: Policy
:param baseline: Baseline
Expand All @@ -41,7 +42,7 @@ def __init__(self,
advantages will be standardized before shifting.
:return:
"""
self.env = env
self.env_spec = env_spec
self.policy = policy
self.baseline = baseline
self.scope = scope
Expand Down
30 changes: 10 additions & 20 deletions garage/tf/algos/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class DDPG(OffPolicyRLAlgorithm):
"""

def __init__(self,
env,
env_spec,
replay_buffer,
target_update_tau=0.01,
policy_lr=1e-4,
Expand All @@ -47,7 +47,7 @@ def __init__(self,
Construct class.

Args:
env(): Environment.
env_spec(EnvSpec): Environment specification.
target_update_tau(float): Interpolation parameter for doing the
soft target update.
discount(float): Discount factor for the cumulative return.
Expand All @@ -65,7 +65,7 @@ def __init__(self,
max_action(float): Maximum action magnitude.
name(str): Name of the algorithm shown in computation graph.
"""
action_bound = env.action_space.high
action_bound = env_spec.action_space.high
self.max_action = action_bound if max_action is None else max_action
self.tau = target_update_tau
self.policy_lr = policy_lr
Expand All @@ -86,7 +86,7 @@ def __init__(self,
self.epoch_qs = []

super(DDPG, self).__init__(
env=env,
env_spec=env_spec,
replay_buffer=replay_buffer,
use_target=True,
discount=discount,
Expand Down Expand Up @@ -119,18 +119,18 @@ def init_opt(self):

with tf.name_scope("inputs"):
if self.input_include_goal:
obs_dim = self.env.observation_space.flat_dim_with_keys(
["observation", "desired_goal"])
obs_dim = self.env_spec.observation_space.\
flat_dim_with_keys(["observation", "desired_goal"])
else:
obs_dim = self.env.observation_space.flat_dim
obs_dim = self.env_spec.observation_space.flat_dim
y = tf.placeholder(tf.float32, shape=(None, 1), name="input_y")
obs = tf.placeholder(
tf.float32,
shape=(None, obs_dim),
name="input_observation")
actions = tf.placeholder(
tf.float32,
shape=(None, self.env.action_space.flat_dim),
shape=(None, self.env_spec.action_space.flat_dim),
name="input_action")

# Set up policy training function
Expand Down Expand Up @@ -196,10 +196,6 @@ def train_once(self, itr, paths):

if itr % self.n_epoch_cycles == 0:
logger.log("Training finished")
logger.log("Saving snapshot #{}".format(int(epoch)))
params = self.get_itr_snapshot(epoch, paths)
logger.save_itr_params(epoch, params)
logger.log("Saved")
if self.evaluate:
logger.record_tabular('Epoch', epoch)
logger.record_tabular('AverageReturn',
Expand Down Expand Up @@ -231,12 +227,6 @@ def train_once(self, itr, paths):
self.epoch_ys = []
self.epoch_qs = []

if self.plot:
self.plotter.update_plot(self.policy, self.max_path_length)
if self.pause_for_plot:
input("Plotting evaluation run: Press Enter to "
"continue...")

self.success_history.clear()

return last_average_return
Expand Down Expand Up @@ -289,8 +279,8 @@ def optimize_policy(self, itr, samples_data):
return qval_loss, ys, qval, action_loss

@overrides
def get_itr_snapshot(self, itr, samples_data):
return dict(itr=itr, policy=self.policy, env=self.env)
def get_itr_snapshot(self, itr):
return dict(itr=itr, policy=self.policy)


def get_target_ops(variables, target_variables, tau):
Expand Down
3 changes: 1 addition & 2 deletions garage/tf/algos/npo.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,11 @@ def optimize_policy(self, itr, samples_data):
self._fit_baseline(samples_data)

@overrides
def get_itr_snapshot(self, itr, samples_data):
def get_itr_snapshot(self, itr):
return dict(
itr=itr,
policy=self.policy,
baseline=self.baseline,
env=self.env,
)

def _build_inputs(self):
Expand Down
22 changes: 2 additions & 20 deletions garage/tf/algos/off_policy_rl_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,19 @@
Off-policy algorithms such as DQN, DDPG can inherit from it.
"""
from garage.algos import RLAlgorithm
from garage.tf.samplers import BatchSampler
from garage.tf.samplers import OffPolicyVectorizedSampler


class OffPolicyRLAlgorithm(RLAlgorithm):
"""This class implements OffPolicyRLAlgorithm."""

def __init__(
self,
env,
env_spec,
policy,
qf,
replay_buffer,
use_target=False,
discount=0.99,
n_epochs=500,
n_epoch_cycles=20,
max_path_length=None,
n_train_steps=50,
Expand All @@ -29,15 +26,10 @@ def __init__(
reward_scale=1.,
input_include_goal=False,
smooth_return=True,
sampler_cls=None,
sampler_args=None,
force_batch_sampler=False,
plot=False,
pause_for_plot=False,
exploration_strategy=None,
):
"""Construct an OffPolicyRLAlgorithm class."""
self.env = env
self.env_spec = env_spec
self.policy = policy
self.qf = qf
self.replay_buffer = replay_buffer
Expand All @@ -52,17 +44,7 @@ def __init__(
self.evaluate = False
self.input_include_goal = input_include_goal
self.smooth_return = smooth_return
if sampler_cls is None:
if policy.vectorized and not force_batch_sampler:
sampler_cls = OffPolicyVectorizedSampler
else:
sampler_cls = BatchSampler
if sampler_args is None:
sampler_args = dict()
self.sampler = sampler_cls(self, **sampler_args)
self.max_path_length = max_path_length
self.plot = plot
self.pause_for_plot = pause_for_plot
self.es = exploration_strategy
self.init_opt()

Expand Down
Loading