Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 32 additions & 12 deletions xrspatial/reproject/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,15 +687,25 @@ def reproject(

ydim, xdim = _find_spatial_dims(raster)
# Carry input attrs forward so units, long_name, scale_factor, etc.
# survive the transform. Pop attrs that are stale after reprojection:
# the affine `transform` and grid `res` describe the old grid, and
# `crs_wkt` would duplicate (or contradict) the canonical `crs` we re-emit.
# survive the transform. Re-emit `transform` and `res` for the new
# output grid (rasterio 6-tuple convention). Drop `crs_wkt` since it
# would duplicate (or contradict) the canonical `crs` we re-emit.
out_left, _out_bottom, _out_right, out_top = grid['bounds']
out_res_x = grid['res_x']
out_res_y = grid['res_y']
out_attrs = {**raster.attrs}
out_attrs.pop('transform', None)
out_attrs.pop('crs_wkt', None)
out_attrs.pop('res', None)
out_attrs['crs'] = tgt_wkt
out_attrs['nodata'] = nd
out_attrs['res'] = (out_res_x, out_res_y)
out_attrs['transform'] = (
out_res_x, 0.0, out_left, 0.0, -out_res_y, out_top,
)
# If the input used `_FillValue`, propagate the resolved nodata to
# both keys so downstream consumers reading either find a consistent
# value. If `_FillValue` was absent, leave it absent.
if '_FillValue' in raster.attrs:
out_attrs['_FillValue'] = nd
if tgt_vertical_crs is not None:
out_attrs['vertical_crs'] = tgt_vertical_crs

Expand Down Expand Up @@ -1479,7 +1489,8 @@ def merge(
"Ensure all rasters have CRS metadata."
)
sb = _source_bounds(r)
ss = (r.sizes[r.dims[-2]], r.sizes[r.dims[-1]])
r_ydim, r_xdim = _find_spatial_dims(r)
ss = (r.sizes[r_ydim], r.sizes[r_xdim])
yd = _is_y_descending(r)
# Per-raster input nodata sentinel. Detected independently of the
# user-supplied output nodata so that mixed-sentinel inputs are
Expand Down Expand Up @@ -1552,8 +1563,7 @@ def merge(
)

y_coords, x_coords = _make_output_coords(out_bounds, out_shape)
ydim = rasters[0].dims[-2]
xdim = rasters[0].dims[-1]
ydim, xdim = _find_spatial_dims(rasters[0])

out_coords = {ydim: y_coords, xdim: x_coords}
# Carry forward non-spatial coords from the first raster (e.g. scalar
Expand All @@ -1567,14 +1577,24 @@ def merge(
out_coords[cname] = cval

# Carry the first raster's attrs forward (matches the default
# strategy='first'). Drop attrs describing the old grid: `transform`,
# `res`, and the duplicate `crs_wkt` are no longer accurate.
# strategy='first'). Drop the duplicate `crs_wkt` and re-emit
# `transform` and `res` for the new output grid (rasterio 6-tuple).
out_left, _out_bottom, _out_right, out_top = grid['bounds']
out_res_x = grid['res_x']
out_res_y = grid['res_y']
out_attrs = {**rasters[0].attrs}
out_attrs.pop('transform', None)
out_attrs.pop('crs_wkt', None)
out_attrs.pop('res', None)
out_attrs['crs'] = tgt_wkt
out_attrs['nodata'] = nd
out_attrs['res'] = (out_res_x, out_res_y)
out_attrs['transform'] = (
out_res_x, 0.0, out_left, 0.0, -out_res_y, out_top,
)
# If the first raster used `_FillValue`, propagate the resolved
# nodata to both keys for consistent round-trip. Leave absent
# otherwise.
if '_FillValue' in rasters[0].attrs:
out_attrs['_FillValue'] = nd

result = xr.DataArray(
result_data,
Expand Down
175 changes: 165 additions & 10 deletions xrspatial/tests/test_reproject.py
Original file line number Diff line number Diff line change
Expand Up @@ -2121,19 +2121,20 @@ def test_reproject_preserves_long_name(self):
result = reproject(raster, 'EPSG:4326', resolution=0.25)
assert result.attrs.get('long_name') == 'elevation'

def test_reproject_drops_stale_transform(self):
def test_reproject_replaces_stale_transform(self):
from xrspatial.reproject import reproject
raster = self._raster_with_attrs(
{'transform': (1.0, 0.0, 0.0, 0.0, -1.0, 0.0)}
)
stale = (1.0, 0.0, 0.0, 0.0, -1.0, 0.0)
raster = self._raster_with_attrs({'transform': stale})
result = reproject(raster, 'EPSG:3857')
assert 'transform' not in result.attrs
assert 'transform' in result.attrs
assert tuple(result.attrs['transform']) != stale

def test_reproject_drops_stale_res(self):
def test_reproject_replaces_stale_res(self):
from xrspatial.reproject import reproject
raster = self._raster_with_attrs({'res': (1.0, 1.0)})
result = reproject(raster, 'EPSG:3857')
assert 'res' not in result.attrs
assert 'res' in result.attrs
assert tuple(result.attrs['res']) != (1.0, 1.0)

def test_reproject_overrides_crs(self):
from xrspatial.reproject import reproject
Expand Down Expand Up @@ -2168,17 +2169,171 @@ def test_merge_preserves_first_raster_attrs(self):
assert result.attrs.get('units') == 'm'
assert result.attrs.get('long_name') == 'elev'

def test_merge_drops_stale_transform(self):
def test_merge_replaces_stale_transform(self):
from xrspatial.reproject import merge
stale = (1.0, 0.0, 0.0, 0.0, -1.0, 0.0)
a = self._raster_with_attrs(
{'transform': (1.0, 0.0, 0.0, 0.0, -1.0, 0.0)},
{'transform': stale},
x_range=(-5, 0), y_range=(-5, 5),
)
b = self._raster_with_attrs(
x_range=(0, 5), y_range=(-5, 5),
)
result = merge([a, b], resolution=1.0)
assert 'transform' not in result.attrs
assert 'transform' in result.attrs
assert tuple(result.attrs['transform']) != stale

# Fresh transform/res emission ----------------------------------------

def test_reproject_emits_fresh_transform(self):
from xrspatial.reproject import reproject
stale = (1.0, 0.0, 0.0, 0.0, -1.0, 10.0)
raster = self._raster_with_attrs({'transform': stale})
result = reproject(raster, 'EPSG:3857', resolution=50000.0)
t = result.attrs['transform']
assert len(t) == 6
# Output is in EPSG:3857 so transform values cannot match the stale
# geographic-degree input.
assert tuple(t) != stale
# transform[0] is res_x, transform[4] is -res_y, transform[2] is
# left edge, transform[5] is top edge.
res_x, res_y = result.attrs['res']
assert t[0] == res_x
assert t[4] == -res_y
# Top edge: y coord of first row plus half a pixel.
y0 = float(result.coords['y'].values[0])
assert t[5] == pytest.approx(y0 + res_y / 2)
# Left edge: x coord of first col minus half a pixel.
x0 = float(result.coords['x'].values[0])
assert t[2] == pytest.approx(x0 - res_x / 2)

def test_reproject_emits_fresh_res(self):
from xrspatial.reproject import reproject
raster = self._raster_with_attrs({'res': (1.0, 1.0)})
# Use an explicit resolution very different from input.
result = reproject(raster, 'EPSG:4326', resolution=0.25)
assert 'res' in result.attrs
res_x, res_y = result.attrs['res']
# Pixel size derived from output coords must match.
x = result.coords['x'].values
y = result.coords['y'].values
actual_res_x = float(abs(x[1] - x[0]))
actual_res_y = float(abs(y[1] - y[0]))
assert res_x == pytest.approx(actual_res_x)
assert res_y == pytest.approx(actual_res_y)

def test_reproject_no_input_transform_still_emits_one(self):
from xrspatial.reproject import reproject
raster = self._raster_with_attrs()
assert 'transform' not in raster.attrs
result = reproject(raster, 'EPSG:4326', resolution=0.25)
assert 'transform' in result.attrs
assert 'res' in result.attrs
assert len(result.attrs['transform']) == 6

def test_merge_emits_fresh_transform_and_res(self):
from xrspatial.reproject import merge
a = self._raster_with_attrs(
x_range=(-5, 0), y_range=(-5, 5),
)
b = self._raster_with_attrs(
x_range=(0, 5), y_range=(-5, 5),
)
result = merge([a, b], resolution=1.0)
assert 'transform' in result.attrs
assert 'res' in result.attrs
t = result.attrs['transform']
assert len(t) == 6
res_x, res_y = result.attrs['res']
assert t[0] == res_x
assert t[4] == -res_y
y0 = float(result.coords['y'].values[0])
x0 = float(result.coords['x'].values[0])
assert t[5] == pytest.approx(y0 + res_y / 2)
assert t[2] == pytest.approx(x0 - res_x / 2)

def test_merge_finds_spatial_dims_with_lat_lon(self):
from xrspatial.reproject import merge
a_data = np.ones((8, 8), dtype=np.float64)
b_data = np.ones((8, 8), dtype=np.float64) * 2
attrs = {'crs': 'EPSG:4326', 'nodata': np.nan}
a = xr.DataArray(
a_data, dims=['lat', 'lon'],
coords={
'lat': np.linspace(5, -5, 8),
'lon': np.linspace(-5, 0, 8),
},
name='a', attrs=attrs,
)
b = xr.DataArray(
b_data, dims=['lat', 'lon'],
coords={
'lat': np.linspace(5, -5, 8),
'lon': np.linspace(0, 5, 8),
},
name='b', attrs=attrs,
)
result = merge([a, b], resolution=1.0)
assert result.dims == ('lat', 'lon')
assert 'lat' in result.coords
assert 'lon' in result.coords

# _FillValue propagation -----------------------------------------------

def test_reproject_propagates_fill_value(self):
from xrspatial.reproject import reproject
# Build a raster with _FillValue set and no nodata key.
data = np.ones((8, 8), dtype=np.float64)
attrs = {'crs': 'EPSG:4326', '_FillValue': -9999}
raster = xr.DataArray(
data, dims=['y', 'x'],
coords={
'y': np.linspace(1, -1, 8),
'x': np.linspace(-1, 1, 8),
},
attrs=attrs,
)
result = reproject(raster, 'EPSG:4326', resolution=0.25)
assert '_FillValue' in result.attrs
assert 'nodata' in result.attrs
assert result.attrs['_FillValue'] == result.attrs['nodata']
assert result.attrs['_FillValue'] == -9999

def test_reproject_omits_fill_value_when_input_omits(self):
from xrspatial.reproject import reproject
raster = self._raster_with_attrs()
assert '_FillValue' not in raster.attrs
result = reproject(raster, 'EPSG:4326', resolution=0.25)
assert '_FillValue' not in result.attrs
assert 'nodata' in result.attrs

def test_merge_propagates_fill_value(self):
from xrspatial.reproject import merge
a_data = np.ones((8, 8), dtype=np.float64)
b_data = np.ones((8, 8), dtype=np.float64) * 2
attrs_a = {'crs': 'EPSG:4326', '_FillValue': -9999}
attrs_b = {'crs': 'EPSG:4326', '_FillValue': -9999}
a = xr.DataArray(
a_data, dims=['y', 'x'],
coords={
'y': np.linspace(5, -5, 8),
'x': np.linspace(-5, 0, 8),
},
name='a', attrs=attrs_a,
)
b = xr.DataArray(
b_data, dims=['y', 'x'],
coords={
'y': np.linspace(5, -5, 8),
'x': np.linspace(0, 5, 8),
},
name='b', attrs=attrs_b,
)
result = merge([a, b], resolution=1.0)
assert '_FillValue' in result.attrs
assert 'nodata' in result.attrs
assert result.attrs['_FillValue'] == result.attrs['nodata']
assert result.attrs['_FillValue'] == -9999

def test_merge_name_falls_back_to_first_raster(self):
from xrspatial.reproject import merge
Expand Down
Loading