diff --git a/xrspatial/reproject/__init__.py b/xrspatial/reproject/__init__.py index 8519dcf1..828e47eb 100644 --- a/xrspatial/reproject/__init__.py +++ b/xrspatial/reproject/__init__.py @@ -687,15 +687,25 @@ def reproject( ydim, xdim = _find_spatial_dims(raster) # Carry input attrs forward so units, long_name, scale_factor, etc. - # survive the transform. Pop attrs that are stale after reprojection: - # the affine `transform` and grid `res` describe the old grid, and - # `crs_wkt` would duplicate (or contradict) the canonical `crs` we re-emit. + # survive the transform. Re-emit `transform` and `res` for the new + # output grid (rasterio 6-tuple convention). Drop `crs_wkt` since it + # would duplicate (or contradict) the canonical `crs` we re-emit. + out_left, _out_bottom, _out_right, out_top = grid['bounds'] + out_res_x = grid['res_x'] + out_res_y = grid['res_y'] out_attrs = {**raster.attrs} - out_attrs.pop('transform', None) out_attrs.pop('crs_wkt', None) - out_attrs.pop('res', None) out_attrs['crs'] = tgt_wkt out_attrs['nodata'] = nd + out_attrs['res'] = (out_res_x, out_res_y) + out_attrs['transform'] = ( + out_res_x, 0.0, out_left, 0.0, -out_res_y, out_top, + ) + # If the input used `_FillValue`, propagate the resolved nodata to + # both keys so downstream consumers reading either find a consistent + # value. If `_FillValue` was absent, leave it absent. + if '_FillValue' in raster.attrs: + out_attrs['_FillValue'] = nd if tgt_vertical_crs is not None: out_attrs['vertical_crs'] = tgt_vertical_crs @@ -1479,7 +1489,8 @@ def merge( "Ensure all rasters have CRS metadata." ) sb = _source_bounds(r) - ss = (r.sizes[r.dims[-2]], r.sizes[r.dims[-1]]) + r_ydim, r_xdim = _find_spatial_dims(r) + ss = (r.sizes[r_ydim], r.sizes[r_xdim]) yd = _is_y_descending(r) # Per-raster input nodata sentinel. Detected independently of the # user-supplied output nodata so that mixed-sentinel inputs are @@ -1552,8 +1563,7 @@ def merge( ) y_coords, x_coords = _make_output_coords(out_bounds, out_shape) - ydim = rasters[0].dims[-2] - xdim = rasters[0].dims[-1] + ydim, xdim = _find_spatial_dims(rasters[0]) out_coords = {ydim: y_coords, xdim: x_coords} # Carry forward non-spatial coords from the first raster (e.g. scalar @@ -1567,14 +1577,24 @@ def merge( out_coords[cname] = cval # Carry the first raster's attrs forward (matches the default - # strategy='first'). Drop attrs describing the old grid: `transform`, - # `res`, and the duplicate `crs_wkt` are no longer accurate. + # strategy='first'). Drop the duplicate `crs_wkt` and re-emit + # `transform` and `res` for the new output grid (rasterio 6-tuple). + out_left, _out_bottom, _out_right, out_top = grid['bounds'] + out_res_x = grid['res_x'] + out_res_y = grid['res_y'] out_attrs = {**rasters[0].attrs} - out_attrs.pop('transform', None) out_attrs.pop('crs_wkt', None) - out_attrs.pop('res', None) out_attrs['crs'] = tgt_wkt out_attrs['nodata'] = nd + out_attrs['res'] = (out_res_x, out_res_y) + out_attrs['transform'] = ( + out_res_x, 0.0, out_left, 0.0, -out_res_y, out_top, + ) + # If the first raster used `_FillValue`, propagate the resolved + # nodata to both keys for consistent round-trip. Leave absent + # otherwise. + if '_FillValue' in rasters[0].attrs: + out_attrs['_FillValue'] = nd result = xr.DataArray( result_data, diff --git a/xrspatial/tests/test_reproject.py b/xrspatial/tests/test_reproject.py index d7e4b504..5c950845 100644 --- a/xrspatial/tests/test_reproject.py +++ b/xrspatial/tests/test_reproject.py @@ -2121,19 +2121,20 @@ def test_reproject_preserves_long_name(self): result = reproject(raster, 'EPSG:4326', resolution=0.25) assert result.attrs.get('long_name') == 'elevation' - def test_reproject_drops_stale_transform(self): + def test_reproject_replaces_stale_transform(self): from xrspatial.reproject import reproject - raster = self._raster_with_attrs( - {'transform': (1.0, 0.0, 0.0, 0.0, -1.0, 0.0)} - ) + stale = (1.0, 0.0, 0.0, 0.0, -1.0, 0.0) + raster = self._raster_with_attrs({'transform': stale}) result = reproject(raster, 'EPSG:3857') - assert 'transform' not in result.attrs + assert 'transform' in result.attrs + assert tuple(result.attrs['transform']) != stale - def test_reproject_drops_stale_res(self): + def test_reproject_replaces_stale_res(self): from xrspatial.reproject import reproject raster = self._raster_with_attrs({'res': (1.0, 1.0)}) result = reproject(raster, 'EPSG:3857') - assert 'res' not in result.attrs + assert 'res' in result.attrs + assert tuple(result.attrs['res']) != (1.0, 1.0) def test_reproject_overrides_crs(self): from xrspatial.reproject import reproject @@ -2168,17 +2169,171 @@ def test_merge_preserves_first_raster_attrs(self): assert result.attrs.get('units') == 'm' assert result.attrs.get('long_name') == 'elev' - def test_merge_drops_stale_transform(self): + def test_merge_replaces_stale_transform(self): from xrspatial.reproject import merge + stale = (1.0, 0.0, 0.0, 0.0, -1.0, 0.0) a = self._raster_with_attrs( - {'transform': (1.0, 0.0, 0.0, 0.0, -1.0, 0.0)}, + {'transform': stale}, x_range=(-5, 0), y_range=(-5, 5), ) b = self._raster_with_attrs( x_range=(0, 5), y_range=(-5, 5), ) result = merge([a, b], resolution=1.0) - assert 'transform' not in result.attrs + assert 'transform' in result.attrs + assert tuple(result.attrs['transform']) != stale + + # Fresh transform/res emission ---------------------------------------- + + def test_reproject_emits_fresh_transform(self): + from xrspatial.reproject import reproject + stale = (1.0, 0.0, 0.0, 0.0, -1.0, 10.0) + raster = self._raster_with_attrs({'transform': stale}) + result = reproject(raster, 'EPSG:3857', resolution=50000.0) + t = result.attrs['transform'] + assert len(t) == 6 + # Output is in EPSG:3857 so transform values cannot match the stale + # geographic-degree input. + assert tuple(t) != stale + # transform[0] is res_x, transform[4] is -res_y, transform[2] is + # left edge, transform[5] is top edge. + res_x, res_y = result.attrs['res'] + assert t[0] == res_x + assert t[4] == -res_y + # Top edge: y coord of first row plus half a pixel. + y0 = float(result.coords['y'].values[0]) + assert t[5] == pytest.approx(y0 + res_y / 2) + # Left edge: x coord of first col minus half a pixel. + x0 = float(result.coords['x'].values[0]) + assert t[2] == pytest.approx(x0 - res_x / 2) + + def test_reproject_emits_fresh_res(self): + from xrspatial.reproject import reproject + raster = self._raster_with_attrs({'res': (1.0, 1.0)}) + # Use an explicit resolution very different from input. + result = reproject(raster, 'EPSG:4326', resolution=0.25) + assert 'res' in result.attrs + res_x, res_y = result.attrs['res'] + # Pixel size derived from output coords must match. + x = result.coords['x'].values + y = result.coords['y'].values + actual_res_x = float(abs(x[1] - x[0])) + actual_res_y = float(abs(y[1] - y[0])) + assert res_x == pytest.approx(actual_res_x) + assert res_y == pytest.approx(actual_res_y) + + def test_reproject_no_input_transform_still_emits_one(self): + from xrspatial.reproject import reproject + raster = self._raster_with_attrs() + assert 'transform' not in raster.attrs + result = reproject(raster, 'EPSG:4326', resolution=0.25) + assert 'transform' in result.attrs + assert 'res' in result.attrs + assert len(result.attrs['transform']) == 6 + + def test_merge_emits_fresh_transform_and_res(self): + from xrspatial.reproject import merge + a = self._raster_with_attrs( + x_range=(-5, 0), y_range=(-5, 5), + ) + b = self._raster_with_attrs( + x_range=(0, 5), y_range=(-5, 5), + ) + result = merge([a, b], resolution=1.0) + assert 'transform' in result.attrs + assert 'res' in result.attrs + t = result.attrs['transform'] + assert len(t) == 6 + res_x, res_y = result.attrs['res'] + assert t[0] == res_x + assert t[4] == -res_y + y0 = float(result.coords['y'].values[0]) + x0 = float(result.coords['x'].values[0]) + assert t[5] == pytest.approx(y0 + res_y / 2) + assert t[2] == pytest.approx(x0 - res_x / 2) + + def test_merge_finds_spatial_dims_with_lat_lon(self): + from xrspatial.reproject import merge + a_data = np.ones((8, 8), dtype=np.float64) + b_data = np.ones((8, 8), dtype=np.float64) * 2 + attrs = {'crs': 'EPSG:4326', 'nodata': np.nan} + a = xr.DataArray( + a_data, dims=['lat', 'lon'], + coords={ + 'lat': np.linspace(5, -5, 8), + 'lon': np.linspace(-5, 0, 8), + }, + name='a', attrs=attrs, + ) + b = xr.DataArray( + b_data, dims=['lat', 'lon'], + coords={ + 'lat': np.linspace(5, -5, 8), + 'lon': np.linspace(0, 5, 8), + }, + name='b', attrs=attrs, + ) + result = merge([a, b], resolution=1.0) + assert result.dims == ('lat', 'lon') + assert 'lat' in result.coords + assert 'lon' in result.coords + + # _FillValue propagation ----------------------------------------------- + + def test_reproject_propagates_fill_value(self): + from xrspatial.reproject import reproject + # Build a raster with _FillValue set and no nodata key. + data = np.ones((8, 8), dtype=np.float64) + attrs = {'crs': 'EPSG:4326', '_FillValue': -9999} + raster = xr.DataArray( + data, dims=['y', 'x'], + coords={ + 'y': np.linspace(1, -1, 8), + 'x': np.linspace(-1, 1, 8), + }, + attrs=attrs, + ) + result = reproject(raster, 'EPSG:4326', resolution=0.25) + assert '_FillValue' in result.attrs + assert 'nodata' in result.attrs + assert result.attrs['_FillValue'] == result.attrs['nodata'] + assert result.attrs['_FillValue'] == -9999 + + def test_reproject_omits_fill_value_when_input_omits(self): + from xrspatial.reproject import reproject + raster = self._raster_with_attrs() + assert '_FillValue' not in raster.attrs + result = reproject(raster, 'EPSG:4326', resolution=0.25) + assert '_FillValue' not in result.attrs + assert 'nodata' in result.attrs + + def test_merge_propagates_fill_value(self): + from xrspatial.reproject import merge + a_data = np.ones((8, 8), dtype=np.float64) + b_data = np.ones((8, 8), dtype=np.float64) * 2 + attrs_a = {'crs': 'EPSG:4326', '_FillValue': -9999} + attrs_b = {'crs': 'EPSG:4326', '_FillValue': -9999} + a = xr.DataArray( + a_data, dims=['y', 'x'], + coords={ + 'y': np.linspace(5, -5, 8), + 'x': np.linspace(-5, 0, 8), + }, + name='a', attrs=attrs_a, + ) + b = xr.DataArray( + b_data, dims=['y', 'x'], + coords={ + 'y': np.linspace(5, -5, 8), + 'x': np.linspace(0, 5, 8), + }, + name='b', attrs=attrs_b, + ) + result = merge([a, b], resolution=1.0) + assert '_FillValue' in result.attrs + assert 'nodata' in result.attrs + assert result.attrs['_FillValue'] == result.attrs['nodata'] + assert result.attrs['_FillValue'] == -9999 def test_merge_name_falls_back_to_first_raster(self): from xrspatial.reproject import merge