@@ -220,9 +220,7 @@ def __init__(self, spatial_sigma, color_sigma):
220220 spatial_sigma = [spatial_sigma [0 ], spatial_sigma [1 ], spatial_sigma [2 ]]
221221 self .len_spatial_sigma = 3
222222 else :
223- raise ValueError (
224- f"len(spatial_sigma) { spatial_sigma } must match number of spatial dims { self .len_spatial_sigma } ."
225- )
223+ raise ValueError (f"len(spatial_sigma) { spatial_sigma } must match number of spatial dims (1, 2 or 3)." )
226224
227225 # Register sigmas as trainable parameters.
228226 self .sigma_x = torch .nn .Parameter (torch .tensor (spatial_sigma [0 ]))
@@ -393,9 +391,7 @@ def __init__(self, spatial_sigma, color_sigma):
393391 spatial_sigma = [spatial_sigma [0 ], spatial_sigma [1 ], spatial_sigma [2 ]]
394392 self .len_spatial_sigma = 3
395393 else :
396- raise ValueError (
397- f"len(spatial_sigma) { spatial_sigma } must match number of spatial dims { self .len_spatial_sigma } ."
398- )
394+ raise ValueError (f"len(spatial_sigma) { spatial_sigma } must match number of spatial dims (1, 2, or 3)." )
399395
400396 # Register sigmas as trainable parameters.
401397 self .sigma_x = torch .nn .Parameter (torch .tensor (spatial_sigma [0 ]))
@@ -404,9 +400,13 @@ def __init__(self, spatial_sigma, color_sigma):
404400 self .sigma_color = torch .nn .Parameter (torch .tensor (color_sigma ))
405401
406402 def forward (self , input_tensor , guidance_tensor ):
403+ if len (input_tensor .shape ) < 3 :
404+ raise ValueError (
405+ f"Input must have at least 3 dimensions (batch, channel, *spatial_dims), got { len (input_tensor .shape )} "
406+ )
407407 if input_tensor .shape [1 ] != 1 :
408408 raise ValueError (
409- f"Currently channel dimensions >1 ({ input_tensor .shape [1 ]} ) are not supported. "
409+ f"Currently channel dimensions > 1 ({ input_tensor .shape [1 ]} ) are not supported. "
410410 "Please use multiple parallel filter layers if you want "
411411 "to filter multiple channels."
412412 )
@@ -417,26 +417,27 @@ def forward(self, input_tensor, guidance_tensor):
417417 )
418418
419419 len_input = len (input_tensor .shape )
420+ spatial_dims = len_input - 2
420421
421422 # C++ extension so far only supports 5-dim inputs.
422- if len_input == 3 :
423+ if spatial_dims == 1 :
423424 input_tensor = input_tensor .unsqueeze (3 ).unsqueeze (4 )
424425 guidance_tensor = guidance_tensor .unsqueeze (3 ).unsqueeze (4 )
425- elif len_input == 4 :
426+ elif spatial_dims == 2 :
426427 input_tensor = input_tensor .unsqueeze (4 )
427428 guidance_tensor = guidance_tensor .unsqueeze (4 )
428429
429- if self .len_spatial_sigma != len_input :
430- raise ValueError (f"Spatial dimension ({ len_input } ) must match initialized len(spatial_sigma)." )
430+ if self .len_spatial_sigma != spatial_dims :
431+ raise ValueError (f"Spatial dimension ({ spatial_dims } ) must match initialized len(spatial_sigma)." )
431432
432433 prediction = TrainableJointBilateralFilterFunction .apply (
433434 input_tensor , guidance_tensor , self .sigma_x , self .sigma_y , self .sigma_z , self .sigma_color
434435 )
435436
436437 # Make sure to return tensor of the same shape as the input.
437- if len_input == 3 :
438+ if spatial_dims == 1 :
438439 prediction = prediction .squeeze (4 ).squeeze (3 )
439- elif len_input == 4 :
440+ elif spatial_dims == 2 :
440441 prediction = prediction .squeeze (4 )
441442
442443 return prediction
0 commit comments