Skip to content

Inference API

Utilities for exporting pretrained SRGAN checkpoints and running tiled inference on Sentinel-2 imagery.

Core helpers

End-to-end inference utilities for Sentinel-2 Super-Resolution with SRGAN.

This module provides
  • A helper to build and load an SRGAN model from a configuration and checkpoint.
  • A convenience wrapper to run patch-based Sentinel-2 inference using the opensr_utils.large_file_processing interface.
  • A simple main() entry point for standalone command-line runs.

Typical usage

from opensr_srgan.inference import run_sen2_inference run_sen2_inference( ... sen2_path="/data/S2A_MSIL2A_EXAMPLE.SAFE", ... config_path="configs/config_20m.yaml", ... ckpt_path="checkpoints/srgan-20m-6band/last.ckpt", ... gpus=[0] ... )

load_model(config_path=None, ckpt_path=None, device=None)

Instantiate an SRGAN model and optionally load pretrained weights.

This helper is safe to call from tests or scripts. It builds the model from a provided configuration file, loads weights from a checkpoint if available, and transfers the model to the selected device.

Parameters

config_path : str or Path, optional Path to the YAML configuration file describing the generator/discriminator setup. ckpt_path : str or Path, optional Path to a Lightning checkpoint or raw PyTorch state dictionary. device : str or torch.device, optional Device on which to place the model. Defaults to "cuda" if available.

Returns

model : SRGAN_model The loaded and ready-to-infer SRGAN Lightning module. device : str The device string used for model placement (e.g., "cuda" or "cpu").

Notes

  • Automatically switches the model to evaluation mode.
  • Tries to use the Lightning checkpoint API first; falls back to a raw state_dict for compatibility with manually saved checkpoints.
Source code in opensr_srgan/inference.py
def load_model(config_path=None, ckpt_path=None, device=None):
    """Instantiate an SRGAN model and optionally load pretrained weights.

    This helper is safe to call from tests or scripts. It builds the model from
    a provided configuration file, loads weights from a checkpoint if available,
    and transfers the model to the selected device.

    Parameters
    ----------
    config_path : str or Path, optional
        Path to the YAML configuration file describing the generator/discriminator setup.
    ckpt_path : str or Path, optional
        Path to a Lightning checkpoint or raw PyTorch state dictionary.
    device : str or torch.device, optional
        Device on which to place the model. Defaults to `"cuda"` if available.

    Returns
    -------
    model : SRGAN_model
        The loaded and ready-to-infer SRGAN Lightning module.
    device : str
        The device string used for model placement (e.g., `"cuda"` or `"cpu"`).

    Notes
    -----
    - Automatically switches the model to evaluation mode.
    - Tries to use the Lightning checkpoint API first; falls back to a raw `state_dict`
      for compatibility with manually saved checkpoints.
    """
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    model = SRGAN_model(config_file_path=config_path).eval().to(device)

    if ckpt_path:
        # Try Lightning API first (without 'strict'); fall back to raw state_dict
        try:
            model = (
                SRGAN_model.load_from_checkpoint(ckpt_path, map_location=device)
                .eval()
                .to(device)
            )
        except TypeError:
            state = torch.load(ckpt_path, map_location=device)
            state = state.get("state_dict", state)
            model.load_state_dict(state, strict=False)

    return model, device

run_sen2_inference(sen2_path=None, config_path=None, ckpt_path=None, gpus=None, window_size=(128, 128), factor=4, overlap=12, eliminate_border_px=2, save_preview=False, debug=False)

Run super-resolution inference on a Sentinel-2 SAFE product.

This function prepares the SRGAN model, launches patch-based processing of large Sentinel-2 tiles, and orchestrates the full inference pipeline using the opensr_utils.large_file_processing backend. WARNING: only works for RGB-NIR as of now.

Parameters

