Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 71 additions & 10 deletions xrspatial/geotiff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1269,9 +1269,45 @@ def read_geotiff_dask(source: str, *, dtype=None, chunks: int | tuple = 512,
if isinstance(source, str) and source.lower().endswith('.vrt'):
return read_vrt(source, dtype=dtype, name=name, chunks=chunks)

# Metadata-only read: O(1) memory via mmap, no pixel decompression
geo_info, full_h, full_w, file_dtype, n_bands = _read_geo_info(
source, overview_level=overview_level)
# P5: HTTP COG sources used to fire one IFD/header GET per chunk
# task. Parse metadata once here so every delayed task can reuse it.
is_http = (
isinstance(source, str)
and source.startswith(('http://', 'https://'))
)
http_meta = None
http_meta_key = None
if is_http:
import dask
from ._reader import _HTTPSource, _parse_cog_http_meta
_src = _HTTPSource(source)
try:
http_header, http_ifd, http_geo, _ = _parse_cog_http_meta(
_src, overview_level=overview_level)
finally:
_src.close()
http_meta = (http_header, http_ifd)
# Wrap the parsed metadata in a single dask Delayed so every
# window task takes it as a graph input, not a Python closure.
# Without this, the (TIFFHeader, IFD) pair -- which can carry
# multi-million-entry TileOffsets/TileByteCounts tuples on
# large COGs -- would be embedded in each chunk task and
# serialised N times under distributed/process schedulers.
http_meta_key = dask.delayed(http_meta, pure=True)
geo_info = http_geo
Comment thread
brendancol marked this conversation as resolved.
full_h = http_ifd.height
full_w = http_ifd.width
from ._dtypes import resolve_bits_per_sample, tiff_dtype_to_numpy
bps = resolve_bits_per_sample(http_ifd.bits_per_sample)
file_dtype = tiff_dtype_to_numpy(bps, http_ifd.sample_format)
n_bands = (
http_ifd.samples_per_pixel
if http_ifd.samples_per_pixel > 1 else 0
)
else:
# Metadata-only read: O(1) memory via mmap, no pixel decompression
geo_info, full_h, full_w, file_dtype, n_bands = _read_geo_info(
source, overview_level=overview_level)
nodata = geo_info.nodata

# Nodata masking promotes integer arrays to float64 (for NaN).
Expand Down Expand Up @@ -1353,7 +1389,8 @@ def read_geotiff_dask(source: str, *, dtype=None, chunks: int | tuple = 512,
_delayed_read_window(source, r0, c0, r1, c1,
overview_level, nodata,
band_arg,
target_dtype=target_dtype if dtype is not None else None),
target_dtype=target_dtype if dtype is not None else None,
http_meta_key=http_meta_key),
shape=block_shape,
dtype=target_dtype,
)
Expand All @@ -1374,13 +1411,37 @@ def read_geotiff_dask(source: str, *, dtype=None, chunks: int | tuple = 512,


def _delayed_read_window(source, r0, c0, r1, c1, overview_level, nodata,
band, *, target_dtype=None):
"""Dask-delayed function to read a single window."""
band, *, target_dtype=None, http_meta_key=None):
"""Dask-delayed function to read a single window.

*http_meta_key* is an optional ``Delayed[(TIFFHeader, IFD)]`` parsed
once by :func:`read_geotiff_dask` and wrapped via ``dask.delayed``.
Passing it as a function argument (rather than a closure capture)
makes the metadata a single graph input that all window tasks
depend on, so distributed/process schedulers serialise it once
instead of once per chunk. For local sources it is ``None``.
"""
import dask

@dask.delayed
def _read():
arr, _ = read_to_array(source, window=(r0, c0, r1, c1),
overview_level=overview_level, band=band)
def _read(http_meta):
if http_meta is not None and isinstance(source, str) and \
source.startswith(('http://', 'https://')):
from ._reader import _HTTPSource, _fetch_decode_cog_http_tiles
header, ifd = http_meta
src = _HTTPSource(source)
try:
arr = _fetch_decode_cog_http_tiles(
src, header, ifd, window=(r0, c0, r1, c1))
finally:
src.close()
if (arr.ndim == 3 and ifd.samples_per_pixel > 1
and band is not None):
arr = arr[:, :, band]
else:
arr, _ = read_to_array(source, window=(r0, c0, r1, c1),
overview_level=overview_level,
band=band)
if nodata is not None:
if arr.dtype.kind == 'f' and not np.isnan(nodata):
arr = arr.copy()
Expand All @@ -1393,7 +1454,7 @@ def _read():
if target_dtype is not None:
arr = arr.astype(target_dtype)
return arr
return _read()
return _read(http_meta_key)


def read_geotiff_gpu(source: str, *,
Expand Down
Loading
Loading