@@ -29,8 +29,8 @@ namespace caffe {
2929template <typename Dtype>
3030class CuDNNConvolutionLayer : public ConvolutionLayer <Dtype> {
3131 public:
32- explicit CuDNNConvolutionLayer (const LayerParameter& param)
33- : ConvolutionLayer<Dtype>(param), handles_setup_( false ) {}
32+ explicit CuDNNConvolutionLayer (const LayerParameter& param);
33+
3434 virtual void LayerSetUp (const vector<Blob<Dtype>*>& bottom,
3535 const vector<Blob<Dtype>*>& top);
3636 virtual void Reshape (const vector<Blob<Dtype>*>& bottom,
@@ -43,49 +43,32 @@ class CuDNNConvolutionLayer : public ConvolutionLayer<Dtype> {
4343 virtual void Backward_gpu (const vector<Blob<Dtype>*>& top,
4444 const vector<bool >& propagate_down, const vector<Blob<Dtype>*>& bottom);
4545
46-
4746 bool handles_setup_;
4847
4948#ifdef USE_MIOPEN
50- miopenHandle_t* handle_;
51- hipStream_t* stream_;
5249
5350 // algorithms for forward and backwards convolutions
54- miopenConvFwdAlgorithm_t* fwd_algo_;
55- miopenConvBwdWeightsAlgorithm_t* bwd_weight_algo_;
56- miopenConvBwdDataAlgorithm_t* bwd_data_algo_;
51+ vector< miopenConvFwdAlgorithm_t> fwd_algo_;
52+ vector< miopenConvBwdWeightsAlgorithm_t> bwd_weight_algo_;
53+ vector< miopenConvBwdDataAlgorithm_t> bwd_data_algo_;
5754
5855 vector<miopenTensorDescriptor_t> bottom_descs_, top_descs_;
5956 miopenTensorDescriptor_t bias_desc_;
6057 miopenTensorDescriptor_t filter_desc_;
6158 vector<miopenConvolutionDescriptor_t> conv_descs_;
6259
6360 int N_, C_, W_, H_;
64- #endif
65-
66- #ifdef USE_CUDNN
67- cudnnHandle_t* handle_;
68- cudaStream_t* stream_;
69-
70- // algorithms for forward and backwards convolutions
71- cudnnConvolutionFwdAlgo_t *fwd_algo_;
72- cudnnConvolutionBwdFilterAlgo_t *bwd_filter_algo_;
73- cudnnConvolutionBwdDataAlgo_t *bwd_data_algo_;
74-
75- vector<cudnnTensorDescriptor_t> bottom_descs_, top_descs_;
76- cudnnTensorDescriptor_t bias_desc_;
77- cudnnFilterDescriptor_t filter_desc_;
78- vector<cudnnConvolutionDescriptor_t> conv_descs_;
61+ miopenHandle_t handle_;
7962#endif
8063
8164 int bottom_offset_, top_offset_, bias_offset_;
8265
83- size_t * workspace_fwd_sizes_;
84- size_t *workspace_bwd_data_sizes_ ;
85- size_t *workspace_bwd_filter_sizes_ ;
66+ vector< size_t > workspace_fwd_sizes_;
67+ vector< size_t > workspace_bwd_filter_sizes_ ;
68+ vector< size_t > workspace_bwd_data_sizes_ ;
8669 size_t workspaceSizeInBytes; // size of underlying storage
8770 void *workspaceData; // underlying storage
88- void ** workspace; // aliases into workspaceData
71+ vector< void *> workspace; // aliases into workspaceData
8972};
9073#endif
9174
0 commit comments