sen2_path : str or Path, optional Path to the Sentinel-2 .SAFE folder to process. config_path : str or Path, optional Path to the model configuration YAML file. ckpt_path : str or Path, optional Path to the model checkpoint file. gpus : list[int], optional GPU IDs to use for inference (e.g., [0, 1]). If None, automatically selects the first GPU if available. window_size : tuple(int, int), default=(128, 128) Size of each input patch (in low-resolution pixels) processed per pass. factor : int, default=4 Super-resolution scaling factor. overlap : int, default=12 Overlap (in LR pixels) between inference windows to reduce edge artifacts. eliminate_border_px : int, default=2 Number of border pixels to remove from each patch during blending. save_preview : bool, default=False If True, saves a small visual preview of the reconstructed image. debug : bool, default=False Enables verbose debug logs for troubleshooting.

Returns

sr_object : opensr_utils.large_file_processing The inference handler object that manages tiling, batching, and output writing.

Notes

  • Automatically sets CUDA_VISIBLE_DEVICES when gpus is provided.
  • Uses the first GPU by default when running on CUDA-enabled systems.
  • The returned object can be further inspected or extended post-run.
Source code in opensr_srgan/inference.py
def run_sen2_inference(
    sen2_path=None,
    config_path=None,
    ckpt_path=None,
    gpus=None,
    window_size=(128, 128),
    factor=4,
    overlap=12,
    eliminate_border_px=2,
    save_preview=False,
    debug=False,
):
    """Run super-resolution inference on a Sentinel-2 SAFE product.

    This function prepares the SRGAN model, launches patch-based processing of
    large Sentinel-2 tiles, and orchestrates the full inference pipeline using
    the `opensr_utils.large_file_processing` backend.
    WARNING: only works for RGB-NIR as of now.

    Parameters
    ----------
    sen2_path : str or Path, optional
        Path to the Sentinel-2 `.SAFE` folder to process.
    config_path : str or Path, optional
        Path to the model configuration YAML file.
    ckpt_path : str or Path, optional
        Path to the model checkpoint file.
    gpus : list[int], optional
        GPU IDs to use for inference (e.g., `[0, 1]`). If `None`, automatically
        selects the first GPU if available.
    window_size : tuple(int, int), default=(128, 128)
        Size of each input patch (in low-resolution pixels) processed per pass.
    factor : int, default=4
        Super-resolution scaling factor.
    overlap : int, default=12
        Overlap (in LR pixels) between inference windows to reduce edge artifacts.
    eliminate_border_px : int, default=2
        Number of border pixels to remove from each patch during blending.
    save_preview : bool, default=False
        If True, saves a small visual preview of the reconstructed image.
    debug : bool, default=False
        Enables verbose debug logs for troubleshooting.

    Returns
    -------
    sr_object : opensr_utils.large_file_processing
        The inference handler object that manages tiling, batching, and output writing.

    Notes
    -----
    - Automatically sets `CUDA_VISIBLE_DEVICES` when `gpus` is provided.
    - Uses the first GPU by default when running on CUDA-enabled systems.
    - The returned object can be further inspected or extended post-run.
    """
    if gpus is not None and len(gpus) > 0:
        os.environ.setdefault("CUDA_VISIBLE_DEVICES", ",".join(map(str, gpus)))

    model, device = load_model(config_path=config_path, ckpt_path=ckpt_path)

    import opensr_utils

    sr_object = opensr_utils.large_file_processing(
        root=sen2_path,
        model=model,
        window_size=window_size,
        factor=factor,
        overlap=overlap,
        eliminate_border_px=eliminate_border_px,
        device=device,
        gpus=gpus if gpus is not None else ([0] if device == "cuda" else []),
        save_preview=save_preview,
        debug=debug,
    )
    sr_object.start_super_resolution()
    return sr_object

main()

Example standalone entry point for testing Sentinel-2 inference.

Sets up example paths for a sample SAFE product, model config, and checkpoint, then runs the super-resolution pipeline on GPU 0.

This function is primarily intended for smoke testing and manual debugging.

