Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 220 additions & 15 deletions xrspatial/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,211 @@ def _agg_mode(data, out_h, out_w):
}


# -- Block-aggregation kernels for dask chunks -------------------------------
#
# These mirror the eager `_agg_mean / _agg_min / ...` family but compute
# per-pixel windows from the *global* input/output geometry and a chunk
# offset, rather than from the local block shape. The whole chunk runs
# inside a single jitted call, instead of one numba dispatch per output
# pixel as the previous `func(sub, 1, 1)[0, 0]` loop did.
#
# Window bounds for output pixel `go` (a *global* output index):
# gy0 = int(go * global_in_h / global_out_h) - in_y0
# gy1 = max(gy0 + 1,
# int((go + 1) * global_in_h / global_out_h) - in_y0)
# where `in_y0` is the global input index of the chunk's first row
# (negative if `_add_overlap` extended the chunk past the input edge).

@ngjit
def _agg_block_mean_nb(data, target_h, target_w,
go_y0, go_x0,
global_in_h, global_in_w,
global_out_h, global_out_w,
in_y0, in_x0):
out = np.empty((target_h, target_w), dtype=np.float64)
for lo_y in range(target_h):
go_y = go_y0 + lo_y
gy0 = int(go_y * global_in_h / global_out_h) - in_y0
gy1 = int((go_y + 1) * global_in_h / global_out_h) - in_y0
if gy1 < gy0 + 1:
gy1 = gy0 + 1
for lo_x in range(target_w):
go_x = go_x0 + lo_x
gx0 = int(go_x * global_in_w / global_out_w) - in_x0
gx1 = int((go_x + 1) * global_in_w / global_out_w) - in_x0
if gx1 < gx0 + 1:
gx1 = gx0 + 1
total = 0.0
count = 0
for y in range(gy0, gy1):
for x in range(gx0, gx1):
v = data[y, x]
if not np.isnan(v):
total += v
count += 1
out[lo_y, lo_x] = total / count if count > 0 else np.nan
return out


@ngjit
def _agg_block_min_nb(data, target_h, target_w,
go_y0, go_x0,
global_in_h, global_in_w,
global_out_h, global_out_w,
in_y0, in_x0):
out = np.empty((target_h, target_w), dtype=np.float64)
for lo_y in range(target_h):
go_y = go_y0 + lo_y
gy0 = int(go_y * global_in_h / global_out_h) - in_y0
gy1 = int((go_y + 1) * global_in_h / global_out_h) - in_y0
if gy1 < gy0 + 1:
gy1 = gy0 + 1
for lo_x in range(target_w):
go_x = go_x0 + lo_x
gx0 = int(go_x * global_in_w / global_out_w) - in_x0
gx1 = int((go_x + 1) * global_in_w / global_out_w) - in_x0
if gx1 < gx0 + 1:
gx1 = gx0 + 1
best = np.inf
found = False
for y in range(gy0, gy1):
for x in range(gx0, gx1):
v = data[y, x]
if not np.isnan(v) and v < best:
best = v
found = True
out[lo_y, lo_x] = best if found else np.nan
return out


@ngjit
def _agg_block_max_nb(data, target_h, target_w,
go_y0, go_x0,
global_in_h, global_in_w,
global_out_h, global_out_w,
in_y0, in_x0):
out = np.empty((target_h, target_w), dtype=np.float64)
for lo_y in range(target_h):
go_y = go_y0 + lo_y
gy0 = int(go_y * global_in_h / global_out_h) - in_y0
gy1 = int((go_y + 1) * global_in_h / global_out_h) - in_y0
if gy1 < gy0 + 1:
gy1 = gy0 + 1
for lo_x in range(target_w):
go_x = go_x0 + lo_x
gx0 = int(go_x * global_in_w / global_out_w) - in_x0
gx1 = int((go_x + 1) * global_in_w / global_out_w) - in_x0
if gx1 < gx0 + 1:
gx1 = gx0 + 1
best = -np.inf
found = False
for y in range(gy0, gy1):
for x in range(gx0, gx1):
v = data[y, x]
if not np.isnan(v) and v > best:
best = v
found = True
out[lo_y, lo_x] = best if found else np.nan
return out


