Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion monai/losses/cldice.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ def __init__(
Args:
iter_: Number of iterations for skeletonization. Must be a non-negative integer. Defaults to 3.
smooth_nr: a small constant added to the numerator to avoid zero. Defaults to 1.0.
smooth_dr: a small constant added to the denominator to avoid nan. Defaults to 1.0.
smooth_dr: a small constant added to the denominator of the individual precision /
sensitivity ratios and the internal Dice denominator to avoid nan. Defaults to 1.0.
smooth: a small constant added to the denominator of the harmonic mean to avoid nan. Defaults to 1e-4.
include_background: if False, channel index 0 (background category) is excluded from the calculation.
if the non-background segmentations are small compared to the total image size they can get overwhelmed
Expand Down
46 changes: 46 additions & 0 deletions tests/losses/test_cldice_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,29 @@ def test_invalid_iter_value(self):
with self.assertRaises(ValueError):
SoftclDiceLoss(iter_=-1)

def test_zero_input_is_finite(self):
loss = SoftclDiceLoss(smooth=1e-7, smooth_dr=1e-5)
result = loss(torch.zeros((1, 2, 4, 4)), torch.zeros((1, 2, 4, 4)))
self.assertTrue(torch.isfinite(result).all())

def test_non_default_smooth_dr_changes_result(self):
input_tensor = torch.zeros((1, 2, 4, 4))
target = torch.zeros((1, 2, 4, 4))
loss_a = SoftclDiceLoss(smooth=1e-7, smooth_dr=1e-3)
loss_b = SoftclDiceLoss(smooth=1e-7, smooth_dr=1e-5)
result_a = loss_a(input_tensor, target)
result_b = loss_b(input_tensor, target)
self.assertTrue(torch.isfinite(result_a).all())
self.assertTrue(torch.isfinite(result_b).all())
self.assertNotAlmostEqual(result_a.item(), result_b.item(), places=5)

Comment thread
coderabbitai[bot] marked this conversation as resolved.
def test_non_overlapping_input_is_finite(self):
loss = SoftclDiceLoss(smooth=1e-7, smooth_dr=1e-5)
input_tensor = torch.tensor([[[[1.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]]])
target = torch.tensor([[[[0.0, 0.0], [0.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]]]])
result = loss(input_tensor, target)
self.assertTrue(torch.isfinite(result).all())


class TestSoftDiceclDiceLoss(unittest.TestCase):
@parameterized.expand(COMBINED_CASES)
Expand Down Expand Up @@ -146,6 +169,29 @@ def test_invalid_alpha_negative(self):
with self.assertRaises(ValueError):
SoftDiceclDiceLoss(alpha=-0.5)

def test_zero_input_is_finite(self):
loss = SoftDiceclDiceLoss(smooth=1e-7, smooth_dr=1e-5)
result = loss(torch.zeros((1, 2, 4, 4)), torch.zeros((1, 2, 4, 4)))
self.assertTrue(torch.isfinite(result).all())

def test_non_default_smooth_dr_changes_result(self):
input_tensor = torch.zeros((1, 2, 4, 4))
target = torch.zeros((1, 2, 4, 4))
loss_a = SoftDiceclDiceLoss(smooth=1e-7, smooth_dr=1e-3)
loss_b = SoftDiceclDiceLoss(smooth=1e-7, smooth_dr=1e-5)
result_a = loss_a(input_tensor, target)
result_b = loss_b(input_tensor, target)
self.assertTrue(torch.isfinite(result_a).all())
self.assertTrue(torch.isfinite(result_b).all())
self.assertNotAlmostEqual(result_a.item(), result_b.item(), places=5)

def test_non_overlapping_input_is_finite(self):
loss = SoftDiceclDiceLoss(smooth=1e-7, smooth_dr=1e-5)
input_tensor = torch.tensor([[[[1.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]]])
target = torch.tensor([[[[0.0, 0.0], [0.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]]]])
result = loss(input_tensor, target)
self.assertTrue(torch.isfinite(result).all())


if __name__ == "__main__":
unittest.main()
Loading