Source code in opensr_srgan/inference.py
def main():
    """Example standalone entry point for testing Sentinel-2 inference.

    Sets up example paths for a sample SAFE product, model config, and checkpoint,
    then runs the super-resolution pipeline on GPU 0.

    This function is primarily intended for smoke testing and manual debugging.
    """
    os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0")

    # ---- Define placeholders ----
    sen2_path = Path(__file__).resolve().parent / "data" / "S2A_MSIL2A_EXAMPLE.SAFE"
    config_path = Path(__file__).resolve().parent / "configs" / "config_20m.yaml"
    ckpt_path = "checkpoints/srgan-20m-6band/last.ckpt"
    gpus = [0]

    run_sen2_inference(
        sen2_path=str(sen2_path),
        config_path=str(config_path),
        ckpt_path=ckpt_path,
        gpus=gpus,
    )

Pretrained model factory

Utility helpers to instantiate pretrained SRGAN models.

This module provides convenience functions to
  • Load SRGAN_model instances from local YAML configuration files and optional checkpoints.
  • Download and instantiate predefined pretrained models from the Hugging Face Hub (e.g., RGB-NIR and SWIR variants).
  • Transparently handle local or remote checkpoints, temporary file storage, and EMA restoration.

Typical usage

from opensr_srgan.model.loading import load_inference_model model = load_inference_model("RGB-NIR") model.eval() sr = model(torch.randn(1, 4, 64, 64))

load_from_config(config_path, checkpoint_uri=None, *, map_location=None, mode='train')

Instantiate an :class:SRGAN_model from a configuration and optional checkpoint.

Parameters

config_path : str or Path Filesystem path to the YAML configuration describing generator/discriminator architecture and training parameters. checkpoint_uri : str or Path, optional Path or URL to a pretrained model checkpoint. If omitted, returns an untrained model. map_location : str or torch.device, optional Device mapping passed to :func:torch.load when deserializing the checkpoint. mode : {"train", "eval"}, default="train" Desired mode in which to initialize the model.

Returns

LightningModule A fully initialized SRGAN Lightning module, optionally loaded with pretrained weights.

Raises

FileNotFoundError If the provided configuration file does not exist. RuntimeError If checkpoint deserialization or state restoration fails.

Notes

  • If the checkpoint contains an exponential moving average (EMA) state, it is restored as well.
  • The model is returned in evaluation mode regardless of mode to prevent accidental training until explicitly switched with model.train().
Source code in opensr_srgan/_factory.py
def load_from_config(
    config_path: Union[str, Path],
    checkpoint_uri: Optional[Union[str, Path]] = None,
    *,
    map_location: Optional[Union[str, torch.device]] = None,
    mode: str = "train",
) -> LightningModule:
    """Instantiate an :class:`SRGAN_model` from a configuration and optional checkpoint.

    Parameters
    ----------
    config_path : str or Path
        Filesystem path to the YAML configuration describing generator/discriminator
        architecture and training parameters.
    checkpoint_uri : str or Path, optional
        Path or URL to a pretrained model checkpoint. If omitted, returns an untrained model.
    map_location : str or torch.device, optional
        Device mapping passed to :func:`torch.load` when deserializing the checkpoint.
    mode : {"train", "eval"}, default="train"
        Desired mode in which to initialize the model.

    Returns
    -------
    LightningModule
        A fully initialized SRGAN Lightning module, optionally loaded with pretrained weights.

    Raises
    ------
    FileNotFoundError
        If the provided configuration file does not exist.
    RuntimeError
        If checkpoint deserialization or state restoration fails.

    Notes
    -----
    - If the checkpoint contains an exponential moving average (EMA) state, it is restored as well.
    - The model is returned in evaluation mode regardless of `mode` to prevent accidental training
      until explicitly switched with ``model.train()``.
    """

    config_path = Path(config_path)
    if not config_path.is_file():
        raise FileNotFoundError(f"Config file '{config_path}' could not be located.")

    model = SRGAN_model(config=config_path, mode=mode)

    if checkpoint_uri is not None:
        with _maybe_download(checkpoint_uri) as resolved_path:
            checkpoint = torch.load(str(resolved_path), map_location=map_location)
        state_dict = checkpoint.get("state_dict", checkpoint)
        model.load_state_dict(state_dict, strict=False)

        if model.ema is not None and "ema_state" in checkpoint:
            model.ema.load_state_dict(checkpoint["ema_state"])

    model.eval()
    return model

