Skip to content

Add tf/DQN with dueling support#582

Merged
ahtsan merged 19 commits intomasterfrom
dqn
May 14, 2019
Merged

Add tf/DQN with dueling support#582
ahtsan merged 19 commits intomasterfrom
dqn

Conversation

@ahtsan
Copy link
Copy Markdown
Contributor

@ahtsan ahtsan commented Mar 12, 2019

  • DQN implementation with garage.Model.
  • Current for pixel environments we use a bunch of wrappers from baselines. Later we can setup data processing pipeline to make that faster.
  • Removed the single-layer mlp in cnn. I think it makes more sense to separate them, so now cnn will return the flattened output. Also modified the unit test for cnn accordingly.
  • Used self.models to store all the models in QFunction. Sometimes there are multiple models and we need something to keep track of them. I assume the models are stored in order, so input must be self.models[0].input and output must be self.models[-1].output.
  • Added corresponding functions and property in QFunction to make reference easier, e.g. def q_vals() and input property.
  • Clone method in QFunction to enable object copying (I think there should be better ways). Since all the necessary operations will be done in object construction, we can simply create a new object with a different name (because we want the new object will have a different variable_scope).

Will post benchmark result soon, fixing the scaling.
Will add more tests later.

@ahtsan ahtsan requested a review from a team as a code owner March 12, 2019 19:46
obs_ph = tf.placeholder(tf.float32, (None, ) + obs_dim, name="obs")

self.model.build(obs_ph)
with tf.variable_scope(self._variable_scope):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with self._variable_scope

Copy link
Copy Markdown
Contributor Author

@ahtsan ahtsan Mar 12, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._variable_scope is a VariableScope object, we have to do

with tf.variable_scope(VariableScopeObject):

to reenter the scope.
(see https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/variable_scope.py#L1804).
Also VariableScope is pickleable, variable_scope is not.

out = model.build(out)

def q_vals(self):
return self.models[-1].networks['default'].outputs
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not use a dict?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean store q_vals in a dict? or storing the models in a dict?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the models. if the order of insertion changes at any time, your code will break.

you can also just keep the models in instance variables (self.model1 etc.) and also add them to the list self.models...

Comment thread examples/tf/dqn_breakout.py Outdated
strides=(4, 2, 1),
dueling=False)

policy = DiscreteQfDerivedPolicy(env_spec=env, qf=qf)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

env.spec

algo.train(sess)