@ngjit
def _agg_block_median_nb(data, target_h, target_w,
go_y0, go_x0,
global_in_h, global_in_w,
global_out_h, global_out_w,
in_y0, in_x0):
out = np.empty((target_h, target_w), dtype=np.float64)
for lo_y in range(target_h):
go_y = go_y0 + lo_y
gy0 = int(go_y * global_in_h / global_out_h) - in_y0
gy1 = int((go_y + 1) * global_in_h / global_out_h) - in_y0
if gy1 < gy0 + 1:
gy1 = gy0 + 1
for lo_x in range(target_w):
go_x = go_x0 + lo_x
gx0 = int(go_x * global_in_w / global_out_w) - in_x0
gx1 = int((go_x + 1) * global_in_w / global_out_w) - in_x0
if gx1 < gx0 + 1:
gx1 = gx0 + 1
buf = np.empty((gy1 - gy0) * (gx1 - gx0), dtype=np.float64)
n = 0
for y in range(gy0, gy1):
for x in range(gx0, gx1):
v = data[y, x]
if not np.isnan(v):
buf[n] = v
n += 1
if n == 0:
out[lo_y, lo_x] = np.nan
else:
s = np.sort(buf[:n])
if n % 2 == 1:
out[lo_y, lo_x] = s[n // 2]
else:
out[lo_y, lo_x] = (s[n // 2 - 1] + s[n // 2]) / 2.0
return out


@ngjit
def _agg_block_mode_nb(data, target_h, target_w,
go_y0, go_x0,
global_in_h, global_in_w,
global_out_h, global_out_w,
in_y0, in_x0):
out = np.empty((target_h, target_w), dtype=np.float64)
for lo_y in range(target_h):
go_y = go_y0 + lo_y
gy0 = int(go_y * global_in_h / global_out_h) - in_y0
gy1 = int((go_y + 1) * global_in_h / global_out_h) - in_y0
if gy1 < gy0 + 1:
gy1 = gy0 + 1
for lo_x in range(target_w):
go_x = go_x0 + lo_x
gx0 = int(go_x * global_in_w / global_out_w) - in_x0
gx1 = int((go_x + 1) * global_in_w / global_out_w) - in_x0
if gx1 < gx0 + 1:
gx1 = gx0 + 1
buf = np.empty((gy1 - gy0) * (gx1 - gx0), dtype=np.float64)
n = 0
for y in range(gy0, gy1):
for x in range(gx0, gx1):
v = data[y, x]
if not np.isnan(v):
buf[n] = v
n += 1
if n == 0:
out[lo_y, lo_x] = np.nan
continue
s = np.sort(buf[:n])
best_val = s[0]
best_cnt = 1
cur_val = s[0]
cur_cnt = 1
for i in range(1, n):
if s[i] == cur_val:
cur_cnt += 1
else:
if cur_cnt > best_cnt:
best_cnt = cur_cnt
best_val = cur_val
cur_val = s[i]
cur_cnt = 1
if cur_cnt > best_cnt:
best_val = cur_val
out[lo_y, lo_x] = best_val
return out


_AGG_BLOCK_FUNCS = {
'average': _agg_block_mean_nb,
'min': _agg_block_min_nb,
'max': _agg_block_max_nb,
'median': _agg_block_median_nb,
'mode': _agg_block_mode_nb,
}


# -- Dask block helpers ------------------------------------------------------
#
# Interpolation uses map_coordinates with *global* coordinate mapping so
Expand Down Expand Up @@ -472,7 +677,12 @@ def _agg_block_np(block, method, global_in_h, global_in_w,
global_out_h, global_out_w,
cum_in_y, cum_in_x, cum_out_y, cum_out_x,
depth_y, depth_x, block_info=None):
"""Block-aggregate one (possibly overlapped) numpy chunk."""
"""Block-aggregate one (possibly overlapped) numpy chunk.

Runs the entire chunk inside one numba dispatch via the
`_agg_block_*_nb` kernels. Earlier versions called a 1x1 jitted
aggregate per output pixel, which scaled badly for large rasters.
"""
yi, xi = block_info[0]['chunk-location']
target_h = int(cum_out_y[yi + 1] - cum_out_y[yi])
target_w = int(cum_out_x[xi + 1] - cum_out_x[xi])
Expand All @@ -481,20 +691,15 @@ def _agg_block_np(block, method, global_in_h, global_in_w,
# The overlapped block starts depth pixels before the original chunk
in_y0 = int(cum_in_y[yi]) - depth_y
in_x0 = int(cum_in_x[xi]) - depth_x
func = _AGG_FUNCS[method]

out = np.empty((target_h, target_w), dtype=np.float64)
for lo_y in range(target_h):
go_y = int(cum_out_y[yi]) + lo_y
gy0 = int(go_y * global_in_h / global_out_h) - in_y0
gy1 = max(gy0 + 1, int((go_y + 1) * global_in_h / global_out_h) - in_y0)
for lo_x in range(target_w):
go_x = int(cum_out_x[xi]) + lo_x
gx0 = int(go_x * global_in_w / global_out_w) - in_x0
gx1 = max(gx0 + 1, int((go_x + 1) * global_in_w / global_out_w) - in_x0)
sub = block[gy0:gy1, gx0:gx1]
out[lo_y, lo_x] = func(sub, 1, 1)[0, 0]

go_y0 = int(cum_out_y[yi])
go_x0 = int(cum_out_x[xi])

kernel = _AGG_BLOCK_FUNCS[method]
out = kernel(block, target_h, target_w,
go_y0, go_x0,
int(global_in_h), int(global_in_w),
int(global_out_h), int(global_out_w),
in_y0, in_x0)
return out.astype(np.float32)


Expand Down
49 changes: 49 additions & 0 deletions xrspatial/tests/test_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,3 +430,52 @@ def test_dask_path_skips_guard(self, grid_4x4):
# guard not to short-circuit a reasonable dask call.
out = resample(dask_agg, scale_factor=100.0, method='nearest')
assert out.shape == (400, 400)


# ---------------------------------------------------------------------------
# Inlined dask aggregate kernel (#1463)
# ---------------------------------------------------------------------------

@dask_array_available
class TestDaskAggregateInlined:
"""Cover the per-chunk numba aggregate kernel."""

@pytest.mark.parametrize('method',
['average', 'min', 'max', 'median', 'mode'])
def test_inlined_kernel_matches_numpy(self, method):
# 60x60 with 20x20 chunks and scale_factor=1/3 produces a 20x20
# output. Each output pixel collapses a 3x3 input window, and
# the output chunks straddle the input chunk boundaries because
# `_add_overlap` extends each chunk by `depth_y = depth_x = 3`.
rng = np.random.RandomState(1463)
data = rng.rand(60, 60).astype(np.float32)
np_agg = create_test_raster(data, backend='numpy',
attrs={'res': (1.0, 1.0)})
dk_agg = create_test_raster(data, backend='dask+numpy',
attrs={'res': (1.0, 1.0)},
chunks=(20, 20))
np_out = resample(np_agg, scale_factor=1.0 / 3.0, method=method)
dk_out = resample(dk_agg, scale_factor=1.0 / 3.0, method=method)
np.testing.assert_array_equal(dk_out.values, np_out.values)

def test_dask_aggregate_smoke_200x200(self):
# Smoke test: confirm the inlined path completes within a
# generous wall-clock budget on a moderate raster. Not a
# perf assertion -- just guards against accidental
# quadratic regressions in the chunk loop.
import time
rng = np.random.RandomState(146301)
data = rng.rand(200, 200).astype(np.float32)
dk_agg = create_test_raster(data, backend='dask+numpy',
attrs={'res': (1.0, 1.0)},
chunks=(50, 50))
t0 = time.perf_counter()
out = resample(dk_agg, scale_factor=0.25, method='average').compute()
elapsed = time.perf_counter() - t0
assert out.shape == (50, 50)
# 5 s is generous; the inlined kernel runs in well under 1 s on
# a typical laptop. The previous per-pixel dispatch could miss
# this on cold-cache numba compilation runs.
assert elapsed < 30.0, (
f"dask aggregate took {elapsed:.2f}s; expected well under 5s"
)
Loading