Skip to content

Commit 834f54a

Browse files
committed
remove cudnn get output
1 parent dcbd9c9 commit 834f54a

4 files changed

Lines changed: 3 additions & 147 deletions

File tree

python/tvm/contrib/cudnn.py

Lines changed: 0 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -285,68 +285,6 @@ def conv_output_shape(
285285
return output
286286

287287

288-
def _conv_output_shape_from_cudnn(
289-
tensor_format, pad, stride, dilation, x_shape, w_shape, data_dtype, conv_dtype, groups=1
290-
):
291-
"""Get output shape of 2D or 3D convolution. The output of this
292-
function should be identical to that of conv_output_shape, but
293-
requires a GPU with CuDNN to be present. This is maintained for
294-
testing purposes to validate the output of conv_output_shape.
295-
296-
Paramters
297-
---------
298-
tensor_format: int
299-
0: CUDNN_TENSOR_NCHW
300-
1: CUDNN_TENSOR_NHWC
301-
2: CUDNN_TENSOR_NCHW_VECT_C
302-
pad: int or list
303-
padding
304-
stride: int or list
305-
stride
306-
dilation: int or list
307-
dilation
308-
x_shape: list
309-
input shape
310-
w_shape: list
311-
weight shape
312-
data_dtype: str
313-
data type
314-
conv_dtype: str
315-
convolution type
316-
groups: int
317-
number of groups
318-
319-
Returns
320-
-------
321-
oshape: list
322-
output shape
323-
324-
"""
325-
dims = len(x_shape)
326-
assert dims in (4, 5)
327-
328-
pad, stride, dilation, xshape, wshape = _prepare_global_func_params(
329-
dims - 2, pad, stride, dilation, x_shape, w_shape
330-
)
331-
oshape = np.zeros((dims), dtype=np.int32)
332-
333-
func = tvm._ffi.get_global_func("tvm.contrib.cudnn.conv.output_shape_from_cudnn")
334-
func(
335-
tensor_format,
336-
dims - 2,
337-
_get_np_int32_array_handle(pad),
338-
_get_np_int32_array_handle(stride),
339-
_get_np_int32_array_handle(dilation),
340-
_get_np_int32_array_handle(xshape),
341-
_get_np_int32_array_handle(wshape),
342-
_get_np_int32_array_handle(oshape),
343-
data_dtype,
344-
conv_dtype,
345-
groups,
346-
)
347-
return list(oshape)
348-
349-
350288
def conv_find_algo(
351289
tensor_format,
352290
pad,

src/runtime/contrib/cudnn/conv_forward.cc

Lines changed: 0 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -60,60 +60,6 @@ void ConvolutionForward(int mode, int format, int algo, int dims, int groups, co
6060
entry_ptr->conv_entry.output_desc, y->data));
6161
}
6262

