From b75e5baa97bf2ae791161602d8079d3fd3d71b0e Mon Sep 17 00:00:00 2001 From: SRK Reddy Date: Wed, 27 Jun 2018 23:58:31 +0530 Subject: [PATCH 1/7] [AVG POOL] Asymmetric padding (SAME) support. --- nnvm/include/nnvm/top/nn.h | 3 +- nnvm/python/nnvm/frontend/tensorflow.py | 33 +++++++---- nnvm/src/top/nn/pooling.cc | 30 +++++++--- .../frontend/tensorflow/test_forward.py | 59 +++++++++---------- topi/include/topi/nn/pooling.h | 45 ++++++++++---- 5 files changed, 104 insertions(+), 66 deletions(-) diff --git a/nnvm/include/nnvm/top/nn.h b/nnvm/include/nnvm/top/nn.h index f37811315e43..693678ae657b 100644 --- a/nnvm/include/nnvm/top/nn.h +++ b/nnvm/include/nnvm/top/nn.h @@ -266,7 +266,8 @@ struct AvgPool2DParam : public dmlc::Parameter { .describe("Specifies the strides of the convolution."); DMLC_DECLARE_FIELD(padding).set_default(TShape({0, 0})) .describe("If padding is non-zero, then the input is implicitly zero-padded" - "on both sides for padding number of points"); + "on both sides for padding number of points" + "Supports asymmetric padding as (pad_top, pad_left, pad_bottom, pad_height)"); DMLC_DECLARE_FIELD(layout).set_default("NCHW") .describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc." "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py index 71536517810d..3b0e9c7aed1b 100644 --- a/nnvm/python/nnvm/frontend/tensorflow.py +++ b/nnvm/python/nnvm/frontend/tensorflow.py @@ -137,22 +137,31 @@ def _impl(inputs, attr, params): pad_v = _get_pad_pair(in_h, kernel_h, stride_h) pad_h = _get_pad_pair(in_w, kernel_w, stride_w) - if attr['data_format'] == 'NHWC': - inputs[0] = _sym.pad(data=inputs[0], - pad_width=((0, 0), - (pad_v[0], pad_v[1]), - (pad_h[0], pad_h[1]), - (0, 0))) + pad_val = float("-inf") if name == "max_pool" else 0 + if name == 'max_pool': + if attr['data_format'] == 'NHWC': + inputs[0] = _sym.pad(data=inputs[0], + pad_width=((0, 0), + (pad_v[0], pad_v[1]), + (pad_h[0], pad_h[1]), + (0, 0)), + pad_value=pad_val) + else: + inputs[0] = _sym.pad(data=inputs[0], + pad_width=((0, 0), + (0, 0), + (pad_v[0], pad_v[1]), + (pad_h[0], pad_h[1])), + pad_value=pad_val) + attr['padding'] = [0, 0] else: - inputs[0] = _sym.pad(data=inputs[0], - pad_width=((0, 0), - (0, 0), - (pad_v[0], pad_v[1]), - (pad_h[0], pad_h[1]))) - attr['padding'] = [0, 0] + attr['padding'] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]] else: raise TypeError("Unsupported padding type : {}".format(attr['padding'])) + if name == "avg_pool": + attr['count_include_pad'] = False + return AttrCvt( op_name=_dimension_picker(name), transforms={ diff --git a/nnvm/src/top/nn/pooling.cc b/nnvm/src/top/nn/pooling.cc index 54eb0d7db3be..08a62866dee1 100644 --- a/nnvm/src/top/nn/pooling.cc +++ b/nnvm/src/top/nn/pooling.cc @@ -1,3 +1,4 @@ + /*! * Copyright (c) 2017 by Contributors * \file pooling.cc @@ -44,23 +45,32 @@ inline bool Pool2DInferShape(const nnvm::NodeAttrs& attrs, const auto hidx = layout.indexof('H'); const auto widx = layout.indexof('W'); + TShape pad = param.padding; + if (param.padding.ndim() == 4) { + pad[0] += pad[2]; + pad[1] += pad[3]; + } else { + pad[0] *= 2; + pad[1] *= 2; + } + TShape oshape = dshape; - CHECK(param.pool_size[0] <= dshape[hidx] + 2 * param.padding[0]) + CHECK(param.pool_size[0] <= dshape[hidx] + pad[0]) << "pool size (" << param.pool_size[0] << ") exceeds input (" << dshape[hidx] - << " padded to " << (dshape[hidx] + 2*param.padding[0]) << ")"; - CHECK(param.pool_size[1] <= dshape[widx] + 2 * param.padding[1]) + << " padded to " << (dshape[hidx] + pad[0]) << ")"; + CHECK(param.pool_size[1] <= dshape[widx] + pad[1]) << "pool size (" << param.pool_size[1] << ") exceeds input (" << dshape[widx] - << " padded to " << (dshape[widx] + 2*param.padding[1]) << ")"; + << " padded to " << (dshape[widx] + pad[1]) << ")"; if (!param.ceil_mode) { - oshape[hidx] = ((dshape[hidx] + 2 * param.padding[0] - param.pool_size[0]) / + oshape[hidx] = ((dshape[hidx] + pad[0] - param.pool_size[0]) / param.strides[0]) + 1; - oshape[widx] = ((dshape[widx] + 2 * param.padding[1] - param.pool_size[1]) / + oshape[widx] = ((dshape[widx] + pad[1] - param.pool_size[1]) / param.strides[1]) + 1; } else { - oshape[hidx] = ((dshape[hidx] + 2 * param.padding[0] - param.pool_size[0] + + oshape[hidx] = ((dshape[hidx] + pad[0] - param.pool_size[0] + param.strides[0] - 1) / param.strides[0]) + 1; - oshape[widx] = ((dshape[3] + 2 * param.padding[1] - param.pool_size[1] + + oshape[widx] = ((dshape[3] + pad[1] - param.pool_size[1] + param.strides[1] - 1) / param.strides[1]) + 1; } NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape); @@ -108,8 +118,12 @@ NNVM_REGISTER_OP(max_pool2d) (batch_size, channels, out_height, out_width) if `layout` is `NCHW`. out_height and out_width are calculated as:: + for symetric padding: out_height = floor((height+2*padding[0]-pool_size[0])/strides[0])+1 out_width = floor((width+2*padding[1]-pool_size[1])/strides[1])+1 + for asymmetric padding: + out_height = floor((height+padding[0]+padding[2]-pool_size[0])/strides[0])+1 + out_width = floor((width+padding[1]+padding[3]-pool_size[1])/strides[1])+1 When `ceil_mode` is `True`, ceil will be used instead of floor in this equation. diff --git a/nnvm/tests/python/frontend/tensorflow/test_forward.py b/nnvm/tests/python/frontend/tensorflow/test_forward.py index 6dc8cfab2ab4..fcd866d062ea 100644 --- a/nnvm/tests/python/frontend/tensorflow/test_forward.py +++ b/nnvm/tests/python/frontend/tensorflow/test_forward.py @@ -108,32 +108,32 @@ def _test_pooling(input_shape, **kwargs): def test_forward_pooling(): """ Pooling """ - - _test_pooling(input_shape=[2, 9, 10, 2], - window_shape=[1, 1], - padding='SAME', - pooling_type='MAX', - dilation_rate=[1, 1], - strides=[1, 1]) - - _test_pooling(input_shape=[2, 9, 10, 2], - window_shape=[1, 1], - padding='SAME', - pooling_type='AVG', - dilation_rate=[1, 1], - strides=[1, 1]) - _test_pooling(input_shape=[2, 10, 9, 2], - window_shape=[1, 1], - padding='SAME', - pooling_type='MAX', - dilation_rate=[1, 1], - strides=[1, 1]) - _test_pooling(input_shape=[2, 10, 9, 2], - window_shape=[1, 1], - padding='SAME', - pooling_type='AVG', - dilation_rate=[1, 1], - strides=[1, 1]) + for padding in ["SAME", "VALID"]: + for pooling_type in ["MAX", "AVG"]: + for input_shape in [[2, 9, 10, 2], [2, 10, 9, 2]]: + for window_shape in [[1, 1], [2, 1], [2, 3]]: + if padding != "SAME": + #for dilation_rate in [[1, 1], [2, 1], [1, 2], [2, 3]]: + for dilation_rate in [[1, 1]]: + with tf.Graph().as_default(): + _test_pooling( + input_shape=input_shape, + window_shape=window_shape, + padding=padding, + pooling_type=pooling_type, + dilation_rate=dilation_rate, + strides=[1, 1]) + for strides in [[1, 1], [2, 1], [1, 2], [2, 3]]: + if np.any(np.array(strides) > window_shape): + continue + with tf.Graph().as_default(): + _test_pooling( + input_shape=input_shape, + window_shape=window_shape, + padding=padding, + pooling_type=pooling_type, + dilation_rate=[1, 1], + strides=strides) ####################################################################### # Convolution @@ -382,12 +382,7 @@ def test_forward_inception_v3(): top_tvm = np.squeeze(tvm_output).argsort()[-3:][::-1] top_tf = np.squeeze(tf_output).argsort()[-3:][::-1] - # TVM implementation of SAME padding some times make a slight deviation. - # Hence check for top predictions. - top_tvm = np.sort(top_tvm) - top_tf = np.sort(top_tf) - - np.testing.assert_allclose(top_tf, top_tvm) + np.testing.assert_allclose(top_tf, top_tvm, rtol=1e-5, atol=1e-5) ####################################################################### # Inception V1 diff --git a/topi/include/topi/nn/pooling.h b/topi/include/topi/nn/pooling.h index 20a7ffe975bf..3a25694f36ca 100644 --- a/topi/include/topi/nn/pooling.h +++ b/topi/include/topi/nn/pooling.h @@ -52,29 +52,45 @@ inline Tensor pool_impl(const Tensor& x, CHECK(x->shape.size() >= 2) << "Pooling input must >= 2-D (H, W)"; CHECK_EQ(kernel_size.size(), 2) << "Pooling kernel_size must have 2 elements"; CHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements"; - CHECK_EQ(padding_size.size(), 2) << "Pooling padding_size must have 2 elements"; + CHECK((padding_size.size() == 2) || + (padding_size.size() == 4)) << "Pooling padding_size must have 2 or 4 elements"; auto kernel_height = kernel_size[0]; auto kernel_width = kernel_size[1]; auto stride_height = stride_size[0]; auto stride_width = stride_size[1]; - auto padding_height = padding_size[0]; - auto padding_width = padding_size[1]; + auto padding_top = padding_size[0]; + auto padding_left = padding_size[1]; + auto padding_bottom = padding_size[0]; + auto padding_right = padding_size[1]; auto height = x->shape[height_axis]; auto width = x->shape[width_axis]; - auto pad_tuple = detail::GetPadTuple(padding_height, padding_width); + auto pad_tuple = detail::GetPadTuple(padding_top, padding_left); auto pad_top = pad_tuple[0]; auto pad_left = pad_tuple[1]; - auto pad_down = pad_tuple[2]; + auto pad_bottom = pad_tuple[2]; auto pad_right = pad_tuple[3]; + if (padding_size.size() == 4) { + padding_bottom = padding_size[2]; + padding_right = padding_size[3]; + + pad_top = padding_size[0]; + pad_left = padding_size[1]; + pad_bottom = padding_size[2]; + pad_right = padding_size[3]; + } + if (ceil_mode) { // Additional padding to ensure we do ceil instead of floor when // dividing by stride. - pad_down += stride_height - 1; + pad_bottom += stride_height - 1; pad_right += stride_width - 1; + + padding_bottom += stride_height - 1; + padding_right += stride_height - 1; } Array pad_before(std::vector(x->shape.size(), 0)); @@ -82,11 +98,11 @@ inline Tensor pool_impl(const Tensor& x, pad_before.Set(width_axis, pad_left); Array pad_after(std::vector(x->shape.size(), 0)); - pad_after.Set(height_axis, pad_down); + pad_after.Set(height_axis, pad_bottom); pad_after.Set(width_axis, pad_right); auto out_height = tvm::ir::Simplify( - (height - kernel_height + pad_top + pad_down) / stride_height + 1); + (height - kernel_height + pad_top + pad_bottom) / stride_height + 1); auto out_width = tvm::ir::Simplify( (width - kernel_width + pad_left + pad_right) / stride_width + 1); @@ -97,9 +113,12 @@ inline Tensor pool_impl(const Tensor& x, out_shape.Set(height_axis, out_height); out_shape.Set(width_axis, out_width); - const int64_t *padding_h = HalideIR::Internal::as_const_int(padding_height); - const int64_t *padding_w = HalideIR::Internal::as_const_int(padding_width); - const bool do_pad = ((padding_h && *padding_h) || (padding_w && *padding_w)); + const int64_t *padding_h0 = HalideIR::Internal::as_const_int(padding_top); + const int64_t *padding_w0 = HalideIR::Internal::as_const_int(padding_left); + const int64_t *padding_h1 = HalideIR::Internal::as_const_int(padding_bottom); + const int64_t *padding_w1 = HalideIR::Internal::as_const_int(padding_right); + const bool do_pad = ((padding_h0 && *padding_h0) || (padding_w0 && *padding_w0)) || + ((padding_h1 && *padding_h1) || (padding_w1 && *padding_w1)); if (pool_type == kMaxPool) { auto temp = do_pad ? pad(x, pad_before, pad_after, x->dtype.min(), "pad_temp") : x; @@ -125,8 +144,8 @@ inline Tensor pool_impl(const Tensor& x, if (count_include_pad) { return tsum(output) / (kernel_height * kernel_width); } else { - Expr h_start = output[height_axis] * stride_height - padding_height; - Expr w_start = output[width_axis] * stride_width - padding_width; + Expr h_start = output[height_axis] * stride_height - padding_top; + Expr w_start = output[width_axis] * stride_width - padding_left; Expr h_end = ir::Min::make(h_start + kernel_height, height); Expr w_end = ir::Min::make(w_start + kernel_width, width); h_start = ir::Max::make(h_start, make_const(Int(32), 0)); From 690ca5586e73e35561cd999d5ff7606480c1ac89 Mon Sep 17 00:00:00 2001 From: SRK Reddy Date: Mon, 2 Jul 2018 20:40:56 +0530 Subject: [PATCH 2/7] [AVG POOL] typo fix --- nnvm/include/nnvm/top/nn.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nnvm/include/nnvm/top/nn.h b/nnvm/include/nnvm/top/nn.h index 693678ae657b..34e5769dbed4 100644 --- a/nnvm/include/nnvm/top/nn.h +++ b/nnvm/include/nnvm/top/nn.h @@ -267,7 +267,7 @@ struct AvgPool2DParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(padding).set_default(TShape({0, 0})) .describe("If padding is non-zero, then the input is implicitly zero-padded" "on both sides for padding number of points" - "Supports asymmetric padding as (pad_top, pad_left, pad_bottom, pad_height)"); + "Supports asymmetric padding as (pad_top, pad_left, pad_bottom, pad_right)"); DMLC_DECLARE_FIELD(layout).set_default("NCHW") .describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc." "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" From 652f77899959bc57e5c04087f67b6c91517af187 Mon Sep 17 00:00:00 2001 From: SRK Reddy Date: Tue, 3 Jul 2018 10:37:52 +0530 Subject: [PATCH 3/7] [POOL] Below changes as per the discussion. * topi : always accespt 4 padding values as (top, left, bottom, right) * nnvm accepts 1 / 2 / 4 padding inputs similar to keras. * Support both Max and Avg. --- nnvm/include/nnvm/top/nn.h | 15 ++++--- nnvm/src/top/nn/pooling.cc | 48 ++++++++++++++++++---- topi/include/topi/nn/pooling.h | 41 +++++------------- topi/python/topi/nn/pooling.py | 4 +- topi/tests/python/test_topi_pooling.py | 37 +++++++++-------- topi/tests/python_cpp/test_topi_pooling.py | 36 ++++++++-------- 6 files changed, 102 insertions(+), 79 deletions(-) diff --git a/nnvm/include/nnvm/top/nn.h b/nnvm/include/nnvm/top/nn.h index 34e5769dbed4..ada0aa4209e1 100644 --- a/nnvm/include/nnvm/top/nn.h +++ b/nnvm/include/nnvm/top/nn.h @@ -237,9 +237,12 @@ struct MaxPool2DParam : public dmlc::Parameter { .describe("Size of the pooling windows.."); DMLC_DECLARE_FIELD(strides).set_default(TShape({1, 1})) .describe("Specifies the strides of the convolution."); - DMLC_DECLARE_FIELD(padding).set_default(TShape({0, 0})) + DMLC_DECLARE_FIELD(padding).set_default(TShape({0, 0, 0, 0})) .describe("If padding is non-zero, then the input is implicitly zero-padded" - "on both sides for padding number of points"); + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding just as given"); DMLC_DECLARE_FIELD(layout).set_default("NCHW") .describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc." "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" @@ -264,10 +267,12 @@ struct AvgPool2DParam : public dmlc::Parameter { .describe("Size of the pooling windows.."); DMLC_DECLARE_FIELD(strides).set_default(TShape({1, 1})) .describe("Specifies the strides of the convolution."); - DMLC_DECLARE_FIELD(padding).set_default(TShape({0, 0})) + DMLC_DECLARE_FIELD(padding).set_default(TShape({0, 0, 0, 0})) .describe("If padding is non-zero, then the input is implicitly zero-padded" - "on both sides for padding number of points" - "Supports asymmetric padding as (pad_top, pad_left, pad_bottom, pad_right)"); + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding just as given"); DMLC_DECLARE_FIELD(layout).set_default("NCHW") .describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc." "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" diff --git a/nnvm/src/top/nn/pooling.cc b/nnvm/src/top/nn/pooling.cc index 08a62866dee1..82cb752128db 100644 --- a/nnvm/src/top/nn/pooling.cc +++ b/nnvm/src/top/nn/pooling.cc @@ -46,12 +46,18 @@ inline bool Pool2DInferShape(const nnvm::NodeAttrs& attrs, const auto widx = layout.indexof('W'); TShape pad = param.padding; - if (param.padding.ndim() == 4) { + if (param.padding.ndim() == 1) { + pad[1] = pad[0]; + pad[2] = pad[0]; + pad[3] = pad[0]; + } else if (param.padding.ndim() == 2) { + pad[0] *= 2; + pad[1] *= 2; + } else if (param.padding.ndim() == 4) { pad[0] += pad[2]; pad[1] += pad[3]; } else { - pad[0] *= 2; - pad[1] *= 2; + return false; } TShape oshape = dshape; @@ -118,13 +124,14 @@ NNVM_REGISTER_OP(max_pool2d) (batch_size, channels, out_height, out_width) if `layout` is `NCHW`. out_height and out_width are calculated as:: - for symetric padding: - out_height = floor((height+2*padding[0]-pool_size[0])/strides[0])+1 - out_width = floor((width+2*padding[1]-pool_size[1])/strides[1])+1 - for asymmetric padding: out_height = floor((height+padding[0]+padding[2]-pool_size[0])/strides[0])+1 out_width = floor((width+padding[1]+padding[3]-pool_size[1])/strides[1])+1 + where padding will be an expanded array based on number of values passed as:: + one int : all sides same padding used. + two int : bottom, right use same as top and left. + four int: same as passed values. + When `ceil_mode` is `True`, ceil will be used instead of floor in this equation. @@ -157,6 +164,15 @@ NNVM_REGISTER_OP(max_pool2d) << "Pool2D only support 4-D input (e.g., NCHW)" << " or 5-D input (last dimension is a split of channel)"; + if (param.padding.ndim() == 1) { + padding.push_back(padding[0]); + padding.push_back(padding[0]); + padding.push_back(padding[0]); + } else if (param.padding.ndim() == 2) { + padding.push_back(padding[0]); + padding.push_back(padding[1]); + } + return Array{ topi::nn::pool(inputs[0], pool_size, strides, padding, topi::nn::kMaxPool, ceil_mode, layout.name())}; @@ -196,8 +212,13 @@ NNVM_REGISTER_OP(avg_pool2d) (batch_size, channels, out_height, out_width) if `layout` is `NCHW`. out_height and out_width are calculated as:: - out_height = floor((height+2*padding[0]-pool_size[0])/strides[0])+1 - out_width = floor((width+2*padding[1]-pool_size[1])/strides[1])+1 + out_height = floor((height+padding[0]+padding[2]-pool_size[0])/strides[0])+1 + out_width = floor((width+padding[1]+padding[3]-pool_size[1])/strides[1])+1 + + where padding will be an expanded array based on number of values passed as:: + one int : all sides same padding used. + two int : bottom, right use same as top and left. + four int: same as passed values. When `ceil_mode` is `True`, ceil will be used instead of floor in this equation. @@ -230,6 +251,15 @@ NNVM_REGISTER_OP(avg_pool2d) << "Pool2D only support 4-D input (e.g., NCHW)" << " or 5-D input (last dimension is a split of channel)"; + if (param.padding.ndim() == 1) { + padding.push_back(padding[0]); + padding.push_back(padding[0]); + padding.push_back(padding[0]); + } else if (param.padding.ndim() == 2) { + padding.push_back(padding[0]); + padding.push_back(padding[1]); + } + return Array{ topi::nn::pool(inputs[0], pool_size, strides, padding, topi::nn::kAvgPool, ceil_mode, layout.name(), count_include_pad)}; diff --git a/topi/include/topi/nn/pooling.h b/topi/include/topi/nn/pooling.h index 3a25694f36ca..34a00dee1947 100644 --- a/topi/include/topi/nn/pooling.h +++ b/topi/include/topi/nn/pooling.h @@ -52,45 +52,26 @@ inline Tensor pool_impl(const Tensor& x, CHECK(x->shape.size() >= 2) << "Pooling input must >= 2-D (H, W)"; CHECK_EQ(kernel_size.size(), 2) << "Pooling kernel_size must have 2 elements"; CHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements"; - CHECK((padding_size.size() == 2) || - (padding_size.size() == 4)) << "Pooling padding_size must have 2 or 4 elements"; + CHECK((padding_size.size() == 4)) << "Pooling padding_size must have 4 elements"; auto kernel_height = kernel_size[0]; auto kernel_width = kernel_size[1]; auto stride_height = stride_size[0]; auto stride_width = stride_size[1]; - auto padding_top = padding_size[0]; - auto padding_left = padding_size[1]; - auto padding_bottom = padding_size[0]; - auto padding_right = padding_size[1]; auto height = x->shape[height_axis]; auto width = x->shape[width_axis]; - auto pad_tuple = detail::GetPadTuple(padding_top, padding_left); - auto pad_top = pad_tuple[0]; - auto pad_left = pad_tuple[1]; - auto pad_bottom = pad_tuple[2]; - auto pad_right = pad_tuple[3]; - - if (padding_size.size() == 4) { - padding_bottom = padding_size[2]; - padding_right = padding_size[3]; - - pad_top = padding_size[0]; - pad_left = padding_size[1]; - pad_bottom = padding_size[2]; - pad_right = padding_size[3]; - } + auto pad_top = padding_size[0]; + auto pad_left = padding_size[1]; + auto pad_bottom = padding_size[2]; + auto pad_right = padding_size[3]; if (ceil_mode) { // Additional padding to ensure we do ceil instead of floor when // dividing by stride. pad_bottom += stride_height - 1; pad_right += stride_width - 1; - - padding_bottom += stride_height - 1; - padding_right += stride_height - 1; } Array pad_before(std::vector(x->shape.size(), 0)); @@ -113,10 +94,10 @@ inline Tensor pool_impl(const Tensor& x, out_shape.Set(height_axis, out_height); out_shape.Set(width_axis, out_width); - const int64_t *padding_h0 = HalideIR::Internal::as_const_int(padding_top); - const int64_t *padding_w0 = HalideIR::Internal::as_const_int(padding_left); - const int64_t *padding_h1 = HalideIR::Internal::as_const_int(padding_bottom); - const int64_t *padding_w1 = HalideIR::Internal::as_const_int(padding_right); + const int64_t *padding_h0 = HalideIR::Internal::as_const_int(pad_top); + const int64_t *padding_w0 = HalideIR::Internal::as_const_int(pad_left); + const int64_t *padding_h1 = HalideIR::Internal::as_const_int(pad_bottom); + const int64_t *padding_w1 = HalideIR::Internal::as_const_int(pad_right); const bool do_pad = ((padding_h0 && *padding_h0) || (padding_w0 && *padding_w0)) || ((padding_h1 && *padding_h1) || (padding_w1 && *padding_w1)); @@ -144,8 +125,8 @@ inline Tensor pool_impl(const Tensor& x, if (count_include_pad) { return tsum(output) / (kernel_height * kernel_width); } else { - Expr h_start = output[height_axis] * stride_height - padding_top; - Expr w_start = output[width_axis] * stride_width - padding_left; + Expr h_start = output[height_axis] * stride_height - pad_top; + Expr w_start = output[width_axis] * stride_width - pad_left; Expr h_end = ir::Min::make(h_start + kernel_height, height); Expr w_end = ir::Min::make(w_start + kernel_width, width); h_start = ir::Max::make(h_start, make_const(Int(32), 0)); diff --git a/topi/python/topi/nn/pooling.py b/topi/python/topi/nn/pooling.py index 0660ce7f1a3a..478141ee1d7e 100644 --- a/topi/python/topi/nn/pooling.py +++ b/topi/python/topi/nn/pooling.py @@ -69,8 +69,8 @@ def pool(data, stride : list/tuple of two ints Stride size, [stride_height, stride_width] - padding : list/tuple of two ints - Pad size, [pad_height, pad_width] + padding : list/tuple of four ints + Pad size, [pad_top, pad_left, pad_bottom, pad_right]] pool_type : str Pool type, 'max' or 'avg' diff --git a/topi/tests/python/test_topi_pooling.py b/topi/tests/python/test_topi_pooling.py index 14a67f11bfe1..a96416606a57 100644 --- a/topi/tests/python/test_topi_pooling.py +++ b/topi/tests/python/test_topi_pooling.py @@ -9,7 +9,7 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_ iw = ih kw = kh sw = sh - ph, pw = padding + pt, pl, pb, pr = padding A = tvm.placeholder((n, ic, ih, iw), name='A') B = topi.nn.pool(A, kernel=[kh, kw], stride=[sh, sw], padding=padding, pool_type=pool_type, ceil_mode=ceil_mode, count_include_pad=count_include_pad) @@ -19,16 +19,15 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_ bshape = get_const_tuple(B.shape) ashape = get_const_tuple(A.shape) if ceil_mode: - assert bshape[2] == int(math.ceil(float(ashape[2] - kh + ph * 2) / sh) + 1) - assert bshape[3] == int(math.ceil(float(ashape[3] - kw + pw * 2) / sw) + 1) + assert bshape[2] == int(math.ceil(float(ashape[2] - kh + pt + pb) / sh) + 1) + assert bshape[3] == int(math.ceil(float(ashape[3] - kw + pl + pr) / sw) + 1) else: - assert bshape[2] == int(math.floor(float(ashape[2] - kh + ph * 2) / sh) + 1) - assert bshape[3] == int(math.floor(float(ashape[3] - kw + pw * 2) / sw) + 1) - + assert bshape[2] == int(math.floor(float(ashape[2] - kh + pt + pb) / sh) + 1) + assert bshape[3] == int(math.floor(float(ashape[3] - kw + pl + pr) / sw) + 1) a_np = np.random.uniform(low=0.001, size=(n, ic, ih, iw)).astype(dtype) - pad_np = np.zeros(shape=(n, ic, ih+2*ph, iw+2*pw)).astype(dtype) - no_zero = (range(n), range(ic), (range(ph, ih+ph)), (range(pw, iw+pw))) + pad_np = np.zeros(shape=(n, ic, ih+pt+pb, iw+pl+pr)).astype(dtype) + no_zero = (range(n), range(ic), (range(pt, ih+pt)), (range(pl, iw+pl))) pad_np[np.ix_(*no_zero)] = a_np _, oc, oh, ow = get_const_tuple(B.shape) b_np = np.zeros(shape=(n, oc, oh, ow)).astype(dtype) @@ -67,15 +66,19 @@ def check_device(device): check_device(device) def test_pool(): - verify_pool(1, 256, 32, 2, 2, [0, 0], 'avg', False, True) - verify_pool(1, 256, 31, 3, 3, [1, 2], 'avg', False, True) - verify_pool(1, 256, 32, 2, 2, [1, 2], 'avg', False, False) - verify_pool(1, 256, 31, 4, 4, [3, 3], 'avg', False, False) - verify_pool(1, 256, 31, 4, 4, [0, 0], 'avg', False, False) - verify_pool(1, 256, 32, 2, 2, [0, 0], 'max', False) - verify_pool(1, 256, 31, 3, 3, [2, 1], 'max', False) - verify_pool(1, 256, 31, 3, 3, [2, 1], 'max', True) - + verify_pool(1, 256, 32, 2, 2, [0, 0, 0, 0], 'avg', False, True) + verify_pool(1, 256, 31, 3, 3, [1, 2, 1, 2], 'avg', False, True) + verify_pool(1, 256, 32, 2, 2, [1, 2, 1, 2], 'avg', False, False) + verify_pool(1, 256, 31, 4, 4, [3, 3, 3, 3], 'avg', False, False) + verify_pool(1, 256, 31, 4, 4, [0, 0, 0, 0], 'avg', False, False) + verify_pool(1, 256, 32, 2, 2, [0, 0, 0, 0], 'max', False) + verify_pool(1, 256, 31, 3, 3, [2, 1, 2, 1], 'max', False) + verify_pool(1, 256, 31, 3, 3, [2, 1, 2, 1], 'max', True) + + verify_pool(1, 256, 31, 3, 3, [2, 1, 0, 3], 'avg', False, True) + verify_pool(1, 256, 32, 2, 2, [0, 3, 2, 1], 'avg', False, False) + verify_pool(1, 256, 31, 3, 3, [1, 0, 3, 2], 'max', False) + verify_pool(1, 256, 31, 3, 3, [3, 2, 1, 0], 'max', True) def verify_global_pool(n, c, h, w, pool_type): diff --git a/topi/tests/python_cpp/test_topi_pooling.py b/topi/tests/python_cpp/test_topi_pooling.py index a2bbc3227d94..42232c8e4848 100644 --- a/topi/tests/python_cpp/test_topi_pooling.py +++ b/topi/tests/python_cpp/test_topi_pooling.py @@ -13,7 +13,7 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_ iw = ih kw = kh sw = sh - ph, pw = padding + pt, pl, pb, pr = padding A = tvm.placeholder((n, ic, ih, iw), name='A') B = topi.cpp.nn.pool(A, [kh, kw], [sh, sw], padding, pool_code[pool_type], ceil_mode, "NCHW", count_include_pad) @@ -23,16 +23,16 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_ bshape = get_const_tuple(B.shape) ashape = get_const_tuple(A.shape) if ceil_mode: - assert bshape[2] == int(math.ceil(float(ashape[2] - kh + ph * 2) / sh) + 1) - assert bshape[3] == int(math.ceil(float(ashape[3] - kw + pw * 2) / sw) + 1) + assert bshape[2] == int(math.ceil(float(ashape[2] - kh + pt + pb) / sh) + 1) + assert bshape[3] == int(math.ceil(float(ashape[3] - kw + pl + pr) / sw) + 1) else: - assert bshape[2] == int(math.floor(float(ashape[2] - kh + ph * 2) / sh) + 1) - assert bshape[3] == int(math.floor(float(ashape[3] - kw + pw * 2) / sw) + 1) + assert bshape[2] == int(math.floor(float(ashape[2] - kh + pt + pb) / sh) + 1) + assert bshape[3] == int(math.floor(float(ashape[3] - kw + pl + pr) / sw) + 1) a_np = np.random.uniform(size=(n, ic, ih, iw)).astype(dtype) - pad_np = np.zeros(shape=(n, ic, ih+2*ph, iw+2*pw)).astype(dtype) - no_zero = (range(n), range(ic), (range(ph, ih+ph)), (range(pw, iw+pw))) + pad_np = np.zeros(shape=(n, ic, ih+pt+pb, iw+pl+pr)).astype(dtype) + no_zero = (range(n), range(ic), (range(pt, ih+pt)), (range(pl, iw+pl))) pad_np[np.ix_(*no_zero)] = a_np _, oc, oh, ow = get_const_tuple(B.shape) b_np = np.zeros(shape=(n, oc, oh, ow)).astype(dtype) @@ -73,15 +73,19 @@ def check_device(device): check_device(device) def test_pool(): - verify_pool(1, 256, 32, 2, 2, [0, 0], 'avg', False, True) - verify_pool(1, 256, 31, 3, 3, [1, 2], 'avg', False, True) - verify_pool(1, 256, 32, 2, 2, [1, 2], 'avg', False, False) - verify_pool(1, 256, 31, 4, 4, [3, 3], 'avg', False, False) - verify_pool(1, 256, 31, 4, 4, [0, 0], 'avg', False, False) - verify_pool(1, 256, 32, 2, 2, [0, 0], 'max', False) - verify_pool(1, 256, 31, 3, 3, [2, 1], 'max', False) - verify_pool(1, 256, 31, 3, 3, [2, 1], 'max', True) - + verify_pool(1, 256, 32, 2, 2, [0, 0, 0, 0], 'avg', False, True) + verify_pool(1, 256, 31, 3, 3, [1, 2, 1, 2], 'avg', False, True) + verify_pool(1, 256, 32, 2, 2, [1, 2, 1, 2], 'avg', False, False) + verify_pool(1, 256, 31, 4, 4, [3, 3, 3, 3], 'avg', False, False) + verify_pool(1, 256, 31, 4, 4, [0, 0, 0, 0], 'avg', False, False) + verify_pool(1, 256, 32, 2, 2, [0, 0, 0, 0], 'max', False) + verify_pool(1, 256, 31, 3, 3, [2, 1, 2, 1], 'max', False) + verify_pool(1, 256, 31, 3, 3, [2, 1, 2, 1], 'max', True) + + verify_pool(1, 256, 31, 3, 3, [2, 1, 0, 3], 'avg', False, True) + verify_pool(1, 256, 32, 2, 2, [0, 3, 2, 1], 'avg', False, False) + verify_pool(1, 256, 31, 3, 3, [1, 0, 3, 2], 'max', False) + verify_pool(1, 256, 31, 3, 3, [3, 2, 1, 0], 'max', True) def verify_global_pool(n, c, h, w, pool_type): From 4ff83869caf6a51ed6e327cbeccdf941fdc0c597 Mon Sep 17 00:00:00 2001 From: SRK Reddy Date: Tue, 3 Jul 2018 10:53:38 +0530 Subject: [PATCH 4/7] [NNVM][TENSORFLOW] external pad op removed as the pool now supports asymmetric padding. --- nnvm/python/nnvm/frontend/tensorflow.py | 20 +------------------- 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py index 3b0e9c7aed1b..8fd21139a009 100644 --- a/nnvm/python/nnvm/frontend/tensorflow.py +++ b/nnvm/python/nnvm/frontend/tensorflow.py @@ -137,25 +137,7 @@ def _impl(inputs, attr, params): pad_v = _get_pad_pair(in_h, kernel_h, stride_h) pad_h = _get_pad_pair(in_w, kernel_w, stride_w) - pad_val = float("-inf") if name == "max_pool" else 0 - if name == 'max_pool': - if attr['data_format'] == 'NHWC': - inputs[0] = _sym.pad(data=inputs[0], - pad_width=((0, 0), - (pad_v[0], pad_v[1]), - (pad_h[0], pad_h[1]), - (0, 0)), - pad_value=pad_val) - else: - inputs[0] = _sym.pad(data=inputs[0], - pad_width=((0, 0), - (0, 0), - (pad_v[0], pad_v[1]), - (pad_h[0], pad_h[1])), - pad_value=pad_val) - attr['padding'] = [0, 0] - else: - attr['padding'] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]] + attr['padding'] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]] else: raise TypeError("Unsupported padding type : {}".format(attr['padding'])) From efb0c7e43b43a386b4e1c665bd8f872fad2fbb1d Mon Sep 17 00:00:00 2001 From: SRK Reddy Date: Tue, 3 Jul 2018 11:26:47 +0530 Subject: [PATCH 5/7] [POOL] Keep 2 ints as default for pool operators. --- nnvm/include/nnvm/top/nn.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nnvm/include/nnvm/top/nn.h b/nnvm/include/nnvm/top/nn.h index ada0aa4209e1..0b9d54d387f8 100644 --- a/nnvm/include/nnvm/top/nn.h +++ b/nnvm/include/nnvm/top/nn.h @@ -237,7 +237,7 @@ struct MaxPool2DParam : public dmlc::Parameter { .describe("Size of the pooling windows.."); DMLC_DECLARE_FIELD(strides).set_default(TShape({1, 1})) .describe("Specifies the strides of the convolution."); - DMLC_DECLARE_FIELD(padding).set_default(TShape({0, 0, 0, 0})) + DMLC_DECLARE_FIELD(padding).set_default(TShape({0, 0})) .describe("If padding is non-zero, then the input is implicitly zero-padded" "Padding support both symmetric and asymmetric as" "one int : same padding used on all sides" @@ -267,7 +267,7 @@ struct AvgPool2DParam : public dmlc::Parameter { .describe("Size of the pooling windows.."); DMLC_DECLARE_FIELD(strides).set_default(TShape({1, 1})) .describe("Specifies the strides of the convolution."); - DMLC_DECLARE_FIELD(padding).set_default(TShape({0, 0, 0, 0})) + DMLC_DECLARE_FIELD(padding).set_default(TShape({0, 0})) .describe("If padding is non-zero, then the input is implicitly zero-padded" "Padding support both symmetric and asymmetric as" "one int : same padding used on all sides" From ec831ae03f46eaea5b07dbb225901800a1a5bdf5 Mon Sep 17 00:00:00 2001 From: SRK Reddy Date: Wed, 4 Jul 2018 09:17:02 +0530 Subject: [PATCH 6/7] [POOL] Review comments addressed. --- nnvm/include/nnvm/top/nn.h | 4 ++-- nnvm/src/top/nn/pooling.cc | 37 +++++++++++++++++++------------------ 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/nnvm/include/nnvm/top/nn.h b/nnvm/include/nnvm/top/nn.h index 0b9d54d387f8..86bdc60a6236 100644 --- a/nnvm/include/nnvm/top/nn.h +++ b/nnvm/include/nnvm/top/nn.h @@ -242,7 +242,7 @@ struct MaxPool2DParam : public dmlc::Parameter { "Padding support both symmetric and asymmetric as" "one int : same padding used on all sides" "two int : bottom, right will use same padding as top, left" - "four int : padding just as given"); + "four int : padding width in the order of (top, left, bottom, right)"); DMLC_DECLARE_FIELD(layout).set_default("NCHW") .describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc." "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" @@ -272,7 +272,7 @@ struct AvgPool2DParam : public dmlc::Parameter { "Padding support both symmetric and asymmetric as" "one int : same padding used on all sides" "two int : bottom, right will use same padding as top, left" - "four int : padding just as given"); + "four int : padding width in the order of (top, left, bottom, right)"); DMLC_DECLARE_FIELD(layout).set_default("NCHW") .describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc." "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" diff --git a/nnvm/src/top/nn/pooling.cc b/nnvm/src/top/nn/pooling.cc index 82cb752128db..cccd5b1c710b 100644 --- a/nnvm/src/top/nn/pooling.cc +++ b/nnvm/src/top/nn/pooling.cc @@ -45,38 +45,39 @@ inline bool Pool2DInferShape(const nnvm::NodeAttrs& attrs, const auto hidx = layout.indexof('H'); const auto widx = layout.indexof('W'); - TShape pad = param.padding; + dim_t pad_h, pad_w; if (param.padding.ndim() == 1) { - pad[1] = pad[0]; - pad[2] = pad[0]; - pad[3] = pad[0]; + pad_h = param.padding[0] * 2; + pad_w = param.padding[0] * 2; } else if (param.padding.ndim() == 2) { - pad[0] *= 2; - pad[1] *= 2; + // (top, left) + pad_h = param.padding[0] * 2; + pad_w = param.padding[1] * 2; } else if (param.padding.ndim() == 4) { - pad[0] += pad[2]; - pad[1] += pad[3]; + // (top, left, bottom, right) + pad_h = param.padding[0] + param.padding[2]; + pad_w = param.padding[1] + param.padding[3]; } else { return false; } TShape oshape = dshape; - CHECK(param.pool_size[0] <= dshape[hidx] + pad[0]) + CHECK(param.pool_size[0] <= dshape[hidx] + pad_h) << "pool size (" << param.pool_size[0] << ") exceeds input (" << dshape[hidx] - << " padded to " << (dshape[hidx] + pad[0]) << ")"; - CHECK(param.pool_size[1] <= dshape[widx] + pad[1]) + << " padded to " << (dshape[hidx] + pad_h) << ")"; + CHECK(param.pool_size[1] <= dshape[widx] + pad_w) << "pool size (" << param.pool_size[1] << ") exceeds input (" << dshape[widx] - << " padded to " << (dshape[widx] + pad[1]) << ")"; + << " padded to " << (dshape[widx] + pad_w) << ")"; if (!param.ceil_mode) { - oshape[hidx] = ((dshape[hidx] + pad[0] - param.pool_size[0]) / + oshape[hidx] = ((dshape[hidx] + pad_h - param.pool_size[0]) / param.strides[0]) + 1; - oshape[widx] = ((dshape[widx] + pad[1] - param.pool_size[1]) / + oshape[widx] = ((dshape[widx] + pad_w - param.pool_size[1]) / param.strides[1]) + 1; } else { - oshape[hidx] = ((dshape[hidx] + pad[0] - param.pool_size[0] + + oshape[hidx] = ((dshape[hidx] + pad_h - param.pool_size[0] + param.strides[0] - 1) / param.strides[0]) + 1; - oshape[widx] = ((dshape[3] + pad[1] - param.pool_size[1] + + oshape[widx] = ((dshape[3] + pad_w - param.pool_size[1] + param.strides[1] - 1) / param.strides[1]) + 1; } NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape); @@ -130,7 +131,7 @@ NNVM_REGISTER_OP(max_pool2d) where padding will be an expanded array based on number of values passed as:: one int : all sides same padding used. two int : bottom, right use same as top and left. - four int: same as passed values. + four int: padding width in the order of (top, left, bottom, right). When `ceil_mode` is `True`, ceil will be used instead of floor in this equation. @@ -218,7 +219,7 @@ NNVM_REGISTER_OP(avg_pool2d) where padding will be an expanded array based on number of values passed as:: one int : all sides same padding used. two int : bottom, right use same as top and left. - four int: same as passed values. + four int: padding width in the order of (top, left, bottom, right). When `ceil_mode` is `True`, ceil will be used instead of floor in this equation. From ebe789586d7f82f27a94ced58b55e243a3e2dcd4 Mon Sep 17 00:00:00 2001 From: SRK Reddy Date: Wed, 4 Jul 2018 10:28:57 +0530 Subject: [PATCH 7/7] [POOL] Review comments (1) --- .../frontend/tensorflow/test_forward.py | 79 +++++++++++++------ topi/include/topi/nn/pooling.h | 2 +- 2 files changed, 54 insertions(+), 27 deletions(-) diff --git a/nnvm/tests/python/frontend/tensorflow/test_forward.py b/nnvm/tests/python/frontend/tensorflow/test_forward.py index fcd866d062ea..b14b15b2ffe7 100644 --- a/nnvm/tests/python/frontend/tensorflow/test_forward.py +++ b/nnvm/tests/python/frontend/tensorflow/test_forward.py @@ -108,32 +108,59 @@ def _test_pooling(input_shape, **kwargs): def test_forward_pooling(): """ Pooling """ - for padding in ["SAME", "VALID"]: - for pooling_type in ["MAX", "AVG"]: - for input_shape in [[2, 9, 10, 2], [2, 10, 9, 2]]: - for window_shape in [[1, 1], [2, 1], [2, 3]]: - if padding != "SAME": - #for dilation_rate in [[1, 1], [2, 1], [1, 2], [2, 3]]: - for dilation_rate in [[1, 1]]: - with tf.Graph().as_default(): - _test_pooling( - input_shape=input_shape, - window_shape=window_shape, - padding=padding, - pooling_type=pooling_type, - dilation_rate=dilation_rate, - strides=[1, 1]) - for strides in [[1, 1], [2, 1], [1, 2], [2, 3]]: - if np.any(np.array(strides) > window_shape): - continue - with tf.Graph().as_default(): - _test_pooling( - input_shape=input_shape, - window_shape=window_shape, - padding=padding, - pooling_type=pooling_type, - dilation_rate=[1, 1], - strides=strides) + + _test_pooling(input_shape=[2, 9, 10, 2], + window_shape=[1, 1], + padding='SAME', + pooling_type='MAX', + dilation_rate=[1, 1], + strides=[1, 1]) + _test_pooling(input_shape=[2, 9, 10, 2], + window_shape=[1, 1], + padding='SAME', + pooling_type='AVG', + dilation_rate=[1, 1], + strides=[1, 1]) + + _test_pooling(input_shape=[2, 10, 9, 2], + window_shape=[1, 1], + padding='SAME', + pooling_type='MAX', + dilation_rate=[1, 1], + strides=[1, 1]) + _test_pooling(input_shape=[2, 10, 9, 2], + window_shape=[1, 1], + padding='SAME', + pooling_type='AVG', + dilation_rate=[1, 1], + strides=[1, 1]) + + _test_pooling(input_shape=[2, 9, 10, 2], + window_shape=[2, 1], + padding='SAME', + pooling_type='MAX', + dilation_rate=[1, 1], + strides=[1, 1]) + _test_pooling(input_shape=[2, 9, 10, 2], + window_shape=[2, 1], + padding='SAME', + pooling_type='AVG', + dilation_rate=[1, 1], + strides=[2, 1]) + + _test_pooling(input_shape=[2, 10, 9, 2], + window_shape=[2, 3], + padding='SAME', + pooling_type='MAX', + dilation_rate=[1, 1], + strides=[2, 1]) + _test_pooling(input_shape=[2, 10, 9, 2], + window_shape=[2, 3], + padding='SAME', + pooling_type='AVG', + dilation_rate=[1, 1], + strides=[1, 2]) + ####################################################################### # Convolution diff --git a/topi/include/topi/nn/pooling.h b/topi/include/topi/nn/pooling.h index 34a00dee1947..26d61d42991d 100644 --- a/topi/include/topi/nn/pooling.h +++ b/topi/include/topi/nn/pooling.h @@ -52,7 +52,7 @@ inline Tensor pool_impl(const Tensor& x, CHECK(x->shape.size() >= 2) << "Pooling input must >= 2-D (H, W)"; CHECK_EQ(kernel_size.size(), 2) << "Pooling kernel_size must have 2 elements"; CHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements"; - CHECK((padding_size.size() == 4)) << "Pooling padding_size must have 4 elements"; + CHECK_EQ(padding_size.size(), 4) << "Pooling padding_size must have 4 elements"; auto kernel_height = kernel_size[0]; auto kernel_width = kernel_size[1];