diff --git a/xrspatial/reproject/__init__.py b/xrspatial/reproject/__init__.py index 8519dcf1..cb547dc1 100644 --- a/xrspatial/reproject/__init__.py +++ b/xrspatial/reproject/__init__.py @@ -287,7 +287,6 @@ def _reproject_chunk_numpy( for b in range(n_bands): band_data = window[:, :, b].astype(np.float64) if not np.isnan(nodata): - band_data = band_data.copy() band_data[band_data == nodata] = np.nan band_result = _resample_numpy(band_data, local_row, local_col, resampling=resampling, nodata=nodata) @@ -302,7 +301,6 @@ def _reproject_chunk_numpy( # Convert sentinel nodata to NaN so numba kernels can detect it if not np.isnan(nodata): - window = window.copy() window[window == nodata] = np.nan result = _resample_numpy(window, local_row, local_col, @@ -353,14 +351,18 @@ def _reproject_chunk_cupy( src_row_px = (src_top - src_y) / src_res_y - 0.5 else: src_row_px = (src_y - src_bottom) / src_res_y - 0.5 - # Need min/max on CPU for window selection - r_min_val = float(cp.nanmin(src_row_px).get()) - if not np.isfinite(r_min_val): - return cp.full(chunk_shape, nodata, dtype=cp.float64) - r_max_val = float(cp.nanmax(src_row_px).get()) - c_min_val = float(cp.nanmin(src_col_px).get()) - c_max_val = float(cp.nanmax(src_col_px).get()) - if not np.isfinite(r_max_val) or not np.isfinite(c_min_val) or not np.isfinite(c_max_val): + # Need min/max on CPU for window selection. + # Stack the four reductions and pull them across in one device-to-host + # transfer to avoid four separate synchronous syncs. + mins_maxes = cp.stack([ + cp.nanmin(src_row_px), cp.nanmax(src_row_px), + cp.nanmin(src_col_px), cp.nanmax(src_col_px), + ]) + r_min_val, r_max_val, c_min_val, c_max_val = ( + float(v) for v in mins_maxes.get() + ) + if not (np.isfinite(r_min_val) and np.isfinite(r_max_val) + and np.isfinite(c_min_val) and np.isfinite(c_max_val)): return cp.full(chunk_shape, nodata, dtype=cp.float64) r_min = int(np.floor(r_min_val)) - 2 r_max = int(np.ceil(r_max_val)) + 3 @@ -440,7 +442,6 @@ def _reproject_chunk_cupy( else: # CPU coordinates -- convert sentinel nodata to NaN before map_coordinates if not np.isnan(nodata): - window = window.copy() window[window == nodata] = cp.nan result = _resample_cupy(window, local_row, local_col, @@ -1119,14 +1120,17 @@ def _reproject_dask_cupy( else: src_row_px = (src_y - src_bottom) / src_res_y - 0.5 - r_min_val = float(cp.nanmin(src_row_px).get()) - if not np.isfinite(r_min_val): - col_offset += cchunk - continue - r_max_val = float(cp.nanmax(src_row_px).get()) - c_min_val = float(cp.nanmin(src_col_px).get()) - c_max_val = float(cp.nanmax(src_col_px).get()) - if not np.isfinite(r_max_val) or not np.isfinite(c_min_val) or not np.isfinite(c_max_val): + # Batch the four reductions into a single device-to-host + # transfer instead of four separate synchronous .get() calls. + mins_maxes = cp.stack([ + cp.nanmin(src_row_px), cp.nanmax(src_row_px), + cp.nanmin(src_col_px), cp.nanmax(src_col_px), + ]) + r_min_val, r_max_val, c_min_val, c_max_val = ( + float(v) for v in mins_maxes.get() + ) + if not (np.isfinite(r_min_val) and np.isfinite(r_max_val) + and np.isfinite(c_min_val) and np.isfinite(c_max_val)): col_offset += cchunk continue r_min = int(np.floor(r_min_val)) - 2 @@ -1190,7 +1194,6 @@ def _reproject_dask_cupy( window = window.astype(cp.float64) if not np.isnan(nodata): - window = window.copy() window[window == nodata] = cp.nan local_row = src_row_px - r_min_clip diff --git a/xrspatial/tests/test_reproject.py b/xrspatial/tests/test_reproject.py index d7e4b504..21de5828 100644 --- a/xrspatial/tests/test_reproject.py +++ b/xrspatial/tests/test_reproject.py @@ -2377,6 +2377,37 @@ def test_dask_cupy_reproject_matches_numpy(self): np_result[finite], dc_vals[finite], rtol=1e-5, atol=1e-5, ) + def test_cupy_reproject_with_nan_chunks(self): + """Regression: target chunks projecting outside the source must + return all-nodata, exercising the batched min/max early-return.""" + from xrspatial.reproject import reproject + data = np.random.RandomState(3).rand(16, 16).astype(np.float64) + # Source covers a small region near the prime meridian / equator. + coords = {'y': np.linspace(2, -2, 16), 'x': np.linspace(-2, 2, 16)} + attrs = {'crs': 'EPSG:4326', 'nodata': np.nan} + cp_raster = xr.DataArray(cp.asarray(data), dims=['y', 'x'], + coords=coords, attrs=attrs) + + # Reproject to a target far outside the source. Coordinates that fall + # outside the source produce NaN row/col pixels, so the batched + # nanmin/nanmax should be NaN and trigger the all-nodata early return. + target_bounds = (5_000_000, 5_000_000, 5_100_000, 5_100_000) + out = reproject(cp_raster, 'EPSG:3857', bounds=target_bounds, + width=8, height=8) + out_vals = out.data.get() if hasattr(out.data, 'get') else np.asarray(out.data) + assert out_vals.shape == (8, 8) + # Out-of-bounds output: all entries must be nodata (NaN here). + assert np.all(np.isnan(out_vals)) + + # Same target as the source exercises the in-bounds branch and must + # return finite values from the same batched-reduction code path. + in_bounds = reproject(cp_raster, 'EPSG:4326', + bounds=(-1.5, -1.5, 1.5, 1.5), + width=8, height=8) + in_vals = (in_bounds.data.get() if hasattr(in_bounds.data, 'get') + else np.asarray(in_bounds.data)) + assert np.isfinite(in_vals).any() + class TestCoordsPreservation: """Non-spatial coords pass through reproject() and merge()."""