I defined a CNN in TensorFlow, here's a chunk of it:
import tensorflow as tf
HEIGHT, WIDTH, DEPTH = 144, 192, 3
N_CLASSES = 2
x = tf.placeholder(tf.float32, shape=[None, HEIGHT, WIDTH, DEPTH], name="input")
y_ = tf.placeholder(tf.float32, shape=[None, N_CLASSES], name="y_")
WIN_X, WIN_Y = 5, 5
N_FILTERS = 4
W1 = tf.Variable(tf.truncated_normal([WIN_X, WIN_Y, DEPTH, N_FILTERS],
stddev=1/np.sqrt(WIN_X*WIN_Y)))
b1 = tf.Variable(tf.constant(0.1, shape=[N_FILTERS]))
xw = tf.nn.conv2d(x, W1, strides=[1,1,1,1], padding="SAME", name="xw")
h1 = tf.nn.relu(xw + b1, name="h1")
p1 = tf.nn.max_pool(h1, ksize=[1,2,2,1], strides=[1,2,2,1], padding="VALID", name="p1")
#...
I was able to train and test the model. I then saved the model using tfdeploy. I can then load the model like so:
model = tfdeploy.Model("myModel.pkl")
# THIS CODE WORKS:
x, test_point = model.get("input", "h1")
test_point.eval({x:samps}) # 38 samples
# BUT THIS DOESN'T WORK:
x, test_point = model.get("input", "p1")
test_point.eval({x:samps})
Any idea what is going on? Here's the error message I'm getting:
Traceback (most recent call last):
File "test.py", line 68, in
test_point.eval({x:samps})
File "/home/sbmorphe/Downloads/tfdeploy.py", line 291, in eval
self.value = self.op.eval(feed_dict=feed_dict, _uuid=_uuid)[self.value_index]
File "/home/sbmorphe/Downloads/tfdeploy.py", line 462, in eval
self.value = self.func(*args)
File "/home/sbmorphe/Downloads/tfdeploy.py", line 474, in func
return cls.func_numpy(*args)
File "/home/sbmorphe/Downloads/tfdeploy.py", line 2185, in MaxPool
patches = _conv_patches(a, np.ones(k[1:] + [1]), strides, padding.decode("ascii"), "edge")
File "/home/sbmorphe/Downloads/tfdeploy.py", line 2108, in _conv_patches
src[s + tuple(slice(*tpl) for tpl in zip(pos, pos + f.shape[:-2]))][en] * f
ValueError: could not broadcast input array from shape (38,2,2,4,1) into shape (38,2,2,1,1)
I defined a CNN in TensorFlow, here's a chunk of it:
I was able to train and test the model. I then saved the model using tfdeploy. I can then load the model like so:
Any idea what is going on? Here's the error message I'm getting: