Skip to content

CategoricalConvPolicy with model#618

Merged
ahtsan merged 9 commits intomasterfrom
categorical_conv_policy_with_model
Apr 25, 2019
Merged

CategoricalConvPolicy with model#618
ahtsan merged 9 commits intomasterfrom
categorical_conv_policy_with_model

Conversation

@ahtsan
Copy link
Copy Markdown
Contributor

@ahtsan ahtsan commented Apr 11, 2019

Time to work on CNN. This PR does the following:

  • fix some small issue in cnn and removed the dense layer from it. Instead of having a "CNNnMLP", it's more reasonable to just do CNN -> MLP.
  • Introduce the notion of self.add_model() and self.build_models() in policy. This is needed for stacking multiple models. In the future, when we eventually also make policy is-a-model, we will also put this notion into the rest of the models.
  • and of course, categorical_conv_policy with model, and tests.

@ahtsan ahtsan requested a review from a team as a code owner April 11, 2019 22:02
@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 11, 2019

Codecov Report

Merging #618 into master will increase coverage by 0.55%.
The diff coverage is 90.47%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #618      +/-   ##
==========================================
+ Coverage   60.63%   61.18%   +0.55%     
==========================================
  Files         156      159       +3     
  Lines        9069     9159      +90     
  Branches     1241     1242       +1     
==========================================
+ Hits         5499     5604     +105     
+ Misses       3260     3238      -22     
- Partials      310      317       +7
Impacted Files Coverage Δ
...e/tf/policies/categorical_mlp_policy_with_model.py 95.91% <ø> (ø) ⬆️
garage/tf/core/parameter.py 100% <ø> (ø) ⬆️
...tf/regressors/gaussian_mlp_regressor_with_model.py 100% <ø> (ø) ⬆️
...tf/policies/deterministic_mlp_policy_with_model.py 92.85% <ø> (ø) ⬆️
...rage/tf/policies/gaussian_mlp_policy_with_model.py 100% <ø> (ø) ⬆️
garage/tf/policies/discrete_qf_derived_policy.py 96% <ø> (ø) ⬆️
garage/tf/core/mlp.py 100% <ø> (ø) ⬆️
garage/tf/q_functions/discrete_mlp_q_function.py 42.1% <0%> (ø) ⬆️
garage/tf/policies/categorical_mlp_policy.py 96% <100%> (ø) ⬆️
garage/tf/models/__init__.py 100% <100%> (ø) ⬆️
... and 23 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 5b8cbc6...a444a91. Read the comment docs.

Comment thread garage/tf/policies/categorical_conv_policy_with_model.py
Comment thread garage/tf/policies/base2.py Outdated
out = input_var
for model in self._models[:-1]:
out = model.build(out, name=name)
self.model = self._models[-1]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you remind me why the last model is self.model?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually a temp-fix from my previous implementation. We don't need self.model anymore, as they now become models. Nice catch.

