@@ -50,7 +50,8 @@ void BatchNormKernel(const Context& dev_ctx,
5050 data_layout_str,
5151 FLAGS_npu_storage_format));
5252
53- if (FLAGS_npu_storage_format) {
53+ if (FLAGS_npu_storage_format &&
54+ x_dims.size () == 4 ) { // TODO(qili93): add 3D support
5455 AllocNPUTensor<T>(dev_ctx, ACL_FORMAT_NC1HWC0, y);
5556 } else {
5657 dev_ctx.template Alloc <T>(y);
@@ -111,7 +112,8 @@ void BatchNormKernel(const Context& dev_ctx,
111112 {{" epsilon" , epsilon}});
112113 runner_infer.Run (stream);
113114 } else {
114- if (FLAGS_npu_storage_format) {
115+ if (FLAGS_npu_storage_format &&
116+ x_dims.size () == 4 ) { // TODO(qili93): add 3D support
115117 AllocNPUTensor<T>(dev_ctx, ACL_FORMAT_NC1HWC0, mean_out);
116118 AllocNPUTensor<T>(dev_ctx, ACL_FORMAT_NC1HWC0, variance_out);
117119 AllocNPUTensor<T>(dev_ctx, ACL_FORMAT_NC1HWC0, saved_mean);
@@ -123,12 +125,16 @@ void BatchNormKernel(const Context& dev_ctx,
123125 dev_ctx.template Alloc <float >(saved_variance);
124126 }
125127
128+ // BN3DTrainingReduce will throw output size mismatch if output tensor in
129+ // NCHW format should change output tensor format same with input tensor
130+ // format NDCHW or NDHWC
126131 phi::DenseTensorMeta meta = {
127- phi::DataType::FLOAT32, mean_out->dims (), x .layout ()};
132+ phi::DataType::FLOAT32, mean_out->dims (), x_tensor .layout ()};
128133 phi::DenseTensor sum, square_sum;
129134 sum.set_meta (meta);
130135 square_sum.set_meta (meta);
131- if (FLAGS_npu_storage_format) {
136+ if (FLAGS_npu_storage_format &&
137+ x_dims.size () == 4 ) { // TODO(qili93): add 3D support
132138 AllocNPUTensor<float >(dev_ctx, ACL_FORMAT_NC1HWC0, &sum);
133139 AllocNPUTensor<float >(dev_ctx, ACL_FORMAT_NC1HWC0, &square_sum);
134140 } else {
@@ -138,19 +144,43 @@ void BatchNormKernel(const Context& dev_ctx,
138144
139145 std::string reduce_name =
140146 (x.dims ().size () == 5 ) ? " BN3DTrainingReduce" : " BNTrainingReduce" ;
141- const auto & runner_reduce = NpuOpRunner (
142- reduce_name, {x_tensor}, {sum, square_sum}, {{" epsilon" , epsilon}});
143- runner_reduce.Run (stream);
147+ NpuOpRunner runner_reduce;
148+ runner_reduce.SetType (reduce_name)
149+ .AddInput (x_tensor)
150+ .AddOutput (sum)
151+ .AddOutput (square_sum)
152+ .AddAttrs ({{" epsilon" , epsilon}})
153+ .Run (stream);
154+
155+ // BN3DTrainingUpdate will throw output size mismatch if output tensor in
156+ // NCHW format should change output tensor format same with input tensor
157+ // format NDCHW or NDHWC
158+ if (x_dims.size () == 5 ) {
159+ mean_out->set_meta (meta);
160+ variance_out->set_meta (meta);
161+ saved_mean->set_meta (meta);
162+ saved_variance->set_meta (meta);
163+ }
144164
145165 std::string update_name =
146166 (x.dims ().size () == 5 ) ? " BN3DTrainingUpdate" : " BNTrainingUpdate" ;
147- const auto & runner_update = NpuOpRunner (
148- update_name,
149- {x_tensor, sum, square_sum, scale, bias, running_mean, running_var},
150- {y_tensor, *mean_out, *variance_out, *saved_mean, *saved_variance},
151- {{" factor" , static_cast <float >(momentum)},
152- {" epsilon" , static_cast <float >(epsilon)}});
153- runner_update.Run (stream);
167+ NpuOpRunner runner_update;
168+ runner_update.SetType (update_name)
169+ .AddInput (x_tensor)
170+ .AddInput (sum)
171+ .AddInput (square_sum)
172+ .AddInput (scale)
173+ .AddInput (bias)
174+ .AddInput (running_mean)
175+ .AddInput (running_var)
176+ .AddOutput (y_tensor)
177+ .AddOutput (*mean_out)
178+ .AddOutput (*variance_out)
179+ .AddOutput (*saved_mean)
180+ .AddOutput (*saved_variance)
181+ .AddAttrs ({{" epsilon" , static_cast <float >(epsilon)}})
182+ .AddAttrs ({{" factor" , static_cast <float >(momentum)}})
183+ .Run (stream);
154184 }
155185}
156186
@@ -246,7 +276,8 @@ void BatchNormGradKernel(
246276
247277 auto stream = dev_ctx.stream ();
248278 if (d_scale && d_bias) {
249- if (FLAGS_npu_storage_format) {
279+ if (FLAGS_npu_storage_format &&
280+ x_dims.size () == 4 ) { // TODO(qili93): add 3D support
250281 AllocNPUTensor<float >(dev_ctx, ACL_FORMAT_NC1HWC0, d_scale);
251282 AllocNPUTensor<float >(dev_ctx, ACL_FORMAT_NC1HWC0, d_bias);
252283 } else {
@@ -271,7 +302,8 @@ void BatchNormGradKernel(
271302 }
272303
273304 if (d_x) {
274- if (FLAGS_npu_storage_format) {
305+ if (FLAGS_npu_storage_format &&
306+ x_dims.size () == 4 ) { // TODO(qili93): add 3D support
275307 AllocNPUTensor<T>(dev_ctx, ACL_FORMAT_NC1HWC0, d_x);
276308 } else {
277309 dev_ctx.template Alloc <T>(d_x);
@@ -332,6 +364,9 @@ void BatchNormInferKernel(const Context& dev_ctx,
332364 const auto & x_dims = x.dims ();
333365 const bool channel_last = data_layout_str == " NHWC" && x_dims.size () > 2 ;
334366
367+ VLOG (1 ) << " 0 -- BatchNormInferKernel: Attr <channel_last> = "
368+ << channel_last;
369+
335370 PADDLE_ENFORCE_EQ (
336371 channel_last && FLAGS_npu_storage_format,
337372 false ,
@@ -343,7 +378,8 @@ void BatchNormInferKernel(const Context& dev_ctx,
343378 data_layout_str,
344379 FLAGS_npu_storage_format));
345380
346- if (FLAGS_npu_storage_format) {
381+ if (FLAGS_npu_storage_format &&
382+ x_dims.size () == 4 ) { // TODO(qili93): add 3D support
347383 AllocNPUTensor<T>(dev_ctx, ACL_FORMAT_NC1HWC0, y);
348384 } else {
349385 dev_ctx.template Alloc <T>(y);
0 commit comments