@@ -35,6 +35,7 @@ __global__ void nd_rasterize_backward_kernel(
3535 const float * __restrict__ v_output,
3636 const float * __restrict__ v_output_alpha,
3737 float2 * __restrict__ v_xy,
38+ float2 * __restrict__ v_xy_abs,
3839 float3 * __restrict__ v_conic,
3940 float * __restrict__ v_rgb,
4041 float * __restrict__ v_opacity
@@ -90,6 +91,7 @@ __global__ void nd_rasterize_backward_kernel(
9091 float v_alpha = 0 .f ;
9192 float3 v_conic_local = {0 .f , 0 .f , 0 .f };
9293 float2 v_xy_local = {0 .f , 0 .f };
94+ float2 v_xy_abs_local = {0 .f , 0 .f };
9395 float v_opacity_local = 0 .f ;
9496 if (valid){
9597 // compute the current T for this gaussian
@@ -114,19 +116,24 @@ __global__ void nd_rasterize_backward_kernel(
114116 0 .5f * v_sigma * delta.y * delta.y };
115117 v_xy_local = {v_sigma * (conic.x * delta.x + conic.y * delta.y ),
116118 v_sigma * (conic.y * delta.x + conic.z * delta.y )};
119+ v_xy_abs_local = {abs (v_xy_local.x ), abs (v_xy_local.y )};
117120 v_opacity_local = vis * v_alpha;
118121 }
119122 warpSum3 (v_conic_local, warp);
120123 warpSum2 (v_xy_local, warp);
124+ warpSum2 (v_xy_abs_local, warp);
121125 warpSum (v_opacity_local, warp);
122126 if (warp.thread_rank () == 0 ) {
123127 float * v_conic_ptr = (float *)(v_conic);
124128 float * v_xy_ptr = (float *)(v_xy);
129+ float * v_xy_abs_ptr = (float *)(v_xy_abs);
125130 atomicAdd (v_conic_ptr + 3 *g + 0 , v_conic_local.x );
126131 atomicAdd (v_conic_ptr + 3 *g + 1 , v_conic_local.y );
127132 atomicAdd (v_conic_ptr + 3 *g + 2 , v_conic_local.z );
128133 atomicAdd (v_xy_ptr + 2 *g + 0 , v_xy_local.x );
129134 atomicAdd (v_xy_ptr + 2 *g + 1 , v_xy_local.y );
135+ atomicAdd (v_xy_abs_ptr + 2 *g + 0 , v_xy_abs_local.x );
136+ atomicAdd (v_xy_abs_ptr + 2 *g + 1 , v_xy_abs_local.y );
130137 atomicAdd (v_opacity + g, v_opacity_local);
131138 }
132139 }
@@ -147,6 +154,7 @@ __global__ void rasterize_backward_kernel(
147154 const float3 * __restrict__ v_output,
148155 const float * __restrict__ v_output_alpha,
149156 float2 * __restrict__ v_xy,
157+ float2 * __restrict__ v_xy_abs,
150158 float3 * __restrict__ v_conic,
151159 float3 * __restrict__ v_rgb,
152160 float * __restrict__ v_opacity
@@ -251,6 +259,7 @@ __global__ void rasterize_backward_kernel(
251259 float3 v_rgb_local = {0 .f , 0 .f , 0 .f };
252260 float3 v_conic_local = {0 .f , 0 .f , 0 .f };
253261 float2 v_xy_local = {0 .f , 0 .f };
262+ float2 v_xy_abs_local = {0 .f , 0 .f };
254263 float v_opacity_local = 0 .f ;
255264 // initialize everything to 0, only set if the lane is valid
256265 if (valid){
@@ -284,11 +293,13 @@ __global__ void rasterize_backward_kernel(
284293 0 .5f * v_sigma * delta.y * delta.y };
285294 v_xy_local = {v_sigma * (conic.x * delta.x + conic.y * delta.y ),
286295 v_sigma * (conic.y * delta.x + conic.z * delta.y )};
296+ v_xy_abs_local = {abs (v_xy_local.x ), abs (v_xy_local.y )};
287297 v_opacity_local = vis * v_alpha;
288298 }
289299 warpSum3 (v_rgb_local, warp);
290300 warpSum3 (v_conic_local, warp);
291301 warpSum2 (v_xy_local, warp);
302+ warpSum2 (v_xy_abs_local, warp);
292303 warpSum (v_opacity_local, warp);
293304 if (warp.thread_rank () == 0 ) {
294305 int32_t g = id_batch[t];
@@ -305,6 +316,10 @@ __global__ void rasterize_backward_kernel(
305316 float * v_xy_ptr = (float *)(v_xy);
306317 atomicAdd (v_xy_ptr + 2 *g + 0 , v_xy_local.x );
307318 atomicAdd (v_xy_ptr + 2 *g + 1 , v_xy_local.y );
319+
320+ float * v_xy_abs_ptr = (float *)(v_xy_abs);
321+ atomicAdd (v_xy_abs_ptr + 2 *g + 0 , v_xy_abs_local.x );
322+ atomicAdd (v_xy_abs_ptr + 2 *g + 1 , v_xy_abs_local.y );
308323
309324 atomicAdd (v_opacity + g, v_opacity_local);
310325 }
0 commit comments