run_experiment(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update this with LocalRunner

Comment thread examples/tf/dqn_cartpole.py Outdated
num_timesteps=num_timesteps,
qf_lr=1e-4,
discount=1.0,
min_buffer_size=1e3,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you added an int() call to min_buffer_size?

Comment thread examples/tf/dqn_cartpole.py Outdated

replay_buffer = SimpleReplayBuffer(
env_spec=env.spec,
size_in_transitions=int(10000),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int(1e4) or 10000

Comment thread examples/tf/dqn_cartpole.py Outdated
qf = DiscreteMLPQFunction(
env_spec=env.spec, hidden_sizes=(64, 64), dueling=False)

policy = DiscreteQfDerivedPolicy(env_spec=env, qf=qf)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

env.spec

Comment thread garage/tf/algos/dqn.py Outdated
qf_lr=0.001,
qf_optimizer=tf.train.AdamOptimizer,
discount=1.0,
name=None,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

name='DQN'

Comment thread garage/tf/models/cnn_model.py Outdated
num_filters,
strides,
name=None,
padding="SAME",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to use single quotations for new files.

filter_dims,
num_filters,
strides,
name=None,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

name='CNNModel'

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Model has a default name if name=None, which will be the class name. How should we do this? If we enforce all derived model class to have name, we don't need the default name at all.

Copy link
Copy Markdown
Member

@ryanjulian ryanjulian Mar 12, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like you have already violated the Model interface by making name an kwarg:

def __init__(self, name):

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yes, so a name is actually required.

Comment thread garage/tf/q_functions/discrete_cnn_q_function.py
for model in self.models:
out = model.build(out)

def q_vals(self):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@property



class TestDQN(TfGraphTestCase):
def test_dqn_cartpole(self):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How long does it take to run this test?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

around 30s.

@CatherineSue CatherineSue changed the title DQN with dueling support Add tf/DQN with dueling support Mar 12, 2019
Comment thread garage/tf/algos/dqn.py Outdated
episode_rewards.append(0.)

for itr in range(self.num_timesteps):
with logger.prefix('Iteration #%d | ' % itr):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'Timestep' sounds more appropriate to me.

Comment thread garage/tf/algos/dqn.py Outdated
self._dueling = dueling

obs_dim = self._env_spec.observation_space.shape
action_dim = env_spec.action_space.flat_dim
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._env_spec.action_space.flat_dim

@@ -0,0 +1,190 @@
"""Discrete MLP QFunction."""
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should state this CNN network actually supports CNN2MLP. This is not the same as the CNN* primitive in the current garage. If you think the naming is ok, please add more details to this documentation.

out = model.build(out, name=name)
return out

def clone(self, name):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should clone interface be in the base class?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it will eventually be an interface in Model too.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update it into the base class.

Comment thread tests/garage/tf/algos/test_dqn.py Outdated
size_in_transitions=int(5000),
time_horizon=max_path_length)
qf = DiscreteMLPQFunction(env_spec=env.spec, hidden_sizes=(64, 64))
policy = DiscreteQfDerivedPolicy(env_spec=env, qf=qf)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

env_spec

Comment thread garage/tf/algos/dqn.py
Comment thread tests/garage/tf/algos/test_dqn.py Outdated
@@ -0,0 +1,55 @@
"""
This script creates a test that fails when garage.tf.algos.DDPG performance is
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

garage.tf.algos.DQN

@ahtsan
Copy link
Copy Markdown
Contributor Author

ahtsan commented Mar 29, 2019

Benchmark with 1M timesteps and random seed=26
SpaceInvadersNoFrameskip-v4_1M
SeaquestNoFrameskip-v4_1M
QbertNoFrameskip-v4_1M
PongNoFrameskip-v4_1M
EnduroNoFrameskip-v4_1M
BreakoutNoFrameskip-v4_1M
BeamRiderNoFrameskip-v4_1M

Copy link
Copy Markdown
Member

@ryanjulian ryanjulian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent work. +1000

Please address other reviewers' comments, rebase, and submit!

@ahtsan ahtsan force-pushed the dqn branch 2 times, most recently from 3890c4d to 5b95308 Compare May 1, 2019 16:04
@codecov
Copy link
Copy Markdown

codecov Bot commented May 1, 2019

Codecov Report

Merging #582 into master will increase coverage by 1.1%.
The diff coverage is 97%.

Impacted file tree graph

@@            Coverage Diff            @@
##           master     #582     +/-   ##
=========================================
+ Coverage   62.25%   63.35%   +1.1%     
=========================================
  Files         163      169      +6     
  Lines        9532     9786    +254     
  Branches     1267     1284     +17     
=========================================
+ Hits         5934     6200    +266     
+ Misses       3288     3269     -19     
- Partials      310      317      +7
Impacted Files Coverage Δ
garage/tf/core/cnn.py 100% <ø> (+3.12%) ⬆️
garage/tf/models/mlp_model.py 100% <ø> (ø) ⬆️
garage/tf/misc/tensor_utils.py 75.73% <100%> (+2.13%) ⬆️
garage/tf/algos/__init__.py 100% <100%> (ø) ⬆️
garage/tf/q_functions/discrete_mlp_q_function.py 100% <100%> (+57.89%) ⬆️
garage/tf/models/__init__.py 100% <100%> (ø) ⬆️
garage/tf/models/cnn_model_max_pooling.py 100% <100%> (ø)
garage/envs/wrappers/__init__.py 100% <100%> (ø) ⬆️
garage/tf/algos/ddpg.py 78.76% <100%> (-0.98%) ⬇️
garage/envs/wrappers/fire_reset.py 100% <100%> (ø) ⬆️
... and 32 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update ee6e96c...ca08e16. Read the comment docs.

@ahtsan
Copy link
Copy Markdown
Contributor Author

ahtsan commented May 1, 2019

Fixed all comments above. I will add more tests.

@ahtsan ahtsan force-pushed the dqn branch 2 times, most recently from b2983a1 to b5d8ff4 Compare May 9, 2019 07:06
Comment thread garage/tf/q_functions/base2.py
@ahtsan ahtsan force-pushed the dqn branch 3 times, most recently from 2dac995 to 4dc7abb Compare May 11, 2019 04:56
Comment thread garage/tf/algos/dqn.py Outdated
Comment thread garage/tf/algos/dqn.py Outdated
Comment thread garage/tf/misc/tensor_utils.py Outdated
Comment thread garage/tf/models/cnn_model_max_pooling.py Outdated
Comment thread garage/envs/wrappers/atari_env.py
Comment thread garage/misc/tensor_utils.py
Comment thread garage/tf/algos/dqn.py
Comment thread garage/tf/q_functions/base2.py Outdated
Comment thread garage/tf/q_functions/base2.py Outdated
# Select which episodes to use
time_horizon = buffer["action"].shape[1]
rollout_batch_size = buffer["action"].shape[0]
time_horizon = buffer['action'].shape[1]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it still valid to have time_horizon since the replay buffer now stores variable-length episodes? It may beyond the scope of this PR. I just think the interface seems a bit confusing now.

Copy link
Copy Markdown
Contributor Author

@ahtsan ahtsan May 13, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it because there is no max_path_length anymore in off policy algos?

Comment thread tests/garage/misc/test_tensor_utils.py
Comment thread garage/tf/algos/dqn.py
Copy link
Copy Markdown
Member

@CatherineSue CatherineSue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only have some minor comments.

@ahtsan ahtsan merged commit b96d154 into master May 14, 2019
@ahtsan ahtsan deleted the dqn branch May 14, 2019 05:30
@ryanjulian ryanjulian mentioned this pull request May 14, 2019
nish21 pushed a commit that referenced this pull request May 14, 2019
DQN implementation with garage.Model.

This is the first algorithm for pixel environments. This PR
adds the algorithm as well as the models, primitives and
environment wrappers required for training in pixel environments.

* Models
  * MLPDuelingModel
  * CNNModelWithMaxPooling
* Primitives
  * QFunction2 (base class, without parameterized)
  * DiscreteCNNQFunction
* Wrappers
  * AtariEnvWrapper (needed when using env wrappers from baselines)

The eviction policy of replay buffer used to be random. To
make experiments determinisitic, it is changed to
First In First Out (FIFO). It was proven to be necessary in
order to achieve better result in complex environment for DQN.

Added corresponding properties in QFunction to make reference
easier, e.g. q_vals.

Added clone method in QFunction to enable copying configuration,
not including the parameters. Since all the necessary operations
will be done in object construction, we can simply create a new
object with a different name (because we want the new object
to have a different name for variable scope).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants