Skip to content

max pooling across multiple filters #19

@sbmorphe

Description

@sbmorphe

I defined a CNN in TensorFlow, here's a chunk of it:

import tensorflow as tf

HEIGHT, WIDTH, DEPTH = 144, 192, 3
N_CLASSES = 2

x = tf.placeholder(tf.float32, shape=[None, HEIGHT, WIDTH, DEPTH], name="input")
y_ = tf.placeholder(tf.float32, shape=[None, N_CLASSES], name="y_")

WIN_X, WIN_Y = 5, 5
N_FILTERS = 4

W1 = tf.Variable(tf.truncated_normal([WIN_X, WIN_Y, DEPTH, N_FILTERS],
stddev=1/np.sqrt(WIN_X*WIN_Y)))
b1 = tf.Variable(tf.constant(0.1, shape=[N_FILTERS]))
xw = tf.nn.conv2d(x, W1, strides=[1,1,1,1], padding="SAME", name="xw")
h1 = tf.nn.relu(xw + b1, name="h1")
p1 = tf.nn.max_pool(h1, ksize=[1,2,2,1], strides=[1,2,2,1], padding="VALID", name="p1")
#...

I was able to train and test the model. I then saved the model using tfdeploy. I can then load the model like so:

model = tfdeploy.Model("myModel.pkl")
# THIS CODE WORKS:

x, test_point = model.get("input", "h1")
test_point.eval({x:samps}) # 38 samples
# BUT THIS DOESN'T WORK:

x, test_point = model.get("input", "p1")
test_point.eval({x:samps})

Any idea what is going on? Here's the error message I'm getting:

Traceback (most recent call last):
File "test.py", line 68, in
test_point.eval({x:samps})
File "/home/sbmorphe/Downloads/tfdeploy.py", line 291, in eval
self.value = self.op.eval(feed_dict=feed_dict, _uuid=_uuid)[self.value_index]
File "/home/sbmorphe/Downloads/tfdeploy.py", line 462, in eval
self.value = self.func(*args)
File "/home/sbmorphe/Downloads/tfdeploy.py", line 474, in func
return cls.func_numpy(*args)
File "/home/sbmorphe/Downloads/tfdeploy.py", line 2185, in MaxPool
patches = _conv_patches(a, np.ones(k[1:] + [1]), strides, padding.decode("ascii"), "edge")
File "/home/sbmorphe/Downloads/tfdeploy.py", line 2108, in _conv_patches
src[s + tuple(slice(*tpl) for tpl in zip(pos, pos + f.shape[:-2]))][en] * f
ValueError: could not broadcast input array from shape (38,2,2,4,1) into shape (38,2,2,1,1)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions