Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 23 additions & 20 deletions xrspatial/reproject/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,6 @@ def _reproject_chunk_numpy(
for b in range(n_bands):
band_data = window[:, :, b].astype(np.float64)
if not np.isnan(nodata):
band_data = band_data.copy()
band_data[band_data == nodata] = np.nan
band_result = _resample_numpy(band_data, local_row, local_col,
resampling=resampling, nodata=nodata)
Expand All @@ -302,7 +301,6 @@ def _reproject_chunk_numpy(

# Convert sentinel nodata to NaN so numba kernels can detect it
if not np.isnan(nodata):
window = window.copy()
window[window == nodata] = np.nan

result = _resample_numpy(window, local_row, local_col,
Expand Down Expand Up @@ -353,14 +351,18 @@ def _reproject_chunk_cupy(
src_row_px = (src_top - src_y) / src_res_y - 0.5
else:
src_row_px = (src_y - src_bottom) / src_res_y - 0.5
# Need min/max on CPU for window selection
r_min_val = float(cp.nanmin(src_row_px).get())
if not np.isfinite(r_min_val):
return cp.full(chunk_shape, nodata, dtype=cp.float64)
r_max_val = float(cp.nanmax(src_row_px).get())
c_min_val = float(cp.nanmin(src_col_px).get())
c_max_val = float(cp.nanmax(src_col_px).get())
if not np.isfinite(r_max_val) or not np.isfinite(c_min_val) or not np.isfinite(c_max_val):
# Need min/max on CPU for window selection.
# Stack the four reductions and pull them across in one device-to-host
# transfer to avoid four separate synchronous syncs.
mins_maxes = cp.stack([
cp.nanmin(src_row_px), cp.nanmax(src_row_px),
cp.nanmin(src_col_px), cp.nanmax(src_col_px),
])
r_min_val, r_max_val, c_min_val, c_max_val = (
float(v) for v in mins_maxes.get()
)
if not (np.isfinite(r_min_val) and np.isfinite(r_max_val)
and np.isfinite(c_min_val) and np.isfinite(c_max_val)):
return cp.full(chunk_shape, nodata, dtype=cp.float64)
r_min = int(np.floor(r_min_val)) - 2
r_max = int(np.ceil(r_max_val)) + 3
Expand Down Expand Up @@ -440,7 +442,6 @@ def _reproject_chunk_cupy(
else:
# CPU coordinates -- convert sentinel nodata to NaN before map_coordinates
if not np.isnan(nodata):
window = window.copy()
window[window == nodata] = cp.nan

result = _resample_cupy(window, local_row, local_col,
Expand Down Expand Up @@ -1119,14 +1120,17 @@ def _reproject_dask_cupy(
else:
src_row_px = (src_y - src_bottom) / src_res_y - 0.5

r_min_val = float(cp.nanmin(src_row_px).get())
if not np.isfinite(r_min_val):
col_offset += cchunk
continue
r_max_val = float(cp.nanmax(src_row_px).get())
c_min_val = float(cp.nanmin(src_col_px).get())
c_max_val = float(cp.nanmax(src_col_px).get())
if not np.isfinite(r_max_val) or not np.isfinite(c_min_val) or not np.isfinite(c_max_val):
# Batch the four reductions into a single device-to-host
# transfer instead of four separate synchronous .get() calls.
mins_maxes = cp.stack([
cp.nanmin(src_row_px), cp.nanmax(src_row_px),
cp.nanmin(src_col_px), cp.nanmax(src_col_px),
])
r_min_val, r_max_val, c_min_val, c_max_val = (
float(v) for v in mins_maxes.get()
)
if not (np.isfinite(r_min_val) and np.isfinite(r_max_val)
and np.isfinite(c_min_val) and np.isfinite(c_max_val)):
col_offset += cchunk
continue
r_min = int(np.floor(r_min_val)) - 2
Expand Down Expand Up @@ -1190,7 +1194,6 @@ def _reproject_dask_cupy(
window = window.astype(cp.float64)

if not np.isnan(nodata):
window = window.copy()
window[window == nodata] = cp.nan

local_row = src_row_px - r_min_clip
Expand Down
31 changes: 31 additions & 0 deletions xrspatial/tests/test_reproject.py
Original file line number Diff line number Diff line change
Expand Up @@ -2377,6 +2377,37 @@ def test_dask_cupy_reproject_matches_numpy(self):
np_result[finite], dc_vals[finite], rtol=1e-5, atol=1e-5,
)

def test_cupy_reproject_with_nan_chunks(self):
"""Regression: target chunks projecting outside the source must
return all-nodata, exercising the batched min/max early-return."""
from xrspatial.reproject import reproject
data = np.random.RandomState(3).rand(16, 16).astype(np.float64)
# Source covers a small region near the prime meridian / equator.
coords = {'y': np.linspace(2, -2, 16), 'x': np.linspace(-2, 2, 16)}
attrs = {'crs': 'EPSG:4326', 'nodata': np.nan}
cp_raster = xr.DataArray(cp.asarray(data), dims=['y', 'x'],
coords=coords, attrs=attrs)

# Reproject to a target far outside the source. Coordinates that fall
# outside the source produce NaN row/col pixels, so the batched
# nanmin/nanmax should be NaN and trigger the all-nodata early return.
target_bounds = (5_000_000, 5_000_000, 5_100_000, 5_100_000)
out = reproject(cp_raster, 'EPSG:3857', bounds=target_bounds,
width=8, height=8)
out_vals = out.data.get() if hasattr(out.data, 'get') else np.asarray(out.data)
assert out_vals.shape == (8, 8)
# Out-of-bounds output: all entries must be nodata (NaN here).
assert np.all(np.isnan(out_vals))

# Same target as the source exercises the in-bounds branch and must
# return finite values from the same batched-reduction code path.
in_bounds = reproject(cp_raster, 'EPSG:4326',
bounds=(-1.5, -1.5, 1.5, 1.5),
width=8, height=8)
in_vals = (in_bounds.data.get() if hasattr(in_bounds.data, 'get')
else np.asarray(in_bounds.data))
assert np.isfinite(in_vals).any()


class TestCoordsPreservation:
"""Non-spatial coords pass through reproject() and merge()."""
Expand Down
Loading