diff --git a/xrspatial/resample.py b/xrspatial/resample.py index b4740d0c..0e4d0cd7 100644 --- a/xrspatial/resample.py +++ b/xrspatial/resample.py @@ -945,12 +945,53 @@ def _run_dask_cupy(data, scale_y, scale_x, method): # -- Public API -------------------------------------------------------------- +def _resolve_nodata(agg, nodata): + """Resolve the input-side nodata sentinel. + + Explicit *nodata* wins. Otherwise fall back to ``_FillValue`` then + ``nodata`` in ``agg.attrs``. Returns ``None`` when no sentinel was + found (the caller skips the masking step). + + NaN sentinels are returned as NaN so the caller can branch on + ``np.isnan`` rather than ``==`` (which never matches NaN). + """ + if nodata is None: + for key in ('_FillValue', 'nodata'): + v = agg.attrs.get(key) + if v is not None: + nodata = v + break + if nodata is None: + return None + nd = float(nodata) + if np.isinf(nd): + raise ValueError(f"nodata must be finite or NaN, got {nodata!r}") + return nd + + +def _apply_nodata_mask(agg, nodata): + """Return a float copy of *agg* with sentinel pixels replaced by NaN. + + Works for numpy, cupy, dask+numpy, and dask+cupy backings via + xarray's ``.where`` (which dispatches per backend). + """ + if nodata is None: + return agg + # Promote to float so NaN can be stored. xr.where keeps the backend. + if not np.issubdtype(agg.dtype, np.floating): + agg = agg.astype(np.float64) + if np.isnan(nodata): + return agg # already-NaN sentinels need no replacement + return agg.where(agg != nodata) + + @supports_dataset def resample( agg, scale_factor=None, target_resolution=None, method='nearest', + nodata=None, name='resample', ): """Change raster resolution without changing its CRS. @@ -960,21 +1001,30 @@ def resample( Parameters ---------- agg : xarray.DataArray - Input raster (2-D). + Input raster. 2-D ``(y, x)`` or 3-D ``(band, y, x)``. For 3-D + inputs each band is resampled independently and the leading + non-spatial coordinate is preserved. scale_factor : float or (float, float), optional Multiplicative factor applied to the number of pixels. ``0.5`` halves the pixel count (doubles the cell size); ``2.0`` doubles the pixel count (halves the cell size). A two-element tuple sets ``(scale_y, scale_x)`` independently. - target_resolution : float, optional + target_resolution : float or (float, float), optional Desired cell size in the same units as the raster coordinates. - Both axes are set to this resolution. + A scalar sets both axes to the same resolution; a 2-tuple sets + ``(res_y, res_x)`` independently. method : str, default ``'nearest'`` Resampling algorithm. Interpolation methods (``'nearest'``, ``'bilinear'``, ``'cubic'``) work for both upsampling and downsampling. Aggregation methods (``'average'``, ``'min'``, ``'max'``, ``'median'``, ``'mode'``) only support downsampling (scale_factor <= 1). + nodata : float, optional + Sentinel value in the input that should be treated as missing. + Input pixels equal to *nodata* are replaced with NaN before + resampling. When ``None``, falls back to ``agg.attrs['_FillValue']`` + then ``agg.attrs['nodata']``. The output uses NaN as the sentinel + regardless of the input convention. name : str, default ``'resample'`` Name for the output DataArray. @@ -984,7 +1034,7 @@ def resample( Resampled raster with updated coordinates, ``res`` attribute, and float32 dtype. """ - _validate_raster(agg, func_name='resample', name='agg') + _validate_raster(agg, func_name='resample', name='agg', ndim=(2, 3)) if method not in ALL_METHODS: raise ValueError( @@ -1025,12 +1075,56 @@ def resample( f"(scale_factor <= 1.0)" ) + # -- nodata: replace sentinels with NaN before resampling ---------------- + nd_resolved = _resolve_nodata(agg, nodata) + has_nodata = nd_resolved is not None + if has_nodata: + agg = _apply_nodata_mask(agg, nd_resolved) + # -- fast path: identity ------------------------------------------------- if scale_y == 1.0 and scale_x == 1.0: out = agg.copy() out.name = name + # When nodata was applied, advertise NaN as the new sentinel. + if has_nodata: + out.attrs['_FillValue'] = float('nan') return out + # -- 3D: dispatch per band ---------------------------------------------- + if agg.ndim == 3: + leading_dim = agg.dims[0] + bands = [] + for i in range(agg.sizes[leading_dim]): + band_2d = agg.isel({leading_dim: i}) + band_out = resample( + band_2d, + scale_factor=scale_factor, + target_resolution=target_resolution, + method=method, + # Pass NaN so the recursive call short-circuits masking + # (we already applied the mask on the 3D input above) and + # ignores the original attrs sentinel. + nodata=float('nan'), + name=name, + ) + bands.append(band_out) + # Stack along the leading dim. concat preserves the per-band + # coordinate when each input has it. + result = xr.concat(bands, dim=leading_dim) + # concat may reorder dims; transpose to the original layout. + result = result.transpose(*agg.dims) + result.name = name + # Carry across input attrs (concat picks the first; merge with input). + new_attrs = dict(agg.attrs) + new_attrs.update(bands[0].attrs) # res from per-band resample + if has_nodata: + new_attrs['_FillValue'] = float('nan') + result.attrs = new_attrs + # Preserve the leading-dim coordinate if it was on the input. + if leading_dim in agg.coords: + result = result.assign_coords({leading_dim: agg.coords[leading_dim]}) + return result + # -- memory guard for eager backends ------------------------------------ # Dask paths build per-chunk allocations lazily (chunk size already # bounds peak memory). The eager numpy and cupy paths allocate the @@ -1077,6 +1171,8 @@ def _new_coords(vals, n_out): new_attrs = dict(agg.attrs) new_attrs['res'] = (abs(px), abs(py)) + if has_nodata: + new_attrs['_FillValue'] = float('nan') # Refresh `transform` if the input had one. The rasterio 6-tuple is # (res_x, 0.0, left, 0.0, -res_y, top). `top` is the upper edge of diff --git a/xrspatial/tests/test_resample.py b/xrspatial/tests/test_resample.py index 92f6f3e0..eccbe835 100644 --- a/xrspatial/tests/test_resample.py +++ b/xrspatial/tests/test_resample.py @@ -728,3 +728,181 @@ def test_dask_aggregate_smoke_200x200(self): assert elapsed < 30.0, ( f"dask aggregate took {elapsed:.2f}s; expected well under 5s" ) + + +# --------------------------------------------------------------------------- +# 3D rasters (issue #1466) +# --------------------------------------------------------------------------- + +class TestThreeDRasters: + """Multi-band ``(band, y, x)`` rasters resample per-band.""" + + def _make_3d(self, backend='numpy'): + # 3 bands of an 8x8 gradient. Each band has a unique offset so we + # can confirm bands aren't mixed during the dispatch. + y, x = np.mgrid[0:8, 0:8] + band0 = (y * 10 + x).astype(np.float32) + band1 = band0 + 100 + band2 = band0 + 200 + data = np.stack([band0, band1, band2], axis=0) + agg = xr.DataArray( + data, + dims=('band', 'y', 'x'), + coords={ + 'band': np.array([1, 2, 3]), + 'y': np.arange(8, dtype=np.float64), + 'x': np.arange(8, dtype=np.float64), + }, + attrs={'res': (1.0, 1.0)}, + name='myraster', + ) + if backend == 'dask': + import dask.array as da + agg = agg.copy() + agg.data = da.from_array(agg.data, chunks=(1, 4, 4)) + elif backend == 'cupy': + import cupy + agg = agg.copy() + agg.data = cupy.asarray(agg.data) + return agg + + def test_3d_numpy_shape_and_band_coord(self): + agg = self._make_3d('numpy') + out = resample(agg, scale_factor=0.5, method='nearest') + assert out.shape == (3, 4, 4) + assert out.dims == ('band', 'y', 'x') + np.testing.assert_array_equal(out['band'].values, [1, 2, 3]) + + def test_3d_per_band_independence(self): + """Each band's output should be the 2D resample of that band.""" + agg = self._make_3d('numpy') + out = resample(agg, scale_factor=0.5, method='average') + for i in range(3): + band_2d = agg.isel(band=i).reset_coords(drop=True) + ref = resample(band_2d, scale_factor=0.5, method='average') + np.testing.assert_allclose(out.isel(band=i).values, ref.values, + atol=1e-5) + + def test_3d_target_resolution_tuple(self): + agg = self._make_3d('numpy') + out = resample(agg, target_resolution=(2.0, 4.0)) + assert out.shape == (3, 4, 2) + + @dask_array_available + def test_3d_dask(self): + agg = self._make_3d('dask') + out = resample(agg, scale_factor=0.5, method='nearest') + assert out.shape == (3, 4, 4) + np.testing.assert_array_equal(out['band'].values, [1, 2, 3]) + + @cuda_and_cupy_available + def test_3d_cupy(self): + agg = self._make_3d('cupy') + out = resample(agg, scale_factor=0.5, method='nearest') + assert out.shape == (3, 4, 4) + np.testing.assert_array_equal(out['band'].values, [1, 2, 3]) + + +# --------------------------------------------------------------------------- +# nodata handling (issue #1466) +# --------------------------------------------------------------------------- + +class TestNodata: + def test_explicit_nodata_int_sentinel(self): + # Integer raster with -9999 sentinel. After resample those pixels + # should become NaN; valid pixels stay as float interpolations. + data = np.array([ + [-9999, -9999, 10, 10], + [-9999, -9999, 10, 10], + [20, 20, 30, 30], + [20, 20, 30, 30], + ], dtype=np.int32) + agg = create_test_raster(data, attrs={'res': (1.0, 1.0)}) + out = resample(agg, scale_factor=0.5, method='nearest', + nodata=-9999) + assert out.shape == (2, 2) + # Top-left output pixel maps to the -9999 region -> NaN + assert np.isnan(out.values[0, 0]) + # Bottom-right pixel maps to a valid region -> finite + assert np.isfinite(out.values[1, 1]) + assert out.attrs.get('_FillValue') is not None + assert np.isnan(out.attrs['_FillValue']) + + def test_nodata_from_fillvalue_attr(self): + # Same data, but sentinel discovered via _FillValue attr. + data = np.array([ + [-9999, -9999, 10, 10], + [-9999, -9999, 10, 10], + [20, 20, 30, 30], + [20, 20, 30, 30], + ], dtype=np.int32) + agg = create_test_raster( + data, attrs={'res': (1.0, 1.0), '_FillValue': -9999} + ) + out = resample(agg, scale_factor=0.5, method='nearest') + assert np.isnan(out.values[0, 0]) + assert np.isfinite(out.values[1, 1]) + + def test_nodata_from_nodata_attr(self): + data = np.array([ + [-9999, -9999, 10, 10], + [-9999, -9999, 10, 10], + [20, 20, 30, 30], + [20, 20, 30, 30], + ], dtype=np.int32) + agg = create_test_raster( + data, attrs={'res': (1.0, 1.0), 'nodata': -9999} + ) + out = resample(agg, scale_factor=0.5, method='average') + assert np.isnan(out.values[0, 0]) + + def test_nodata_none_no_attrs_unchanged(self): + # Without an explicit param or attr, behavior matches the old + # (pre-#1466) implementation -- no masking, no _FillValue added. + data = np.arange(16, dtype=np.float32).reshape(4, 4) + agg = create_test_raster(data, attrs={'res': (1.0, 1.0)}) + out = resample(agg, scale_factor=0.5, method='nearest') + assert '_FillValue' not in out.attrs + + def test_nodata_float_explicit(self): + # Float sentinel -- e.g. -1.0 marking masked pixels. + data = np.array([[-1.0, -1.0, 5.0, 5.0], + [-1.0, -1.0, 5.0, 5.0], + [3.0, 3.0, 7.0, 7.0], + [3.0, 3.0, 7.0, 7.0]], dtype=np.float32) + agg = create_test_raster(data, attrs={'res': (1.0, 1.0)}) + out = resample(agg, scale_factor=0.5, method='nearest', nodata=-1.0) + assert np.isnan(out.values[0, 0]) + + def test_explicit_nodata_overrides_attr(self): + # Explicit param wins over _FillValue attr. + # 4x4 with -1 in the top-left 2x2 block. _FillValue says -999 + # (which doesn't appear); explicit nodata=-1 should mask the corner. + data = np.array([[-1.0, -1.0, 5.0, 5.0], + [-1.0, -1.0, 5.0, 5.0], + [3.0, 3.0, 7.0, 7.0], + [3.0, 3.0, 7.0, 7.0]], dtype=np.float32) + agg = create_test_raster( + data, attrs={'res': (1.0, 1.0), '_FillValue': -999.0} + ) + out = resample(agg, scale_factor=0.5, method='nearest', nodata=-1.0) + # Without override the attr would say -999 (no match) and -1 would + # leak through; with override the top-left output pixel is NaN. + assert np.isnan(out.values[0, 0]) + + +# --------------------------------------------------------------------------- +# target_resolution as 2-tuple (issue #1466) +# --------------------------------------------------------------------------- + +class TestTargetResolutionTuple: + def test_tuple_resolution_independent_axes(self, grid_8x8): + # 8x8 grid with res=(1, 1) -> target (2, 4) -> output (4, 2). + out = resample(grid_8x8, target_resolution=(2.0, 4.0)) + assert out.shape == (4, 2) + + def test_tuple_resolution_matches_scale_factor(self, grid_8x8): + # target_resolution=(2.0, 2.0) should match scale_factor=0.5. + a = resample(grid_8x8, target_resolution=(2.0, 2.0), method='nearest') + b = resample(grid_8x8, scale_factor=0.5, method='nearest') + np.testing.assert_allclose(a.values, b.values)