Skip to content

Commit 409bcd3

Browse files
authored
Add backward gradient for background (#170)
* bg grad * delete lines for debug * . * only compute bg grad when needed * reformat
1 parent 2ddee2a commit 409bcd3

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

gsplat/rasterize.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def forward(
103103
img_height: int,
104104
img_width: int,
105105
block_width: int,
106-
background: Optional[Float[Tensor, "channels"]] = None,
106+
background: Float[Tensor, "channels"],
107107
return_alpha: Optional[bool] = False,
108108
) -> Tensor:
109109
num_points = xys.size(0)
@@ -232,6 +232,11 @@ def backward(ctx, v_out_img, v_out_alpha=None):
232232
v_out_img,
233233
v_out_alpha,
234234
)
235+
v_background = None
236+
if background.requires_grad:
237+
v_background = torch.matmul(
238+
v_out_img.float().view(-1, 3).t(), final_Ts.float().view(-1, 1)
239+
).squeeze()
235240

236241
# Abs grad for gaussian splitting criterion. See
237242
# - "AbsGS: Recovering Fine Details for 3D Gaussian Splatting"
@@ -249,6 +254,6 @@ def backward(ctx, v_out_img, v_out_alpha=None):
249254
None, # img_height
250255
None, # img_width
251256
None, # block_width
252-
None, # background
257+
v_background, # background
253258
None, # return_alpha
254259
)

0 commit comments

Comments
 (0)