diff --git a/monai/metrics/hausdorff_distance.py b/monai/metrics/hausdorff_distance.py index 1b83c93e5b..82a15761cb 100644 --- a/monai/metrics/hausdorff_distance.py +++ b/monai/metrics/hausdorff_distance.py @@ -17,7 +17,13 @@ import numpy as np import torch -from monai.metrics.utils import do_metric_reduction, get_edge_surface_distance, ignore_background, prepare_spacing +from monai.metrics.utils import ( + compute_voronoi_regions_fast, + do_metric_reduction, + get_edge_surface_distance, + ignore_background, + prepare_spacing, +) from monai.utils import MetricReduction, convert_data_type from .metric import CumulativeIterationMetric @@ -37,6 +43,18 @@ class HausdorffDistanceMetric(CumulativeIterationMetric): Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`. + The ``per_component=True`` approach computes the Hausdorff distance on a per-connected component basis in the ground + truth segmentation. This ensures that each component contributes equally to the final metric, regardless of its size. + Traditional Hausdorff distance can be dominated by large structures, but the per-component method gives a more + balanced evaluation, particularly for small or fragmented objects. This provides a granular assessment of segmentation + quality, which is especially important in cases with multiple disconnected foreground components. + Note: + - The input prediction (`y_pred`) and ground truth (`y`) must both have 2 channels (foreground/background), + with binary segmentation (0 for background, 1 for foreground). That is, this assumes the shape of both prediction + and ground truth is B2HW[D]. + - This method cannot be used with multiclass segmentation. + For more information, refer to the original paper: https://arxiv.org/abs/2410.18684 + Args: include_background: whether to include distance computation on the first channel of the predicted output. Defaults to ``False``. @@ -51,7 +69,7 @@ class HausdorffDistanceMetric(CumulativeIterationMetric): ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans). Here `not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric. - + per_component: whether to compute the Hausdorff distance on a per-connected component basis. Defaults to ``False``. """ def __init__( @@ -62,6 +80,7 @@ def __init__( directed: bool = False, reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False, + per_component: bool = False, ) -> None: super().__init__() self.include_background = include_background @@ -70,6 +89,7 @@ def __init__( self.directed = directed self.reduction = reduction self.get_not_nans = get_not_nans + self.per_component = per_component def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) -> torch.Tensor: # type: ignore[override] """ @@ -96,7 +116,17 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) dims = y_pred.ndimension() if dims < 3: raise ValueError("y_pred should have at least three dimensions.") - + if self.per_component: + if y_pred.ndim not in (4, 5) or y.ndim not in (4, 5) or y_pred.shape[1] != 2 or y.shape[1] != 2: + same_rank = y_pred.ndim == y.ndim and y_pred.ndim in (4, 5) + binary_channels = y_pred.shape[1] == 2 and y.shape[1] == 2 + same_shape = y_pred.shape == y.shape + if not (same_rank and binary_channels and same_shape): + raise ValueError( + "per_component requires matching 4D/5D binary tensors " + "(B, 2, H, W) or (B, 2, D, H, W). " + f"Got y_pred={tuple(y_pred.shape)}, y={tuple(y.shape)}." + ) # compute (BxC) for each channel for each batch return compute_hausdorff_distance( y_pred=y_pred, @@ -106,6 +136,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) percentile=self.percentile, directed=self.directed, spacing=kwargs.get("spacing"), + per_component=self.per_component, ) def aggregate( @@ -137,6 +168,7 @@ def compute_hausdorff_distance( percentile: float | None = None, directed: bool = False, spacing: int | float | np.ndarray | Sequence[int | float | np.ndarray | Sequence[int | float]] | None = None, + per_component: bool = False, ) -> torch.Tensor: """ Compute the Hausdorff distance. @@ -162,6 +194,7 @@ def compute_hausdorff_distance( If inner sequence has length 1, isotropic spacing with that value is used for all images in the batch, else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used for all images in batch. Defaults to ``None``. + per_component: whether to compute the Hausdorff distance on a per-connected component basis. Defaults to ``False``. """ if not include_background: @@ -179,17 +212,82 @@ def compute_hausdorff_distance( spacing_list = prepare_spacing(spacing=spacing, batch_size=batch_size, img_dim=img_dim) for b, c in np.ndindex(batch_size, n_class): - _, distances, _ = get_edge_surface_distance( - y_pred[b, c], - y[b, c], - distance_metric=distance_metric, - spacing=spacing_list[b], - symmetric=not directed, - class_index=c, - ) - percentile_distances = [_compute_percentile_hausdorff_distance(d, percentile) for d in distances] - max_distance = torch.max(torch.stack(percentile_distances)) - hd[b, c] = max_distance + if per_component: + pred_empty = y_pred[b, c].sum() == 0 + label_empty = y[b, c].sum() == 0 + if pred_empty and label_empty: + hd[b, c] = 0.0 if (pred_empty and label_empty) else float("nan") + continue + cc_assignment = compute_voronoi_regions_fast(y[b, c].cpu().numpy()) + if cc_assignment.device != y_pred[b, c].device: + cc_assignment = cc_assignment.to(y_pred[b, c].device) + component_scores = [] + for cc_id in torch.unique(cc_assignment.view(-1)): + cc_mask = cc_assignment == cc_id + + coords = torch.nonzero(cc_mask, as_tuple=False) + min_corner_idx = coords.min(dim=0).values + max_corner_idx = coords.max(dim=0).values + + crop_pred = ( + y_pred[b, c][ + min_corner_idx[0] : max_corner_idx[0] + 1, + min_corner_idx[1] : max_corner_idx[1] + 1, + min_corner_idx[2] : max_corner_idx[2] + 1, + ] + if y_pred.ndim == 5 + else y_pred[b, c][ + min_corner_idx[0] : max_corner_idx[0] + 1, min_corner_idx[1] : max_corner_idx[1] + 1 + ] + ) + + crop_label = ( + y[b, c][ + min_corner_idx[0] : max_corner_idx[0] + 1, + min_corner_idx[1] : max_corner_idx[1] + 1, + min_corner_idx[2] : max_corner_idx[2] + 1, + ] + if y.ndim == 5 + else y[b, c][min_corner_idx[0] : max_corner_idx[0] + 1, min_corner_idx[1] : max_corner_idx[1] + 1] + ) + + cc_crop_mask = ( + cc_mask[ + min_corner_idx[0] : max_corner_idx[0] + 1, + min_corner_idx[1] : max_corner_idx[1] + 1, + min_corner_idx[2] : max_corner_idx[2] + 1, + ] + if y_pred.ndim == 5 + else cc_mask[min_corner_idx[0] : max_corner_idx[0] + 1, min_corner_idx[1] : max_corner_idx[1] + 1] + ) + + pred_masked = crop_pred * cc_crop_mask + label_masked = crop_label * cc_crop_mask + + _, distances, _ = get_edge_surface_distance( + pred_masked, + label_masked, + distance_metric=distance_metric, + spacing=spacing_list[b], + symmetric=not directed, + class_index=c, + ) + percentile_distances = [_compute_percentile_hausdorff_distance(d, percentile) for d in distances] + component_scores.append(torch.max(torch.stack(percentile_distances))) + + hd[b, c] = torch.nanmean(torch.stack(component_scores)) if component_scores else 0.0 + else: + _, distances, _ = get_edge_surface_distance( + y_pred[b, c], + y[b, c], + distance_metric=distance_metric, + spacing=spacing_list[b], + symmetric=not directed, + class_index=c, + ) + percentile_distances = [_compute_percentile_hausdorff_distance(d, percentile) for d in distances] + max_distance = torch.max(torch.stack(percentile_distances)) + hd[b, c] = max_distance return hd diff --git a/tests/metrics/test_hausdorff_distance.py b/tests/metrics/test_hausdorff_distance.py index 20276a1832..4b1b3fa654 100644 --- a/tests/metrics/test_hausdorff_distance.py +++ b/tests/metrics/test_hausdorff_distance.py @@ -161,6 +161,45 @@ def create_spherical_seg_3d( for i, (metric, directed) in enumerate(product(["euclidean", "chessboard", "taxicab"], [True, False])): TEST_CASES_EXPANDED.append((_device, metric, directed, test_input, test_output[i])) +TEST_CASES_CC_METRICS = [] +y = torch.zeros((2, 2, 32, 32, 32), device=_devices[-1]) +y_hat = torch.zeros((2, 2, 32, 32, 32), device=_devices[-1]) +TEST_CASES_CC_METRICS.append([[y, y_hat], [[0.0], [0.0]]]) + +y = torch.zeros((2, 2, 32, 32, 32), device=_devices[-1]) +y_hat = torch.zeros((2, 2, 32, 32, 32), device=_devices[-1]) +y_hat[0, 1, 5:10, 5:10, 5:10] = 1 +y_hat[0, 0] = 1 - y_hat[0, 1] +TEST_CASES_CC_METRICS.append([[y, y_hat], [[float("inf")], [0.0]]]) + +y = torch.zeros((2, 2, 32, 32, 32), device=_devices[-1]) +y_hat = torch.zeros((2, 2, 32, 32, 32), device=_devices[-1]) +y[0, 1, 10:15, 10:15, 10:15] = 1 +y[0, 0] = 1 - y[0, 1] +y_hat[0, 1, 10:15, 10:15, 10:15] = 1 +y_hat[0, 0] = 1 - y_hat[0, 1] +TEST_CASES_CC_METRICS.append([[y, y_hat], [[0.0], [0.0]]]) + +y = torch.zeros((2, 2, 32, 32, 32), device=_devices[-1]) +y_hat = torch.zeros((2, 2, 32, 32, 32), device=_devices[-1]) +y[0, 1, 10:15, 10:15, 10:15] = 1 +y[0, 1, 20:25, 20:25, 20:25] = 1 +y[0, 0] = 1 - y[0, 1] +y_hat[0, 1, 11:16, 10:15, 10:15] = 1 +y_hat[0, 1, 21:26, 19:24, 20:25] = 1 +y_hat[0, 0] = 1 - y_hat[0, 1] +TEST_CASES_CC_METRICS.append([[y, y_hat], [[1.2071], [0.0]]]) + +y = torch.zeros((2, 2, 32, 32), device=_devices[-1]) +y_hat = torch.zeros((2, 2, 32, 32), device=_devices[-1]) +y[0, 1, 10:15, 10:15] = 1 +y[0, 1, 20:25, 20:25] = 1 +y[0, 0] = 1 - y[0, 1] +y_hat[0, 1, 10:15, 10:15] = 1 +y_hat[0, 1, 21:26, 19:24] = 1 +y_hat[0, 0] = 1 - y_hat[0, 1] +TEST_CASES_CC_METRICS.append([[y, y_hat], [[0.7071], [0.0]]]) + def _describe_test_case(test_func, test_number, params): _device, metric, directed, test_input, test_output = params.args @@ -204,6 +243,20 @@ def test_nans(self, input_data): np.testing.assert_allclose(0, result, rtol=1e-7) np.testing.assert_allclose(0, not_nans, rtol=1e-7) + @parameterized.expand(TEST_CASES_CC_METRICS) + def test_cc_metrics(self, input_data, expected_value): + [seg_1, seg_2] = input_data + seg_1 = torch.tensor(seg_1) + seg_2 = torch.tensor(seg_2) + hd_metric = HausdorffDistanceMetric(per_component=True) + hd_metric(seg_1, seg_2) + result = hd_metric.aggregate(reduction="none") + np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) + + def test_channel_dimensions(self): + with self.assertRaises(ValueError): + HausdorffDistanceMetric(per_component=True)(torch.ones([3, 3, 144, 144]), torch.ones([3, 3, 144, 144])) + if __name__ == "__main__": unittest.main()