load_inference_model(preset, *, cache_dir=None, map_location=None)

Load a pretrained SRGAN model from the Hugging Face Hub.

Downloads both the configuration and checkpoint associated with the requested preset, instantiates the model, restores weights, and returns it ready for inference.

Parameters

preset : {"RGB-NIR", "SWIR"} Name of the pretrained model variant to load. cache_dir : str or Path, optional Directory to cache the downloaded files. Uses the default HF cache if omitted. map_location : str or torch.device, optional Device mapping for checkpoint deserialization.

Returns

LightningModule Pretrained SRGAN model ready for inference.

Raises

ValueError If an unknown preset name is provided. ImportError If huggingface_hub is not installed. FileNotFoundError If download or local resolution fails.

Examples

model = load_inference_model("RGB-NIR") model.eval() x = torch.randn(1, 4, 64, 64) y = model(x) print(y.shape) torch.Size([1, 4, 256, 256])

Source code in opensr_srgan/_factory.py
def load_inference_model(
    preset: str,
    *,
    cache_dir: Optional[Union[str, Path]] = None,
    map_location: Optional[Union[str, torch.device]] = None,
) -> LightningModule:
    """Load a pretrained SRGAN model from the Hugging Face Hub.

    Downloads both the configuration and checkpoint associated with the requested
    preset, instantiates the model, restores weights, and returns it ready for inference.

    Parameters
    ----------
    preset : {"RGB-NIR", "SWIR"}
        Name of the pretrained model variant to load.
    cache_dir : str or Path, optional
        Directory to cache the downloaded files. Uses the default HF cache if omitted.
    map_location : str or torch.device, optional
        Device mapping for checkpoint deserialization.

    Returns
    -------
    LightningModule
        Pretrained SRGAN model ready for inference.

    Raises
    ------
    ValueError
        If an unknown preset name is provided.
    ImportError
        If ``huggingface_hub`` is not installed.
    FileNotFoundError
        If download or local resolution fails.

    Examples
    --------
    >>> model = load_inference_model("RGB-NIR")
    >>> model.eval()
    >>> x = torch.randn(1, 4, 64, 64)
    >>> y = model(x)
    >>> print(y.shape)
    torch.Size([1, 4, 256, 256])
    """

    key = preset.strip().replace("_", "-").upper()
    try:
        preset_meta = _PRESETS[key]
    except KeyError as err:
        valid = ", ".join(sorted(_PRESETS))
        raise ValueError(
            f"Unknown preset '{preset}'. Available options: {valid}."
        ) from err

    try:  # pragma: no cover - import guard only used at runtime
        from huggingface_hub import hf_hub_download
    except ImportError as exc:  # pragma: no cover - dependency guard
        raise ImportError(
            "huggingface_hub is required for load_inference_model. "
            "Install the project extras or run 'pip install huggingface-hub'."
        ) from exc

    config_path = hf_hub_download(
        repo_id=preset_meta.repo_id,
        filename=preset_meta.config_filename,
        cache_dir=None if cache_dir is None else str(cache_dir),
    )
    checkpoint_path = hf_hub_download(
        repo_id=preset_meta.repo_id,
        filename=preset_meta.checkpoint_filename,
        cache_dir=None if cache_dir is None else str(cache_dir),
    )

    return load_from_config(
        config_path,
        checkpoint_path,
        map_location=map_location,
        mode="eval",
    )