Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix undefined type error
Signed-off-by: ytl0623 <david89062388@gmail.com>
  • Loading branch information
ytl0623 committed Dec 19, 2025
commit 1b2483441dccbc0ee117dfdec15f4663bd4a8e73
16 changes: 12 additions & 4 deletions monai/losses/focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,17 @@ def __init__(
self.include_background = include_background
self.to_onehot_y = to_onehot_y
self.gamma = gamma
self.alpha = alpha
self.weight = weight
self.use_softmax = use_softmax
weight = torch.as_tensor(weight) if weight is not None else None
self.register_buffer("class_weight", weight)
self.class_weight: None | torch.Tensor
self.alpha: float | torch.Tensor | None

if isinstance(alpha, (list, tuple)):
self.alpha = torch.tensor(alpha)
else:
self.alpha = alpha

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -159,7 +164,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
input = input.float()
target = target.float()

alpha_arg = self.alpha
alpha_arg: float | torch.Tensor | None = self.alpha
if isinstance(alpha_arg, torch.Tensor):
alpha_arg = alpha_arg.to(input.device)

if self.use_softmax:
if not self.include_background and self.alpha is not None:
if isinstance(self.alpha, (float, int)):
Expand Down Expand Up @@ -208,7 +216,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:


def softmax_focal_loss(
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | Sequence[float] | None = None
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | torch.Tensor | None = None
) -> torch.Tensor:
"""
FL(pt) = -alpha * (1 - pt)**gamma * log(pt)
Expand Down Expand Up @@ -241,7 +249,7 @@ def softmax_focal_loss(


def sigmoid_focal_loss(
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | Sequence[float] | None = None
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | torch.Tensor | None = None
) -> torch.Tensor:
"""
FL(pt) = -alpha * (1 - pt)**gamma * log(pt)
Expand Down
Loading