Data Pipeline¶
Components responsible for turning configuration files into ready-to-use PyTorch Lightning data modules and applying reproducible normalization policies.
Dataset selection¶
select_dataset(config)
¶
Build train/val datasets from config and wrap them into a LightningDataModule.
Expected config fields (OmegaConf/dict-like):
- config.Data.dataset_type : str
One of {"ExampleDataset", "SEN2NAIP", "LRHRFolderDataset"}.
- config.Data.root_dir / config.Data.dataset_root : str, optional
Required for LRHRFolderDataset.
- config.Data.sen2naip_taco_file : str, optional
Required for SEN2NAIP unless the dataset class receives another source.
Returns:
| Name | Type | Description |
|---|---|---|
pl_datamodule |
LightningDataModule
|
A tiny DataModule that exposes train/val DataLoaders built from the selected datasets. |
Source code in opensr_srgan/data/dataset_selector.py
def select_dataset(config):
"""
Build train/val datasets from `config` and wrap them into a LightningDataModule.
Expected `config` fields (OmegaConf/dict-like):
- config.Data.dataset_type : str
One of {"ExampleDataset", "SEN2NAIP", "LRHRFolderDataset"}.
- config.Data.root_dir / config.Data.dataset_root : str, optional
Required for LRHRFolderDataset.
- config.Data.sen2naip_taco_file : str, optional
Required for SEN2NAIP unless the dataset class receives another source.
Returns
-------
pl_datamodule : LightningDataModule
A tiny DataModule that exposes train/val DataLoaders built from the selected datasets.
"""
dataset_selection = config.Data.dataset_type
# Legacy keys from older configs intentionally fall through to the error below.
if dataset_selection == "ExampleDataset":
from opensr_srgan.data.example_data.example_dataset import ExampleDataset
print("WARNING -- Using Example Dataset!")
print(
"This dataset is exclusively meant for demonstration and debugging, not training or evaluation."
)
print("Please use a proper dataset for any serious work.")
path = "example_dataset/"
ds_train = ExampleDataset(folder=path, phase="train")
ds_val = ExampleDataset(folder=path, phase="val")
elif str(dataset_selection).lower() == "sen2naip":
from opensr_srgan.data.sen2naip.sen2naip_dataset import SEN2NAIP
taco_file = getattr(config.Data, "sen2naip_taco_file", None)
ds_train = SEN2NAIP(config=config, phase="train", taco_file=taco_file)
ds_val = SEN2NAIP(config=config, phase="val", taco_file=taco_file)
elif dataset_selection == "LRHRFolderDataset":
from opensr_srgan.data.lrhr_folder.lrhr_folder_dataset import LRHRFolderDataset
root_folder = getattr(config.Data, "root_dir", None)
if root_folder is None:
root_folder = getattr(config.Data, "dataset_root", None)
if root_folder is None:
raise ValueError(
"LRHRFolderDataset requires Data.root_dir or Data.dataset_root in the config."
)
path = Path(root_folder).expanduser()
if not path.is_dir():
raise FileNotFoundError(
f"LRHRFolderDataset root path does not exist: '{path}'. "
"Set Data.root_dir to a valid dataset directory."
)
ds_train = LRHRFolderDataset(config=config, root_folder=path, phase="train")
ds_val = LRHRFolderDataset(config=config, root_folder=path, phase="val")
else:
# Centralized error so unsupported keys fail loudly & clearly.
raise NotImplementedError(
f"Dataset {dataset_selection} not implemented!"
f"This can happen when:"
f" - (a) you misspelled the dataset name in the config"
f" - (b) the dataset is not implemented in the data folder."
f" - (c) you are trying to use a custom dataset but forgot to add it in data/dataset_selector.py."
)
# Wrap the two datasets into a LightningDataModule with config-driven loader knobs.
pl_datamodule = datamodule_from_datasets(config, ds_train, ds_val)
return pl_datamodule
datamodule_from_datasets(config, ds_train, ds_val)
¶
Convert a pair of prebuilt PyTorch Datasets into a minimal PyTorch Lightning DataModule.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
OmegaConf / dict - like
|
Expected to contain: - Data.train_batch_size : int (fallback: Data.batch_size or 8) - Data.val_batch_size : int (fallback: Data.batch_size or 8) - Data.num_workers : int (default: 4) - Data.prefetch_factor : int (default: 2) |
required |
ds_train
|
Dataset
|
Training dataset (already instantiated). |
required |
ds_val
|
Dataset
|
Validation dataset (already instantiated). |
required |
Returns:
| Type | Description |
|---|---|
LightningDataModule
|
Exposes |
Source code in opensr_srgan/data/dataset_selector.py
def datamodule_from_datasets(config, ds_train, ds_val):
"""
Convert a pair of prebuilt PyTorch Datasets into a minimal PyTorch Lightning DataModule.
Parameters
----------
config : OmegaConf/dict-like
Expected to contain:
- Data.train_batch_size : int (fallback: Data.batch_size or 8)
- Data.val_batch_size : int (fallback: Data.batch_size or 8)
- Data.num_workers : int (default: 4)
- Data.prefetch_factor : int (default: 2)
ds_train : torch.utils.data.Dataset
Training dataset (already instantiated).
ds_val : torch.utils.data.Dataset
Validation dataset (already instantiated).
Returns
-------
LightningDataModule
Exposes `train_dataloader()` and `val_dataloader()` using the settings above.
"""
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
class CustomDataModule(LightningDataModule):
"""Tiny DataModule that forwards config-driven loader settings to DataLoader."""
def __init__(self, ds_train, ds_val, config):
super().__init__()
self.ds_train = ds_train
self.ds_val = ds_val
# Pull loader settings from config with safe fallbacks.
self.train_bs = getattr(
config.Data, "train_batch_size", getattr(config.Data, "batch_size", 8)
)
self.val_bs = getattr(
config.Data, "val_batch_size", getattr(config.Data, "batch_size", 8)
)
self.num_workers = getattr(config.Data, "num_workers", 4)
self.prefetch_factor = getattr(config.Data, "prefetch_factor", 2)
# print dataset sizes for sanity
print(
f"Created Dataset type {config.Data.dataset_type} with {len(self.ds_train)} training samples and {len(self.ds_val)} validation samples.\n"
)
def train_dataloader(self):
"""Return the training DataLoader with common performance flags."""
kwargs = dict(
batch_size=self.train_bs,
shuffle=True, # Shuffle only in training
num_workers=self.num_workers,
pin_memory=True, # Speeds up host→GPU transfer on CUDA
persistent_workers=self.num_workers
> 0, # Keep workers alive between epochs
)
# prefetch_factor is only valid when num_workers > 0
if self.num_workers > 0:
kwargs["prefetch_factor"] = self.prefetch_factor
return DataLoader(self.ds_train, **kwargs)
def val_dataloader(self):
"""Return the validation DataLoader (no shuffle)."""
kwargs = dict(
batch_size=self.val_bs,
shuffle=False,
num_workers=self.num_workers,
pin_memory=True,
persistent_workers=self.num_workers > 0,
)
if self.num_workers > 0:
kwargs["prefetch_factor"] = self.prefetch_factor
return DataLoader(self.ds_val, **kwargs)
return CustomDataModule(ds_train, ds_val, config)
Normalization utilities¶
Utility helpers for configuring tensor normalization strategies.
Normalizer
¶
Factory for applying configurable normalization/denormalization.
The normalizer inspects the provided configuration, determines the
requested normalization scheme, and exposes normalize / denormalize
helpers that downstream components can reuse without importing
:mod:utils.spectral_helpers directly.
Supported methods include remote-sensing-focused helpers such as
"normalise_10k" (0–10000 reflectance → [0, 1]),
"normalise_10k_signed" (0–10000 reflectance → [-1, 1]),
"normalise_s2" (Sentinel-2 symmetric stretch), "zero_one" (clamp to
[0, 1]) and "zero_one_signed" ([0, 1] ↔ [-1, 1]). Custom
strategies can be registered by providing a mapping with
{"name": "custom", "normalize": "module:callable", ...} in the
configuration.
Source code in opensr_srgan/data/utils/normalizer.py
class Normalizer:
"""Factory for applying configurable normalization/denormalization.
The normalizer inspects the provided configuration, determines the
requested normalization scheme, and exposes ``normalize`` / ``denormalize``
helpers that downstream components can reuse without importing
:mod:`utils.spectral_helpers` directly.
Supported methods include remote-sensing-focused helpers such as
``"normalise_10k"`` (0–10000 reflectance → ``[0, 1]``),
``"normalise_10k_signed"`` (0–10000 reflectance → ``[-1, 1]``),
``"normalise_s2"`` (Sentinel-2 symmetric stretch), ``"zero_one"`` (clamp to
``[0, 1]``) and ``"zero_one_signed"`` (``[0, 1]`` ↔ ``[-1, 1]``). Custom
strategies can be registered by providing a mapping with
``{"name": "custom", "normalize": "module:callable", ...}`` in the
configuration.
"""
_STANDARD_METHODS: Dict[str, NormalizationStrategy] = {
"sen2_stretch": NormalizationStrategy(
normalize=sen2_stretch,
denormalize=lambda tensor: torch.clamp(tensor * (3.0 / 10.0), 0.0, 1.0),
),
"normalise_10k": NormalizationStrategy(
normalize=partial(normalise_10k, stage="norm"),
denormalize=partial(normalise_10k, stage="denorm"),
),
"normalise_10k_signed": NormalizationStrategy(
normalize=partial(normalise_10k_signed, stage="norm"),
denormalize=partial(normalise_10k_signed, stage="denorm"),
),
"normalise_s2": NormalizationStrategy(
normalize=partial(normalise_s2, stage="norm"),
denormalize=partial(normalise_s2, stage="denorm"),
),
"zero_one": NormalizationStrategy(
normalize=lambda tensor: torch.clamp(tensor, 0.0, 1.0),
denormalize=lambda tensor: torch.clamp(tensor, 0.0, 1.0),
),
"zero_one_signed": NormalizationStrategy(
normalize=partial(zero_one_signed, stage="norm"),
denormalize=partial(zero_one_signed, stage="denorm"),
),
"identity": NormalizationStrategy(
normalize=lambda tensor: tensor,
denormalize=lambda tensor: tensor,
),
}
_ALIASES: Dict[str, str] = {
"normalize_10k": "normalise_10k",
"reflectance": "normalise_10k",
"reflectance_0_1": "normalise_10k",
"reflectance_signed": "normalise_10k_signed",
"normalize_10k_signed": "normalise_10k_signed",
"sentinel2": "normalise_10k",
"sentinel2_signed": "normalise_10k_signed",
"s2": "normalise_10k",
"s2_signed": "normalise_10k_signed",
"normalize_s2": "normalise_s2",
"zero_to_one": "zero_one",
"zero_one_range": "zero_one",
"signed_zero_one": "zero_one_signed",
"minusone_one": "zero_one_signed",
"none": "identity",
}
def __init__(self, config: Any):
data_cfg = getattr(config, "Data", None)
raw_cfg: Any = None
if data_cfg is not None:
raw_cfg = getattr(data_cfg, "normalization", None)
if raw_cfg is None and isinstance(data_cfg, dict):
raw_cfg = data_cfg.get("normalization")
if raw_cfg is None:
raw_cfg = "sen2_stretch"
method, strategy = self._resolve_strategy(raw_cfg)
self._cfg = _NormalizerConfig(method=method)
self._strategy = strategy
@property
def method(self) -> str:
"""Return the normalization method configured for this instance."""
return self._cfg.method
def normalize(self, tensor: torch.Tensor) -> torch.Tensor:
"""Normalize ``tensor`` according to the configured method."""
return self._strategy.normalize(tensor)
def denormalize(self, tensor: torch.Tensor) -> torch.Tensor:
"""Invert the normalization previously applied by :meth:`normalize`."""
return self._strategy.denormalize(tensor)
@classmethod
def available_methods(cls) -> Tuple[str, ...]:
"""Return the canonical names of built-in normalization strategies."""
return tuple(sorted(cls._STANDARD_METHODS.keys()))
def _resolve_strategy(self, raw_cfg: Any) -> Tuple[str, NormalizationStrategy]:
"""Resolve ``raw_cfg`` into a normalisation strategy.
Parameters
----------
raw_cfg : Any
Configuration value extracted from ``Data.normalization``. Can be a
string alias, a mapping with ``name``/``method`` keys, or a mapping
describing custom callables.
"""
if isinstance(raw_cfg, Mapping):
name = raw_cfg.get("name") or raw_cfg.get("method") or "custom"
name = str(name).strip().lower()
name = name.replace("normalize", "normalise")
if name == "custom":
strategy = self._build_custom_strategy(raw_cfg)
return "custom", strategy
raw_cfg = name
if not isinstance(raw_cfg, str):
raise TypeError(
"Normalization config must be a string or mapping, "
f"received: {type(raw_cfg)!r}"
)
method = raw_cfg.strip().lower()
method = method.replace("normalize", "normalise")
method = self._ALIASES.get(method, method)
if method == "custom":
raise ValueError(
"Use a mapping with 'name: custom' and callable paths to configure custom normalization."
)
try:
strategy = self._STANDARD_METHODS[method]
except KeyError as exc:
supported = ", ".join(sorted(self._STANDARD_METHODS))
raise ValueError(
f"Unsupported normalization '{raw_cfg}'. Supported methods: {supported}."
) from exc
return method, strategy
def _build_custom_strategy(self, cfg: Mapping[str, Any]) -> NormalizationStrategy:
"""Instantiate a strategy from user-supplied callables."""
if "normalize" not in cfg:
raise ValueError(
"Custom normalization requires a 'normalize' callable path."
)
normalize_path = cfg["normalize"]
denormalize_path = cfg.get("denormalize")
shared_kwargs = dict(cfg.get("kwargs", {}))
norm_kwargs = {**shared_kwargs, **cfg.get("normalize_kwargs", {})}
denorm_kwargs = {**shared_kwargs, **cfg.get("denormalize_kwargs", {})}
normalize_fn = _load_callable(normalize_path, norm_kwargs)
if denormalize_path is None:
if denorm_kwargs:
raise ValueError(
"'denormalize_kwargs' provided without a 'denormalize' callable."
)
denormalize_fn = lambda tensor: tensor
else:
denormalize_fn = _load_callable(denormalize_path, denorm_kwargs)
return NormalizationStrategy(
normalize=normalize_fn,
denormalize=denormalize_fn,
)