Skip to content

Refactor GRU with garage.tf.models.Model#668

Merged
ahtsan merged 3 commits intomasterfrom
gru
May 22, 2019
Merged

Refactor GRU with garage.tf.models.Model#668
ahtsan merged 3 commits intomasterfrom
gru

Conversation

@ahtsan
Copy link
Copy Markdown
Contributor

@ahtsan ahtsan commented May 17, 2019

Added GRU, GRUModel and CategoricalGRUPolicyWithModel.
It aims to replace the existing GRUNetwork and GRULayer,
which are based on garage.tf.core.layers.

Added test for TRPO with CategoricalGRUPolicyWithModel.

Apart from testing functionality of CategoricalGRUPolicyWithModel
in test_categorical_gru_policy_with_model.py, transitions from the
old model (CategoricalGRUPolicy) to the new model
(CategoricalGRUPolicyWithModel) are also tested in
test_categorical_gru_policy_with_model_transit.py, to make sure
they have the same API.

Existing GRU implementation in GRULayer is not exactly the same as
TensorFlow implementation (from original paper), and is modified in
this PR.

@ahtsan ahtsan requested review from CatherineSue and ryanjulian May 17, 2019 23:28
@ahtsan ahtsan requested a review from a team as a code owner May 17, 2019 23:28
@ahtsan ahtsan requested a review from nish21 May 17, 2019 23:30
Comment thread tests/garage/tf/core/test_gru.py Outdated
(1, 3, 1, 0.5, 0.5), # yapf: disable
(3, 1, 1, 0.5, 0.5), # yapf: disable
(3, 3, 1, 0.5, 0.5), # yapf: disable
(3, 3, 3, 0.5, 0.5)) # yapf: disable
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think you should be able to just put one yapf_disable right after the @params statement and it should ignore the whole block.

if it doesn't, you can use this form

# yapf: disable
stuff()
more_stuff()
# yapf: enable

Comment thread tests/garage/tf/core/test_gru.py Outdated
obs_inputs = np.full((self.batch_size, time_step, input_dim), 1.)
obs_input = np.full((self.batch_size, input_dim), 1.)

_input_var = tf.placeholder(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the _private?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure. I will remove the _

@codecov
Copy link
Copy Markdown

codecov Bot commented May 18, 2019

Codecov Report

Merging #668 into master will increase coverage by 0.52%.
The diff coverage is 97.43%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #668      +/-   ##
==========================================
+ Coverage   62.73%   63.26%   +0.52%     
==========================================
  Files         164      167       +3     
  Lines        9572     9715     +143     
  Branches     1247     1256       +9     
==========================================
+ Hits         6005     6146     +141     
+ Misses       3263     3254       -9     
- Partials      304      315      +11
Impacted Files Coverage Δ
garage/tf/policies/categorical_lstm_policy.py 79.66% <ø> (ø) ⬆️
garage/tf/core/lstm.py 100% <ø> (ø) ⬆️
garage/experiment/local_tf_runner.py 88.97% <ø> (ø) ⬆️
garage/tf/core/network.py 64.28% <ø> (ø) ⬆️
garage/tf/models/__init__.py 100% <100%> (ø) ⬆️
garage/tf/core/gru.py 100% <100%> (ø)
garage/tf/policies/__init__.py 100% <100%> (ø) ⬆️
garage/tf/models/gru_model.py 100% <100%> (ø)
garage/tf/core/layers.py 51.1% <50%> (ø) ⬆️
garage/tf/policies/categorical_gru_policy.py 79.82% <90%> (-0.18%) ⬇️
... and 12 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update a5277f4...6cfa668. Read the comment docs.

Comment thread tests/garage/tf/core/test_gru.py Outdated
input_var = tf.placeholder(
tf.float32, shape=(None, None, input_dim), name='input')
step_input_var = tf.placeholder(
tf.float32, shape=(None, input_dim), name='input')
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do the two placeholders have the same name?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will fix

Comment thread garage/tf/core/layers.py
Comment thread garage/tf/core/gru.py Outdated
name (str): Name of the variable scope.
gru_cell (tf.keras.layers.Layer): GRU cell used to generate
outputs.
all_input_var (tf.Tensor): Place holder for entire time-seried inputs.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

time-series

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will fix

Comment thread garage/tf/core/gru.py Outdated
hidden state is trainable.

Return:
outputs (tf.Tensor): Entire time-seried outputs.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean time-series? please ignore if time-seried is a word.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will fix

assert full_output.shape == (self.batch_size, time_step, output_dim)

# yapf: disable
@params(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a test of variable length inputs to GRU and LSTM? It seems we don't have any tests using off-policy algorithms with recurrent policies. And since DQN doesn't set a max_path_length, I assume the episodes would have variable lengths. For recurrent policies, we should pad them. Do we pad it somewhere? Please point it to me if we do.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we pad the inputs, then we don't have to worry about variable length inputs?

Copy link
Copy Markdown
Member

@CatherineSue CatherineSue May 21, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry just saw this.
I think I might have used an incorrect example. It is uncommon to use LSTMNetwork in DQN. Besides, during training, the shape of samples (from replay buffer) would be (sample_size, 1, obs_dim). During sampling, the shape would be (1, time_steps, obs_dim).
If we pad them, we just need to indicate that <pad> is a condition to stop the traverse. tf.while_loop or tf.scan(what we are using) should have an argument to pass the condition (please check, I am not familiar with tf.scan).

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this problem seems to beyond the pr's scope. If you think we should address this later or we don't have the situation to deal with variable-length inputs, feel free to ignore this issue.

@ahtsan ahtsan merged commit 39b16ce into master May 22, 2019
@ahtsan ahtsan deleted the gru branch May 22, 2019 16:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants