@@ -60,60 +60,6 @@ void ConvolutionForward(int mode, int format, int algo, int dims, int groups, co
6060 entry_ptr->conv_entry .output_desc , y->data ));
6161}
6262
63- void OutputShape (int format, int dims, int groups, const int pad[], const int stride[],
64- const int dilation[], const int x_dim[], const int w_dim[], void * out_shape,
65- const std::string& data_dtype, const std::string& conv_dtype) {
66- CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal ();
67-
68- // Set Data Type
69- entry_ptr->conv_entry .data_type = CuDNNDataType::DLTypeToCuDNNType (String2DLDataType (conv_dtype));
70- cudnnDataType_t data_type = CuDNNDataType::DLTypeToCuDNNType (String2DLDataType (data_dtype));
71- // Set Format
72- entry_ptr->conv_entry .tensor_format = static_cast <cudnnTensorFormat_t>(format);
73- // Dims includes N and C
74- int full_dims = dims + 2 ;
75-
76- // conv desc
77- CUDNN_CALL (cudnnSetConvolutionGroupCount (entry_ptr->conv_entry .conv_desc , groups));
78- CUDNN_CALL (cudnnSetConvolutionNdDescriptor (entry_ptr->conv_entry .conv_desc , dims, pad, stride,
79- dilation, CUDNN_CROSS_CORRELATION,
80- entry_ptr->conv_entry .data_type ));
81-
82- if (entry_ptr->conv_entry .tensor_format == CUDNN_TENSOR_NHWC) {
83- ICHECK_EQ (full_dims, 4 ) << " Use of layout CUDNN_TENSOR_NHWC is only supported for 4d tensors" ;
84-
85- // Set Input
86- CUDNN_CALL (cudnnSetTensor4dDescriptor (entry_ptr->conv_entry .input_desc ,
87- entry_ptr->conv_entry .tensor_format , data_type, x_dim[0 ],
88- x_dim[3 ], x_dim[1 ], x_dim[2 ]));
89-
90- // filter desc
91- CUDNN_CALL (cudnnSetFilter4dDescriptor (entry_ptr->conv_entry .filter_desc , data_type,
92- entry_ptr->conv_entry .tensor_format , w_dim[0 ], w_dim[3 ],
93- w_dim[1 ], w_dim[2 ]));
94-
95- CUDNN_CALL (cudnnGetConvolution2dForwardOutputDim (
96- entry_ptr->conv_entry .conv_desc , entry_ptr->conv_entry .input_desc ,
97- entry_ptr->conv_entry .filter_desc , static_cast <int *>(out_shape),
98- static_cast <int *>(out_shape) + 3 , static_cast <int *>(out_shape) + 1 ,
99- static_cast <int *>(out_shape) + 2 ));
100- } else {
101- // Set Input
102- std::vector<int > tensor_stride (full_dims);
103- GetCudnnStride (full_dims, x_dim, tensor_stride.data ());
104-
105- CUDNN_CALL (cudnnSetTensorNdDescriptor (entry_ptr->conv_entry .input_desc , data_type, full_dims,
106- x_dim, tensor_stride.data ()));
107- // filter desc
108- CUDNN_CALL (cudnnSetFilterNdDescriptor (entry_ptr->conv_entry .filter_desc , data_type,
109- entry_ptr->conv_entry .tensor_format , full_dims, w_dim));
110-
111- CUDNN_CALL (cudnnGetConvolutionNdForwardOutputDim (
112- entry_ptr->conv_entry .conv_desc , entry_ptr->conv_entry .input_desc ,
113- entry_ptr->conv_entry .filter_desc , full_dims, static_cast <int *>(out_shape)));
114- }
115- }
116-
11763void FindAlgo (int format, int dims, int groups, const int pad[], const int stride[],
11864 const int dilation[], const int x_dim[], const int w_dim[], const int y_dim[],
11965 const std::string& data_dtype, const std::string& conv_dtype, TVMRetValue* ret) {
@@ -201,24 +147,6 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.forward")
201147 conv_dtype);
202148 });
203149
204- TVM_REGISTER_GLOBAL (" tvm.contrib.cudnn.conv.output_shape_from_cudnn" )
205- .set_body([](TVMArgs args, TVMRetValue* ret) {
206- int format = args[0 ];
207- int dims = args[1 ];
208- int * pad = static_cast <int *>(static_cast <void *>(args[2 ]));
209- int * stride = static_cast <int *>(static_cast <void *>(args[3 ]));
210- int * dilation = static_cast <int *>(static_cast <void *>(args[4 ]));
211- int * x_dim = static_cast <int *>(static_cast <void *>(args[5 ]));
212- int * w_dim = static_cast <int *>(static_cast <void *>(args[6 ]));
213- void * out_shape = args[7 ];
214- std::string data_dtype = args[8 ];
215- std::string conv_dtype = args[9 ];
216- int groups = args[10 ];
217-
218- OutputShape (format, dims, groups, pad, stride, dilation, x_dim, w_dim, out_shape, data_dtype,
219- conv_dtype);
220- });
221-
222150TVM_REGISTER_GLOBAL (" tvm.contrib.cudnn.conv.find_algo" )
223151 .set_body([](TVMArgs args, TVMRetValue* ret) {
224152 int format = args[0 ];
0 commit comments