Inference API¶
Utilities for exporting pretrained SRGAN checkpoints and running tiled inference on Sentinel-2 imagery.
Core helpers¶
End-to-end inference utilities for Sentinel-2 Super-Resolution with SRGAN.
This module provides
- A helper to build and load an SRGAN model from a configuration and checkpoint.
- A convenience wrapper to run patch-based Sentinel-2 inference using
the
opensr_utils.large_file_processinginterface. - A simple
main()entry point for standalone command-line runs.
Typical usage¶
from opensr_srgan.inference import run_sen2_inference run_sen2_inference( ... sen2_path="/data/S2A_MSIL2A_EXAMPLE.SAFE", ... config_path="configs/config_20m.yaml", ... ckpt_path="checkpoints/srgan-20m-6band/last.ckpt", ... gpus=[0] ... )
load_model(config_path=None, ckpt_path=None, device=None)
¶
Instantiate an SRGAN model and optionally load pretrained weights.
This helper is safe to call from tests or scripts. It builds the model from a provided configuration file, loads weights from a checkpoint if available, and transfers the model to the selected device.
Parameters¶
config_path : str or Path, optional
Path to the YAML configuration file describing the generator/discriminator setup.
ckpt_path : str or Path, optional
Path to a Lightning checkpoint or raw PyTorch state dictionary.
device : str or torch.device, optional
Device on which to place the model. Defaults to "cuda" if available.
Returns¶
model : SRGAN_model
The loaded and ready-to-infer SRGAN Lightning module.
device : str
The device string used for model placement (e.g., "cuda" or "cpu").
Notes¶
- Automatically switches the model to evaluation mode.
- Tries to use the Lightning checkpoint API first; falls back to a raw
state_dictfor compatibility with manually saved checkpoints.
Source code in opensr_srgan/inference.py
def load_model(config_path=None, ckpt_path=None, device=None):
"""Instantiate an SRGAN model and optionally load pretrained weights.
This helper is safe to call from tests or scripts. It builds the model from
a provided configuration file, loads weights from a checkpoint if available,
and transfers the model to the selected device.
Parameters
----------
config_path : str or Path, optional
Path to the YAML configuration file describing the generator/discriminator setup.
ckpt_path : str or Path, optional
Path to a Lightning checkpoint or raw PyTorch state dictionary.
device : str or torch.device, optional
Device on which to place the model. Defaults to `"cuda"` if available.
Returns
-------
model : SRGAN_model
The loaded and ready-to-infer SRGAN Lightning module.
device : str
The device string used for model placement (e.g., `"cuda"` or `"cpu"`).
Notes
-----
- Automatically switches the model to evaluation mode.
- Tries to use the Lightning checkpoint API first; falls back to a raw `state_dict`
for compatibility with manually saved checkpoints.
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SRGAN_model(config_file_path=config_path).eval().to(device)
if ckpt_path:
# Try Lightning API first (without 'strict'); fall back to raw state_dict
try:
model = (
SRGAN_model.load_from_checkpoint(ckpt_path, map_location=device)
.eval()
.to(device)
)
except TypeError:
state = torch.load(ckpt_path, map_location=device)
state = state.get("state_dict", state)
model.load_state_dict(state, strict=False)
return model, device
run_sen2_inference(sen2_path=None, config_path=None, ckpt_path=None, gpus=None, window_size=(128, 128), factor=4, overlap=12, eliminate_border_px=2, save_preview=False, debug=False)
¶
Run super-resolution inference on a Sentinel-2 SAFE product.
This function prepares the SRGAN model, launches patch-based processing of
large Sentinel-2 tiles, and orchestrates the full inference pipeline using
the opensr_utils.large_file_processing backend.
WARNING: only works for RGB-NIR as of now.
Parameters¶
sen2_path : str or Path, optional
Path to the Sentinel-2 .SAFE folder to process.
config_path : str or Path, optional
Path to the model configuration YAML file.
ckpt_path : str or Path, optional
Path to the model checkpoint file.
gpus : list[int], optional
GPU IDs to use for inference (e.g., [0, 1]). If None, automatically
selects the first GPU if available.
window_size : tuple(int, int), default=(128, 128)
Size of each input patch (in low-resolution pixels) processed per pass.
factor : int, default=4
Super-resolution scaling factor.
overlap : int, default=12
Overlap (in LR pixels) between inference windows to reduce edge artifacts.
eliminate_border_px : int, default=2
Number of border pixels to remove from each patch during blending.
save_preview : bool, default=False
If True, saves a small visual preview of the reconstructed image.
debug : bool, default=False
Enables verbose debug logs for troubleshooting.
Returns¶
sr_object : opensr_utils.large_file_processing The inference handler object that manages tiling, batching, and output writing.
Notes¶
- Automatically sets
CUDA_VISIBLE_DEVICESwhengpusis provided. - Uses the first GPU by default when running on CUDA-enabled systems.
- The returned object can be further inspected or extended post-run.
Source code in opensr_srgan/inference.py
def run_sen2_inference(
sen2_path=None,
config_path=None,
ckpt_path=None,
gpus=None,
window_size=(128, 128),
factor=4,
overlap=12,
eliminate_border_px=2,
save_preview=False,
debug=False,
):
"""Run super-resolution inference on a Sentinel-2 SAFE product.
This function prepares the SRGAN model, launches patch-based processing of
large Sentinel-2 tiles, and orchestrates the full inference pipeline using
the `opensr_utils.large_file_processing` backend.
WARNING: only works for RGB-NIR as of now.
Parameters
----------
sen2_path : str or Path, optional
Path to the Sentinel-2 `.SAFE` folder to process.
config_path : str or Path, optional
Path to the model configuration YAML file.
ckpt_path : str or Path, optional
Path to the model checkpoint file.
gpus : list[int], optional
GPU IDs to use for inference (e.g., `[0, 1]`). If `None`, automatically
selects the first GPU if available.
window_size : tuple(int, int), default=(128, 128)
Size of each input patch (in low-resolution pixels) processed per pass.
factor : int, default=4
Super-resolution scaling factor.
overlap : int, default=12
Overlap (in LR pixels) between inference windows to reduce edge artifacts.
eliminate_border_px : int, default=2
Number of border pixels to remove from each patch during blending.
save_preview : bool, default=False
If True, saves a small visual preview of the reconstructed image.
debug : bool, default=False
Enables verbose debug logs for troubleshooting.
Returns
-------
sr_object : opensr_utils.large_file_processing
The inference handler object that manages tiling, batching, and output writing.
Notes
-----
- Automatically sets `CUDA_VISIBLE_DEVICES` when `gpus` is provided.
- Uses the first GPU by default when running on CUDA-enabled systems.
- The returned object can be further inspected or extended post-run.
"""
if gpus is not None and len(gpus) > 0:
os.environ.setdefault("CUDA_VISIBLE_DEVICES", ",".join(map(str, gpus)))
model, device = load_model(config_path=config_path, ckpt_path=ckpt_path)
import opensr_utils
sr_object = opensr_utils.large_file_processing(
root=sen2_path,
model=model,
window_size=window_size,
factor=factor,
overlap=overlap,
eliminate_border_px=eliminate_border_px,
device=device,
gpus=gpus if gpus is not None else ([0] if device == "cuda" else []),
save_preview=save_preview,
debug=debug,
)
sr_object.start_super_resolution()
return sr_object
main()
¶
Example standalone entry point for testing Sentinel-2 inference.
Sets up example paths for a sample SAFE product, model config, and checkpoint, then runs the super-resolution pipeline on GPU 0.
This function is primarily intended for smoke testing and manual debugging.
Source code in opensr_srgan/inference.py
def main():
"""Example standalone entry point for testing Sentinel-2 inference.
Sets up example paths for a sample SAFE product, model config, and checkpoint,
then runs the super-resolution pipeline on GPU 0.
This function is primarily intended for smoke testing and manual debugging.
"""
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0")
# ---- Define placeholders ----
sen2_path = Path(__file__).resolve().parent / "data" / "S2A_MSIL2A_EXAMPLE.SAFE"
config_path = Path(__file__).resolve().parent / "configs" / "config_20m.yaml"
ckpt_path = "checkpoints/srgan-20m-6band/last.ckpt"
gpus = [0]
run_sen2_inference(
sen2_path=str(sen2_path),
config_path=str(config_path),
ckpt_path=ckpt_path,
gpus=gpus,
)
Pretrained model factory¶
Utility helpers to instantiate pretrained SRGAN models.
This module provides convenience functions to
- Load
SRGAN_modelinstances from local YAML configuration files and optional checkpoints. - Download and instantiate predefined pretrained models from the Hugging Face Hub (e.g., RGB-NIR and SWIR variants).
- Transparently handle local or remote checkpoints, temporary file storage, and EMA restoration.
Typical usage¶
from opensr_srgan.model.loading import load_inference_model model = load_inference_model("RGB-NIR") model.eval() sr = model(torch.randn(1, 4, 64, 64))
load_from_config(config_path, checkpoint_uri=None, *, map_location=None, mode='train')
¶
Instantiate an :class:SRGAN_model from a configuration and optional checkpoint.
Parameters¶
config_path : str or Path
Filesystem path to the YAML configuration describing generator/discriminator
architecture and training parameters.
checkpoint_uri : str or Path, optional
Path or URL to a pretrained model checkpoint. If omitted, returns an untrained model.
map_location : str or torch.device, optional
Device mapping passed to :func:torch.load when deserializing the checkpoint.
mode : {"train", "eval"}, default="train"
Desired mode in which to initialize the model.
Returns¶
LightningModule A fully initialized SRGAN Lightning module, optionally loaded with pretrained weights.
Raises¶
FileNotFoundError If the provided configuration file does not exist. RuntimeError If checkpoint deserialization or state restoration fails.
Notes¶
- If the checkpoint contains an exponential moving average (EMA) state, it is restored as well.
- The model is returned in evaluation mode regardless of
modeto prevent accidental training until explicitly switched withmodel.train().
Source code in opensr_srgan/_factory.py
def load_from_config(
config_path: Union[str, Path],
checkpoint_uri: Optional[Union[str, Path]] = None,
*,
map_location: Optional[Union[str, torch.device]] = None,
mode: str = "train",
) -> LightningModule:
"""Instantiate an :class:`SRGAN_model` from a configuration and optional checkpoint.
Parameters
----------
config_path : str or Path
Filesystem path to the YAML configuration describing generator/discriminator
architecture and training parameters.
checkpoint_uri : str or Path, optional
Path or URL to a pretrained model checkpoint. If omitted, returns an untrained model.
map_location : str or torch.device, optional
Device mapping passed to :func:`torch.load` when deserializing the checkpoint.
mode : {"train", "eval"}, default="train"
Desired mode in which to initialize the model.
Returns
-------
LightningModule
A fully initialized SRGAN Lightning module, optionally loaded with pretrained weights.
Raises
------
FileNotFoundError
If the provided configuration file does not exist.
RuntimeError
If checkpoint deserialization or state restoration fails.
Notes
-----
- If the checkpoint contains an exponential moving average (EMA) state, it is restored as well.
- The model is returned in evaluation mode regardless of `mode` to prevent accidental training
until explicitly switched with ``model.train()``.
"""
config_path = Path(config_path)
if not config_path.is_file():
raise FileNotFoundError(f"Config file '{config_path}' could not be located.")
model = SRGAN_model(config=config_path, mode=mode)
if checkpoint_uri is not None:
with _maybe_download(checkpoint_uri) as resolved_path:
checkpoint = torch.load(str(resolved_path), map_location=map_location)
state_dict = checkpoint.get("state_dict", checkpoint)
model.load_state_dict(state_dict, strict=False)
if model.ema is not None and "ema_state" in checkpoint:
model.ema.load_state_dict(checkpoint["ema_state"])
model.eval()
return model
load_inference_model(preset, *, cache_dir=None, map_location=None)
¶
Load a pretrained SRGAN model from the Hugging Face Hub.
Downloads both the configuration and checkpoint associated with the requested preset, instantiates the model, restores weights, and returns it ready for inference.
Parameters¶
preset : {"RGB-NIR", "SWIR"} Name of the pretrained model variant to load. cache_dir : str or Path, optional Directory to cache the downloaded files. Uses the default HF cache if omitted. map_location : str or torch.device, optional Device mapping for checkpoint deserialization.
Returns¶
LightningModule Pretrained SRGAN model ready for inference.
Raises¶
ValueError
If an unknown preset name is provided.
ImportError
If huggingface_hub is not installed.
FileNotFoundError
If download or local resolution fails.
Examples¶
model = load_inference_model("RGB-NIR") model.eval() x = torch.randn(1, 4, 64, 64) y = model(x) print(y.shape) torch.Size([1, 4, 256, 256])
Source code in opensr_srgan/_factory.py
def load_inference_model(
preset: str,
*,
cache_dir: Optional[Union[str, Path]] = None,
map_location: Optional[Union[str, torch.device]] = None,
) -> LightningModule:
"""Load a pretrained SRGAN model from the Hugging Face Hub.
Downloads both the configuration and checkpoint associated with the requested
preset, instantiates the model, restores weights, and returns it ready for inference.
Parameters
----------
preset : {"RGB-NIR", "SWIR"}
Name of the pretrained model variant to load.
cache_dir : str or Path, optional
Directory to cache the downloaded files. Uses the default HF cache if omitted.
map_location : str or torch.device, optional
Device mapping for checkpoint deserialization.
Returns
-------
LightningModule
Pretrained SRGAN model ready for inference.
Raises
------
ValueError
If an unknown preset name is provided.
ImportError
If ``huggingface_hub`` is not installed.
FileNotFoundError
If download or local resolution fails.
Examples
--------
>>> model = load_inference_model("RGB-NIR")
>>> model.eval()
>>> x = torch.randn(1, 4, 64, 64)
>>> y = model(x)
>>> print(y.shape)
torch.Size([1, 4, 256, 256])
"""
key = preset.strip().replace("_", "-").upper()
try:
preset_meta = _PRESETS[key]
except KeyError as err:
valid = ", ".join(sorted(_PRESETS))
raise ValueError(
f"Unknown preset '{preset}'. Available options: {valid}."
) from err
try: # pragma: no cover - import guard only used at runtime
from huggingface_hub import hf_hub_download
except ImportError as exc: # pragma: no cover - dependency guard
raise ImportError(
"huggingface_hub is required for load_inference_model. "
"Install the project extras or run 'pip install huggingface-hub'."
) from exc
config_path = hf_hub_download(
repo_id=preset_meta.repo_id,
filename=preset_meta.config_filename,
cache_dir=None if cache_dir is None else str(cache_dir),
)
checkpoint_path = hf_hub_download(
repo_id=preset_meta.repo_id,
filename=preset_meta.checkpoint_filename,
cache_dir=None if cache_dir is None else str(cache_dir),
)
return load_from_config(
config_path,
checkpoint_path,
map_location=map_location,
mode="eval",
)