Skip to content

Commit 2ddee2a

Browse files
liruilong940607Ruilong LiJianbo Yekerrj
authored
Implement AbsGS (#166)
* absgrad * add version bump --------- Co-authored-by: Ruilong Li <397653553@qq.com> Co-authored-by: Jianbo Ye <jianboye@amazon.com> Co-authored-by: Justin Kerr <justin.g.kerr@gmail.com>
1 parent 8a19034 commit 2ddee2a

File tree

6 files changed

+36
-4
lines changed

6 files changed

+36
-4
lines changed

gsplat/cuda/csrc/backward.cu

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ __global__ void nd_rasterize_backward_kernel(
3535
const float* __restrict__ v_output,
3636
const float* __restrict__ v_output_alpha,
3737
float2* __restrict__ v_xy,
38+
float2* __restrict__ v_xy_abs,
3839
float3* __restrict__ v_conic,
3940
float* __restrict__ v_rgb,
4041
float* __restrict__ v_opacity
@@ -90,6 +91,7 @@ __global__ void nd_rasterize_backward_kernel(
9091
float v_alpha = 0.f;
9192
float3 v_conic_local = {0.f, 0.f, 0.f};
9293
float2 v_xy_local = {0.f, 0.f};
94+
float2 v_xy_abs_local = {0.f, 0.f};
9395
float v_opacity_local = 0.f;
9496
if(valid){
9597
// compute the current T for this gaussian
@@ -114,19 +116,24 @@ __global__ void nd_rasterize_backward_kernel(
114116
0.5f * v_sigma * delta.y * delta.y};
115117
v_xy_local = {v_sigma * (conic.x * delta.x + conic.y * delta.y),
116118
v_sigma * (conic.y * delta.x + conic.z * delta.y)};
119+
v_xy_abs_local = {abs(v_xy_local.x), abs(v_xy_local.y)};
117120
v_opacity_local = vis * v_alpha;
118121
}
119122
warpSum3(v_conic_local, warp);
120123
warpSum2(v_xy_local, warp);
124+
warpSum2(v_xy_abs_local, warp);
121125
warpSum(v_opacity_local, warp);
122126
if (warp.thread_rank() == 0) {
123127
float* v_conic_ptr = (float*)(v_conic);
124128
float* v_xy_ptr = (float*)(v_xy);
129+
float* v_xy_abs_ptr = (float*)(v_xy_abs);
125130
atomicAdd(v_conic_ptr + 3*g + 0, v_conic_local.x);
126131
atomicAdd(v_conic_ptr + 3*g + 1, v_conic_local.y);
127132
atomicAdd(v_conic_ptr + 3*g + 2, v_conic_local.z);
128133
atomicAdd(v_xy_ptr + 2*g + 0, v_xy_local.x);
129134
atomicAdd(v_xy_ptr + 2*g + 1, v_xy_local.y);
135+
atomicAdd(v_xy_abs_ptr + 2*g + 0, v_xy_abs_local.x);
136+
atomicAdd(v_xy_abs_ptr + 2*g + 1, v_xy_abs_local.y);
130137
atomicAdd(v_opacity + g, v_opacity_local);
131138
}
132139
}
@@ -147,6 +154,7 @@ __global__ void rasterize_backward_kernel(
147154
const float3* __restrict__ v_output,
148155
const float* __restrict__ v_output_alpha,
149156
float2* __restrict__ v_xy,
157+
float2* __restrict__ v_xy_abs,
150158
float3* __restrict__ v_conic,
151159
float3* __restrict__ v_rgb,
152160
float* __restrict__ v_opacity
@@ -251,6 +259,7 @@ __global__ void rasterize_backward_kernel(
251259
float3 v_rgb_local = {0.f, 0.f, 0.f};
252260
float3 v_conic_local = {0.f, 0.f, 0.f};
253261
float2 v_xy_local = {0.f, 0.f};
262+
float2 v_xy_abs_local = {0.f, 0.f};
254263
float v_opacity_local = 0.f;
255264
//initialize everything to 0, only set if the lane is valid
256265
if(valid){
@@ -284,11 +293,13 @@ __global__ void rasterize_backward_kernel(
284293
0.5f * v_sigma * delta.y * delta.y};
285294
v_xy_local = {v_sigma * (conic.x * delta.x + conic.y * delta.y),
286295
v_sigma * (conic.y * delta.x + conic.z * delta.y)};
296+
v_xy_abs_local = {abs(v_xy_local.x), abs(v_xy_local.y)};
287297
v_opacity_local = vis * v_alpha;
288298
}
289299
warpSum3(v_rgb_local, warp);
290300
warpSum3(v_conic_local, warp);
291301
warpSum2(v_xy_local, warp);
302+
warpSum2(v_xy_abs_local, warp);
292303
warpSum(v_opacity_local, warp);
293304
if (warp.thread_rank() == 0) {
294305
int32_t g = id_batch[t];
@@ -305,6 +316,10 @@ __global__ void rasterize_backward_kernel(
305316
float* v_xy_ptr = (float*)(v_xy);
306317
atomicAdd(v_xy_ptr + 2*g + 0, v_xy_local.x);
307318
atomicAdd(v_xy_ptr + 2*g + 1, v_xy_local.y);
319+
320+
float* v_xy_abs_ptr = (float*)(v_xy_abs);
321+
atomicAdd(v_xy_abs_ptr + 2*g + 0, v_xy_abs_local.x);
322+
atomicAdd(v_xy_abs_ptr + 2*g + 1, v_xy_abs_local.y);
308323

309324
atomicAdd(v_opacity + g, v_opacity_local);
310325
}

gsplat/cuda/csrc/backward.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ __global__ void nd_rasterize_backward_kernel(
4646
const float* __restrict__ v_output,
4747
const float* __restrict__ v_output_alpha,
4848
float2* __restrict__ v_xy,
49+
float2* __restrict__ v_xy_abs,
4950
float3* __restrict__ v_conic,
5051
float* __restrict__ v_rgb,
5152
float* __restrict__ v_opacity
@@ -66,6 +67,7 @@ __global__ void rasterize_backward_kernel(
6667
const float3* __restrict__ v_output,
6768
const float* __restrict__ v_output_alpha,
6869
float2* __restrict__ v_xy,
70+
float2* __restrict__ v_xy_abs,
6971
float3* __restrict__ v_conic,
7072
float3* __restrict__ v_rgb,
7173
float* __restrict__ v_opacity

gsplat/cuda/csrc/bindings.cu

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,7 @@ nd_rasterize_forward_tensor(
525525
std::
526526
tuple<
527527
torch::Tensor, // dL_dxy
528+
torch::Tensor, // dL_dxy_abs
528529
torch::Tensor, // dL_dconic
529530
torch::Tensor, // dL_dcolors
530531
torch::Tensor // dL_dopacity
@@ -568,6 +569,7 @@ std::
568569
const int channels = colors.size(1);
569570

570571
torch::Tensor v_xy = torch::zeros({num_points, 2}, xys.options());
572+
torch::Tensor v_xy_abs = torch::zeros({num_points, 2}, xys.options());
571573
torch::Tensor v_conic = torch::zeros({num_points, 3}, xys.options());
572574
torch::Tensor v_colors =
573575
torch::zeros({num_points, channels}, xys.options());
@@ -595,17 +597,19 @@ std::
595597
v_output.contiguous().data_ptr<float>(),
596598
v_output_alpha.contiguous().data_ptr<float>(),
597599
(float2 *)v_xy.contiguous().data_ptr<float>(),
600+
(float2 *)v_xy_abs.contiguous().data_ptr<float>(),
598601
(float3 *)v_conic.contiguous().data_ptr<float>(),
599602
v_colors.contiguous().data_ptr<float>(),
600603
v_opacity.contiguous().data_ptr<float>()
601604
);
602605

603-
return std::make_tuple(v_xy, v_conic, v_colors, v_opacity);
606+
return std::make_tuple(v_xy, v_xy_abs, v_conic, v_colors, v_opacity);
604607
}
605608

606609
std::
607610
tuple<
608611
torch::Tensor, // dL_dxy
612+
torch::Tensor, // dL_dxy_abs
609613
torch::Tensor, // dL_dconic
610614
torch::Tensor, // dL_dcolors
611615
torch::Tensor // dL_dopacity
@@ -649,6 +653,7 @@ std::
649653
const int channels = colors.size(1);
650654

651655
torch::Tensor v_xy = torch::zeros({num_points, 2}, xys.options());
656+
torch::Tensor v_xy_abs = torch::zeros({num_points, 2}, xys.options());
652657
torch::Tensor v_conic = torch::zeros({num_points, 3}, xys.options());
653658
torch::Tensor v_colors =
654659
torch::zeros({num_points, channels}, xys.options());
@@ -669,10 +674,11 @@ std::
669674
(float3 *)v_output.contiguous().data_ptr<float>(),
670675
v_output_alpha.contiguous().data_ptr<float>(),
671676
(float2 *)v_xy.contiguous().data_ptr<float>(),
677+
(float2 *)v_xy_abs.contiguous().data_ptr<float>(),
672678
(float3 *)v_conic.contiguous().data_ptr<float>(),
673679
(float3 *)v_colors.contiguous().data_ptr<float>(),
674680
v_opacity.contiguous().data_ptr<float>()
675681
);
676682

677-
return std::make_tuple(v_xy, v_conic, v_colors, v_opacity);
683+
return std::make_tuple(v_xy, v_xy_abs, v_conic, v_colors, v_opacity);
678684
}

gsplat/cuda/csrc/bindings.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ std::tuple<
149149
std::
150150
tuple<
151151
torch::Tensor, // dL_dxy
152+
torch::Tensor, // dL_dxy_abs
152153
torch::Tensor, // dL_dconic
153154
torch::Tensor, // dL_dcolors
154155
torch::Tensor // dL_dopacity
@@ -173,6 +174,7 @@ std::
173174
std::
174175
tuple<
175176
torch::Tensor, // dL_dxy
177+
torch::Tensor, // dL_dxy_abs
176178
torch::Tensor, // dL_dconic
177179
torch::Tensor, // dL_dcolors
178180
torch::Tensor // dL_dopacity

gsplat/rasterize.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from torch.autograd import Function
99

1010
import gsplat.cuda as _C
11+
1112
from .utils import bin_and_sort_gaussians, compute_cumulative_intersects
1213

1314

@@ -205,6 +206,7 @@ def backward(ctx, v_out_img, v_out_alpha=None):
205206

206207
if num_intersects < 1:
207208
v_xy = torch.zeros_like(xys)
209+
v_xy_abs = torch.zeros_like(xys)
208210
v_conic = torch.zeros_like(conics)
209211
v_colors = torch.zeros_like(colors)
210212
v_opacity = torch.zeros_like(opacity)
@@ -214,7 +216,7 @@ def backward(ctx, v_out_img, v_out_alpha=None):
214216
rasterize_fn = _C.rasterize_backward
215217
else:
216218
rasterize_fn = _C.nd_rasterize_backward
217-
v_xy, v_conic, v_colors, v_opacity = rasterize_fn(
219+
v_xy, v_xy_abs, v_conic, v_colors, v_opacity = rasterize_fn(
218220
img_height,
219221
img_width,
220222
ctx.block_width,
@@ -231,6 +233,11 @@ def backward(ctx, v_out_img, v_out_alpha=None):
231233
v_out_alpha,
232234
)
233235

236+
# Abs grad for gaussian splitting criterion. See
237+
# - "AbsGS: Recovering Fine Details for 3D Gaussian Splatting"
238+
# - "EfficientGS: Streamlining Gaussian Splatting for Large-Scale High-Resolution Scene Representation"
239+
xys.absgrad = v_xy_abs
240+
234241
return (
235242
v_xy, # xys
236243
None, # depths

gsplat/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.1.10"
1+
__version__ = "0.1.11"

0 commit comments

Comments
 (0)