Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/tvm/relay/frontend/tensorflow_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,8 +464,8 @@ def _impl(inputs, attr, params, mod):
if opname == "conv":
attr["kernel_layout"] = "HWIO" if attr["data_format"] == "NHWC" else "OIHW"
elif opname == "conv_transpose":
# conv_transpose in TVM has weights be IOHW for NCHW
attr["kernel_layout"] = "HWIO" if attr["data_format"] == "NHWC" else "IOHW"
# conv_transpose in TVM has weights be IOHW, because the attr["data_format"] always be NCHW when opname='conv_transpose'.
attr["kernel_layout"] = "IOHW"
else:
attr["kernel_layout"] = "HWOI" if attr["data_format"] == "NHWC" else "OIHW"

Expand Down
23 changes: 22 additions & 1 deletion tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,17 @@ def test_forward_convolution():
"NCHW",
[1, 1, 8, 8],
)

_test_convolution(
"conv_transpose",
[4, 19, 8, 8],
[2, 2, 66, 19],
[1, 1],
[2, 2],
"VALID",
"NCHW",
[4, 66, 16, 16],
)

_test_convolution("conv", [4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], "SAME", "NHWC")
_test_convolution("conv", [4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], "VALID", "NHWC")
_test_convolution("conv", [4, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], "SAME", "NHWC")
Expand Down Expand Up @@ -917,6 +927,17 @@ def test_forward_convolution():
[4, 8, 8, 176],
add_shapes_to_graph_def=False,
)
_test_convolution(
"conv_transpose",
[4, 8, 8, 19],
[2, 2, 66, 19],
[1, 1],
[2, 2],
"VALID",
"NHWC",
[4, 16, 16, 66],
)

# Explicit padding
if package_version.parse(tf.VERSION) >= package_version.parse("2.4.1"):
_test_convolution(
Expand Down