Skip to content

Commit 1b80cfa

Browse files
Update monai/networks/layers/filtering.py
Signed-off-by: Abdoulaye Diallo <abdoulayediallo338@gmail.com>
1 parent aea9bcf commit 1b80cfa

1 file changed

Lines changed: 15 additions & 6 deletions

File tree

monai/networks/layers/filtering.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,10 @@ def __init__(self, spatial_sigma, color_sigma):
220220
spatial_sigma = [spatial_sigma[0], spatial_sigma[1], spatial_sigma[2]]
221221
self.len_spatial_sigma = 3
222222
else:
223-
raise ValueError(f"Length of `spatial_sigma` must match number of spatial dims (1, 2 or 3) or be a single float value ({spatial_sigma=}).")
223+
raise ValueError(
224+
f"Length of `spatial_sigma` must match number of spatial dims (1, 2 or 3)"
225+
f"or be a single float value ({spatial_sigma=})."
226+
)
224227

225228
# Register sigmas as trainable parameters.
226229
self.sigma_x = torch.nn.Parameter(torch.tensor(spatial_sigma[0]))
@@ -250,7 +253,9 @@ def forward(self, input_tensor):
250253
input_tensor = input_tensor.unsqueeze(4)
251254

252255
if self.len_spatial_sigma != spatial_dims:
253-
raise ValueError(f"Number of spatial dimensions ({spatial_dims}) must match initialized `len(spatial_sigma)`.")
256+
raise ValueError(
257+
f"Number of spatial dimensions ({spatial_dims}) must match initialized `len(spatial_sigma)`."
258+
)
254259

255260
prediction = TrainableBilateralFilterFunction.apply(
256261
input_tensor, self.sigma_x, self.sigma_y, self.sigma_z, self.sigma_color
@@ -391,7 +396,10 @@ def __init__(self, spatial_sigma, color_sigma):
391396
spatial_sigma = [spatial_sigma[0], spatial_sigma[1], spatial_sigma[2]]
392397
self.len_spatial_sigma = 3
393398
else:
394-
raise ValueError(f"Length of `spatial_sigma` must match number of spatial dims (1, 2 or 3) or be a single float value ({spatial_sigma=}).")
399+
raise ValueError(
400+
f"Length of `spatial_sigma` must match number of spatial dims (1, 2 or 3)\n"
401+
f"or be a single float value ({spatial_sigma=})."
402+
)
395403

396404
# Register sigmas as trainable parameters.
397405
self.sigma_x = torch.nn.Parameter(torch.tensor(spatial_sigma[0]))
@@ -412,8 +420,7 @@ def forward(self, input_tensor, guidance_tensor):
412420
)
413421
if input_tensor.shape != guidance_tensor.shape:
414422
raise ValueError(
415-
"Shape of input image must equal shape of guidance image."
416-
f"Got {input_tensor.shape} and {guidance_tensor.shape}."
423+
f"Shape of input image must equal shape of guidance image.Got {input_tensor.shape} and {guidance_tensor.shape}."
417424
)
418425

419426
len_input = len(input_tensor.shape)
@@ -428,7 +435,9 @@ def forward(self, input_tensor, guidance_tensor):
428435
guidance_tensor = guidance_tensor.unsqueeze(4)
429436

430437
if self.len_spatial_sigma != spatial_dims:
431-
raise ValueError(f"Number of spatial dimensions ({spatial_dims}) must match initialized `len(spatial_sigma)`.")
438+
raise ValueError(
439+
f"Number of spatial dimensions ({spatial_dims}) must match initialized `len(spatial_sigma)`."
440+
)
432441

433442
prediction = TrainableJointBilateralFilterFunction.apply(
434443
input_tensor, guidance_tensor, self.sigma_x, self.sigma_y, self.sigma_z, self.sigma_color

0 commit comments

Comments
 (0)