Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 112 additions & 14 deletions monai/metrics/hausdorff_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@
import numpy as np
import torch

from monai.metrics.utils import do_metric_reduction, get_edge_surface_distance, ignore_background, prepare_spacing
from monai.metrics.utils import (
compute_voronoi_regions_fast,
do_metric_reduction,
get_edge_surface_distance,
ignore_background,
prepare_spacing,
)
from monai.utils import MetricReduction, convert_data_type

from .metric import CumulativeIterationMetric
Expand All @@ -37,6 +43,18 @@ class HausdorffDistanceMetric(CumulativeIterationMetric):

Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.

The ``per_component=True`` approach computes the Hausdorff distance on a per-connected component basis in the ground
truth segmentation. This ensures that each component contributes equally to the final metric, regardless of its size.
Traditional Hausdorff distance can be dominated by large structures, but the per-component method gives a more
balanced evaluation, particularly for small or fragmented objects. This provides a granular assessment of segmentation
quality, which is especially important in cases with multiple disconnected foreground components.
Note:
- The input prediction (`y_pred`) and ground truth (`y`) must both have 2 channels (foreground/background),
with binary segmentation (0 for background, 1 for foreground). That is, this assumes the shape of both prediction
and ground truth is B2HW[D].
- This method cannot be used with multiclass segmentation.
For more information, refer to the original paper: https://arxiv.org/abs/2410.18684
Comment thread
coderabbitai[bot] marked this conversation as resolved.