63-
void OutputShape(int format, int dims, int groups, const int pad[], const int stride[],
64-
const int dilation[], const int x_dim[], const int w_dim[], void* out_shape,
65-
const std::string& data_dtype, const std::string& conv_dtype) {
66-
CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
67-
68-
// Set Data Type
69-
entry_ptr->conv_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(String2DLDataType(conv_dtype));
70-
cudnnDataType_t data_type = CuDNNDataType::DLTypeToCuDNNType(String2DLDataType(data_dtype));
71-
// Set Format
72-
entry_ptr->conv_entry.tensor_format = static_cast<cudnnTensorFormat_t>(format);
73-
// Dims includes N and C
74-
int full_dims = dims + 2;
75-
76-
// conv desc
77-
CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups));
78-
CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, stride,
79-
dilation, CUDNN_CROSS_CORRELATION,
80-
entry_ptr->conv_entry.data_type));
81-
82-
if (entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) {
83-
ICHECK_EQ(full_dims, 4) << "Use of layout CUDNN_TENSOR_NHWC is only supported for 4d tensors";
84-
85-
// Set Input
86-
CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc,
87-
entry_ptr->conv_entry.tensor_format, data_type, x_dim[0],
88-
x_dim[3], x_dim[1], x_dim[2]));
89-
90-
// filter desc
91-
CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc, data_type,
92-
entry_ptr->conv_entry.tensor_format, w_dim[0], w_dim[3],
93-
w_dim[1], w_dim[2]));
94-
95-
CUDNN_CALL(cudnnGetConvolution2dForwardOutputDim(
96-
entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.input_desc,
97-
entry_ptr->conv_entry.filter_desc, static_cast<int*>(out_shape),
98-
static_cast<int*>(out_shape) + 3, static_cast<int*>(out_shape) + 1,
99-
static_cast<int*>(out_shape) + 2));
100-
} else {
101-
// Set Input
102-
std::vector<int> tensor_stride(full_dims);
103-
GetCudnnStride(full_dims, x_dim, tensor_stride.data());
104-
105-
CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, data_type, full_dims,
106-
x_dim, tensor_stride.data()));
107-
// filter desc
108-
CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, data_type,
109-
entry_ptr->conv_entry.tensor_format, full_dims, w_dim));
110-
111-
CUDNN_CALL(cudnnGetConvolutionNdForwardOutputDim(
112-
entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.input_desc,
113-
entry_ptr->conv_entry.filter_desc, full_dims, static_cast<int*>(out_shape)));
114-
}
115-
}
116-
11763
void FindAlgo(int format, int dims, int groups, const int pad[], const int stride[],
11864
const int dilation[], const int x_dim[], const int w_dim[], const int y_dim[],
11965
const std::string& data_dtype, const std::string& conv_dtype, TVMRetValue* ret) {
@@ -201,24 +147,6 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.forward")
201147
conv_dtype);
202148
});
203149

204-
TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.output_shape_from_cudnn")
205-
.set_body([](TVMArgs args, TVMRetValue* ret) {
206-
int format = args[0];
207-
int dims = args[1];
208-
int* pad = static_cast<int*>(static_cast<void*>(args[2]));
209-
int* stride = static_cast<int*>(static_cast<void*>(args[3]));
210-
int* dilation = static_cast<int*>(static_cast<void*>(args[4]));
211-
int* x_dim = static_cast<int*>(static_cast<void*>(args[5]));
212-
int* w_dim = static_cast<int*>(static_cast<void*>(args[6]));
213-
void* out_shape = args[7];
214-
std::string data_dtype = args[8];
215-
std::string conv_dtype = args[9];
216-
int groups = args[10];
217-
218-
OutputShape(format, dims, groups, pad, stride, dilation, x_dim, w_dim, out_shape, data_dtype,
219-
conv_dtype);
220-
});
221-
222150
TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.find_algo")
223151
.set_body([](TVMArgs args, TVMRetValue* ret) {
224152
int format = args[0];

tests/python/contrib/test_cudnn.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030

3131
requires_cudnn = pytest.mark.skipif(
32-
tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape_from_cudnn", True) is None,
32+
tvm.get_global_func("tvm.contrib.cudnn.conv2d.forward", True) is None,
3333
reason="CuDNN is not enabled",
3434
)
3535

@@ -307,15 +307,5 @@ def conv_output_shape_kwargs(request):
307307
return request.param
308308

309309

310-
@tvm.testing.requires_gpu
311-
@requires_cudnn
312-
def test_conv_output_shape(conv_output_shape_kwargs):
313-
shape_from_cudnn = cudnn._conv_output_shape_from_cudnn(**conv_output_shape_kwargs)
314-
shape_from_python = cudnn.conv_output_shape(**conv_output_shape_kwargs)
315-
assert shape_from_cudnn == shape_from_python
316-
317-
318310
if __name__ == "__main__":
319-
# sys.exit(pytest.main(sys.argv))
320-
test_conv2d()
321-
test_conv3d()
311+
sys.exit(pytest.main(sys.argv))

tests/python/relay/test_any.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,7 @@ def verify_any_conv2d(
541541
kernel_np = np.random.uniform(size=kernel_shape).astype(dtype)
542542

543543
targets = None
544-
if use_cudnn and tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape_from_cudnn", True):
544+
if use_cudnn and tvm.get_global_func("tvm.contrib.cudnn.conv2d.forward", True):
545545
targets = [("cuda -libs=cudnn", tvm.cuda(0))]
546546

547547
check_result([data_np, kernel_np], mod, ref_out_shape, assert_shape=True, targets=targets)

0 commit comments

Comments
 (0)