diff --git a/xrspatial/geotiff/_compression.py b/xrspatial/geotiff/_compression.py index 007c78bd..40a39e7e 100644 --- a/xrspatial/geotiff/_compression.py +++ b/xrspatial/geotiff/_compression.py @@ -696,6 +696,15 @@ def _fp_predictor_encode_row(row_data, width, bps): row_data[i] = np.uint8((np.int32(row_data[i]) - np.int32(row_data[i - 1])) & 0xFF) +@ngjit +def _fp_predictor_encode_rows(data, width, height, bps): + """Dispatch per-row encode from Numba, avoiding Python loop overhead.""" + row_len = width * bps + for row in range(height): + start = row * row_len + _fp_predictor_encode_row(data[start:start + row_len], width, bps) + + def fp_predictor_encode(data: np.ndarray, width: int, height: int, bytes_per_sample: int) -> np.ndarray: """Apply floating-point predictor (predictor=3). @@ -715,10 +724,7 @@ def fp_predictor_encode(data: np.ndarray, width: int, height: int, Encoded array. """ buf = np.ascontiguousarray(data) - row_len = width * bytes_per_sample - for row in range(height): - start = row * row_len - _fp_predictor_encode_row(buf[start:start + row_len], width, bytes_per_sample) + _fp_predictor_encode_rows(buf, width, height, bytes_per_sample) return buf diff --git a/xrspatial/geotiff/tests/test_predictor_fp_write_1313.py b/xrspatial/geotiff/tests/test_predictor_fp_write_1313.py index d168a209..bcdf2872 100644 --- a/xrspatial/geotiff/tests/test_predictor_fp_write_1313.py +++ b/xrspatial/geotiff/tests/test_predictor_fp_write_1313.py @@ -9,6 +9,7 @@ """ from __future__ import annotations +import os import struct import numpy as np @@ -201,3 +202,72 @@ def test_predictor3_multiband_round_trip(tmp_path): else: out_arr = out.values np.testing.assert_array_equal(out_arr, arr) + + +def test_predictor3_large_round_trip_value_exact(tmp_path): + """1024x1024 float32 deflate+predictor=3 round-trips with no value drift. + + The encode path was refactored to dispatch the per-row kernel from + inside an ``@ngjit`` wrapper instead of from a Python ``for`` loop. + Guards against any silent corruption from the refactor by asserting + the output array is byte-for-byte identical to the input: dtype must + match, and a ``uint8`` view of the bytes must compare equal so the + check catches signed-zero drift, NaN payload changes, and any other + bit-level divergence that ``assert_array_equal`` would mask. + """ + h, w = 1024, 1024 + arr = _smooth_float((h, w), np.float32) + da = _da(arr) + path = tmp_path / 'pred3_large_round_trip_1313.tif' + to_geotiff(da, str(path), compression='deflate', predictor=3) + + assert _read_predictor_tag(str(path)) == 3 + out = open_geotiff(str(path)) + out_arr = np.ascontiguousarray(out.values) + assert out_arr.dtype == arr.dtype, ( + f"dtype drift: in={arr.dtype}, out={out_arr.dtype}" + ) + assert out_arr.shape == arr.shape + assert out_arr.tobytes() == arr.tobytes(), ( + "predictor=3 round-trip diverged at the bit level " + "(signed zero, NaN payload, or actual corruption)" + ) + + +def test_predictor3_encode_within_2x_of_predictor2(tmp_path): + """Loose regression check: predictor=3 encode is within 2x of predictor=2. + + Before the ngjit row-loop refactor, predictor=3 was ~2.5x slower than + predictor=2 because the row loop was in Python. Opt-in via + ``XRSPATIAL_RUN_PERF_TESTS=1`` -- shared CI runners, CPU throttling, + debug builds, and noisy filesystems all make absolute wall-clock + timings flaky, so the test stays off by default. Matches the + convention from ``test_streaming_write_parallel.py``. + """ + if os.environ.get('XRSPATIAL_RUN_PERF_TESTS') != '1': + pytest.skip( + "set XRSPATIAL_RUN_PERF_TESTS=1 to run wall-clock perf tests") + + import time + + arr = _smooth_float((1024, 1024), np.float32) + da = _da(arr) + p2 = tmp_path / 'pred2_timing.tif' + p3 = tmp_path / 'pred3_timing.tif' + + # Warm up numba + to_geotiff(da, str(p2), compression='deflate', predictor=2) + to_geotiff(da, str(p3), compression='deflate', predictor=3) + + t0 = time.perf_counter() + to_geotiff(da, str(p2), compression='deflate', predictor=2) + t_p2 = time.perf_counter() - t0 + + t0 = time.perf_counter() + to_geotiff(da, str(p3), compression='deflate', predictor=3) + t_p3 = time.perf_counter() - t0 + + assert t_p3 < 2.0 * t_p2, ( + f'predictor=3 ({t_p3*1000:.1f} ms) is more than 2x slower than ' + f'predictor=2 ({t_p2*1000:.1f} ms); ngjit row loop may have regressed' + )