Skip to content

Commit fced85b

Browse files
fix: apply same dimension handling fixes to TrainableJointBilateralFilter
Signed-off-by: Abdoulaye Diallo <abdoulayediallo338@gmail.com>
1 parent 4924aa6 commit fced85b

File tree

1 file changed

+14
-13
lines changed

1 file changed

+14
-13
lines changed

monai/networks/layers/filtering.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -220,9 +220,7 @@ def __init__(self, spatial_sigma, color_sigma):
220220
spatial_sigma = [spatial_sigma[0], spatial_sigma[1], spatial_sigma[2]]
221221
self.len_spatial_sigma = 3
222222
else:
223-
raise ValueError(
224-
f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.len_spatial_sigma}."
225-
)
223+
raise ValueError(f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims (1, 2 or 3).")
226224

227225
# Register sigmas as trainable parameters.
228226
self.sigma_x = torch.nn.Parameter(torch.tensor(spatial_sigma[0]))
@@ -393,9 +391,7 @@ def __init__(self, spatial_sigma, color_sigma):
393391
spatial_sigma = [spatial_sigma[0], spatial_sigma[1], spatial_sigma[2]]
394392
self.len_spatial_sigma = 3
395393
else:
396-
raise ValueError(
397-
f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.len_spatial_sigma}."
398-
)
394+
raise ValueError(f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims (1, 2, or 3).")
399395

400396
# Register sigmas as trainable parameters.
401397
self.sigma_x = torch.nn.Parameter(torch.tensor(spatial_sigma[0]))
@@ -404,9 +400,13 @@ def __init__(self, spatial_sigma, color_sigma):
404400
self.sigma_color = torch.nn.Parameter(torch.tensor(color_sigma))
405401

406402
def forward(self, input_tensor, guidance_tensor):
403+
if len(input_tensor.shape) < 3:
404+
raise ValueError(
405+
f"Input must have at least 3 dimensions (batch, channel, *spatial_dims), got {len(input_tensor.shape)}"
406+
)
407407
if input_tensor.shape[1] != 1:
408408
raise ValueError(
409-
f"Currently channel dimensions >1 ({input_tensor.shape[1]}) are not supported. "
409+
f"Currently channel dimensions > 1 ({input_tensor.shape[1]}) are not supported. "
410410
"Please use multiple parallel filter layers if you want "
411411
"to filter multiple channels."
412412
)
@@ -417,26 +417,27 @@ def forward(self, input_tensor, guidance_tensor):
417417
)
418418

419419
len_input = len(input_tensor.shape)
420+
spatial_dims = len_input - 2
420421

421422
# C++ extension so far only supports 5-dim inputs.
422-
if len_input == 3:
423+
if spatial_dims == 1:
423424
input_tensor = input_tensor.unsqueeze(3).unsqueeze(4)
424425
guidance_tensor = guidance_tensor.unsqueeze(3).unsqueeze(4)
425-
elif len_input == 4:
426+
elif spatial_dims == 2:
426427
input_tensor = input_tensor.unsqueeze(4)
427428
guidance_tensor = guidance_tensor.unsqueeze(4)
428429

429-
if self.len_spatial_sigma != len_input:
430-
raise ValueError(f"Spatial dimension ({len_input}) must match initialized len(spatial_sigma).")
430+
if self.len_spatial_sigma != spatial_dims:
431+
raise ValueError(f"Spatial dimension ({spatial_dims}) must match initialized len(spatial_sigma).")
431432

432433
prediction = TrainableJointBilateralFilterFunction.apply(
433434
input_tensor, guidance_tensor, self.sigma_x, self.sigma_y, self.sigma_z, self.sigma_color
434435
)
435436

436437
# Make sure to return tensor of the same shape as the input.
437-
if len_input == 3:
438+
if spatial_dims == 1:
438439
prediction = prediction.squeeze(4).squeeze(3)
439-
elif len_input == 4:
440+
elif spatial_dims == 2:
440441
prediction = prediction.squeeze(4)
441442

442443
return prediction

0 commit comments

Comments
 (0)