Comment thread garage/tf/core/cnn.py Outdated
"""
strides = [1, stride, stride, 1]

if padding not in ['SAME', 'VALID']:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think TensorFlow would also throw a ValueError. Any reason you want to raise the error here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was not aware of that. Then I think it's fine to let TensorFlow handle it.

Comment thread garage/tf/models/cnn_model.py Outdated
CNN Model.

Args:
filter_dims: Dimension of the filters.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add types to these parameters

e.g. filter_dims (tuple[int]): ...

Copy link
Copy Markdown
Member

@ryanjulian ryanjulian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This LGTM mostly. Other than basic comments, please resolve:

  • Whether the inputs/outputs API should be on Policy or Model
  • The question of the add_model API for policy -- can't this just be a simple derived model class instead?

Comment thread garage/tf/policies/base2.py Outdated
Comment thread garage/tf/policies/base2.py Outdated

def build_models(self, input_var, name=None):
out = input_var
for model in self._models:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if my models are not sequential?

Perhaps instead you can define a model class Sequential

class Sequential(Model):
    
    def __init__(self, *models):
        self._models = models

    def _build(self, input_var, name=None):
        out = input_var
        for model in self._models:
            out = model.build(out, name=name)
        
        return out

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good idea.

Comment thread garage/tf/policies/base2.py Outdated
return self._models[0].networks['default'].input

@property
def outputs(self):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe this should just an an API on Model instead?

Copy link
Copy Markdown
Contributor Author

@ahtsan ahtsan Apr 12, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. We should add this into Model. Something like

class Model(...):
    ...
    @property
    def input(self):
        return self.networks['default'].input

    @property
    def output(self):
        return self.networks['default'].output

and for the Sequential model, we can override as

class Sequential(Model):
    ...
    @property
    def input(self):
        return self._models[0].networks['default'].input

    @property
    def output(self):
        return self._models[-1].networks['default'].output

It only works with akro.tf.Discrete action space.

Args:
env_spec: Environment specification.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please include types in docstrings

@overrides
def get_action(self, observation):
"""Return a single action."""
flat_obs = self.observation_space.flatten(observation)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this flatten the 2D image, or only the batch?

Copy link
Copy Markdown
Contributor Author

@ahtsan ahtsan Apr 12, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently it flatten the 2D image, since we are doing self.obs_dim = env_spec.observation_space.flat_dim in the policy.

I actually think this is a mistake, it should not flatten the observation here. For example, in pixel environment we want to pass the original image input with shape (w, h, c) to the policy. Therefore, we should do self.obs_dim = env_spec.observation_space.shape instead.

This is missed because the CNNModel was mocked out.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cam you fix it?
i don't think the image should be flattened. that seems wrong.

Comment thread garage/tf/core/cnn.py
of intermediate dense layer(s).
hidden_b_init: Initializer function for the bias
of intermediate dense layer(s).
output_nonlinearity: Activation function for
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please take a moment to add types to this docstring

Comment thread garage/tf/core/cnn.py Outdated
pool_stride: The stride of the pooling layer(s).
pool_shapes: Dimension of the pooling layer(s).
pool_strides: The strides of the pooling layer(s).
padding: The type of padding algorithm to use, from "SAME", "VALID".
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

single quotes

Comment thread garage/tf/models/cnn_model.py Outdated
num_filters=self._num_filters,
strides=self._strides,
padding=self._padding,
name="cnn")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

single quotes

Comment thread garage/tf/policies/__init__.py Outdated
__all__ = [
"Policy",
"StochasticPolicy",
"CategoricalConvPolicy",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please take a moment to replace all these with single quotes (when you visit a file)

Comment thread garage/tf/core/cnn.py Outdated
hidden_w_init=tf.glorot_uniform_initializer(),
hidden_b_init=tf.zeros_initializer()):
"""
CNN model. Based on 'NHWC' data format: [batch, height, width, channel].
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think it's a model

Copy link
Copy Markdown
Member

@ryanjulian ryanjulian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. please submit once docstrings are updated. see my suggestion about how to name inner models in a Sequential

Comment thread garage/tf/models/sequential.py Outdated
Sequential Model.

Args:
name: Variable scope of the Sequential model.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

types, please

Comment thread garage/tf/models/sequential.py
Comment thread garage/tf/core/cnn.py Outdated
CNN. Based on 'NHWC' data format: [batch, height, width, channel].

Args:
input_var: Input tf.Tensor to the CNN.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

types

Comment thread garage/tf/core/cnn.py Outdated
CNN model. Based on 'NHWC' data format: [batch, height, width, channel].

Args:
input_var: Input tf.Tensor to the CNN.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still missing types

Comment thread garage/tf/core/cnn.py Outdated
pool_strides(tuple[int]): The strides of the pooling layer(s). For
example, (2, 2) means that all the pooling layers have
strides (2, 2).
padding: The type of padding algorithm to use,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please provide a type

Comment thread garage/tf/models/base.py Outdated
inputs: Tensor input(s), recommended to be position arguments, e.g.
def build(self, state_input=None, action_input=None, name=None).
It would be usually same as the inputs in build().
name: Variable scope of the inner model, if exist.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please update the docstrings with types