Args:
include_background: whether to include distance computation on the first channel of
the predicted output. Defaults to ``False``.
Expand All @@ -51,7 +69,7 @@ class HausdorffDistanceMetric(CumulativeIterationMetric):
``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).
Here `not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric.

per_component: whether to compute the Hausdorff distance on a per-connected component basis. Defaults to ``False``.
"""

def __init__(
Expand All @@ -62,6 +80,7 @@ def __init__(
directed: bool = False,
reduction: MetricReduction | str = MetricReduction.MEAN,
get_not_nans: bool = False,
per_component: bool = False,
) -> None:
super().__init__()
self.include_background = include_background
Expand All @@ -70,6 +89,7 @@ def __init__(
self.directed = directed
self.reduction = reduction
self.get_not_nans = get_not_nans
self.per_component = per_component

def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) -> torch.Tensor: # type: ignore[override]
"""
Expand All @@ -96,7 +116,17 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any)
dims = y_pred.ndimension()
if dims < 3:
raise ValueError("y_pred should have at least three dimensions.")

if self.per_component:
if y_pred.ndim not in (4, 5) or y.ndim not in (4, 5) or y_pred.shape[1] != 2 or y.shape[1] != 2:
same_rank = y_pred.ndim == y.ndim and y_pred.ndim in (4, 5)
binary_channels = y_pred.shape[1] == 2 and y.shape[1] == 2
same_shape = y_pred.shape == y.shape
if not (same_rank and binary_channels and same_shape):
raise ValueError(
"per_component requires matching 4D/5D binary tensors "
"(B, 2, H, W) or (B, 2, D, H, W). "
f"Got y_pred={tuple(y_pred.shape)}, y={tuple(y.shape)}."
)
# compute (BxC) for each channel for each batch
return compute_hausdorff_distance(
y_pred=y_pred,
Expand All @@ -106,6 +136,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any)
percentile=self.percentile,
directed=self.directed,
spacing=kwargs.get("spacing"),
per_component=self.per_component,
)

def aggregate(
Expand Down Expand Up @@ -137,6 +168,7 @@ def compute_hausdorff_distance(
percentile: float | None = None,
directed: bool = False,
spacing: int | float | np.ndarray | Sequence[int | float | np.ndarray | Sequence[int | float]] | None = None,
per_component: bool = False,
) -> torch.Tensor:
"""
Compute the Hausdorff distance.
Expand All @@ -162,6 +194,7 @@ def compute_hausdorff_distance(
If inner sequence has length 1, isotropic spacing with that value is used for all images in the batch,
else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used
for all images in batch. Defaults to ``None``.
per_component: whether to compute the Hausdorff distance on a per-connected component basis. Defaults to ``False``.
"""

if not include_background:
Expand All @@ -179,17 +212,82 @@ def compute_hausdorff_distance(
spacing_list = prepare_spacing(spacing=spacing, batch_size=batch_size, img_dim=img_dim)

for b, c in np.ndindex(batch_size, n_class):
_, distances, _ = get_edge_surface_distance(
y_pred[b, c],
y[b, c],
distance_metric=distance_metric,
spacing=spacing_list[b],
symmetric=not directed,
class_index=c,
)
percentile_distances = [_compute_percentile_hausdorff_distance(d, percentile) for d in distances]
max_distance = torch.max(torch.stack(percentile_distances))
hd[b, c] = max_distance
if per_component:
pred_empty = y_pred[b, c].sum() == 0
label_empty = y[b, c].sum() == 0
if pred_empty and label_empty:
hd[b, c] = 0.0 if (pred_empty and label_empty) else float("nan")
continue
Comment thread
VijayVignesh1 marked this conversation as resolved.
cc_assignment = compute_voronoi_regions_fast(y[b, c].cpu().numpy())
if cc_assignment.device != y_pred[b, c].device:
cc_assignment = cc_assignment.to(y_pred[b, c].device)
component_scores = []
for cc_id in torch.unique(cc_assignment.view(-1)):
cc_mask = cc_assignment == cc_id

coords = torch.nonzero(cc_mask, as_tuple=False)
min_corner_idx = coords.min(dim=0).values
max_corner_idx = coords.max(dim=0).values

crop_pred = (
y_pred[b, c][
min_corner_idx[0] : max_corner_idx[0] + 1,
min_corner_idx[1] : max_corner_idx[1] + 1,
min_corner_idx[2] : max_corner_idx[2] + 1,
]
if y_pred.ndim == 5
else y_pred[b, c][
min_corner_idx[0] : max_corner_idx[0] + 1, min_corner_idx[1] : max_corner_idx[1] + 1
]
)

crop_label = (
y[b, c][
min_corner_idx[0] : max_corner_idx[0] + 1,
min_corner_idx[1] : max_corner_idx[1] + 1,
min_corner_idx[2] : max_corner_idx[2] + 1,
]
if y.ndim == 5
else y[b, c][min_corner_idx[0] : max_corner_idx[0] + 1, min_corner_idx[1] : max_corner_idx[1] + 1]
)

cc_crop_mask = (
cc_mask[
min_corner_idx[0] : max_corner_idx[0] + 1,
min_corner_idx[1] : max_corner_idx[1] + 1,
min_corner_idx[2] : max_corner_idx[2] + 1,
]
if y_pred.ndim == 5
else cc_mask[min_corner_idx[0] : max_corner_idx[0] + 1, min_corner_idx[1] : max_corner_idx[1] + 1]
)

pred_masked = crop_pred * cc_crop_mask
label_masked = crop_label * cc_crop_mask

_, distances, _ = get_edge_surface_distance(
pred_masked,
label_masked,
distance_metric=distance_metric,
spacing=spacing_list[b],
symmetric=not directed,
class_index=c,
)
percentile_distances = [_compute_percentile_hausdorff_distance(d, percentile) for d in distances]
component_scores.append(torch.max(torch.stack(percentile_distances)))

hd[b, c] = torch.nanmean(torch.stack(component_scores)) if component_scores else 0.0
else:
Comment thread
coderabbitai[bot] marked this conversation as resolved.
_, distances, _ = get_edge_surface_distance(
y_pred[b, c],
y[b, c],
distance_metric=distance_metric,
spacing=spacing_list[b],
symmetric=not directed,
class_index=c,
)
percentile_distances = [_compute_percentile_hausdorff_distance(d, percentile) for d in distances]
max_distance = torch.max(torch.stack(percentile_distances))
hd[b, c] = max_distance
return hd


Expand Down
53 changes: 53 additions & 0 deletions tests/metrics/test_hausdorff_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,45 @@ def create_spherical_seg_3d(
for i, (metric, directed) in enumerate(product(["euclidean", "chessboard", "taxicab"], [True, False])):
TEST_CASES_EXPANDED.append((_device, metric, directed, test_input, test_output[i]))

TEST_CASES_CC_METRICS = []
y = torch.zeros((2, 2, 32, 32, 32), device=_devices[-1])
y_hat = torch.zeros((2, 2, 32, 32, 32), device=_devices[-1])
TEST_CASES_CC_METRICS.append([[y, y_hat], [[0.0], [0.0]]])

y = torch.zeros((2, 2, 32, 32, 32), device=_devices[-1])
y_hat = torch.zeros((2, 2, 32, 32, 32), device=_devices[-1])
y_hat[0, 1, 5:10, 5:10, 5:10] = 1
y_hat[0, 0] = 1 - y_hat[0, 1]
TEST_CASES_CC_METRICS.append([[y, y_hat], [[float("inf")], [0.0]]])

y = torch.zeros((2, 2, 32, 32, 32), device=_devices[-1])
y_hat = torch.zeros((2, 2, 32, 32, 32), device=_devices[-1])
y[0, 1, 10:15, 10:15, 10:15] = 1
y[0, 0] = 1 - y[0, 1]
y_hat[0, 1, 10:15, 10:15, 10:15] = 1
y_hat[0, 0] = 1 - y_hat[0, 1]
TEST_CASES_CC_METRICS.append([[y, y_hat], [[0.0], [0.0]]])

y = torch.zeros((2, 2, 32, 32, 32), device=_devices[-1])
y_hat = torch.zeros((2, 2, 32, 32, 32), device=_devices[-1])
y[0, 1, 10:15, 10:15, 10:15] = 1
y[0, 1, 20:25, 20:25, 20:25] = 1
y[0, 0] = 1 - y[0, 1]
y_hat[0, 1, 11:16, 10:15, 10:15] = 1
y_hat[0, 1, 21:26, 19:24, 20:25] = 1
y_hat[0, 0] = 1 - y_hat[0, 1]
TEST_CASES_CC_METRICS.append([[y, y_hat], [[1.2071], [0.0]]])

y = torch.zeros((2, 2, 32, 32), device=_devices[-1])
y_hat = torch.zeros((2, 2, 32, 32), device=_devices[-1])
y[0, 1, 10:15, 10:15] = 1
y[0, 1, 20:25, 20:25] = 1
y[0, 0] = 1 - y[0, 1]
y_hat[0, 1, 10:15, 10:15] = 1
y_hat[0, 1, 21:26, 19:24] = 1
y_hat[0, 0] = 1 - y_hat[0, 1]
TEST_CASES_CC_METRICS.append([[y, y_hat], [[0.7071], [0.0]]])

Comment thread
VijayVignesh1 marked this conversation as resolved.

def _describe_test_case(test_func, test_number, params):
_device, metric, directed, test_input, test_output = params.args
Expand Down Expand Up @@ -204,6 +243,20 @@ def test_nans(self, input_data):
np.testing.assert_allclose(0, result, rtol=1e-7)
np.testing.assert_allclose(0, not_nans, rtol=1e-7)

@parameterized.expand(TEST_CASES_CC_METRICS)
def test_cc_metrics(self, input_data, expected_value):
[seg_1, seg_2] = input_data
seg_1 = torch.tensor(seg_1)
seg_2 = torch.tensor(seg_2)
hd_metric = HausdorffDistanceMetric(per_component=True)
hd_metric(seg_1, seg_2)
result = hd_metric.aggregate(reduction="none")
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)

def test_channel_dimensions(self):
with self.assertRaises(ValueError):
HausdorffDistanceMetric(per_component=True)(torch.ones([3, 3, 144, 144]), torch.ones([3, 3, 144, 144]))


if __name__ == "__main__":
unittest.main()
Loading