# pyrawph/l1/l1_tile.py
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import math
import numpy as np
import matplotlib.pyplot as plt
try:
import torch
except Exception:
torch = None
BandSpec = Union[int, float, str]
_ALIAS_NM: Dict[str, int] = {
"BLUE": 490, "B": 490,
"GREEN": 560, "G": 560,
"RED": 665, "R": 665,
"REDEDGE1": 705, "RE1": 705,
"REDEDGE2": 740, "RE2": 740,
"REDEDGE3": 783, "RE3": 783,
"NIR": 842,
}
def _resolve_band_from_meta(
band: BandSpec,
wavelengths_nm: Optional[Sequence[Optional[int]]],
C: int,
) -> int:
if isinstance(band, int):
if not (0 <= band < C):
raise ValueError(f"Band index out of range: {band} (C={C})")
return band
# float wavelength (nm) => closest wavelength (ignore None)
if isinstance(band, float):
wl = int(round(band))
wls = list(wavelengths_nm or [])
valid = [(i, v) for i, v in enumerate(wls) if v is not None]
if not valid:
raise ValueError("No valid wavelengths in metadata; cannot resolve wavelength.")
return int(min(valid, key=lambda iv: abs(int(iv[1]) - wl))[0])
# string cases
s = str(band).strip().upper()
s_clean = s.replace(" ", "").replace("_", "")
# "842nm"
if s_clean.endswith("NM") and s_clean[:-2].isdigit():
return _resolve_band_from_meta(float(int(s_clean[:-2])), wavelengths_nm, C)
# "3"
if s.isdigit():
return _resolve_band_from_meta(int(s), wavelengths_nm, C)
# "B3" / "BAND_3"
for prefix in ("BAND_", "BAND", "B"):
if s.startswith(prefix) and s[len(prefix):].isdigit():
return _resolve_band_from_meta(int(s[len(prefix):]), wavelengths_nm, C)
# aliases => closest wavelength
if s in _ALIAS_NM:
return _resolve_band_from_meta(float(_ALIAS_NM[s]), wavelengths_nm, C)
raise ValueError(f"Cannot resolve band spec: {band!r}")
[docs]
@dataclass
class L1_tile:
"""
Represent one local tile extracted from an L1 event.
A tile stores a local multiband array together with its metadata and provides
convenience methods for array/tensor conversion, geographic corner retrieval,
tile inspection, band visualization, and optional location plotting.
The tile array is expected to follow the `(C, H, W)` convention, where `C` is
the number of spectral bands. Metadata may include geographic bounds, CRS,
band wavelengths, sensing time, and creation time.
"""
tile_name: str
arr: np.ndarray
meta: Dict[str, Any]
device: str = "cpu"
# basic getters
[docs]
def as_numpy(self) -> np.ndarray:
"""
Return the tile data as a NumPy array.
The returned array is the internal tile array stored by the object and is
expected to have shape `(C, H, W)`.
Returns:
The tile data as a NumPy array of shape `(C, H, W)`.
"""
return self.arr
[docs]
def as_tensor(self, as_float32: bool = True):
"""
Return the tile data as a PyTorch tensor on the configured device.
Args:
as_float32: If `True`, cast the tensor to `torch.float32` before moving
it to the target device.
Returns:
A PyTorch tensor containing the tile data, typically with shape
`(C, H, W)`.
Raises:
ImportError: If PyTorch is not available in the current environment.
"""
if torch is None:
raise ImportError("torch is not available")
t = torch.from_numpy(self.arr)
if as_float32 and t.dtype != torch.float32:
t = t.float()
return t.to(self.device)
# Helpers
[docs]
def get_tile_coordinates(self, latlon_format: bool = True) -> List[List[float]]:
"""
Return the four tile corner coordinates derived from the metadata bounds.
Corners are returned in the following order:
top-left, bottom-left, bottom-right, top-right.
Coordinates are extracted from `meta["bounds"]`. The returned values are
formatted as `[y, x]` pairs. This method does not reproject coordinates; it
only reformats the stored bounds values.
Args:
latlon_format: If `True`, return coordinates in `[lat, lon]`-style
ordering. In the current implementation, coordinates are returned as
`[y, x]` pairs in all cases.
Returns:
A list of four corner coordinates. If no bounds are available, returns an
empty list.
"""
b = self.meta.get("bounds", None)
crs = self.meta.get("crs", None)
if b is None:
return []
left = float(getattr(b, "left", b[0]))
bottom = float(getattr(b, "bottom", b[1]))
right = float(getattr(b, "right", b[2]))
top = float(getattr(b, "top", b[3]))
# corners as (x, y)
corners_xy = [(left, top), (left, bottom), (right, bottom), (right, top)]
if latlon_format and crs == "EPSG:4326":
return [[y, x] for (x, y) in corners_xy] # [lat, lon]
return [[y, x] for (x, y) in corners_xy]
[docs]
def get_tile_info(self):
"""
Return a compact summary of the tile.
The returned tuple contains the tile name, sensing time, creation time,
corner coordinates, and footprint coordinates.
Returns:
A tuple of the form:
- tile name,
- sensing time,
- creation time,
- corner coordinates from :meth:`get_tile_coordinates`,
- footprint coordinates from :meth:`get_tile_footprint_coordinates`.
"""
tile_name = self.tile_name
sensing_time = self.meta.get("sensing_time", None)
creation_time = self.meta.get("creation_time", None)
return (
tile_name,
sensing_time,
creation_time,
self.get_tile_coordinates(latlon_format=True),
self.get_tile_footprint_coordinates(latlon_format=True),
)
# Visualization
[docs]
def show_bands(
self,
bands: Optional[Sequence[BandSpec]] = None,
downsampling: bool = True,
max_size: int = 512,
stretch: Tuple[float, float] = (2.0, 98.0),
cmap: str = "viridis",
) -> None:
"""
Display one or more tile bands for visual inspection.
If `bands` is `None`, all tile bands are displayed. Otherwise, each requested
band is resolved from the tile metadata using integer indices, wavelengths in
nanometers, or string specifications such as `"B3"`, `"842nm"`, `"NIR"`, or
`"RED"`.
Displayed bands can be downsampled for faster rendering, and each band is
independently normalized using a percentile stretch before display.
Args:
bands: Optional sequence of band selectors to display. If `None`, all
tile bands are shown.
downsampling: If `True`, apply stride-based downsampling so that the
displayed image size remains manageable.
max_size: Target maximum size used to determine the downsampling stride.
stretch: Lower and upper percentiles used for contrast stretching.
cmap: Matplotlib colormap used for display.
Returns:
None.
Raises:
ValueError: If one of the requested band selectors cannot be resolved.
"""
arr = self.arr
C, H, W = int(arr.shape[0]), int(arr.shape[1]), int(arr.shape[2])
wls = self.meta.get("band_wavelength_nm", None)
if bands is None:
idxs = list(range(C))
labels = []
for i in idxs:
wl = None
if isinstance(wls, (list, tuple)) and i < len(wls):
wl = wls[i]
labels.append(f"B{i}" if wl is None else f"B{i} ({wl}nm)")
else:
idxs = [_resolve_band_from_meta(b, wls, C) for b in bands]
labels = []
for spec, i in zip(bands, idxs):
wl = None
if isinstance(wls, (list, tuple)) and i < len(wls):
wl = wls[i]
base = str(spec)
labels.append(base if wl is None else f"{base} -> B{i} ({wl}nm)")
stride = 1
if downsampling:
stride = max(1, int(math.ceil(max(H, W) / float(max_size))))
n = len(idxs)
ncols = min(4, n)
nrows = int(math.ceil(n / ncols))
fig, axes = plt.subplots(nrows, ncols, figsize=(4 * ncols, 4 * nrows))
axes = np.array(axes).reshape(-1)
for k, ax in enumerate(axes):
if k >= n:
ax.axis("off")
continue
bi = idxs[k]
x = arr[bi][::stride, ::stride].astype(np.float32)
lo, hi = np.nanpercentile(x, stretch)
x = (x - lo) / (hi - lo + 1e-6)
x = np.clip(x, 0, 1)
ax.imshow(x, cmap=cmap)
ax.set_title(labels[k])
ax.axis("off")
plt.tight_layout()
plt.show()
[docs]
def plot_location(self, world: bool = True, title: Optional[str] = None):
"""
Plot the geographic location of the tile from its metadata bounds.
This method delegates the actual plotting to
`pyrawph.utils.optional_plots.plot_bounds`. If available, a world basemap can
be shown in the background; otherwise the tile rectangle alone is plotted.
Args:
world: If `True`, attempt to display the tile on top of a world basemap.
title: Optional plot title. If `None`, the tile name is used.
Returns:
The Matplotlib axes used for the plot.
Raises:
ValueError: If the tile metadata does not contain geographic bounds.
"""
from ..utils.optional_plots import plot_bounds
b = self.meta.get("bounds", None)
if b is None:
raise ValueError("No bounds in tile meta; cannot plot location.")
return plot_bounds(b, world=world, title=title or self.tile_name)