Comment thread garage/tf/models/base.py
def inputs(self):
return self.networks['default'].inputs

@property
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should document these new properties with docstrings

Comment thread garage/tf/core/mlp.py
Comment thread garage/tf/models/cnn_model.py Outdated
strides(tuple[int]): The stride of the sliding window. For example,
(1, 2) means there are two convolutional layers. The stride of the
filter for first layer is 1 and that of the second layer is 2.
name: Variable scope of the cnn model.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please provide a type for every parameter

return ['sample', 'mean', 'log_std', 'std_param', 'dist']

def _build(self, state_input):
def _build(self, state_input, name=None):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please take a moment to update the docstrings here with types.

self._layer_normalization = layer_normalization

def _build(self, state_input):
def _build(self, state_input, name=None):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please take a moment to update the docstring here with types

self._name = name
self._env_spec = env_spec
self._variable_scope = tf.VariableScope(reuse=False, name=name)
self._models = []
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please take a moment to make these docstrings complete


@overrides
def dist_info_sym(self, obs_var, state_info_vars=None, name=None):
"""Symbolic graph of the distribution."""
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please provide full docstrings for all methods (unless the parent class provides a docstring which is equivalent)

hidden_nonlinearity=hidden_nonlinearity,
output_nonlinearity=output_nonlinearity,
layer_normalization=layer_normalization)
layer_normalization=layer_normalization,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please take a moment to add types to these docstrings

self.model = MLPModel(
output_dim=action_dim,
name=name,
name='MLPModel',
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please take a moment to add types to these docstrings

std_output_nonlinearity=std_output_nonlinearity,
std_parameterization=std_parameterization,
layer_normalization=layer_normalization)
layer_normalization=layer_normalization,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please take a moment to add types to these docstrings

]

def _build(self, state_input):
def _build(self, state_input, name=None):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please take a moment to add types to these docstrings

Comment thread garage/tf/core/mlp.py Outdated
For example, (32, 32) means this MLP consists of two
hidden layers, each with 32 hidden units.
name (str): Network name, also the variable scope.
hidden_nonlinearity: Activation function for
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about these types?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are functions. Make it hidden_nonlinearity(function)?

Comment thread garage/tf/models/cnn_model.py Outdated
name (str): Model name, also the variable scope.
padding (str): The type of padding algorithm to use,
either 'SAME' or 'VALID'.
hidden_nonlinearity: Activation function for
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about these types?

Comment thread garage/tf/models/sequential.py Outdated
Args:
name: Variable scope of the Sequential model.
name (str): Model name, also the variable scope.
models (list[garage.Model]): The models to be connected
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

garage.tf.models.Model?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yes. It only takes garage.tf.models.Model.

hidden_sizes (list[int]): Output dimension of dense layer(s).
For example, (32, 32) means the MLP of this policy consists
of two hidden layers, each with 32 hidden units.
hidden_nonlinearity: Activation function for
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are tf.Operation right?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think they are tf.Operation. In python doing type(tf.nn.relu) returns <class 'function'>.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Member

@ryanjulian ryanjulian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking great. Thanks for updating the docstrings. I think activation functions are just tf.Tensor. Feel free to submit once they are all cleared up.

For example, (32, 32) means the MLP of this policy consists
of two hidden layers, each with 32 hidden units.
hidden_nonlinearity: Activation function for
intermediate dense layer(s).
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indent

For example, (32, 32) means the MLP of this policy consists of two
hidden layers, each with 32 hidden units.
hidden_nonlinearity: Activation function for
intermediate dense layer(s).
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indent

@ahtsan ahtsan force-pushed the categorical_conv_policy_with_model branch from a963ba1 to 4d51f86 Compare April 25, 2019 00:45
@ahtsan ahtsan force-pushed the categorical_conv_policy_with_model branch from caafe45 to a444a91 Compare April 25, 2019 04:28
@ahtsan ahtsan merged commit f11c494 into master Apr 25, 2019
@ahtsan ahtsan deleted the categorical_conv_policy_with_model branch April 25, 2019 06:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants