Skip to content

Model Components

Classes and building blocks used to assemble SRGAN training and inference graphs.

Lightning module

SRGAN_model

Bases: LightningModule

SRGAN_model

A flexible, PyTorch Lightning–based SRGAN implementation for single-image super-resolution in remote sensing and general imaging. The model supports multiple generator backbones (SRResNet, RCAB, RRDB, LKA, and a flexible registry-driven variant) and discriminator types (standard, PatchGAN), optional generator-only pretraining, adversarial loss ramp-up, and an Exponential Moving Average (EMA) of generator weights for more stable evaluation.

Key features

  • Backbone flexibility: Select generator/discriminator architectures via config.
  • Training modes: Generator pretraining, adversarial training, and LR warm-up.
  • PL compatibility: Automatic optimization for PL < 2.0; manual optimization for PL ≥ 2.0.
  • EMA support: Optional EMA tracking with delayed activation and device placement.
  • Metrics & logging: Content/perceptual metrics, LR logging, and optional W&B image logs.
  • Inference helpers: Normalization/denormalization for 0–10000 reflectance and histogram matching.

Args

config : str | pathlib.Path | dict | omegaconf.DictConfig, optional Path to a YAML file, an in-memory dict, or an OmegaConf config with sections defining Generator/Discriminator, Training, Optimizers, Schedulers, and Logging. Defaults to "config.yaml". mode : {"train", "eval"}, optional Build both G and D in "train" mode; only G in "eval". Defaults to "train".

Configuration overview (minimal)

  • Model: in_bands (int)
    • Generator: model_type ("SRResNet", "stochastic_gan", "esrgan"), optional block_type for SRResNet variants ("standard", "res", "rcab", "rrdb", "lka"), n_channels, n_blocks, large_kernel_size, small_kernel_size, scaling_factor, plus ESRGAN-specific knobs (growth_channels, res_scale, out_channels).
  • Discriminator: model_type ("standard", "patchgan", "esrgan"), n_blocks (optional), ESRGAN extras (base_channels, linear_size).
  • Training:
  • pretrain_g_only (bool), g_pretrain_steps (int)
  • adv_loss_ramp_steps (int), label_smoothing (bool)
  • Losses.adv_loss_beta (float), Losses.adv_loss_schedule ("linear"|"cosine")
  • EMA.enabled (bool), EMA.decay (float), EMA.use_num_updates (bool), EMA.update_after_step (int), EMA.device (str|None)
  • Optimizers: optim_g_lr (float), optim_d_lr (float)
  • Schedulers: factor_g, factor_d, patience_g, patience_d, metric, optional g_warmup_steps, g_warmup_type ("linear"|"cosine")
  • Logging: wandb.enabled (bool), num_val_images (int)

Behavior & versioning

  • PL ≥ 2.0: Manual optimization (automatic_optimization = False). The bound training_step_PL2 performs explicit zero_grad/step calls and handles EMA updates.
  • PL < 2.0: Automatic optimization. The legacy training_step_PL1 is used, and optimizer_step coordinates stepping and EMA after generator updates.

Created attributes (non-exhaustive)

generator : torch.nn.Module Super-resolution generator. discriminator : torch.nn.Module | None Only present in "train" mode. ema : ExponentialMovingAverage | None EMA tracker for the generator (if enabled). content_loss_criterion : torch.nn.Module Perceptual/pixel content loss wrapper used in training/validation. adversarial_loss_criterion : torch.nn.Module BCEWithLogits loss for adversarial training (D/G).

Input/Output conventions

  • Forward: lr_imgs of shape (B, C, H, W) → SR output with spatial scale set by Generator.scaling_factor.
  • Predict: Applies optional normalization (0–10000 → 0–1), EMA averaging, histogram matching, and denormalization back to the input range; returns CPU tensors.

Example

model = SRGAN_model(config="config.yaml", mode="train")

Trainer handles fit; inference via .predict_step() or forward():

model.eval() with torch.no_grad(): ... sr = model(lr_imgs)

Notes

  • Discriminator is frozen during generator pretraining (pretrain_g_only=True and global_step < g_pretrain_steps).
  • The adversarial loss contribution is ramped from 0 to adv_loss_beta over adv_loss_ramp_steps with linear or cosine schedule.
  • Learning rate warm-up for the generator is supported via a per-step LambdaLR.
Source code in opensr_srgan/model/SRGAN.py
class SRGAN_model(pl.LightningModule):
    """
    SRGAN_model
    ===========

    A flexible, PyTorch Lightning–based SRGAN implementation for single-image super-resolution
    in remote sensing and general imaging. The model supports multiple generator backbones
    (SRResNet, RCAB, RRDB, LKA, and a flexible registry-driven variant) and discriminator
    types (standard, PatchGAN), optional generator-only pretraining, adversarial loss ramp-up,
    and an Exponential Moving Average (EMA) of generator weights for more stable evaluation.

    Key features
    ------------
    - **Backbone flexibility:** Select generator/discriminator architectures via config.
    - **Training modes:** Generator pretraining, adversarial training, and LR warm-up.
    - **PL compatibility:** Automatic optimization for PL < 2.0; manual optimization for PL ≥ 2.0.
    - **EMA support:** Optional EMA tracking with delayed activation and device placement.
    - **Metrics & logging:** Content/perceptual metrics, LR logging, and optional W&B image logs.
    - **Inference helpers:** Normalization/denormalization for 0–10000 reflectance and histogram matching.

    Args
    ----
    config : str | pathlib.Path | dict | omegaconf.DictConfig, optional
        Path to a YAML file, an in-memory dict, or an OmegaConf config with sections
        defining Generator/Discriminator, Training, Optimizers, Schedulers, and Logging.
        Defaults to `"config.yaml"`.
    mode : {"train", "eval"}, optional
        Build both G and D in `"train"` mode; only G in `"eval"`. Defaults to `"train"`.

    Configuration overview (minimal)
    --------------------------------
    - **Model**: `in_bands` (int)
        - **Generator**: `model_type` (`"SRResNet"`, `"stochastic_gan"`, `"esrgan"`),
        optional `block_type` for SRResNet variants (`"standard"`, `"res"`, `"rcab"`,
        `"rrdb"`, `"lka"`), `n_channels`, `n_blocks`, `large_kernel_size`,
        `small_kernel_size`, `scaling_factor`, plus ESRGAN-specific knobs
        (`growth_channels`, `res_scale`, `out_channels`).
    - **Discriminator**: `model_type` (`"standard"`, `"patchgan"`, `"esrgan"`), `n_blocks`
        (optional), ESRGAN extras (`base_channels`, `linear_size`).
    - **Training**:
    - `pretrain_g_only` (bool), `g_pretrain_steps` (int)
    - `adv_loss_ramp_steps` (int), `label_smoothing` (bool)
    - `Losses.adv_loss_beta` (float), `Losses.adv_loss_schedule` (`"linear"`|`"cosine"`)
    - `EMA.enabled` (bool), `EMA.decay` (float), `EMA.use_num_updates` (bool),
        `EMA.update_after_step` (int), `EMA.device` (str|None)
    - **Optimizers**: `optim_g_lr` (float), `optim_d_lr` (float)
    - **Schedulers**: `factor_g`, `factor_d`, `patience_g`, `patience_d`, `metric`,
    optional `g_warmup_steps`, `g_warmup_type` (`"linear"`|`"cosine"`)
    - **Logging**: `wandb.enabled` (bool), `num_val_images` (int)

    Behavior & versioning
    ---------------------
    - **PL ≥ 2.0**: Manual optimization (`automatic_optimization = False`). The bound
    `training_step_PL2` performs explicit `zero_grad/step` calls and handles EMA updates.
    - **PL < 2.0**: Automatic optimization. The legacy `training_step_PL1` is used, and
    `optimizer_step` coordinates stepping and EMA after generator updates.

    Created attributes (non-exhaustive)
    -----------------------------------
    generator : torch.nn.Module
        Super-resolution generator.
    discriminator : torch.nn.Module | None
        Only present in `"train"` mode.
    ema : ExponentialMovingAverage | None
        EMA tracker for the generator (if enabled).
    content_loss_criterion : torch.nn.Module
        Perceptual/pixel content loss wrapper used in training/validation.
    adversarial_loss_criterion : torch.nn.Module
        BCEWithLogits loss for adversarial training (D/G).

    Input/Output conventions
    ------------------------
    - **Forward**: `lr_imgs` of shape `(B, C, H, W)` → SR output with spatial scale set by
    `Generator.scaling_factor`.
    - **Predict**: Applies optional normalization (0–10000 → 0–1), EMA averaging, histogram
    matching, and denormalization back to the input range; returns CPU tensors.

    Example
    -------
    >>> model = SRGAN_model(config="config.yaml", mode="train")
    >>> # Trainer handles fit; inference via .predict_step() or forward():
    >>> model.eval()
    >>> with torch.no_grad():
    ...     sr = model(lr_imgs)

    Notes
    -----
    - Discriminator is frozen during generator pretraining (`pretrain_g_only=True` and
    `global_step < g_pretrain_steps`).
    - The adversarial loss contribution is ramped from 0 to `adv_loss_beta` over
    `adv_loss_ramp_steps` with linear or cosine schedule.
    - Learning rate warm-up for the generator is supported via a per-step LambdaLR.
    """

    def __init__(self, config="config.yaml", mode="train"):
        super(SRGAN_model, self).__init__()

        # ======================================================================
        # SECTION: Load Configuration
        # Purpose: Load and parse model/training hyperparameters from YAML file.
        # ======================================================================
        if isinstance(config, str) or isinstance(config, Path):
            config = OmegaConf.load(config)
        elif isinstance(config, dict):
            config = OmegaConf.create(config)
        elif OmegaConf.is_config(config):
            pass
        else:
            raise TypeError(
                "Config must be a filepath (str or Path), dict, or OmegaConf object."
            )
        assert mode in {
            "train",
            "eval",
        }, "Mode must be 'train' or 'eval'"  # validate mode

        # ======================================================================
        # SECTION: Set Variables
        # Purpose: Set config and mode variables model-wide, including PL version.
        # ======================================================================
        self.config = config
        self.mode = mode
        self.pl_version = tuple(int(x) for x in pl.__version__.split("."))
        self.normalizer = Normalizer(self.config)

        # ======================================================================
        # SECTION: Get Training settings
        # Purpose: Define model variables to enable training strategies.
        # ======================================================================
        self.pretrain_g_only = bool(
            getattr(self.config.Training, "pretrain_g_only", False)
        )  # pretrain generator only (default False)
        self.g_pretrain_steps = int(
            getattr(self.config.Training, "g_pretrain_steps", 0)
        )  # number of steps for G pretraining
        self.adv_loss_ramp_steps = int(
            getattr(self.config.Training, "adv_loss_ramp_steps", 20000)
        )  # linear ramp-up steps for adversarial loss
        self.adv_target = (
            0.9 if getattr(self.config.Training, "label_smoothing", False) else 1.0
        )  # use 0.9 if label smoothing enabled, else 1.0
        self.adv_loss_type = str(
            getattr(self.config.Training.Losses, "adv_loss_type", "bce")
        ).lower()
        if self.adv_loss_type not in {"bce", "wasserstein"}:
            raise ValueError(
                "Training.Losses.adv_loss_type must be either 'bce' or 'wasserstein'"
            )
        self.r1_gamma = float(
            getattr(self.config.Training.Losses, "r1_gamma", 0.0)
        )  # R1 gradient penalty strength (0 disables)

        # ======================================================================
        # SECTION: Set up Training Strategy
        # Purpose: Depending on PL version, set up optimizers, schedulers, etc.
        # ======================================================================
        self.setup_lightning()  # dynamically builds and attaches generator + discriminator

        # ======================================================================
        # SECTION: Initialize Generator
        # Purpose: Build generator network depending on selected architecture.
        # ======================================================================
        self.get_models(
            mode=self.mode
        )  # dynamically builds and attaches generator + discriminator

        # ======================================================================
        # SECTION: Initialize EMA
        # Purpose: Optional exponential moving average (EMA) tracking for generator weights
        # ======================================================================
        self.initialize_ema()

        # ======================================================================
        # SECTION: Define Loss Functions
        # Purpose: Configure generator content loss and discriminator adversarial loss.
        # ======================================================================
        if self.mode == "train":
            from opensr_srgan.model.loss import GeneratorContentLoss

            self.content_loss_criterion = GeneratorContentLoss(
                self.config
            )  # perceptual loss (VGG + pixel)
            if self.adv_loss_type == "bce":  # check for WS GAN or BCE
                self.adversarial_loss_criterion = torch.nn.BCEWithLogitsLoss()
            else:
                self.adversarial_loss_criterion = None

    def get_models(self, mode):
        """Initialize and attach the Generator and (optionally) Discriminator models.

        This method builds the generator and discriminator architectures based on
        the configuration provided in `self.config`. It supports multiple generator
        backbones (e.g., SRResNet, RCAB, RRDB, LKA) and discriminator types
        (standard, PatchGAN). The discriminator is only initialized when the mode
        is set to `"train"`.

        Args:
            mode (str): Operational mode of the model. Must be one of:
                - `"train"`: Initializes both generator and discriminator.
                - Any other value: Initializes only the generator.

        Raises:
            ValueError: If an unknown generator or discriminator type is specified
                in the configuration.

        Attributes:
            generator (nn.Module): The initialized generator network instance.
            discriminator (nn.Module, optional): The initialized discriminator
                network instance (only present if `mode == "train"`).
        """

        # ======================================================================
        # SECTION: Initialize Generator
        # Purpose: Build generator network depending on selected architecture.
        # ======================================================================
        self.generator = build_generator(self.config)

        if mode == "train":  # only get discriminator in training mode
            # ======================================================================
            # SECTION: Initialize Discriminator
            # Purpose: Build discriminator network for adversarial training.
            # ======================================================================
            raw_discriminator_type = getattr(
                self.config.Discriminator, "model_type", "standard"
            )
            discriminator_type = str(raw_discriminator_type).strip().lower()
            n_blocks = getattr(self.config.Discriminator, "n_blocks", None)

            if discriminator_type == "standard":
                from opensr_srgan.model.discriminators.srgan_discriminator import (
                    Discriminator,
                )

                discriminator_kwargs = {
                    "in_channels": self.config.Model.in_bands,
                }
                if n_blocks is not None:
                    discriminator_kwargs["n_blocks"] = n_blocks

                # pass spectral norm option
                use_spectral_norm = getattr(
                    self.config.Discriminator, "use_spectral_norm", True
                )
                discriminator_kwargs["use_spectral_norm"] = bool(use_spectral_norm)

                self.discriminator = Discriminator(**discriminator_kwargs)
            elif discriminator_type == "patchgan":
                from opensr_srgan.model.discriminators.patchgan import (
                    PatchGANDiscriminator,
                )

                patchgan_layers = n_blocks if n_blocks is not None else 3
                self.discriminator = PatchGANDiscriminator(
                    input_nc=self.config.Model.in_bands,
                    n_layers=patchgan_layers,
                )
            elif discriminator_type == "esrgan":
                from opensr_srgan.model.discriminators.esrgan import (
                    ESRGANDiscriminator,
                )

                ignored_options = []
                if n_blocks is not None:
                    ignored_options.append("n_blocks")
                if ignored_options:
                    ignored_joined = ", ".join(sorted(ignored_options))
                    print(
                        f"[Discriminator:esrgan] Ignoring unsupported configuration options: {ignored_joined}."
                    )

                base_channels = getattr(self.config.Discriminator, "base_channels", 64)
                linear_size = getattr(self.config.Discriminator, "linear_size", 1024)
                self.discriminator = ESRGANDiscriminator(
                    in_channels=self.config.Model.in_bands,
                    base_channels=int(base_channels),
                    linear_size=int(linear_size),
                )
            else:
                raise ValueError(
                    f"Unknown discriminator model type: {raw_discriminator_type}"
                )

    def setup_lightning(self):
        """Configure PyTorch Lightning behavior based on the detected version.

        This method ensures compatibility between different versions of
        PyTorch Lightning (PL) by setting appropriate optimization modes
        and binding the correct training step implementation.

        - For PL ≥ 2.0: Enables **manual optimization**, required for GAN training.
        - For PL < 2.0: Uses **automatic optimization** and the legacy training step.

        The selected training step function (`training_step_PL1` or `training_step_PL2`)
        is dynamically attached to the model as `_training_step_implementation`.

        Raises:
            AssertionError: If `automatic_optimization` is incorrectly set for PL < 2.0.
            RuntimeError: If the detected PyTorch Lightning version is unsupported.

        Attributes:
            automatic_optimization (bool): Indicates whether Lightning manages
                optimizer steps automatically.
            _training_step_implementation (Callable): Bound training step function
                corresponding to the active PL version.
        """
        # Check for PL version - Define PL Hooks accordingly
        if self.pl_version >= (2, 0, 0):
            self.automatic_optimization = False  # manual optimization for PL 2.x
            # Set up Training Step
            from opensr_srgan.model.training_step_PL import training_step_PL2

            self._training_step_implementation = MethodType(training_step_PL2, self)
        elif self.pl_version < (2, 0, 0):
            assert (
                self.automatic_optimization is True
            ), "For PL <2.0, automatic_optimization must be True."
            # Set up Training Step
            from opensr_srgan.model.training_step_PL import training_step_PL1

            self._training_step_implementation = MethodType(training_step_PL1, self)
        else:
            raise RuntimeError(
                f"Unsupported PyTorch Lightning version: {pl.__version__}"
            )

    def initialize_ema(self):
        """Initialize the Exponential Moving Average (EMA) mechanism for the generator.

        This method sets up an EMA shadow copy of the generator parameters to
        stabilize training and improve the quality of generated outputs. EMA is
        enabled only if specified in the training configuration.

        The EMA model tracks the moving average of generator weights with a
        configurable decay factor and update schedule.

        Configuration fields under `config.Training.EMA`:
            - `enabled` (bool): Whether to enable EMA tracking.
            - `decay` (float): Exponential decay factor for weight averaging (default: 0.999).
            - `device` (str | None): Device to store the EMA weights on.
            - `use_num_updates` (bool): Whether to use step-based update counting.
            - `update_after_step` (int): Number of steps to wait before starting updates.

        Attributes:
            ema (ExponentialMovingAverage | None): EMA object tracking generator parameters.
            _ema_update_after_step (int): Step count threshold before EMA updates begin.
            _ema_applied (bool): Indicates whether EMA weights are currently applied to the generator.
        """
        ema_cfg = getattr(self.config.Training, "EMA", None)
        self.ema: ExponentialMovingAverage | None = None
        self._ema_update_after_step = 0
        self._ema_applied = False
        if ema_cfg is not None and getattr(ema_cfg, "enabled", False):
            ema_decay = float(getattr(ema_cfg, "decay", 0.999))
            ema_device = getattr(ema_cfg, "device", None)
            use_num_updates = bool(getattr(ema_cfg, "use_num_updates", True))
            self.ema = ExponentialMovingAverage(
                self.generator,
                decay=ema_decay,
                use_num_updates=use_num_updates,
            )
            self._ema_update_after_step = int(getattr(ema_cfg, "update_after_step", 0))

    def forward(self, lr_imgs):
        """Forward pass through the generator network.

        Takes a batch of low-resolution (LR) input images and produces
        their corresponding super-resolved (SR) outputs using the generator model.

        Args:
            lr_imgs (torch.Tensor): Batch of input low-resolution images
                with shape `(B, C, H, W)` where:
                - `B`: batch size
                - `C`: number of channels
                - `H`, `W`: spatial dimensions.

        Returns:
            torch.Tensor: Super-resolved output images with increased spatial resolution,
            typically scaled by the model's configured upsampling factor.
        """
        sr_imgs = self.generator(lr_imgs)  # pass LR input through generator network
        return sr_imgs  # return super-resolved output

    @torch.no_grad()
    def predict_step(self, lr_imgs):
        """Run a single super-resolution inference step.

        Performs forward inference using the generator (optionally under EMA weights)
        to produce super-resolved (SR) outputs from low-resolution (LR) inputs.
        The method normalizes input values using the configured strategy (e.g., raw
        Sentinel-2 reflectance via ``normalise_10k``), applies histogram matching, and
        denormalizes the outputs back to their original scale.

        Args:
            lr_imgs (torch.Tensor): Batch of input low-resolution images
                with shape `(B, C, H, W)`. Pixel value ranges may vary depending
                on preprocessing (e.g., 0–10000 for Sentinel-2 reflectance).

        Returns:
            torch.Tensor: Super-resolved output images with matched histograms
            and restored value range, detached from the computation graph and
            placed on CPU memory.

        Raises:
            AssertionError: If the generator is not in evaluation mode (`.eval()`).
        """
        assert (
            self.generator.training is False
        ), "Generator must be in eval mode for prediction."  # ensure eval mode
        lr_imgs = lr_imgs.to(self.device)  # move to device (GPU or CPU)

        # --- Normalize inputs according to configuration ---
        normalized_lr = self.normalizer.normalize(lr_imgs)

        # --- Perform super-resolution (optionally using EMA weights) ---
        context = (
            self.ema.average_parameters(self.generator)
            if self.ema is not None
            else nullcontext()
        )
        with context:
            sr_imgs = self.generator(normalized_lr)  # forward pass (SR prediction)

        # --- Histogram match SR to LR ---
        sr_imgs = histogram_match(normalized_lr, sr_imgs)  # match distributions

        # --- Denormalize output back to original range ---
        sr_imgs = self.normalizer.denormalize(sr_imgs)

        # --- Move to CPU and return ---
        sr_imgs = sr_imgs.cpu().detach()  # detach from graph for inference output
        return sr_imgs

    def training_step(
        self, batch, batch_idx, optimizer_idx: Optional[int] = None, *args
    ):
        """Dispatch the correct training step implementation based on PyTorch Lightning version.

        This method acts as a compatibility layer between different PyTorch Lightning
        versions that handle multi-optimizer GAN training differently.

        - For PL ≥ 2.0: Manual optimization is used, and the optimizer index is not passed.
        - For PL < 2.0: Automatic optimization is used, and the optimizer index is passed
        to handle generator/discriminator updates separately.

        Args:
            batch (Any): A batch of training data (input tensors and targets as defined by the DataModule).
            batch_idx (int): Index of the current batch within the epoch.
            optimizer_idx (int | None, optional): Index of the active optimizer (0 for generator,
                1 for discriminator) when using PL < 2.0.
            *args: Additional arguments that may be passed by older Lightning versions.

        Returns:
            Any: The output of the active training step implementation, loss value.
        """
        # Depending on PL version, and depending on the manual optimization
        if self.pl_version >= (2, 0, 0):
            # In PL2.x, optimizer_idx is not passed, manual optimization is performed
            return self._training_step_implementation(batch, batch_idx)  # no optim_idx
        else:
            # In Pl1.x, optimizer_idx arrives twice and is passed on
            return self._training_step_implementation(
                batch, batch_idx, optimizer_idx
            )  # pass optim_idx

    def optimizer_step(
        self,
        epoch,
        batch_idx,
        optimizer,
        optimizer_idx=None,
        optimizer_closure=None,
        **kwargs,  # absorbs on_tpu/using_lbfgs/etc across PL versions
    ):
        """Custom optimizer step handling for PL 1.x automatic optimization.

        This method ensures correct behavior across different PyTorch Lightning
        versions and training modes. It is invoked automatically during training
        in PL < 2.0 when `automatic_optimization=True`. For PL ≥ 2.0, where manual
        optimization is used, this function is effectively bypassed.

        - In **PL ≥ 2.0 (manual optimization)**: The optimizer step is explicitly
        called within `training_step_PL2()`, including EMA updates.
        - In **PL < 2.0 (automatic optimization)**: This function manages optimizer
        stepping, gradient zeroing, and optional EMA updates after generator steps.

        Args:
            epoch (int): Current training epoch.
            batch_idx (int): Index of the current batch.
            optimizer (torch.optim.Optimizer): The active optimizer instance.
            optimizer_idx (int, optional): Index of the optimizer being stepped
                (e.g., 0 for discriminator, 1 for generator).
            optimizer_closure (Callable, optional): Closure for re-evaluating the
                model and loss before optimizer step (used with some optimizers).
            **kwargs: Additional arguments passed by PL depending on backend
                (e.g., TPU flags, LBFGS options).

        Notes:
            - EMA updates are performed only after generator steps (optimizer_idx == 1).
            - The update starts after `self._ema_update_after_step` global steps.

        """
        # If we're in manual optimization (PL >=2 path), do nothing special.
        if not self.automatic_optimization:
            # In manual mode we call opt.step()/zero_grad() in training_step_PL2.
            # In manual mode, we update EMA weights manually in training step too.
            return super().optimizer_step(
                epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, **kwargs
            )

        # ---- PL 1.x auto-optimization path ----
        if optimizer_closure is not None:
            optimizer.step(closure=optimizer_closure)
        else:
            optimizer.step()
        optimizer.zero_grad()

        # EMA after the generator step (assumes G is optimizer_idx == 1)
        if (
            self.ema is not None
            and optimizer_idx == 1
            and self.global_step >= self._ema_update_after_step
        ):
            self.ema.update(self.generator)

    @torch.no_grad()
    def validation_step(self, batch, batch_idx):
        """Run the validation loop for a single batch.

        This method performs super-resolution inference on validation data,
        computes image quality metrics (e.g., PSNR, SSIM), logs them, and
        optionally visualizes SR–HR–LR triplets. It also evaluates the
        discriminator’s adversarial response if applicable.

        Workflow:
            1. Forward pass (LR → SR) through the generator.
            2. Compute content-based validation metrics.
            3. Optionally log visual examples to the logger (e.g., Weights & Biases).
            4. Compute and log discriminator metrics, unless in pretraining mode.

        Args:
            batch (Tuple[torch.Tensor, torch.Tensor]): A tuple `(lr_imgs, hr_imgs)` of
                low-resolution and high-resolution tensors with shape `(B, C, H, W)`.
            batch_idx (int): Index of the current validation batch.

        Returns:
            None: Metrics and images are logged via Lightning’s logger interface.

        Raises:
            AssertionError: If an unexpected number of bands or invalid visualization
                configuration is encountered.

        Notes:
            - Validation is executed without gradient tracking.
            - Only the first `config.Logging.num_val_images` batches are visualized.
            - If EMA is enabled, the generator predictions reflect the current EMA state.
        """
        # ======================================================================
        # SECTION: Forward pass — Generate SR prediction from LR input
        # Purpose: Run model inference on validation batch without gradient tracking.
        # ======================================================================
        """ 1. Extract and Predict """
        lr_imgs, hr_imgs = batch  # unpack LR and HR tensors
        sr_imgs = self.forward(lr_imgs)  # run generator to produce SR prediction

        # ======================================================================
        # SECTION: Compute and log validation metrics
        # Purpose: measure content-based metrics (PSNR/SSIM/etc.) on SR vs HR.
        # ======================================================================
        """ 2. Log Generator Metrics """
        metrics_hr_img = torch.clone(
            hr_imgs
        )  # clone to avoid in-place ops on autograd graph
        metrics_sr_img = torch.clone(sr_imgs)  # same for SR
        # metrics = calculate_metrics(metrics_sr_img, metrics_hr_img, phase="val_metrics")
        metrics = self.content_loss_criterion.return_metrics(
            metrics_sr_img, metrics_hr_img, prefix="val_metrics/"
        )  # compute metrics using loss criterion helper
        del metrics_hr_img, metrics_sr_img  # free cloned tensors from GPU memory

        for key, value in metrics.items():  # iterate over metrics dict
            self.log(
                f"{key}", value, sync_dist=True
            )  # log each metric to logger (e.g., W&B, TensorBoard)

        # ======================================================================
        # SECTION: Optional visualization — Log example SR/HR/LR images
        # Purpose: visually track qualitative progress of the model.
        # ======================================================================
        # only perform image logging for first N batches to avoid logging all 200 images
        if batch_idx < self.config.Logging.num_val_images:
            base_lr = lr_imgs  # use original LR for visualization

            # --- Select visualization bands (if multispectral) ---
            if self.config.Model.in_bands < 3:
                # show only first band
                lr_vis = base_lr[:, :1, :, :]  # e.g., single-band input
                hr_vis = hr_imgs[:, :1, :, :]  # subset HR
                sr_vis = sr_imgs[:, :1, :, :]  # subset SR
            elif self.config.Model.in_bands == 3:
                # we can show normally
                pass
            elif self.config.Model.in_bands == 4:
                # assume its RGB-NIR, show RGB
                lr_vis = base_lr[:, :3, :, :]  # e.g., Sentinel-2 RGB
                hr_vis = hr_imgs[:, :3, :, :]  # subset HR
                sr_vis = sr_imgs[:, :3, :, :]  # subset SR
            elif self.config.Model.in_bands > 4:  # e.g., Sentinel-2 with >3 channels
                # random selection of bands
                idx = np.random.choice(
                    sr_imgs.shape[1], 3, replace=False
                )  # randomly select 3 bands
                lr_vis = base_lr[:, idx, :, :]  # subset LR
                hr_vis = hr_imgs[:, idx, :, :]  # subset HR
                sr_vis = sr_imgs[:, idx, :, :]  # subset SR
            else:
                # should not happen
                pass

            # --- Clone tensors for plotting to avoid affecting main tensors ---
            plot_lr_img = lr_vis.clone()
            plot_hr_img = hr_vis.clone()
            plot_sr_img = sr_vis.clone()

            # --- Generate matplotlib visualization (LR, SR, HR side-by-side) ---
            val_img = plot_tensors(plot_lr_img, plot_sr_img, plot_hr_img, title="Val")

            # --- Cleanup ---
            del plot_lr_img, plot_hr_img, plot_sr_img  # free memory after plotting

            # --- Log image to WandB (or compatible logger), if wanted ---
            if self.config.Logging.wandb.enabled:
                self.logger.experiment.log(
                    {"Val SR": wandb.Image(val_img)}
                )  # upload to dashboard

            """ 3. Log Discriminator metrics """
            # If in pretraining, discard D metrics
            if self._pretrain_check():  # check if we'e in pretrain phase
                self.log(
                    "discriminator/adversarial_loss",
                    torch.zeros(1, device=lr_imgs.device),
                    prog_bar=False,
                    sync_dist=True,
                )
            else:
                # run discriminator and get loss between pred labels and true labels
                hr_discriminated = self.discriminator(hr_imgs)
                sr_discriminated = self.discriminator(sr_imgs)

                # Run loss depending on type
                if self.adv_loss_type == "wasserstein":
                    adversarial_loss = sr_discriminated.mean() - hr_discriminated.mean()
                else:
                    adversarial_loss = self.adversarial_loss_criterion(
                        sr_discriminated, torch.zeros_like(sr_discriminated)
                    ) + self.adversarial_loss_criterion(
                        hr_discriminated, torch.ones_like(hr_discriminated)
                    )

                # Log image
                self.log(
                    "validation/DISC_adversarial_loss", adversarial_loss, sync_dist=True
                )

    def on_validation_epoch_start(self):
        """Hook executed at the start of each validation epoch.

        Applies the Exponential Moving Average (EMA) weights to the generator
        before running validation to ensure evaluation uses the smoothed model
        parameters.

        Notes:
            - Calls the parent hook via `super().on_validation_epoch_start()`.
            - Restores original weights at the end of validation.
        """
        super().on_validation_epoch_start()
        self._apply_generator_ema_weights()

    def on_validation_epoch_end(self):
        """Hook executed at the end of each validation epoch.

        Restores the generator’s original (non-EMA) weights after validation.
        Ensures subsequent training or testing uses up-to-date parameters.

        Notes:
            - Calls the parent hook via `super().on_validation_epoch_end()`.
        """
        self._restore_generator_weights()
        super().on_validation_epoch_end()

    def on_test_epoch_start(self):
        """Hook executed at the start of each testing epoch.

        Applies the Exponential Moving Average (EMA) weights to the generator
        before running tests to ensure consistent evaluation with the
        smoothed model parameters.

        Notes:
            - Calls the parent hook via `super().on_test_epoch_start()`.
            - Restores original weights at the end of testing.
        """
        super().on_test_epoch_start()
        self._apply_generator_ema_weights()

    def on_test_epoch_end(self):
        """Hook executed at the end of each testing epoch.

        Restores the generator’s original (non-EMA) weights after testing.
        Ensures the model is reset to its latest training state.

        Notes:
            - Calls the parent hook via `super().on_test_epoch_end()`.
        """
        self._restore_generator_weights()
        super().on_test_epoch_end()

    def configure_optimizers(self):
        """
        Robust optimizers & schedulers for GANs (PL1 & PL2 compatible).

        - TTUR by default (D lr <= G lr)
        - Adam with GAN-friendly betas/eps
        - Exclude norm/affine/bias params from weight decay
        - Separate Plateau schedulers for G and D (with cooldown/min_lr)
        - Optional step-wise warmup for G (linear/cosine), no LR jump at handoff
        - Returned order is [D, G] to match your training_step expectations
        """
        import math
        import torch
        import torch.nn as nn
        from torch.optim import Adam
        from torch.optim.lr_scheduler import ReduceLROnPlateau

        cfg_opt = self.config.Optimizers
        cfg_sch = self.config.Schedulers

        # ---------- helpers ----------
        def _split_wd_params(model):
            """Return two lists: params_with_wd, params_without_wd."""
            wd, no_wd = [], []
            norm_like = (
                nn.BatchNorm1d,
                nn.BatchNorm2d,
                nn.BatchNorm3d,
                nn.GroupNorm,
                nn.LayerNorm,
                nn.InstanceNorm1d,
                nn.InstanceNorm2d,
                nn.InstanceNorm3d,
            )
            for m in model.modules():
                for n, p in m.named_parameters(recurse=False):
                    if not p.requires_grad:
                        continue
                    if n.endswith("bias") or isinstance(m, norm_like):
                        no_wd.append(p)
                    else:
                        wd.append(p)
            # catch any top-level params not in modules (rare)
            seen = set(map(id, wd + no_wd))
            for n, p in model.named_parameters():
                if p.requires_grad and id(p) not in seen:
                    (no_wd if n.endswith("bias") else wd).append(p)
            return wd, no_wd

        def _adam(params, lr):
            # GAN-friendly defaults; tune in config if needed
            betas = getattr(cfg_opt, "betas", (0.0, 0.99))
            eps = getattr(cfg_opt, "eps", 1e-7)
            return Adam(params, lr=lr, betas=betas, eps=eps)

        # ---------- LRs (TTUR) ----------
        lr_g = float(getattr(cfg_opt, "optim_g_lr", 1e-4))
        lr_d = float(
            getattr(cfg_opt, "optim_d_lr", max(lr_g * 0.5, 1e-6))
        )  # default: D slower than G

        # weight decay (only on non-norm, non-bias)
        wd_g = float(getattr(cfg_opt, "weight_decay_g", 0.0))
        wd_d = float(getattr(cfg_opt, "weight_decay_d", 0.0))

        # ---------- build optimizers with clean param groups ----------
        g_wd, g_no = _split_wd_params(self.generator)
        d_wd, d_no = _split_wd_params(self.discriminator)

        optimizer_g = _adam(
            [
                {"params": g_wd, "weight_decay": wd_g},
                {"params": g_no, "weight_decay": 0.0},
            ],
            lr=lr_g,
        )
        optimizer_d = _adam(
            [
                {"params": d_wd, "weight_decay": wd_d},
                {"params": d_no, "weight_decay": 0.0},
            ],
            lr=lr_d,
        )

        # ---------- schedulers ----------
        # Use distinct monitors for clarity (recommend: log these in validation)
        monitor_g = getattr(
            cfg_sch, "metric_g", getattr(cfg_sch, "metric", "val_g_loss")
        )
        monitor_d = getattr(
            cfg_sch, "metric_d", getattr(cfg_sch, "metric", "val_d_loss")
        )

        sched_kwargs = dict(
            mode="min",
            factor=float(getattr(cfg_sch, "factor_g", 0.5)),
            patience=int(getattr(cfg_sch, "patience_g", 5)),
            threshold=float(getattr(cfg_sch, "threshold", 1e-3)),
            threshold_mode="rel",
            cooldown=int(getattr(cfg_sch, "cooldown", 0)),
            min_lr=float(getattr(cfg_sch, "min_lr", 1e-7)),
            verbose=bool(getattr(cfg_sch, "verbose", False)),
        )
        # D can have its own factor/patience; fall back to G’s if not set
        sched_kwargs_d = dict(sched_kwargs)
        sched_kwargs_d["factor"] = float(
            getattr(cfg_sch, "factor_d", sched_kwargs["factor"])
        )
        sched_kwargs_d["patience"] = int(
            getattr(cfg_sch, "patience_d", sched_kwargs["patience"])
        )

        scheduler_g = ReduceLROnPlateau(optimizer_g, **sched_kwargs)
        scheduler_d = ReduceLROnPlateau(optimizer_d, **sched_kwargs_d)

        sch_configs = [
            {
                "scheduler": scheduler_d,
                "monitor": monitor_d,
                "reduce_on_plateau": True,
                "interval": "epoch",
                "frequency": 1,
                "name": "plateau_d",
            },
            {
                "scheduler": scheduler_g,
                "monitor": monitor_g,
                "reduce_on_plateau": True,
                "interval": "epoch",
                "frequency": 1,
                "name": "plateau_g",
            },
        ]

        # ---------- optional warmup for G (step-wise, multiplicative) ----------
        warmup_steps = int(getattr(cfg_sch, "g_warmup_steps", 0))
        warmup_type = str(getattr(cfg_sch, "g_warmup_type", "none")).lower()
        if warmup_steps > 0 and warmup_type in {"linear", "cosine"}:

            def _g_warmup_lambda(step: int) -> float:
                if step >= warmup_steps:
                    return 1.0
                t = (step + 1) / max(1, warmup_steps)
                return (
                    t
                    if warmup_type == "linear"
                    else 0.5 * (1.0 - math.cos(math.pi * t))
                )

            warmup_g = torch.optim.lr_scheduler.LambdaLR(
                optimizer_g, lr_lambda=_g_warmup_lambda
            )
            # Runs every step; multiplies base LR so there is no jump at the end
            sch_configs.append(
                {
                    "scheduler": warmup_g,
                    "interval": "step",
                    "frequency": 1,
                    "name": "warmup_g",
                }
            )

        # Return order [D, G] to match your training_step
        return [optimizer_d, optimizer_g], sch_configs

    def on_train_batch_start(
        self, batch, batch_idx
    ):  # called before each training batch
        """Hook executed before each training batch.

        Freezes or unfreezes discriminator parameters depending on the
        current training phase. During pretraining, the discriminator is
        frozen to allow the generator to learn reconstruction without
        adversarial pressure.

        Args:
            batch (Any): The current batch of training data.
            batch_idx (int): Index of the current batch in the epoch.
        """
        pre = self._pretrain_check()  # check if currently in pretraining phase
        for p in self.discriminator.parameters():  # loop over all discriminator params
            p.requires_grad = not pre  # freeze D during pretrain, unfreeze otherwise

    def on_train_batch_end(self, outputs, batch, batch_idx):
        """Hook executed after each training batch.

        Logs the current learning rates for all active optimizers to
        the logger for monitoring and debugging purposes.

        Args:
            outputs (Any): Outputs returned by `training_step`.
            batch (Any): The batch of data processed.
            batch_idx (int): Index of the current batch in the epoch.
        """
        self._log_lrs()  # log LR's on each batch end

    def on_fit_start(self):  # called once at the start of training
        """Hook executed once at the beginning of model fitting.

        Performs setup tasks that must occur before training starts:
        - Moves EMA weights to the correct device.
        - Prints a model summary (only from global rank 0 in DDP setups).

        Notes:
            - Calls `super().on_fit_start()` to preserve Lightning’s default behavior.
            - The model summary is only printed by the global zero process
            to avoid duplicated output in distributed training.
        """
        super().on_fit_start()
        if self.ema is not None and self.ema.device is None:  # move ema weights
            self.ema.to(self.device)
        from opensr_srgan.utils.gpu_rank import _is_global_zero

        if _is_global_zero():
            print_model_summary(self)  # print model summary to console

    def _log_generator_content_loss(self, content_loss: torch.Tensor) -> None:
        """Helper to consistently log the generator content loss across training phases."""
        self.log(
            "generator/content_loss",
            content_loss,
            prog_bar=True,
            sync_dist=True,
        )

    def _log_ema_setup_metrics(self) -> None:
        """Log static Exponential Moving Average (EMA) configuration parameters.

        Records whether EMA is enabled, along with its core hyperparameters
        (decay rate, activation delay, update mode). This information is
        logged once when training begins to help track model configuration.

        Notes:
            - If EMA is disabled, logs `"EMA/enabled" = 0.0`.
            - Called after the trainer is initialized to ensure logging context.
        """
        if getattr(self, "trainer", None) is None:
            return

        if self.ema is None:
            self.log(
                "EMA/enabled",
                0.0,
                on_step=False,
                on_epoch=True,
                prog_bar=False,
                sync_dist=True,
            )
            return

        self.log(
            "EMA/enabled",
            1.0,
            on_step=False,
            on_epoch=True,
            prog_bar=False,
            sync_dist=True,
        )
        self.log(
            "EMA/decay",
            float(self.ema.decay),
            on_step=False,
            on_epoch=True,
            prog_bar=False,
            sync_dist=True,
        )
        self.log(
            "EMA/update_after_step",
            float(self._ema_update_after_step),
            on_step=False,
            on_epoch=True,
            prog_bar=False,
            sync_dist=True,
        )
        self.log(
            "EMA/use_num_updates",
            1.0 if self.ema.num_updates is not None else 0.0,
            on_step=False,
            on_epoch=True,
            prog_bar=False,
            sync_dist=True,
        )

    def _log_ema_step_metrics(self, *, updated: bool) -> None:
        """Log dynamic EMA statistics during training.

        Tracks per-step EMA state, including whether an update occurred,
        how many steps remain until activation, and the most recent decay value.
        These metrics provide insight into EMA behavior over time.

        Args:
            updated (bool): Whether EMA weights were updated in the current step.

        Notes:
            - If EMA is disabled, this function exits without logging.
            - Logs include:
                - `"EMA/is_active"`: Indicates if EMA is currently updating.
                - `"EMA/steps_until_activation"`: Steps remaining before EMA starts updating.
                - `"EMA/last_decay"`: Latest applied decay value.
                - `"EMA/num_updates"`: Total number of EMA updates performed.
        """
        if self.ema is None:
            return

        self.log(
            "EMA/is_active",
            1.0 if updated else 0.0,
            on_step=True,
            on_epoch=False,
            prog_bar=False,
            sync_dist=True,
        )

        steps_until_active = max(0, self._ema_update_after_step - self.global_step)
        self.log(
            "EMA/steps_until_activation",
            float(steps_until_active),
            on_step=True,
            on_epoch=False,
            prog_bar=False,
            sync_dist=True,
        )

        if not updated:
            return

        self.log(
            "EMA/last_decay",
            float(self.ema.last_decay),
            on_step=True,
            on_epoch=False,
            prog_bar=False,
            sync_dist=True,
        )

        if self.ema.num_updates is not None:
            self.log(
                "EMA/num_updates",
                float(self.ema.num_updates),
                on_step=True,
                on_epoch=False,
                prog_bar=False,
                sync_dist=True,
            )

    def _pretrain_check(self) -> bool:
        """Check whether the model is still in the generator pretraining phase.

        Returns:
            bool: True if the generator-only pretraining phase is active
            (i.e., `global_step` < `g_pretrain_steps`), otherwise False.

        Notes:
            - During pretraining, the discriminator is frozen and only the
            generator is updated.
        """
        if (
            self.pretrain_g_only and self.global_step < self.g_pretrain_steps
        ):  # true if pretraining active
            return True
        else:
            return False  # false once pretrain steps are exceeded

    def _compute_adv_loss_weight(self) -> float:
        """Compute the current adversarial loss weighting factor.

        Determines how strongly the adversarial loss contributes to the total
        generator loss, following a configurable ramp-up schedule. This helps
        stabilize early training by gradually increasing the influence of the
        discriminator.

        Returns:
            float: The current adversarial loss weight for the active step.

        Configuration Fields:
            config.Training.Losses:
                - `adv_loss_beta` (float): Maximum scaling factor for adversarial loss.
                - `adv_loss_schedule` (str): Type of ramp schedule (`"linear"` or `"cosine"`).

        Notes:
            - Returns `0.0` during pretraining steps (`global_step < g_pretrain_steps`).
            - After the ramp-up phase, the weight saturates at `beta`.
            - Cosine schedule provides a smoother ramp-up than linear.

        Raises:
            ValueError: If an unknown schedule type is provided in the configuration.
        """
        """Compute the current adversarial loss weight using the configured ramp schedule."""
        beta = float(self.config.Training.Losses.adv_loss_beta)
        schedule = getattr(
            self.config.Training.Losses,
            "adv_loss_schedule",
            "cosine",
        ).lower()

        # Handle pretraining and edge cases early
        if self.global_step < self.g_pretrain_steps:
            return 0.0

        if (
            self.adv_loss_ramp_steps <= 0
            or self.global_step >= self.g_pretrain_steps + self.adv_loss_ramp_steps
        ):
            return beta

        # Normalize progress to [0, 1]
        progress = (self.global_step - self.g_pretrain_steps) / self.adv_loss_ramp_steps
        progress = max(0.0, min(progress, 1.0))

        if schedule == "linear":
            return progress * beta

        if schedule == "cosine":
            # Cosine ramp to match the generator warmup behaviour
            return 0.5 * (1.0 - math.cos(math.pi * progress)) * beta

        raise ValueError(
            f"Unknown adversarial loss schedule '{schedule}'. Expected 'linear' or 'cosine'."
        )

    def _log_adv_loss_weight(self, adv_weight: float) -> None:
        """Log the current adversarial loss weight.

        Args:
            adv_weight (float): Scalar multiplier applied to the adversarial loss term.
        """
        self.log("training/adv_loss_weight", adv_weight, sync_dist=True)

    def _adv_loss_weight(self) -> float:
        """Compute and log the current adversarial loss weight.

        Calls the internal scheduler/heuristic to obtain the adversarial loss weight,
        logs it, and returns the value.

        Returns:
            float: The computed adversarial loss weight for the current step/epoch.
        """
        adv_weight = self._compute_adv_loss_weight()
        self._log_adv_loss_weight(adv_weight)
        return adv_weight

    def _apply_generator_ema_weights(self) -> None:
        """Swap the generator's parameters to their EMA-smoothed counterparts.

        Applies EMA weights to the generator for evaluation (e.g., val/test). A no-op if
        EMA is disabled or already applied. Moves EMA to the correct device if needed.

        Notes:
            - Sets an internal flag to avoid double application during the same phase.
        """
        if self.ema is None or self._ema_applied:
            return
        if self.ema.device is None:
            self.ema.to(self.device)
        self.ema.apply_to(self.generator)
        self._ema_applied = True

    def _restore_generator_weights(self) -> None:
        """Restore the generator's original (non-EMA) parameters.

        Reverts the parameter swap performed by `_apply_generator_ema_weights()`.
        A no-op if EMA is disabled or not currently applied.

        Notes:
            - Clears the internal "applied" flag to enable future swaps.
        """
        if self.ema is None or not self._ema_applied:
            return
        self.ema.restore(self.generator)
        self._ema_applied = False

    def on_save_checkpoint(self, checkpoint: dict) -> None:
        """Augment the checkpoint with EMA state, if available.

        Adds the EMA buffer/metadata to the checkpoint so that EMA can be restored upon load.

        Args:
            checkpoint (dict): Mutable checkpoint dictionary provided by Lightning.
        """
        super().on_save_checkpoint(checkpoint)
        if self.ema is not None:
            checkpoint["ema_state"] = self.ema.state_dict()

    def on_load_checkpoint(self, checkpoint: dict) -> None:
        """Restore EMA state from a checkpoint, if present.

        Args:
            checkpoint (dict): Checkpoint dictionary provided by Lightning containing
                model state and optional `"ema_state"` entry.
        """
        super().on_load_checkpoint(checkpoint)
        if self.ema is not None and "ema_state" in checkpoint:
            self.ema.load_state_dict(checkpoint["ema_state"])

    def _log_lrs(self) -> None:
        """Log learning rates for discriminator and generator optimizers.

        Notes:
            - Assumes optimizers are ordered as `[optimizer_d, optimizer_g]` in the trainer.
            - Logs both on-step and on-epoch for easier tracking.
        """
        # order matches your return: [optimizer_d, optimizer_g]
        opt_d = self.trainer.optimizers[0]
        opt_g = self.trainer.optimizers[1]
        self.log(
            "lr_discriminator",
            opt_d.param_groups[0]["lr"],
            on_step=True,
            on_epoch=True,
            prog_bar=False,
            logger=True,
            sync_dist=True,
        )
        self.log(
            "lr_generator",
            opt_g.param_groups[0]["lr"],
            on_step=True,
            on_epoch=True,
            prog_bar=False,
            logger=True,
            sync_dist=True,
        )

    def load_from_checkpoint(self, ckpt_path) -> None:
        """Load model weights from a PyTorch Lightning checkpoint file.

        Loads the `state_dict` from the given checkpoint and maps it to the current device.

        Args:
            ckpt_path (str | pathlib.Path): Path to the `.ckpt` file saved by Lightning.

        Raises:
            FileNotFoundError: If the checkpoint path does not exist.
            KeyError: If the checkpoint does not contain a `'state_dict'` entry.
        """
        # load ckpt
        ckpt = torch.load(ckpt_path, map_location=self.device)
        self.load_state_dict(ckpt["state_dict"])
        print(f"Loaded checkpoint from {ckpt_path}")

get_models(mode)

Initialize and attach the Generator and (optionally) Discriminator models.

This method builds the generator and discriminator architectures based on the configuration provided in self.config. It supports multiple generator backbones (e.g., SRResNet, RCAB, RRDB, LKA) and discriminator types (standard, PatchGAN). The discriminator is only initialized when the mode is set to "train".

Parameters:

Name Type Description Default
mode str

Operational mode of the model. Must be one of: - "train": Initializes both generator and discriminator. - Any other value: Initializes only the generator.

required

Raises:

Type Description
ValueError

If an unknown generator or discriminator type is specified in the configuration.

Attributes:

Name Type Description
generator Module

The initialized generator network instance.

discriminator Module

The initialized discriminator network instance (only present if mode == "train").

Source code in opensr_srgan/model/SRGAN.py
def get_models(self, mode):
    """Initialize and attach the Generator and (optionally) Discriminator models.

    This method builds the generator and discriminator architectures based on
    the configuration provided in `self.config`. It supports multiple generator
    backbones (e.g., SRResNet, RCAB, RRDB, LKA) and discriminator types
    (standard, PatchGAN). The discriminator is only initialized when the mode
    is set to `"train"`.

    Args:
        mode (str): Operational mode of the model. Must be one of:
            - `"train"`: Initializes both generator and discriminator.
            - Any other value: Initializes only the generator.

    Raises:
        ValueError: If an unknown generator or discriminator type is specified
            in the configuration.

    Attributes:
        generator (nn.Module): The initialized generator network instance.
        discriminator (nn.Module, optional): The initialized discriminator
            network instance (only present if `mode == "train"`).
    """

    # ======================================================================
    # SECTION: Initialize Generator
    # Purpose: Build generator network depending on selected architecture.
    # ======================================================================
    self.generator = build_generator(self.config)

    if mode == "train":  # only get discriminator in training mode
        # ======================================================================
        # SECTION: Initialize Discriminator
        # Purpose: Build discriminator network for adversarial training.
        # ======================================================================
        raw_discriminator_type = getattr(
            self.config.Discriminator, "model_type", "standard"
        )
        discriminator_type = str(raw_discriminator_type).strip().lower()
        n_blocks = getattr(self.config.Discriminator, "n_blocks", None)

        if discriminator_type == "standard":
            from opensr_srgan.model.discriminators.srgan_discriminator import (
                Discriminator,
            )

            discriminator_kwargs = {
                "in_channels": self.config.Model.in_bands,
            }
            if n_blocks is not None:
                discriminator_kwargs["n_blocks"] = n_blocks

            # pass spectral norm option
            use_spectral_norm = getattr(
                self.config.Discriminator, "use_spectral_norm", True
            )
            discriminator_kwargs["use_spectral_norm"] = bool(use_spectral_norm)

            self.discriminator = Discriminator(**discriminator_kwargs)
        elif discriminator_type == "patchgan":
            from opensr_srgan.model.discriminators.patchgan import (
                PatchGANDiscriminator,
            )

            patchgan_layers = n_blocks if n_blocks is not None else 3
            self.discriminator = PatchGANDiscriminator(
                input_nc=self.config.Model.in_bands,
                n_layers=patchgan_layers,
            )
        elif discriminator_type == "esrgan":
            from opensr_srgan.model.discriminators.esrgan import (
                ESRGANDiscriminator,
            )

            ignored_options = []
            if n_blocks is not None:
                ignored_options.append("n_blocks")
            if ignored_options:
                ignored_joined = ", ".join(sorted(ignored_options))
                print(
                    f"[Discriminator:esrgan] Ignoring unsupported configuration options: {ignored_joined}."
                )

            base_channels = getattr(self.config.Discriminator, "base_channels", 64)
            linear_size = getattr(self.config.Discriminator, "linear_size", 1024)
            self.discriminator = ESRGANDiscriminator(
                in_channels=self.config.Model.in_bands,
                base_channels=int(base_channels),
                linear_size=int(linear_size),
            )
        else:
            raise ValueError(
                f"Unknown discriminator model type: {raw_discriminator_type}"
            )

setup_lightning()

Configure PyTorch Lightning behavior based on the detected version.

This method ensures compatibility between different versions of PyTorch Lightning (PL) by setting appropriate optimization modes and binding the correct training step implementation.

  • For PL ≥ 2.0: Enables manual optimization, required for GAN training.
  • For PL < 2.0: Uses automatic optimization and the legacy training step.

The selected training step function (training_step_PL1 or training_step_PL2) is dynamically attached to the model as _training_step_implementation.

Raises:

Type Description
AssertionError

If automatic_optimization is incorrectly set for PL < 2.0.

RuntimeError

If the detected PyTorch Lightning version is unsupported.

Attributes:

Name Type Description
automatic_optimization bool

Indicates whether Lightning manages optimizer steps automatically.

_training_step_implementation Callable

Bound training step function corresponding to the active PL version.

Source code in opensr_srgan/model/SRGAN.py
def setup_lightning(self):
    """Configure PyTorch Lightning behavior based on the detected version.

    This method ensures compatibility between different versions of
    PyTorch Lightning (PL) by setting appropriate optimization modes
    and binding the correct training step implementation.

    - For PL ≥ 2.0: Enables **manual optimization**, required for GAN training.
    - For PL < 2.0: Uses **automatic optimization** and the legacy training step.

    The selected training step function (`training_step_PL1` or `training_step_PL2`)
    is dynamically attached to the model as `_training_step_implementation`.

    Raises:
        AssertionError: If `automatic_optimization` is incorrectly set for PL < 2.0.
        RuntimeError: If the detected PyTorch Lightning version is unsupported.

    Attributes:
        automatic_optimization (bool): Indicates whether Lightning manages
            optimizer steps automatically.
        _training_step_implementation (Callable): Bound training step function
            corresponding to the active PL version.
    """
    # Check for PL version - Define PL Hooks accordingly
    if self.pl_version >= (2, 0, 0):
        self.automatic_optimization = False  # manual optimization for PL 2.x
        # Set up Training Step
        from opensr_srgan.model.training_step_PL import training_step_PL2

        self._training_step_implementation = MethodType(training_step_PL2, self)
    elif self.pl_version < (2, 0, 0):
        assert (
            self.automatic_optimization is True
        ), "For PL <2.0, automatic_optimization must be True."
        # Set up Training Step
        from opensr_srgan.model.training_step_PL import training_step_PL1

        self._training_step_implementation = MethodType(training_step_PL1, self)
    else:
        raise RuntimeError(
            f"Unsupported PyTorch Lightning version: {pl.__version__}"
        )

initialize_ema()

Initialize the Exponential Moving Average (EMA) mechanism for the generator.

This method sets up an EMA shadow copy of the generator parameters to stabilize training and improve the quality of generated outputs. EMA is enabled only if specified in the training configuration.

The EMA model tracks the moving average of generator weights with a configurable decay factor and update schedule.

Configuration fields under config.Training.EMA: - enabled (bool): Whether to enable EMA tracking. - decay (float): Exponential decay factor for weight averaging (default: 0.999). - device (str | None): Device to store the EMA weights on. - use_num_updates (bool): Whether to use step-based update counting. - update_after_step (int): Number of steps to wait before starting updates.

Attributes:

Name Type Description
ema ExponentialMovingAverage | None

EMA object tracking generator parameters.

_ema_update_after_step int

Step count threshold before EMA updates begin.

_ema_applied bool

Indicates whether EMA weights are currently applied to the generator.

Source code in opensr_srgan/model/SRGAN.py
def initialize_ema(self):
    """Initialize the Exponential Moving Average (EMA) mechanism for the generator.

    This method sets up an EMA shadow copy of the generator parameters to
    stabilize training and improve the quality of generated outputs. EMA is
    enabled only if specified in the training configuration.

    The EMA model tracks the moving average of generator weights with a
    configurable decay factor and update schedule.

    Configuration fields under `config.Training.EMA`:
        - `enabled` (bool): Whether to enable EMA tracking.
        - `decay` (float): Exponential decay factor for weight averaging (default: 0.999).
        - `device` (str | None): Device to store the EMA weights on.
        - `use_num_updates` (bool): Whether to use step-based update counting.
        - `update_after_step` (int): Number of steps to wait before starting updates.

    Attributes:
        ema (ExponentialMovingAverage | None): EMA object tracking generator parameters.
        _ema_update_after_step (int): Step count threshold before EMA updates begin.
        _ema_applied (bool): Indicates whether EMA weights are currently applied to the generator.
    """
    ema_cfg = getattr(self.config.Training, "EMA", None)
    self.ema: ExponentialMovingAverage | None = None
    self._ema_update_after_step = 0
    self._ema_applied = False
    if ema_cfg is not None and getattr(ema_cfg, "enabled", False):
        ema_decay = float(getattr(ema_cfg, "decay", 0.999))
        ema_device = getattr(ema_cfg, "device", None)
        use_num_updates = bool(getattr(ema_cfg, "use_num_updates", True))
        self.ema = ExponentialMovingAverage(
            self.generator,
            decay=ema_decay,
            use_num_updates=use_num_updates,
        )
        self._ema_update_after_step = int(getattr(ema_cfg, "update_after_step", 0))

forward(lr_imgs)

Forward pass through the generator network.

Takes a batch of low-resolution (LR) input images and produces their corresponding super-resolved (SR) outputs using the generator model.

Parameters:

Name Type Description Default
lr_imgs Tensor

Batch of input low-resolution images with shape (B, C, H, W) where: - B: batch size - C: number of channels - H, W: spatial dimensions.

required

Returns:

Type Description

torch.Tensor: Super-resolved output images with increased spatial resolution,

typically scaled by the model's configured upsampling factor.

Source code in opensr_srgan/model/SRGAN.py
def forward(self, lr_imgs):
    """Forward pass through the generator network.

    Takes a batch of low-resolution (LR) input images and produces
    their corresponding super-resolved (SR) outputs using the generator model.

    Args:
        lr_imgs (torch.Tensor): Batch of input low-resolution images
            with shape `(B, C, H, W)` where:
            - `B`: batch size
            - `C`: number of channels
            - `H`, `W`: spatial dimensions.

    Returns:
        torch.Tensor: Super-resolved output images with increased spatial resolution,
        typically scaled by the model's configured upsampling factor.
    """
    sr_imgs = self.generator(lr_imgs)  # pass LR input through generator network
    return sr_imgs  # return super-resolved output

predict_step(lr_imgs)

Run a single super-resolution inference step.

Performs forward inference using the generator (optionally under EMA weights) to produce super-resolved (SR) outputs from low-resolution (LR) inputs. The method normalizes input values using the configured strategy (e.g., raw Sentinel-2 reflectance via normalise_10k), applies histogram matching, and denormalizes the outputs back to their original scale.

Parameters:

Name Type Description Default
lr_imgs Tensor

Batch of input low-resolution images with shape (B, C, H, W). Pixel value ranges may vary depending on preprocessing (e.g., 0–10000 for Sentinel-2 reflectance).

required

Returns:

Type Description

torch.Tensor: Super-resolved output images with matched histograms

and restored value range, detached from the computation graph and

placed on CPU memory.

Raises:

Type Description
AssertionError

If the generator is not in evaluation mode (.eval()).

Source code in opensr_srgan/model/SRGAN.py
@torch.no_grad()
def predict_step(self, lr_imgs):
    """Run a single super-resolution inference step.

    Performs forward inference using the generator (optionally under EMA weights)
    to produce super-resolved (SR) outputs from low-resolution (LR) inputs.
    The method normalizes input values using the configured strategy (e.g., raw
    Sentinel-2 reflectance via ``normalise_10k``), applies histogram matching, and
    denormalizes the outputs back to their original scale.

    Args:
        lr_imgs (torch.Tensor): Batch of input low-resolution images
            with shape `(B, C, H, W)`. Pixel value ranges may vary depending
            on preprocessing (e.g., 0–10000 for Sentinel-2 reflectance).

    Returns:
        torch.Tensor: Super-resolved output images with matched histograms
        and restored value range, detached from the computation graph and
        placed on CPU memory.

    Raises:
        AssertionError: If the generator is not in evaluation mode (`.eval()`).
    """
    assert (
        self.generator.training is False
    ), "Generator must be in eval mode for prediction."  # ensure eval mode
    lr_imgs = lr_imgs.to(self.device)  # move to device (GPU or CPU)

    # --- Normalize inputs according to configuration ---
    normalized_lr = self.normalizer.normalize(lr_imgs)

    # --- Perform super-resolution (optionally using EMA weights) ---
    context = (
        self.ema.average_parameters(self.generator)
        if self.ema is not None
        else nullcontext()
    )
    with context:
        sr_imgs = self.generator(normalized_lr)  # forward pass (SR prediction)

    # --- Histogram match SR to LR ---
    sr_imgs = histogram_match(normalized_lr, sr_imgs)  # match distributions

    # --- Denormalize output back to original range ---
    sr_imgs = self.normalizer.denormalize(sr_imgs)

    # --- Move to CPU and return ---
    sr_imgs = sr_imgs.cpu().detach()  # detach from graph for inference output
    return sr_imgs

training_step(batch, batch_idx, optimizer_idx=None, *args)

Dispatch the correct training step implementation based on PyTorch Lightning version.

This method acts as a compatibility layer between different PyTorch Lightning versions that handle multi-optimizer GAN training differently.

  • For PL ≥ 2.0: Manual optimization is used, and the optimizer index is not passed.
  • For PL < 2.0: Automatic optimization is used, and the optimizer index is passed to handle generator/discriminator updates separately.

Parameters:

Name Type Description Default
batch Any

A batch of training data (input tensors and targets as defined by the DataModule).

required
batch_idx int

Index of the current batch within the epoch.

required
optimizer_idx int | None

Index of the active optimizer (0 for generator, 1 for discriminator) when using PL < 2.0.

None
*args

Additional arguments that may be passed by older Lightning versions.

()

Returns:

Name Type Description
Any

The output of the active training step implementation, loss value.

Source code in opensr_srgan/model/SRGAN.py
def training_step(
    self, batch, batch_idx, optimizer_idx: Optional[int] = None, *args
):
    """Dispatch the correct training step implementation based on PyTorch Lightning version.

    This method acts as a compatibility layer between different PyTorch Lightning
    versions that handle multi-optimizer GAN training differently.

    - For PL ≥ 2.0: Manual optimization is used, and the optimizer index is not passed.
    - For PL < 2.0: Automatic optimization is used, and the optimizer index is passed
    to handle generator/discriminator updates separately.

    Args:
        batch (Any): A batch of training data (input tensors and targets as defined by the DataModule).
        batch_idx (int): Index of the current batch within the epoch.
        optimizer_idx (int | None, optional): Index of the active optimizer (0 for generator,
            1 for discriminator) when using PL < 2.0.
        *args: Additional arguments that may be passed by older Lightning versions.

    Returns:
        Any: The output of the active training step implementation, loss value.
    """
    # Depending on PL version, and depending on the manual optimization
    if self.pl_version >= (2, 0, 0):
        # In PL2.x, optimizer_idx is not passed, manual optimization is performed
        return self._training_step_implementation(batch, batch_idx)  # no optim_idx
    else:
        # In Pl1.x, optimizer_idx arrives twice and is passed on
        return self._training_step_implementation(
            batch, batch_idx, optimizer_idx
        )  # pass optim_idx

optimizer_step(epoch, batch_idx, optimizer, optimizer_idx=None, optimizer_closure=None, **kwargs)

Custom optimizer step handling for PL 1.x automatic optimization.

This method ensures correct behavior across different PyTorch Lightning versions and training modes. It is invoked automatically during training in PL < 2.0 when automatic_optimization=True. For PL ≥ 2.0, where manual optimization is used, this function is effectively bypassed.

  • In PL ≥ 2.0 (manual optimization): The optimizer step is explicitly called within training_step_PL2(), including EMA updates.
  • In PL < 2.0 (automatic optimization): This function manages optimizer stepping, gradient zeroing, and optional EMA updates after generator steps.

Parameters:

Name Type Description Default
epoch int

Current training epoch.

required
batch_idx int

Index of the current batch.

required
optimizer Optimizer

The active optimizer instance.

required
optimizer_idx int

Index of the optimizer being stepped (e.g., 0 for discriminator, 1 for generator).

None
optimizer_closure Callable

Closure for re-evaluating the model and loss before optimizer step (used with some optimizers).

None
**kwargs

Additional arguments passed by PL depending on backend (e.g., TPU flags, LBFGS options).

{}
Notes
  • EMA updates are performed only after generator steps (optimizer_idx == 1).
  • The update starts after self._ema_update_after_step global steps.
Source code in opensr_srgan/model/SRGAN.py
def optimizer_step(
    self,
    epoch,
    batch_idx,
    optimizer,
    optimizer_idx=None,
    optimizer_closure=None,
    **kwargs,  # absorbs on_tpu/using_lbfgs/etc across PL versions
):
    """Custom optimizer step handling for PL 1.x automatic optimization.

    This method ensures correct behavior across different PyTorch Lightning
    versions and training modes. It is invoked automatically during training
    in PL < 2.0 when `automatic_optimization=True`. For PL ≥ 2.0, where manual
    optimization is used, this function is effectively bypassed.

    - In **PL ≥ 2.0 (manual optimization)**: The optimizer step is explicitly
    called within `training_step_PL2()`, including EMA updates.
    - In **PL < 2.0 (automatic optimization)**: This function manages optimizer
    stepping, gradient zeroing, and optional EMA updates after generator steps.

    Args:
        epoch (int): Current training epoch.
        batch_idx (int): Index of the current batch.
        optimizer (torch.optim.Optimizer): The active optimizer instance.
        optimizer_idx (int, optional): Index of the optimizer being stepped
            (e.g., 0 for discriminator, 1 for generator).
        optimizer_closure (Callable, optional): Closure for re-evaluating the
            model and loss before optimizer step (used with some optimizers).
        **kwargs: Additional arguments passed by PL depending on backend
            (e.g., TPU flags, LBFGS options).

    Notes:
        - EMA updates are performed only after generator steps (optimizer_idx == 1).
        - The update starts after `self._ema_update_after_step` global steps.

    """
    # If we're in manual optimization (PL >=2 path), do nothing special.
    if not self.automatic_optimization:
        # In manual mode we call opt.step()/zero_grad() in training_step_PL2.
        # In manual mode, we update EMA weights manually in training step too.
        return super().optimizer_step(
            epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, **kwargs
        )

    # ---- PL 1.x auto-optimization path ----
    if optimizer_closure is not None:
        optimizer.step(closure=optimizer_closure)
    else:
        optimizer.step()
    optimizer.zero_grad()

    # EMA after the generator step (assumes G is optimizer_idx == 1)
    if (
        self.ema is not None
        and optimizer_idx == 1
        and self.global_step >= self._ema_update_after_step
    ):
        self.ema.update(self.generator)

validation_step(batch, batch_idx)

Run the validation loop for a single batch.

This method performs super-resolution inference on validation data, computes image quality metrics (e.g., PSNR, SSIM), logs them, and optionally visualizes SR–HR–LR triplets. It also evaluates the discriminator’s adversarial response if applicable.

Workflow
  1. Forward pass (LR → SR) through the generator.
  2. Compute content-based validation metrics.
  3. Optionally log visual examples to the logger (e.g., Weights & Biases).
  4. Compute and log discriminator metrics, unless in pretraining mode.

Parameters:

Name Type Description Default
batch Tuple[Tensor, Tensor]

A tuple (lr_imgs, hr_imgs) of low-resolution and high-resolution tensors with shape (B, C, H, W).

required
batch_idx int

Index of the current validation batch.

required

Returns:

Name Type Description
None

Metrics and images are logged via Lightning’s logger interface.

Raises:

Type Description
AssertionError

If an unexpected number of bands or invalid visualization configuration is encountered.

Notes
  • Validation is executed without gradient tracking.
  • Only the first config.Logging.num_val_images batches are visualized.
  • If EMA is enabled, the generator predictions reflect the current EMA state.
Source code in opensr_srgan/model/SRGAN.py
@torch.no_grad()
def validation_step(self, batch, batch_idx):
    """Run the validation loop for a single batch.

    This method performs super-resolution inference on validation data,
    computes image quality metrics (e.g., PSNR, SSIM), logs them, and
    optionally visualizes SR–HR–LR triplets. It also evaluates the
    discriminator’s adversarial response if applicable.

    Workflow:
        1. Forward pass (LR → SR) through the generator.
        2. Compute content-based validation metrics.
        3. Optionally log visual examples to the logger (e.g., Weights & Biases).
        4. Compute and log discriminator metrics, unless in pretraining mode.

    Args:
        batch (Tuple[torch.Tensor, torch.Tensor]): A tuple `(lr_imgs, hr_imgs)` of
            low-resolution and high-resolution tensors with shape `(B, C, H, W)`.
        batch_idx (int): Index of the current validation batch.

    Returns:
        None: Metrics and images are logged via Lightning’s logger interface.

    Raises:
        AssertionError: If an unexpected number of bands or invalid visualization
            configuration is encountered.

    Notes:
        - Validation is executed without gradient tracking.
        - Only the first `config.Logging.num_val_images` batches are visualized.
        - If EMA is enabled, the generator predictions reflect the current EMA state.
    """
    # ======================================================================
    # SECTION: Forward pass — Generate SR prediction from LR input
    # Purpose: Run model inference on validation batch without gradient tracking.
    # ======================================================================
    """ 1. Extract and Predict """
    lr_imgs, hr_imgs = batch  # unpack LR and HR tensors
    sr_imgs = self.forward(lr_imgs)  # run generator to produce SR prediction

    # ======================================================================
    # SECTION: Compute and log validation metrics
    # Purpose: measure content-based metrics (PSNR/SSIM/etc.) on SR vs HR.
    # ======================================================================
    """ 2. Log Generator Metrics """
    metrics_hr_img = torch.clone(
        hr_imgs
    )  # clone to avoid in-place ops on autograd graph
    metrics_sr_img = torch.clone(sr_imgs)  # same for SR
    # metrics = calculate_metrics(metrics_sr_img, metrics_hr_img, phase="val_metrics")
    metrics = self.content_loss_criterion.return_metrics(
        metrics_sr_img, metrics_hr_img, prefix="val_metrics/"
    )  # compute metrics using loss criterion helper
    del metrics_hr_img, metrics_sr_img  # free cloned tensors from GPU memory

    for key, value in metrics.items():  # iterate over metrics dict
        self.log(
            f"{key}", value, sync_dist=True
        )  # log each metric to logger (e.g., W&B, TensorBoard)

    # ======================================================================
    # SECTION: Optional visualization — Log example SR/HR/LR images
    # Purpose: visually track qualitative progress of the model.
    # ======================================================================
    # only perform image logging for first N batches to avoid logging all 200 images
    if batch_idx < self.config.Logging.num_val_images:
        base_lr = lr_imgs  # use original LR for visualization

        # --- Select visualization bands (if multispectral) ---
        if self.config.Model.in_bands < 3:
            # show only first band
            lr_vis = base_lr[:, :1, :, :]  # e.g., single-band input
            hr_vis = hr_imgs[:, :1, :, :]  # subset HR
            sr_vis = sr_imgs[:, :1, :, :]  # subset SR
        elif self.config.Model.in_bands == 3:
            # we can show normally
            pass
        elif self.config.Model.in_bands == 4:
            # assume its RGB-NIR, show RGB
            lr_vis = base_lr[:, :3, :, :]  # e.g., Sentinel-2 RGB
            hr_vis = hr_imgs[:, :3, :, :]  # subset HR
            sr_vis = sr_imgs[:, :3, :, :]  # subset SR
        elif self.config.Model.in_bands > 4:  # e.g., Sentinel-2 with >3 channels
            # random selection of bands
            idx = np.random.choice(
                sr_imgs.shape[1], 3, replace=False
            )  # randomly select 3 bands
            lr_vis = base_lr[:, idx, :, :]  # subset LR
            hr_vis = hr_imgs[:, idx, :, :]  # subset HR
            sr_vis = sr_imgs[:, idx, :, :]  # subset SR
        else:
            # should not happen
            pass

        # --- Clone tensors for plotting to avoid affecting main tensors ---
        plot_lr_img = lr_vis.clone()
        plot_hr_img = hr_vis.clone()
        plot_sr_img = sr_vis.clone()

        # --- Generate matplotlib visualization (LR, SR, HR side-by-side) ---
        val_img = plot_tensors(plot_lr_img, plot_sr_img, plot_hr_img, title="Val")

        # --- Cleanup ---
        del plot_lr_img, plot_hr_img, plot_sr_img  # free memory after plotting

        # --- Log image to WandB (or compatible logger), if wanted ---
        if self.config.Logging.wandb.enabled:
            self.logger.experiment.log(
                {"Val SR": wandb.Image(val_img)}
            )  # upload to dashboard

        """ 3. Log Discriminator metrics """
        # If in pretraining, discard D metrics
        if self._pretrain_check():  # check if we'e in pretrain phase
            self.log(
                "discriminator/adversarial_loss",
                torch.zeros(1, device=lr_imgs.device),
                prog_bar=False,
                sync_dist=True,
            )
        else:
            # run discriminator and get loss between pred labels and true labels
            hr_discriminated = self.discriminator(hr_imgs)
            sr_discriminated = self.discriminator(sr_imgs)

            # Run loss depending on type
            if self.adv_loss_type == "wasserstein":
                adversarial_loss = sr_discriminated.mean() - hr_discriminated.mean()
            else:
                adversarial_loss = self.adversarial_loss_criterion(
                    sr_discriminated, torch.zeros_like(sr_discriminated)
                ) + self.adversarial_loss_criterion(
                    hr_discriminated, torch.ones_like(hr_discriminated)
                )

            # Log image
            self.log(
                "validation/DISC_adversarial_loss", adversarial_loss, sync_dist=True
            )

on_validation_epoch_start()

Hook executed at the start of each validation epoch.

Applies the Exponential Moving Average (EMA) weights to the generator before running validation to ensure evaluation uses the smoothed model parameters.

Notes
  • Calls the parent hook via super().on_validation_epoch_start().
  • Restores original weights at the end of validation.
Source code in opensr_srgan/model/SRGAN.py
def on_validation_epoch_start(self):
    """Hook executed at the start of each validation epoch.

    Applies the Exponential Moving Average (EMA) weights to the generator
    before running validation to ensure evaluation uses the smoothed model
    parameters.

    Notes:
        - Calls the parent hook via `super().on_validation_epoch_start()`.
        - Restores original weights at the end of validation.
    """
    super().on_validation_epoch_start()
    self._apply_generator_ema_weights()

on_validation_epoch_end()

Hook executed at the end of each validation epoch.

Restores the generator’s original (non-EMA) weights after validation. Ensures subsequent training or testing uses up-to-date parameters.

Notes
  • Calls the parent hook via super().on_validation_epoch_end().
Source code in opensr_srgan/model/SRGAN.py
def on_validation_epoch_end(self):
    """Hook executed at the end of each validation epoch.

    Restores the generator’s original (non-EMA) weights after validation.
    Ensures subsequent training or testing uses up-to-date parameters.

    Notes:
        - Calls the parent hook via `super().on_validation_epoch_end()`.
    """
    self._restore_generator_weights()
    super().on_validation_epoch_end()

on_test_epoch_start()

Hook executed at the start of each testing epoch.

Applies the Exponential Moving Average (EMA) weights to the generator before running tests to ensure consistent evaluation with the smoothed model parameters.

Notes
  • Calls the parent hook via super().on_test_epoch_start().
  • Restores original weights at the end of testing.
Source code in opensr_srgan/model/SRGAN.py
def on_test_epoch_start(self):
    """Hook executed at the start of each testing epoch.

    Applies the Exponential Moving Average (EMA) weights to the generator
    before running tests to ensure consistent evaluation with the
    smoothed model parameters.

    Notes:
        - Calls the parent hook via `super().on_test_epoch_start()`.
        - Restores original weights at the end of testing.
    """
    super().on_test_epoch_start()
    self._apply_generator_ema_weights()

on_test_epoch_end()

Hook executed at the end of each testing epoch.

Restores the generator’s original (non-EMA) weights after testing. Ensures the model is reset to its latest training state.

Notes
  • Calls the parent hook via super().on_test_epoch_end().
Source code in opensr_srgan/model/SRGAN.py
def on_test_epoch_end(self):
    """Hook executed at the end of each testing epoch.

    Restores the generator’s original (non-EMA) weights after testing.
    Ensures the model is reset to its latest training state.

    Notes:
        - Calls the parent hook via `super().on_test_epoch_end()`.
    """
    self._restore_generator_weights()
    super().on_test_epoch_end()

configure_optimizers()

Robust optimizers & schedulers for GANs (PL1 & PL2 compatible).

  • TTUR by default (D lr <= G lr)
  • Adam with GAN-friendly betas/eps
  • Exclude norm/affine/bias params from weight decay
  • Separate Plateau schedulers for G and D (with cooldown/min_lr)
  • Optional step-wise warmup for G (linear/cosine), no LR jump at handoff
  • Returned order is [D, G] to match your training_step expectations
Source code in opensr_srgan/model/SRGAN.py
def configure_optimizers(self):
    """
    Robust optimizers & schedulers for GANs (PL1 & PL2 compatible).

    - TTUR by default (D lr <= G lr)
    - Adam with GAN-friendly betas/eps
    - Exclude norm/affine/bias params from weight decay
    - Separate Plateau schedulers for G and D (with cooldown/min_lr)
    - Optional step-wise warmup for G (linear/cosine), no LR jump at handoff
    - Returned order is [D, G] to match your training_step expectations
    """
    import math
    import torch
    import torch.nn as nn
    from torch.optim import Adam
    from torch.optim.lr_scheduler import ReduceLROnPlateau

    cfg_opt = self.config.Optimizers
    cfg_sch = self.config.Schedulers

    # ---------- helpers ----------
    def _split_wd_params(model):
        """Return two lists: params_with_wd, params_without_wd."""
        wd, no_wd = [], []
        norm_like = (
            nn.BatchNorm1d,
            nn.BatchNorm2d,
            nn.BatchNorm3d,
            nn.GroupNorm,
            nn.LayerNorm,
            nn.InstanceNorm1d,
            nn.InstanceNorm2d,
            nn.InstanceNorm3d,
        )
        for m in model.modules():
            for n, p in m.named_parameters(recurse=False):
                if not p.requires_grad:
                    continue
                if n.endswith("bias") or isinstance(m, norm_like):
                    no_wd.append(p)
                else:
                    wd.append(p)
        # catch any top-level params not in modules (rare)
        seen = set(map(id, wd + no_wd))
        for n, p in model.named_parameters():
            if p.requires_grad and id(p) not in seen:
                (no_wd if n.endswith("bias") else wd).append(p)
        return wd, no_wd

    def _adam(params, lr):
        # GAN-friendly defaults; tune in config if needed
        betas = getattr(cfg_opt, "betas", (0.0, 0.99))
        eps = getattr(cfg_opt, "eps", 1e-7)
        return Adam(params, lr=lr, betas=betas, eps=eps)

    # ---------- LRs (TTUR) ----------
    lr_g = float(getattr(cfg_opt, "optim_g_lr", 1e-4))
    lr_d = float(
        getattr(cfg_opt, "optim_d_lr", max(lr_g * 0.5, 1e-6))
    )  # default: D slower than G

    # weight decay (only on non-norm, non-bias)
    wd_g = float(getattr(cfg_opt, "weight_decay_g", 0.0))
    wd_d = float(getattr(cfg_opt, "weight_decay_d", 0.0))

    # ---------- build optimizers with clean param groups ----------
    g_wd, g_no = _split_wd_params(self.generator)
    d_wd, d_no = _split_wd_params(self.discriminator)

    optimizer_g = _adam(
        [
            {"params": g_wd, "weight_decay": wd_g},
            {"params": g_no, "weight_decay": 0.0},
        ],
        lr=lr_g,
    )
    optimizer_d = _adam(
        [
            {"params": d_wd, "weight_decay": wd_d},
            {"params": d_no, "weight_decay": 0.0},
        ],
        lr=lr_d,
    )

    # ---------- schedulers ----------
    # Use distinct monitors for clarity (recommend: log these in validation)
    monitor_g = getattr(
        cfg_sch, "metric_g", getattr(cfg_sch, "metric", "val_g_loss")
    )
    monitor_d = getattr(
        cfg_sch, "metric_d", getattr(cfg_sch, "metric", "val_d_loss")
    )

    sched_kwargs = dict(
        mode="min",
        factor=float(getattr(cfg_sch, "factor_g", 0.5)),
        patience=int(getattr(cfg_sch, "patience_g", 5)),
        threshold=float(getattr(cfg_sch, "threshold", 1e-3)),
        threshold_mode="rel",
        cooldown=int(getattr(cfg_sch, "cooldown", 0)),
        min_lr=float(getattr(cfg_sch, "min_lr", 1e-7)),
        verbose=bool(getattr(cfg_sch, "verbose", False)),
    )
    # D can have its own factor/patience; fall back to G’s if not set
    sched_kwargs_d = dict(sched_kwargs)
    sched_kwargs_d["factor"] = float(
        getattr(cfg_sch, "factor_d", sched_kwargs["factor"])
    )
    sched_kwargs_d["patience"] = int(
        getattr(cfg_sch, "patience_d", sched_kwargs["patience"])
    )

    scheduler_g = ReduceLROnPlateau(optimizer_g, **sched_kwargs)
    scheduler_d = ReduceLROnPlateau(optimizer_d, **sched_kwargs_d)

    sch_configs = [
        {
            "scheduler": scheduler_d,
            "monitor": monitor_d,
            "reduce_on_plateau": True,
            "interval": "epoch",
            "frequency": 1,
            "name": "plateau_d",
        },
        {
            "scheduler": scheduler_g,
            "monitor": monitor_g,
            "reduce_on_plateau": True,
            "interval": "epoch",
            "frequency": 1,
            "name": "plateau_g",
        },
    ]

    # ---------- optional warmup for G (step-wise, multiplicative) ----------
    warmup_steps = int(getattr(cfg_sch, "g_warmup_steps", 0))
    warmup_type = str(getattr(cfg_sch, "g_warmup_type", "none")).lower()
    if warmup_steps > 0 and warmup_type in {"linear", "cosine"}:

        def _g_warmup_lambda(step: int) -> float:
            if step >= warmup_steps:
                return 1.0
            t = (step + 1) / max(1, warmup_steps)
            return (
                t
                if warmup_type == "linear"
                else 0.5 * (1.0 - math.cos(math.pi * t))
            )

        warmup_g = torch.optim.lr_scheduler.LambdaLR(
            optimizer_g, lr_lambda=_g_warmup_lambda
        )
        # Runs every step; multiplies base LR so there is no jump at the end
        sch_configs.append(
            {
                "scheduler": warmup_g,
                "interval": "step",
                "frequency": 1,
                "name": "warmup_g",
            }
        )

    # Return order [D, G] to match your training_step
    return [optimizer_d, optimizer_g], sch_configs

on_train_batch_start(batch, batch_idx)

Hook executed before each training batch.

Freezes or unfreezes discriminator parameters depending on the current training phase. During pretraining, the discriminator is frozen to allow the generator to learn reconstruction without adversarial pressure.

Parameters:

Name Type Description Default
batch Any

The current batch of training data.

required
batch_idx int

Index of the current batch in the epoch.

required
Source code in opensr_srgan/model/SRGAN.py
def on_train_batch_start(
    self, batch, batch_idx
):  # called before each training batch
    """Hook executed before each training batch.

    Freezes or unfreezes discriminator parameters depending on the
    current training phase. During pretraining, the discriminator is
    frozen to allow the generator to learn reconstruction without
    adversarial pressure.

    Args:
        batch (Any): The current batch of training data.
        batch_idx (int): Index of the current batch in the epoch.
    """
    pre = self._pretrain_check()  # check if currently in pretraining phase
    for p in self.discriminator.parameters():  # loop over all discriminator params
        p.requires_grad = not pre  # freeze D during pretrain, unfreeze otherwise

on_train_batch_end(outputs, batch, batch_idx)

Hook executed after each training batch.

Logs the current learning rates for all active optimizers to the logger for monitoring and debugging purposes.

Parameters:

Name Type Description Default
outputs Any

Outputs returned by training_step.

required
batch Any

The batch of data processed.

required
batch_idx int

Index of the current batch in the epoch.

required
Source code in opensr_srgan/model/SRGAN.py
def on_train_batch_end(self, outputs, batch, batch_idx):
    """Hook executed after each training batch.

    Logs the current learning rates for all active optimizers to
    the logger for monitoring and debugging purposes.

    Args:
        outputs (Any): Outputs returned by `training_step`.
        batch (Any): The batch of data processed.
        batch_idx (int): Index of the current batch in the epoch.
    """
    self._log_lrs()  # log LR's on each batch end

on_fit_start()

Hook executed once at the beginning of model fitting.

Performs setup tasks that must occur before training starts: - Moves EMA weights to the correct device. - Prints a model summary (only from global rank 0 in DDP setups).

Notes
  • Calls super().on_fit_start() to preserve Lightning’s default behavior.
  • The model summary is only printed by the global zero process to avoid duplicated output in distributed training.
Source code in opensr_srgan/model/SRGAN.py
def on_fit_start(self):  # called once at the start of training
    """Hook executed once at the beginning of model fitting.

    Performs setup tasks that must occur before training starts:
    - Moves EMA weights to the correct device.
    - Prints a model summary (only from global rank 0 in DDP setups).

    Notes:
        - Calls `super().on_fit_start()` to preserve Lightning’s default behavior.
        - The model summary is only printed by the global zero process
        to avoid duplicated output in distributed training.
    """
    super().on_fit_start()
    if self.ema is not None and self.ema.device is None:  # move ema weights
        self.ema.to(self.device)
    from opensr_srgan.utils.gpu_rank import _is_global_zero

    if _is_global_zero():
        print_model_summary(self)  # print model summary to console

on_save_checkpoint(checkpoint)

Augment the checkpoint with EMA state, if available.

Adds the EMA buffer/metadata to the checkpoint so that EMA can be restored upon load.

Parameters:

Name Type Description Default
checkpoint dict

Mutable checkpoint dictionary provided by Lightning.

required
Source code in opensr_srgan/model/SRGAN.py
def on_save_checkpoint(self, checkpoint: dict) -> None:
    """Augment the checkpoint with EMA state, if available.

    Adds the EMA buffer/metadata to the checkpoint so that EMA can be restored upon load.

    Args:
        checkpoint (dict): Mutable checkpoint dictionary provided by Lightning.
    """
    super().on_save_checkpoint(checkpoint)
    if self.ema is not None:
        checkpoint["ema_state"] = self.ema.state_dict()

on_load_checkpoint(checkpoint)

Restore EMA state from a checkpoint, if present.

Parameters:

Name Type Description Default
checkpoint dict

Checkpoint dictionary provided by Lightning containing model state and optional "ema_state" entry.

required
Source code in opensr_srgan/model/SRGAN.py
def on_load_checkpoint(self, checkpoint: dict) -> None:
    """Restore EMA state from a checkpoint, if present.

    Args:
        checkpoint (dict): Checkpoint dictionary provided by Lightning containing
            model state and optional `"ema_state"` entry.
    """
    super().on_load_checkpoint(checkpoint)
    if self.ema is not None and "ema_state" in checkpoint:
        self.ema.load_state_dict(checkpoint["ema_state"])

load_from_checkpoint(ckpt_path)

Load model weights from a PyTorch Lightning checkpoint file.

Loads the state_dict from the given checkpoint and maps it to the current device.

Parameters:

Name Type Description Default
ckpt_path str | Path

Path to the .ckpt file saved by Lightning.

required

Raises:

Type Description
FileNotFoundError

If the checkpoint path does not exist.

KeyError

If the checkpoint does not contain a 'state_dict' entry.

Source code in opensr_srgan/model/SRGAN.py
def load_from_checkpoint(self, ckpt_path) -> None:
    """Load model weights from a PyTorch Lightning checkpoint file.

    Loads the `state_dict` from the given checkpoint and maps it to the current device.

    Args:
        ckpt_path (str | pathlib.Path): Path to the `.ckpt` file saved by Lightning.

    Raises:
        FileNotFoundError: If the checkpoint path does not exist.
        KeyError: If the checkpoint does not contain a `'state_dict'` entry.
    """
    # load ckpt
    ckpt = torch.load(ckpt_path, map_location=self.device)
    self.load_state_dict(ckpt["state_dict"])
    print(f"Loaded checkpoint from {ckpt_path}")

Building blocks

Reusable building blocks for SRGAN family models.

ConvolutionalBlock

Bases: Module

A convolutional block comprised of Conv → (BN) → (Activation).

Source code in opensr_srgan/model/model_blocks/__init__.py
class ConvolutionalBlock(nn.Module):
    """A convolutional block comprised of Conv → (BN) → (Activation)."""

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int = 1,
        batch_norm: bool = False,
        activation: str | None = None,
    ) -> None:
        super().__init__()

        act = activation.lower() if activation is not None else None
        if act is not None:
            if act not in {"prelu", "leakyrelu", "tanh"}:
                raise AssertionError("activation must be one of {'prelu', 'leakyrelu', 'tanh'}")

        layers: list[nn.Module] = [
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=kernel_size // 2,
            )
        ]

        if batch_norm:
            layers.append(nn.BatchNorm2d(num_features=out_channels))

        if act == "prelu":
            layers.append(nn.PReLU())
        elif act == "leakyrelu":
            layers.append(nn.LeakyReLU(0.2))
        elif act == "tanh":
            layers.append(nn.Tanh())

        self.conv_block = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.conv_block(x)

SubPixelConvolutionalBlock

Bases: Module

Conv → PixelShuffle → PReLU upsampling block.

Source code in opensr_srgan/model/model_blocks/__init__.py
class SubPixelConvolutionalBlock(nn.Module):
    """Conv → PixelShuffle → PReLU upsampling block."""

    def __init__(
        self,
        kernel_size: int = 3,
        n_channels: int = 64,
        scaling_factor: int = 2,
    ) -> None:
        super().__init__()

        self.conv = nn.Conv2d(
            in_channels=n_channels,
            out_channels=n_channels * (scaling_factor**2),
            kernel_size=kernel_size,
            padding=kernel_size // 2,
        )
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor=scaling_factor)
        self.prelu = nn.PReLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        x = self.prelu(x)
        return x

ResidualBlock

Bases: Module

BN-enabled residual block used in the original SRResNet.

Source code in opensr_srgan/model/model_blocks/__init__.py
class ResidualBlock(nn.Module):
    """BN-enabled residual block used in the original SRResNet."""

    def __init__(self, kernel_size: int = 3, n_channels: int = 64) -> None:
        super().__init__()
        self.conv_block1 = ConvolutionalBlock(
            in_channels=n_channels,
            out_channels=n_channels,
            kernel_size=kernel_size,
            batch_norm=True,
            activation="PReLu",
        )
        self.conv_block2 = ConvolutionalBlock(
            in_channels=n_channels,
            out_channels=n_channels,
            kernel_size=kernel_size,
            batch_norm=True,
            activation=None,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = x
        x = self.conv_block1(x)
        x = self.conv_block2(x)
        return x + residual

ResidualBlockNoBN

Bases: Module

Residual block variant without batch norm and with residual scaling.

Source code in opensr_srgan/model/model_blocks/__init__.py
class ResidualBlockNoBN(nn.Module):
    """Residual block variant without batch norm and with residual scaling."""

    def __init__(
        self,
        n_channels: int = 64,
        kernel_size: int = 3,
        res_scale: float = 0.2,
    ) -> None:
        super().__init__()
        padding = kernel_size // 2
        self.body = nn.Sequential(
            nn.Conv2d(n_channels, n_channels, kernel_size, padding=padding),
            nn.PReLU(),
            nn.Conv2d(n_channels, n_channels, kernel_size, padding=padding),
        )
        self.res_scale = res_scale

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.res_scale * self.body(x)

RCAB

Bases: Module

Residual Channel Attention Block (no BN).

Source code in opensr_srgan/model/model_blocks/__init__.py
class RCAB(nn.Module):
    """Residual Channel Attention Block (no BN)."""

    def __init__(
        self,
        n_channels: int = 64,
        kernel_size: int = 3,
        reduction: int = 16,
        res_scale: float = 0.2,
    ) -> None:
        super().__init__()
        padding = kernel_size // 2
        self.conv1 = nn.Conv2d(n_channels, n_channels, kernel_size, padding=padding)
        self.act = nn.PReLU()
        self.conv2 = nn.Conv2d(n_channels, n_channels, kernel_size, padding=padding)
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(n_channels, n_channels // reduction, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(n_channels // reduction, n_channels, 1),
            nn.Sigmoid(),
        )
        self.res_scale = res_scale

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = self.conv1(x)
        y = self.act(y)
        y = self.conv2(y)
        w = self.se(y)
        return x + self.res_scale * (y * w)

DenseBlock5

Bases: Module

ESRGAN-style dense block with five convolutions.

Source code in opensr_srgan/model/model_blocks/__init__.py
class DenseBlock5(nn.Module):
    """ESRGAN-style dense block with five convolutions."""

    def __init__(self, n_features: int = 64, growth_channels: int = 32, kernel_size: int = 3) -> None:
        super().__init__()
        padding = kernel_size // 2
        self.act = nn.LeakyReLU(0.2, inplace=True)
        self.c1 = nn.Conv2d(n_features, growth_channels, kernel_size, padding=padding)
        self.c2 = nn.Conv2d(n_features + growth_channels, growth_channels, kernel_size, padding=padding)
        self.c3 = nn.Conv2d(n_features + 2 * growth_channels, growth_channels, kernel_size, padding=padding)
        self.c4 = nn.Conv2d(n_features + 3 * growth_channels, growth_channels, kernel_size, padding=padding)
        self.c5 = nn.Conv2d(n_features + 4 * growth_channels, n_features, kernel_size, padding=padding)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x1 = self.act(self.c1(x))
        x2 = self.act(self.c2(torch.cat([x, x1], dim=1)))
        x3 = self.act(self.c3(torch.cat([x, x1, x2], dim=1)))
        x4 = self.act(self.c4(torch.cat([x, x1, x2, x3], dim=1)))
        x5 = self.c5(torch.cat([x, x1, x2, x3, x4], dim=1))
        return x5

RRDB

Bases: Module

Residual-in-Residual Dense Block.

Source code in opensr_srgan/model/model_blocks/__init__.py
class RRDB(nn.Module):
    """Residual-in-Residual Dense Block."""

    def __init__(self, n_features: int = 64, growth_channels: int = 32, res_scale: float = 0.2) -> None:
        super().__init__()
        self.db1 = DenseBlock5(n_features, growth_channels)
        self.db2 = DenseBlock5(n_features, growth_channels)
        self.db3 = DenseBlock5(n_features, growth_channels)
        self.res_scale = res_scale

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = self.db1(x)
        y = self.db2(y)
        y = self.db3(y)
        return x + self.res_scale * y

LKA

Bases: Module

Lightweight Large-Kernel Attention module.

Source code in opensr_srgan/model/model_blocks/__init__.py
class LKA(nn.Module):
    """Lightweight Large-Kernel Attention module."""

    def __init__(self, n_channels: int = 64) -> None:
        super().__init__()
        self.dw5 = nn.Conv2d(n_channels, n_channels, 5, padding=2, groups=n_channels)
        self.dw7d = nn.Conv2d(n_channels, n_channels, 7, padding=9, dilation=3, groups=n_channels)
        self.pw = nn.Conv2d(n_channels, n_channels, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        attn = self.dw5(x)
        attn = self.dw7d(attn)
        attn = self.pw(attn)
        return x * torch.sigmoid(attn)

LKAResBlock

Bases: Module

Residual block incorporating Large-Kernel Attention.

Source code in opensr_srgan/model/model_blocks/__init__.py
class LKAResBlock(nn.Module):
    """Residual block incorporating Large-Kernel Attention."""

    def __init__(self, n_channels: int = 64, kernel_size: int = 3, res_scale: float = 0.2) -> None:
        super().__init__()
        padding = kernel_size // 2
        self.conv1 = nn.Conv2d(n_channels, n_channels, kernel_size, padding=padding)
        self.act = nn.PReLU()
        self.lka = LKA(n_channels)
        self.conv2 = nn.Conv2d(n_channels, n_channels, kernel_size, padding=padding)
        self.res_scale = res_scale

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = self.conv1(x)
        y = self.act(y)
        y = self.lka(y)
        y = self.conv2(y)
        return x + self.res_scale * y

ExponentialMovingAverage

Maintain an exponential moving average (EMA) of a model’s parameters and buffers.

This class provides a self-contained implementation of parameter smoothing via EMA, commonly used to stabilize training and improve generalization in deep generative models. It tracks both model parameters and registered buffers (e.g., batch norm statistics), maintains a decayed running average, and allows temporary swapping of model weights for evaluation or checkpointing.

EMA is updated with each training step:

shadow = decay * shadow + (1 - decay) * parameter
where decay is typically close to 1.0 (e.g., 0.999–0.9999).

The class includes
  • On-the-fly registration of parameters/buffers from an existing model.
  • Safe apply/restore methods to temporarily replace model weights.
  • Device management for multi-GPU and CPU environments.
  • Full checkpoint serialization support.

Parameters:

Name Type Description Default
model Module

The model whose parameters are to be tracked.

required
decay float

Smoothing coefficient (0 ≤ decay ≤ 1). Higher values make EMA updates slower. Default is 0.999.

0.999
use_num_updates bool

Whether to adapt decay during early updates (useful for warm-up). Default is True.

True
device str | device | None

Optional target device for storing EMA parameters (e.g., "cpu" for offloading). Default is None.

None

Attributes:

Name Type Description
decay float

EMA smoothing coefficient.

num_updates int | None

Counter of EMA updates, used to adapt decay.

device device | None

Device where EMA tensors are stored.

shadow_params Dict[str, Tensor]

Smoothed parameter tensors.

shadow_buffers Dict[str, Tensor]

Smoothed buffer tensors.

collected_params Dict[str, Tensor]

Temporary cache for original parameters during apply/restore operations.

collected_buffers Dict[str, Tensor]

Temporary cache for original buffers during apply/restore operations.

Source code in opensr_srgan/model/model_blocks/EMA.py
class ExponentialMovingAverage:
    """Maintain an exponential moving average (EMA) of a model’s parameters and buffers.

    This class provides a self-contained implementation of parameter smoothing
    via EMA, commonly used to stabilize training and improve generalization in
    deep generative models. It tracks both model parameters and registered
    buffers (e.g., batch norm statistics), maintains a decayed running average,
    and allows temporary swapping of model weights for evaluation or checkpointing.

    EMA is updated with each training step:
    ```
    shadow = decay * shadow + (1 - decay) * parameter
    ```
    where ``decay`` is typically close to 1.0 (e.g., 0.999–0.9999).

    The class includes:
        - On-the-fly registration of parameters/buffers from an existing model.
        - Safe apply/restore methods to temporarily replace model weights.
        - Device management for multi-GPU and CPU environments.
        - Full checkpoint serialization support.

    Args:
        model (nn.Module): The model whose parameters are to be tracked.
        decay (float, optional): Smoothing coefficient (0 ≤ decay ≤ 1).
            Higher values make EMA updates slower. Default is 0.999.
        use_num_updates (bool, optional): Whether to adapt decay during early
            updates (useful for warm-up). Default is True.
        device (str | torch.device | None, optional): Optional target device for
            storing EMA parameters (e.g., "cpu" for offloading). Default is None.

    Attributes:
        decay (float): EMA smoothing coefficient.
        num_updates (int | None): Counter of EMA updates, used to adapt decay.
        device (torch.device | None): Device where EMA tensors are stored.
        shadow_params (Dict[str, torch.Tensor]): Smoothed parameter tensors.
        shadow_buffers (Dict[str, torch.Tensor]): Smoothed buffer tensors.
        collected_params (Dict[str, torch.Tensor]): Temporary cache for original
            parameters during apply/restore operations.
        collected_buffers (Dict[str, torch.Tensor]): Temporary cache for original
            buffers during apply/restore operations.
    """

    def __init__(
        self,
        model: nn.Module,
        decay: float = 0.999,
        *,
        use_num_updates: bool = True,
        device: str | torch.device | None = None,
    ) -> None:
        if not 0.0 <= decay <= 1.0:
            raise ValueError("decay must be between 0 and 1 (inclusive)")

        self.decay = float(decay)
        self.num_updates = 0 if use_num_updates else None
        self.device = torch.device(device) if device is not None else None

        self.shadow_params: Dict[str, torch.Tensor] = {}
        self.shadow_buffers: Dict[str, torch.Tensor] = {}
        self.collected_params: Dict[str, torch.Tensor] = {}
        self.collected_buffers: Dict[str, torch.Tensor] = {}

        self._register(model)

    def _register(self, model: nn.Module) -> None:
        for name, param in model.named_parameters():
            if not param.requires_grad:
                continue
            shadow = param.detach().clone()
            if self.device is not None:
                shadow = shadow.to(self.device)
            self.shadow_params[name] = shadow

        for name, buffer in model.named_buffers():
            shadow = buffer.detach().clone()
            if self.device is not None:
                shadow = shadow.to(self.device)
            self.shadow_buffers[name] = shadow

    def update(self, model: nn.Module) -> None:
        """Update the EMA weights using the latest parameters from ``model``.

        Performs an in-place exponential moving average update on all
        trainable parameters and buffers tracked in ``shadow_params`` and
        ``shadow_buffers``. If ``use_num_updates=True``, adapts the decay
        coefficient during early steps for smoother warm-up.

        Args:
            model (nn.Module): Model whose parameters and buffers are used to
                update the EMA state.

        Notes:
            - Dynamically adds new parameters or buffers if they were not
              present during initialization.
            - Operates in ``torch.no_grad()`` context to avoid gradient tracking.
        """
        if self.num_updates is not None:
            self.num_updates += 1
            decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
        else:
            decay = self.decay

        one_minus_decay = 1.0 - decay

        with torch.no_grad():
            for name, param in model.named_parameters():
                if not param.requires_grad:
                    continue

                if name not in self.shadow_params:
                    # Parameter may have been added dynamically (rare, but safeguard)
                    shadow = param.detach().clone()
                    if self.device is not None:
                        shadow = shadow.to(self.device)
                    self.shadow_params[name] = shadow

                shadow_param = self.shadow_params[name]
                param_data = param.detach()
                if param_data.device != shadow_param.device:
                    param_data = param_data.to(shadow_param.device)

                shadow_param.lerp_(param_data, one_minus_decay)

            for name, buffer in model.named_buffers():
                if name not in self.shadow_buffers:
                    shadow = buffer.detach().clone()
                    if self.device is not None:
                        shadow = shadow.to(self.device)
                    self.shadow_buffers[name] = shadow

                shadow_buffer = self.shadow_buffers[name]
                buffer_data = buffer.detach()
                if buffer_data.device != shadow_buffer.device:
                    buffer_data = buffer_data.to(shadow_buffer.device)
                shadow_buffer.copy_(buffer_data)

    def apply_to(self, model: nn.Module) -> None:
        """Replace model parameters with EMA-smoothed versions (in-place).

        Temporarily swaps the current model parameters and buffers with their
        EMA counterparts for evaluation or checkpoint export. The original
        tensors are cached internally and can be restored later with
        :meth:`restore`.

        Args:
            model (nn.Module): Model whose parameters will be replaced.

        Raises:
            RuntimeError: If EMA weights are already applied and not yet restored.
        """
        if self.collected_params or self.collected_buffers:
            raise RuntimeError("EMA weights already applied; call restore() before reapplying.")

        for name, param in model.named_parameters():
            if not param.requires_grad or name not in self.shadow_params:
                continue
            self.collected_params[name] = param.detach().clone()
            param.data.copy_(self.shadow_params[name].to(param.device))

        for name, buffer in model.named_buffers():
            if name not in self.shadow_buffers:
                continue
            self.collected_buffers[name] = buffer.detach().clone()
            buffer.data.copy_(self.shadow_buffers[name].to(buffer.device))

    def restore(self, model: nn.Module) -> None:
        """Restore the model’s original parameters after an EMA swap.

        Reverts the parameter and buffer changes made by :meth:`apply_to`
        by restoring the cached tensors. This is a no-op if EMA weights
        were never applied.

        Args:
            model (nn.Module): Model whose parameters will be restored.
        """
        for name, param in model.named_parameters():
            cached = self.collected_params.pop(name, None)
            if cached is None:
                continue
            param.data.copy_(cached.to(param.device))

        for name, buffer in model.named_buffers():
            cached = self.collected_buffers.pop(name, None)
            if cached is None:
                continue
            buffer.data.copy_(cached.to(buffer.device))

    @contextmanager
    def average_parameters(self, model: nn.Module) -> Iterator[None]:
        """Context manager to temporarily apply EMA weights to ``model``.

        This convenience wrapper allows for automatic restoration after use.
        Example:
        ```python
        with ema.average_parameters(model):
            validate(model)
        ```

        Args:
            model (nn.Module): The model to temporarily replace parameters for.

        Yields:
            None: Executes the body of the context with EMA weights applied.
        """
        self.apply_to(model)
        try:
            yield
        finally:
            self.restore(model)

    def to(self, device: str | torch.device) -> None:
        """Move EMA-tracked tensors to a target device.

        Transfers all shadow parameters and buffers to the specified device,
        updating the internal ``self.device`` reference.

        Args:
            device (str | torch.device): Target device (e.g., "cuda", "cpu").
        """
        target_device = torch.device(device)
        for name, tensor in list(self.shadow_params.items()):
            self.shadow_params[name] = tensor.to(target_device)
        for name, tensor in list(self.shadow_buffers.items()):
            self.shadow_buffers[name] = tensor.to(target_device)
        self.device = target_device

    def state_dict(self) -> Dict[str, object]:
        """Return a serializable state dictionary for checkpointing.

        Packages all relevant EMA state into a plain dictionary, compatible
        with PyTorch’s standard checkpoint format. Converts all tensors to CPU
        for safe serialization.

        Returns:
            Dict[str, object]: Dictionary containing EMA decay, update count,
            device info, and copies of shadow parameters/buffers.
        """
        return {
            "decay": self.decay,
            "num_updates": self.num_updates,
            "device": str(self.device) if self.device is not None else None,
            "shadow_params": {k: v.detach().cpu() for k, v in self.shadow_params.items()},
            "shadow_buffers": {k: v.detach().cpu() for k, v in self.shadow_buffers.items()},
        }

    def load_state_dict(self, state_dict: Dict[str, object]) -> None:
        """Load EMA state from a previously saved checkpoint.

        Reconstructs the EMA tracking state from a saved dictionary, restoring
        all tracked parameters, buffers, and metadata such as decay, device,
        and update count.

        Args:
            state_dict (Dict[str, object]): Dictionary as produced by
                :meth:`state_dict`.

        Notes:
            - Tensors are moved to the current or saved device automatically.
            - Clears existing collected (applied) caches to avoid stale state.
        """
        self.decay = float(state_dict["decay"])
        self.num_updates = state_dict["num_updates"]
        device_str = state_dict.get("device", None)
        self.device = torch.device(device_str) if device_str is not None else None

        self.shadow_params = {
            name: tensor.clone().to(self.device) if self.device is not None else tensor.clone()
            for name, tensor in state_dict.get("shadow_params", {}).items()
        }
        self.shadow_buffers = {
            name: tensor.clone().to(self.device) if self.device is not None else tensor.clone()
            for name, tensor in state_dict.get("shadow_buffers", {}).items()
        }

        self.collected_params = {}
        self.collected_buffers = {}

update(model)

Update the EMA weights using the latest parameters from model.

Performs an in-place exponential moving average update on all trainable parameters and buffers tracked in shadow_params and shadow_buffers. If use_num_updates=True, adapts the decay coefficient during early steps for smoother warm-up.

Parameters:

Name Type Description Default
model Module

Model whose parameters and buffers are used to update the EMA state.

required
Notes
  • Dynamically adds new parameters or buffers if they were not present during initialization.
  • Operates in torch.no_grad() context to avoid gradient tracking.
Source code in opensr_srgan/model/model_blocks/EMA.py
def update(self, model: nn.Module) -> None:
    """Update the EMA weights using the latest parameters from ``model``.

    Performs an in-place exponential moving average update on all
    trainable parameters and buffers tracked in ``shadow_params`` and
    ``shadow_buffers``. If ``use_num_updates=True``, adapts the decay
    coefficient during early steps for smoother warm-up.

    Args:
        model (nn.Module): Model whose parameters and buffers are used to
            update the EMA state.

    Notes:
        - Dynamically adds new parameters or buffers if they were not
          present during initialization.
        - Operates in ``torch.no_grad()`` context to avoid gradient tracking.
    """
    if self.num_updates is not None:
        self.num_updates += 1
        decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
    else:
        decay = self.decay

    one_minus_decay = 1.0 - decay

    with torch.no_grad():
        for name, param in model.named_parameters():
            if not param.requires_grad:
                continue

            if name not in self.shadow_params:
                # Parameter may have been added dynamically (rare, but safeguard)
                shadow = param.detach().clone()
                if self.device is not None:
                    shadow = shadow.to(self.device)
                self.shadow_params[name] = shadow

            shadow_param = self.shadow_params[name]
            param_data = param.detach()
            if param_data.device != shadow_param.device:
                param_data = param_data.to(shadow_param.device)

            shadow_param.lerp_(param_data, one_minus_decay)

        for name, buffer in model.named_buffers():
            if name not in self.shadow_buffers:
                shadow = buffer.detach().clone()
                if self.device is not None:
                    shadow = shadow.to(self.device)
                self.shadow_buffers[name] = shadow

            shadow_buffer = self.shadow_buffers[name]
            buffer_data = buffer.detach()
            if buffer_data.device != shadow_buffer.device:
                buffer_data = buffer_data.to(shadow_buffer.device)
            shadow_buffer.copy_(buffer_data)

apply_to(model)

Replace model parameters with EMA-smoothed versions (in-place).

Temporarily swaps the current model parameters and buffers with their EMA counterparts for evaluation or checkpoint export. The original tensors are cached internally and can be restored later with :meth:restore.

Parameters:

Name Type Description Default
model Module

Model whose parameters will be replaced.

required

Raises:

Type Description
RuntimeError

If EMA weights are already applied and not yet restored.

Source code in opensr_srgan/model/model_blocks/EMA.py
def apply_to(self, model: nn.Module) -> None:
    """Replace model parameters with EMA-smoothed versions (in-place).

    Temporarily swaps the current model parameters and buffers with their
    EMA counterparts for evaluation or checkpoint export. The original
    tensors are cached internally and can be restored later with
    :meth:`restore`.

    Args:
        model (nn.Module): Model whose parameters will be replaced.

    Raises:
        RuntimeError: If EMA weights are already applied and not yet restored.
    """
    if self.collected_params or self.collected_buffers:
        raise RuntimeError("EMA weights already applied; call restore() before reapplying.")

    for name, param in model.named_parameters():
        if not param.requires_grad or name not in self.shadow_params:
            continue
        self.collected_params[name] = param.detach().clone()
        param.data.copy_(self.shadow_params[name].to(param.device))

    for name, buffer in model.named_buffers():
        if name not in self.shadow_buffers:
            continue
        self.collected_buffers[name] = buffer.detach().clone()
        buffer.data.copy_(self.shadow_buffers[name].to(buffer.device))

restore(model)

Restore the model’s original parameters after an EMA swap.

Reverts the parameter and buffer changes made by :meth:apply_to by restoring the cached tensors. This is a no-op if EMA weights were never applied.

Parameters:

Name Type Description Default
model Module

Model whose parameters will be restored.

required
Source code in opensr_srgan/model/model_blocks/EMA.py
def restore(self, model: nn.Module) -> None:
    """Restore the model’s original parameters after an EMA swap.

    Reverts the parameter and buffer changes made by :meth:`apply_to`
    by restoring the cached tensors. This is a no-op if EMA weights
    were never applied.

    Args:
        model (nn.Module): Model whose parameters will be restored.
    """
    for name, param in model.named_parameters():
        cached = self.collected_params.pop(name, None)
        if cached is None:
            continue
        param.data.copy_(cached.to(param.device))

    for name, buffer in model.named_buffers():
        cached = self.collected_buffers.pop(name, None)
        if cached is None:
            continue
        buffer.data.copy_(cached.to(buffer.device))

average_parameters(model)

Context manager to temporarily apply EMA weights to model.

This convenience wrapper allows for automatic restoration after use. Example:

with ema.average_parameters(model):
    validate(model)

Parameters:

Name Type Description Default
model Module

The model to temporarily replace parameters for.

required

Yields:

Name Type Description
None None

Executes the body of the context with EMA weights applied.

Source code in opensr_srgan/model/model_blocks/EMA.py
@contextmanager
def average_parameters(self, model: nn.Module) -> Iterator[None]:
    """Context manager to temporarily apply EMA weights to ``model``.

    This convenience wrapper allows for automatic restoration after use.
    Example:
    ```python
    with ema.average_parameters(model):
        validate(model)
    ```

    Args:
        model (nn.Module): The model to temporarily replace parameters for.

    Yields:
        None: Executes the body of the context with EMA weights applied.
    """
    self.apply_to(model)
    try:
        yield
    finally:
        self.restore(model)

to(device)

Move EMA-tracked tensors to a target device.

Transfers all shadow parameters and buffers to the specified device, updating the internal self.device reference.

Parameters:

Name Type Description Default
device str | device

Target device (e.g., "cuda", "cpu").

required
Source code in opensr_srgan/model/model_blocks/EMA.py
def to(self, device: str | torch.device) -> None:
    """Move EMA-tracked tensors to a target device.

    Transfers all shadow parameters and buffers to the specified device,
    updating the internal ``self.device`` reference.

    Args:
        device (str | torch.device): Target device (e.g., "cuda", "cpu").
    """
    target_device = torch.device(device)
    for name, tensor in list(self.shadow_params.items()):
        self.shadow_params[name] = tensor.to(target_device)
    for name, tensor in list(self.shadow_buffers.items()):
        self.shadow_buffers[name] = tensor.to(target_device)
    self.device = target_device

state_dict()

Return a serializable state dictionary for checkpointing.

Packages all relevant EMA state into a plain dictionary, compatible with PyTorch’s standard checkpoint format. Converts all tensors to CPU for safe serialization.

Returns:

Type Description
Dict[str, object]

Dict[str, object]: Dictionary containing EMA decay, update count,

Dict[str, object]

device info, and copies of shadow parameters/buffers.

Source code in opensr_srgan/model/model_blocks/EMA.py
def state_dict(self) -> Dict[str, object]:
    """Return a serializable state dictionary for checkpointing.

    Packages all relevant EMA state into a plain dictionary, compatible
    with PyTorch’s standard checkpoint format. Converts all tensors to CPU
    for safe serialization.

    Returns:
        Dict[str, object]: Dictionary containing EMA decay, update count,
        device info, and copies of shadow parameters/buffers.
    """
    return {
        "decay": self.decay,
        "num_updates": self.num_updates,
        "device": str(self.device) if self.device is not None else None,
        "shadow_params": {k: v.detach().cpu() for k, v in self.shadow_params.items()},
        "shadow_buffers": {k: v.detach().cpu() for k, v in self.shadow_buffers.items()},
    }

load_state_dict(state_dict)

Load EMA state from a previously saved checkpoint.

Reconstructs the EMA tracking state from a saved dictionary, restoring all tracked parameters, buffers, and metadata such as decay, device, and update count.

Parameters:

Name Type Description Default
state_dict Dict[str, object]

Dictionary as produced by :meth:state_dict.

required
Notes
  • Tensors are moved to the current or saved device automatically.
  • Clears existing collected (applied) caches to avoid stale state.
Source code in opensr_srgan/model/model_blocks/EMA.py
def load_state_dict(self, state_dict: Dict[str, object]) -> None:
    """Load EMA state from a previously saved checkpoint.

    Reconstructs the EMA tracking state from a saved dictionary, restoring
    all tracked parameters, buffers, and metadata such as decay, device,
    and update count.

    Args:
        state_dict (Dict[str, object]): Dictionary as produced by
            :meth:`state_dict`.

    Notes:
        - Tensors are moved to the current or saved device automatically.
        - Clears existing collected (applied) caches to avoid stale state.
    """
    self.decay = float(state_dict["decay"])
    self.num_updates = state_dict["num_updates"]
    device_str = state_dict.get("device", None)
    self.device = torch.device(device_str) if device_str is not None else None

    self.shadow_params = {
        name: tensor.clone().to(self.device) if self.device is not None else tensor.clone()
        for name, tensor in state_dict.get("shadow_params", {}).items()
    }
    self.shadow_buffers = {
        name: tensor.clone().to(self.device) if self.device is not None else tensor.clone()
        for name, tensor in state_dict.get("shadow_buffers", {}).items()
    }

    self.collected_params = {}
    self.collected_buffers = {}

make_upsampler(n_channels, scale)

Create a pixel-shuffle upsampler matching the flexible generator implementation.

Source code in opensr_srgan/model/model_blocks/__init__.py
def make_upsampler(n_channels: int, scale: int) -> nn.Sequential:
    """Create a pixel-shuffle upsampler matching the flexible generator implementation."""

    stages: list[nn.Module] = []
    for _ in range(int(math.log2(scale))):
        stages.extend(
            [
                nn.Conv2d(n_channels, n_channels * 4, 3, padding=1),
                nn.PixelShuffle(2),
                nn.PReLU(),
            ]
        )
    return nn.Sequential(*stages)

Generator architectures

Generator architectures for SRGAN.

SRResNet

Bases: Module

Canonical SRResNet generator for single-image super-resolution.

Implements the SRResNet backbone with

1) Large-kernel stem conv (+PReLU) 2) N × residual blocks (small-kernel, no upsampling) 3) Conv for global residual fusion 4) Log2(scaling_factor) × SubPixelConvolutionalBlock (×2 each) 5) Large-kernel output conv (+Tanh)

Parameters:

Name Type Description Default
in_channels int

Input channels (e.g., 3 for RGB).

3
large_kernel_size int

Kernel size for head/tail convolutions.

9
small_kernel_size int

Kernel size used inside residual/upsampling blocks.

3
n_channels int

Feature width across the network.

64
n_blocks int

Number of residual blocks in the trunk.

16
scaling_factor int

Upscale factor (must be one of {2, 4, 8}).

4

Returns:

Type Description

torch.Tensor: Super-resolved image of shape (B, in_channels, Hscale, Wscale).

Notes
  • The network uses a global skip connection around the residual stack.
  • Upsampling is performed by PixelShuffle via sub-pixel convolution blocks.
Source code in opensr_srgan/model/generators/srresnet.py
class SRResNet(nn.Module):
    """Canonical SRResNet generator for single-image super-resolution.

    Implements the SRResNet backbone with:
        1) Large-kernel stem conv (+PReLU)
        2) N × residual blocks (small-kernel, no upsampling)
        3) Conv for global residual fusion
        4) Log2(scaling_factor) × SubPixelConvolutionalBlock (×2 each)
        5) Large-kernel output conv (+Tanh)

    Args:
        in_channels (int): Input channels (e.g., 3 for RGB).
        large_kernel_size (int): Kernel size for head/tail convolutions.
        small_kernel_size (int): Kernel size used inside residual/upsampling blocks.
        n_channels (int): Feature width across the network.
        n_blocks (int): Number of residual blocks in the trunk.
        scaling_factor (int): Upscale factor (must be one of {2, 4, 8}).

    Returns:
        torch.Tensor: Super-resolved image of shape (B, in_channels, H*scale, W*scale).

    Notes:
        - The network uses a global skip connection around the residual stack.
        - Upsampling is performed by PixelShuffle via sub-pixel convolution blocks.
    """

    def __init__(
        self,
        in_channels: int = 3,
        large_kernel_size: int = 9,
        small_kernel_size: int = 3,
        n_channels: int = 64,
        n_blocks: int = 16,
        scaling_factor: int = 4,
    ) -> None:
        """Build the canonical SRResNet generator network."""
        super().__init__()

        scaling_factor = int(scaling_factor)
        if scaling_factor not in {2, 4, 8}:
            raise AssertionError("The scaling factor must be 2, 4, or 8!")

        self.conv_block1 = ConvolutionalBlock(
            in_channels=in_channels,
            out_channels=n_channels,
            kernel_size=large_kernel_size,
            batch_norm=False,
            activation="PReLu",
        )

        self.residual_blocks = nn.Sequential(
            *[
                ResidualBlock(
                    kernel_size=small_kernel_size,
                    n_channels=n_channels,
                )
                for _ in range(n_blocks)
            ]
        )

        self.conv_block2 = ConvolutionalBlock(
            in_channels=n_channels,
            out_channels=n_channels,
            kernel_size=small_kernel_size,
            batch_norm=True,
            activation=None,
        )

        n_subpixel_convolution_blocks = int(math.log2(scaling_factor))
        self.subpixel_convolutional_blocks = nn.Sequential(
            *[
                SubPixelConvolutionalBlock(
                    kernel_size=small_kernel_size,
                    n_channels=n_channels,
                    scaling_factor=2,
                )
                for _ in range(n_subpixel_convolution_blocks)
            ]
        )

        self.conv_block3 = ConvolutionalBlock(
            in_channels=n_channels,
            out_channels=in_channels,
            kernel_size=large_kernel_size,
            batch_norm=False,
            activation="Tanh",
        )

    def forward(self, lr_imgs: torch.Tensor) -> torch.Tensor:
        """Forward propagation through SRResNet."""
        output = self.conv_block1(lr_imgs)
        residual = output
        output = self.residual_blocks(output)
        output = self.conv_block2(output)
        output = output + residual
        output = self.subpixel_convolutional_blocks(output)
        sr_imgs = self.conv_block3(output)
        return sr_imgs

forward(lr_imgs)

Forward propagation through SRResNet.

Source code in opensr_srgan/model/generators/srresnet.py
def forward(self, lr_imgs: torch.Tensor) -> torch.Tensor:
    """Forward propagation through SRResNet."""
    output = self.conv_block1(lr_imgs)
    residual = output
    output = self.residual_blocks(output)
    output = self.conv_block2(output)
    output = output + residual
    output = self.subpixel_convolutional_blocks(output)
    sr_imgs = self.conv_block3(output)
    return sr_imgs

Generator

Bases: Module

SRGAN generator wrapper around :class:SRResNet.

Provides a thin adapter that
  • Builds an internal :class:SRResNet with the given hyperparameters.
  • Optionally initializes weights from a pretrained SRResNet checkpoint.
  • Exposes a unified forward for SRGAN pipelines.

Parameters:

Name Type Description Default
in_channels int

Input channels (e.g., 3 for RGB).

3
large_kernel_size int

Kernel size for head/tail convolutions.

9
small_kernel_size int

Kernel size used inside residual/upsampling blocks.

3
n_channels int

Feature width across the network.

64
n_blocks int

Number of residual blocks in the trunk.

16
scaling_factor int

Upscale factor (must be one of {2, 4, 8}).

4

Returns:

Type Description

torch.Tensor: Super-resolved image produced by the wrapped SRResNet.

Source code in opensr_srgan/model/generators/srresnet.py
class Generator(nn.Module):
    """SRGAN generator wrapper around :class:`SRResNet`.

    Provides a thin adapter that:
        - Builds an internal :class:`SRResNet` with the given hyperparameters.
        - Optionally initializes weights from a pretrained SRResNet checkpoint.
        - Exposes a unified forward for SRGAN pipelines.

    Args:
        in_channels (int): Input channels (e.g., 3 for RGB).
        large_kernel_size (int): Kernel size for head/tail convolutions.
        small_kernel_size (int): Kernel size used inside residual/upsampling blocks.
        n_channels (int): Feature width across the network.
        n_blocks (int): Number of residual blocks in the trunk.
        scaling_factor (int): Upscale factor (must be one of {2, 4, 8}).

    Returns:
        torch.Tensor: Super-resolved image produced by the wrapped SRResNet.
    """

    def __init__(
        self,
        in_channels: int = 3,
        large_kernel_size: int = 9,
        small_kernel_size: int = 3,
        n_channels: int = 64,
        n_blocks: int = 16,
        scaling_factor: int = 4,
    ) -> None:
        super().__init__()

        self.net = SRResNet(
            in_channels=in_channels,
            large_kernel_size=large_kernel_size,
            small_kernel_size=small_kernel_size,
            n_channels=n_channels,
            n_blocks=n_blocks,
            scaling_factor=scaling_factor,
        )

    def initialize_with_srresnet(self, srresnet_checkpoint: str) -> None:
        """Initialize the generator weights from a pretrained SRResNet checkpoint."""
        srresnet = torch.load(srresnet_checkpoint)["model"]
        self.net.load_state_dict(srresnet.state_dict())
        print("\nLoaded weights from pre-trained SRResNet.\n")

    def forward(self, lr_imgs: torch.Tensor) -> torch.Tensor:
        """Forward propagation via the wrapped SRResNet."""
        return self.net(lr_imgs)

initialize_with_srresnet(srresnet_checkpoint)

Initialize the generator weights from a pretrained SRResNet checkpoint.

Source code in opensr_srgan/model/generators/srresnet.py
def initialize_with_srresnet(self, srresnet_checkpoint: str) -> None:
    """Initialize the generator weights from a pretrained SRResNet checkpoint."""
    srresnet = torch.load(srresnet_checkpoint)["model"]
    self.net.load_state_dict(srresnet.state_dict())
    print("\nLoaded weights from pre-trained SRResNet.\n")

forward(lr_imgs)

Forward propagation via the wrapped SRResNet.

Source code in opensr_srgan/model/generators/srresnet.py
def forward(self, lr_imgs: torch.Tensor) -> torch.Tensor:
    """Forward propagation via the wrapped SRResNet."""
    return self.net(lr_imgs)

StochasticGenerator

Bases: Module

Stochastic generator with latent noise modulation for super-resolution.

Extends a standard SR generator by injecting stochastic latent noise through NoiseResBlocks, enabling diverse texture generation conditioned on the same low-resolution (LR) input. When no latent vector is provided, one is sampled internally from a standard normal distribution, allowing both deterministic and stochastic inference modes.

The architecture follows a residual backbone with
  • A wide receptive-field head convolution.
  • Multiple noise-modulated residual blocks.
  • A tail convolution with learnable upsampling.
  • Configurable scaling factor (×2, ×4, or ×8).

Parameters:

Name Type Description Default
in_channels int

Number of input channels (e.g., RGB+NIR = 4 or 6).

6
n_channels int

Base number of feature channels in the generator.

96
n_blocks int

Number of noise-modulated residual blocks.

16
small_kernel int

Kernel size for body convolutions.

3
large_kernel int

Kernel size for head/tail convolutions.

9
scale int

Upscaling factor (must be one of {2, 4, 8}).

4
noise_dim int

Dimensionality of the latent vector z.

128
res_scale float

Residual scaling factor for block stability.

0.2

Attributes:

Name Type Description
noise_dim int

Dimensionality of latent vector z.

head Sequential

Initial convolutional stem.

body ModuleList

Sequence of NoiseResBlocks.

upsampler Module

PixelShuffle-based upsampling module.

tail Conv2d

Final convolution projecting to output space.

Example

g = StochasticGenerator(in_channels=3, scale=4) lr = torch.randn(1, 3, 64, 64) sr, noise = g(lr, return_noise=True)

Source code in opensr_srgan/model/generators/cgan_generator.py
class StochasticGenerator(nn.Module):
    """Stochastic generator with latent noise modulation for super-resolution.

    Extends a standard SR generator by injecting stochastic latent noise through
    `NoiseResBlock`s, enabling diverse texture generation conditioned on the same
    low-resolution (LR) input. When no latent vector is provided, one is sampled
    internally from a standard normal distribution, allowing both deterministic
    and stochastic inference modes.

    The architecture follows a residual backbone with:
        - A wide receptive-field head convolution.
        - Multiple noise-modulated residual blocks.
        - A tail convolution with learnable upsampling.
        - Configurable scaling factor (×2, ×4, or ×8).

    Args:
        in_channels (int): Number of input channels (e.g., RGB+NIR = 4 or 6).
        n_channels (int): Base number of feature channels in the generator.
        n_blocks (int): Number of noise-modulated residual blocks.
        small_kernel (int): Kernel size for body convolutions.
        large_kernel (int): Kernel size for head/tail convolutions.
        scale (int): Upscaling factor (must be one of {2, 4, 8}).
        noise_dim (int): Dimensionality of the latent vector z.
        res_scale (float): Residual scaling factor for block stability.

    Attributes:
        noise_dim (int): Dimensionality of latent vector z.
        head (nn.Sequential): Initial convolutional stem.
        body (nn.ModuleList): Sequence of `NoiseResBlock`s.
        upsampler (nn.Module): PixelShuffle-based upsampling module.
        tail (nn.Conv2d): Final convolution projecting to output space.

    Example:
        >>> g = StochasticGenerator(in_channels=3, scale=4)
        >>> lr = torch.randn(1, 3, 64, 64)
        >>> sr, noise = g(lr, return_noise=True)
    """

    def __init__(
        self,
        in_channels: int = 6,
        n_channels: int = 96,
        n_blocks: int = 16,
        small_kernel: int = 3,
        large_kernel: int = 9,
        scale: int = 4,
        noise_dim: int = 128,
        res_scale: float = 0.2,
    ) -> None:
        super().__init__()

        if scale not in {2, 4, 8}:
            raise ValueError("scale must be one of {2, 4, 8}")

        self.noise_dim = noise_dim
        self.scale = scale

        padding_large = large_kernel // 2
        self.head = nn.Sequential(
            nn.Conv2d(in_channels, n_channels, large_kernel, padding=padding_large),
            nn.PReLU(),
        )

        self.body = nn.ModuleList(
            [
                NoiseResBlock(n_channels, small_kernel, noise_dim, res_scale)
                for _ in range(n_blocks)
            ]
        )
        self.body_tail = nn.Conv2d(
            n_channels,
            n_channels,
            small_kernel,
            padding=small_kernel // 2,
        )
        self.upsampler = make_upsampler(n_channels, scale)
        self.tail = nn.Conv2d(
            n_channels, in_channels, large_kernel, padding=padding_large
        )

    def sample_noise(
        self,
        batch_size: int,
        device: Optional[torch.device] = None,
        dtype: Optional[torch.dtype] = None,
    ) -> torch.Tensor:
        """
        Sample a latent noise tensor consistent with the generator configuration.

        Parameters
        ----------
        batch_size : int
            Number of noise vectors to generate.
        device : torch.device, optional
            Device on which to allocate the tensor. Defaults to the model's current device.
        dtype : torch.dtype, optional
            Tensor dtype. Defaults to the model's parameter dtype.

        Returns
        -------
        torch.Tensor
            Random latent tensor sampled from a standard normal distribution,
            shape (B, D) where B = `batch_size` and D = `self.noise_dim`.

        Notes
        -----
        - Used for stochastic generation when no latent vector is provided to ``forward()``.
        - Ensures type and device consistency with the current model parameters.
        - The resulting noise can be reused to reproduce identical stochastic outputs.
        """
        if device is None:
            device = next(self.parameters()).device
        if dtype is None:
            dtype = next(self.parameters()).dtype
        return torch.randn(batch_size, self.noise_dim, device=device, dtype=dtype)

    def forward(
        self,
        lr: torch.Tensor,
        noise: Optional[torch.Tensor] = None,
        return_noise: bool = False,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass of the stochastic super-resolution generator.

        Parameters
        ----------
        lr : torch.Tensor
            Low-resolution input tensor of shape (B, C, H, W).
        noise : torch.Tensor, optional
            Latent noise tensor of shape (B, D), where D = `noise_dim`.
            If None, a random vector is sampled internally.
        return_noise : bool, default=False
            If True, returns both the super-resolved image and the
            latent noise used for generation.

        Returns
        -------
        torch.Tensor or (torch.Tensor, torch.Tensor)
            - ``sr``: Super-resolved output image of shape (B, C, sH, sW),
              where s is the upscaling factor.
            - Optionally, ``noise``: The latent vector used (if `return_noise=True`).

        Notes
        -----
        - The latent code is broadcast and applied to all residual blocks.
        - Supports both deterministic (fixed noise) and stochastic (random noise)
          inference modes.
        - The upsampling is performed via sub-pixel convolution (PixelShuffle)
          defined in `make_upsampler()`.
        """
        if noise is None:
            noise = torch.randn(
                lr.size(0),
                self.noise_dim,
                device=lr.device,
                dtype=lr.dtype,
            )

        features = self.head(lr)
        residual = features
        for block in self.body:
            residual = block(residual, noise)
        residual = self.body_tail(residual)
        features = features + residual
        features = self.upsampler(features)
        sr = self.tail(features)

        if return_noise:
            return sr, noise
        return sr

sample_noise(batch_size, device=None, dtype=None)

Sample a latent noise tensor consistent with the generator configuration.

Parameters

batch_size : int Number of noise vectors to generate. device : torch.device, optional Device on which to allocate the tensor. Defaults to the model's current device. dtype : torch.dtype, optional Tensor dtype. Defaults to the model's parameter dtype.

Returns

torch.Tensor Random latent tensor sampled from a standard normal distribution, shape (B, D) where B = batch_size and D = self.noise_dim.

Notes
  • Used for stochastic generation when no latent vector is provided to forward().
  • Ensures type and device consistency with the current model parameters.
  • The resulting noise can be reused to reproduce identical stochastic outputs.
Source code in opensr_srgan/model/generators/cgan_generator.py
def sample_noise(
    self,
    batch_size: int,
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
    """
    Sample a latent noise tensor consistent with the generator configuration.

    Parameters
    ----------
    batch_size : int
        Number of noise vectors to generate.
    device : torch.device, optional
        Device on which to allocate the tensor. Defaults to the model's current device.
    dtype : torch.dtype, optional
        Tensor dtype. Defaults to the model's parameter dtype.

    Returns
    -------
    torch.Tensor
        Random latent tensor sampled from a standard normal distribution,
        shape (B, D) where B = `batch_size` and D = `self.noise_dim`.

    Notes
    -----
    - Used for stochastic generation when no latent vector is provided to ``forward()``.
    - Ensures type and device consistency with the current model parameters.
    - The resulting noise can be reused to reproduce identical stochastic outputs.
    """
    if device is None:
        device = next(self.parameters()).device
    if dtype is None:
        dtype = next(self.parameters()).dtype
    return torch.randn(batch_size, self.noise_dim, device=device, dtype=dtype)

forward(lr, noise=None, return_noise=False)

Forward pass of the stochastic super-resolution generator.

Parameters

lr : torch.Tensor Low-resolution input tensor of shape (B, C, H, W). noise : torch.Tensor, optional Latent noise tensor of shape (B, D), where D = noise_dim. If None, a random vector is sampled internally. return_noise : bool, default=False If True, returns both the super-resolved image and the latent noise used for generation.

Returns

torch.Tensor or (torch.Tensor, torch.Tensor) - sr: Super-resolved output image of shape (B, C, sH, sW), where s is the upscaling factor. - Optionally, noise: The latent vector used (if return_noise=True).

Notes
  • The latent code is broadcast and applied to all residual blocks.
  • Supports both deterministic (fixed noise) and stochastic (random noise) inference modes.
  • The upsampling is performed via sub-pixel convolution (PixelShuffle) defined in make_upsampler().
Source code in opensr_srgan/model/generators/cgan_generator.py
def forward(
    self,
    lr: torch.Tensor,
    noise: Optional[torch.Tensor] = None,
    return_noise: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """
    Forward pass of the stochastic super-resolution generator.

    Parameters
    ----------
    lr : torch.Tensor
        Low-resolution input tensor of shape (B, C, H, W).
    noise : torch.Tensor, optional
        Latent noise tensor of shape (B, D), where D = `noise_dim`.
        If None, a random vector is sampled internally.
    return_noise : bool, default=False
        If True, returns both the super-resolved image and the
        latent noise used for generation.

    Returns
    -------
    torch.Tensor or (torch.Tensor, torch.Tensor)
        - ``sr``: Super-resolved output image of shape (B, C, sH, sW),
          where s is the upscaling factor.
        - Optionally, ``noise``: The latent vector used (if `return_noise=True`).

    Notes
    -----
    - The latent code is broadcast and applied to all residual blocks.
    - Supports both deterministic (fixed noise) and stochastic (random noise)
      inference modes.
    - The upsampling is performed via sub-pixel convolution (PixelShuffle)
      defined in `make_upsampler()`.
    """
    if noise is None:
        noise = torch.randn(
            lr.size(0),
            self.noise_dim,
            device=lr.device,
            dtype=lr.dtype,
        )

    features = self.head(lr)
    residual = features
    for block in self.body:
        residual = block(residual, noise)
    residual = self.body_tail(residual)
    features = features + residual
    features = self.upsampler(features)
    sr = self.tail(features)

    if return_noise:
        return sr, noise
    return sr

ESRGANGenerator

Bases: Module

ESRGAN generator network for single-image super-resolution.

This implementation follows the design proposed in "ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks" (Wang et al., 2018). It replaces traditional residual blocks with Residual-in-Residual Dense Blocks (RRDBs), omits batch normalization, and applies residual scaling for improved stability and perceptual quality.

The architecture can be summarized as

Input → Conv(3×3) → [RRDB × N] → Conv(3×3) → (PixelShuffle upsampling × scale) → Conv(3×3) → LeakyReLU → Conv(3×3) → Output

Parameters

in_channels : int, default=3 Number of input channels (e.g., 3 for RGB or 4 for RGB-NIR). out_channels : int or None, default=None Number of output channels. If None, defaults to in_channels. n_features : int, default=64 Number of feature maps in the base convolutional layers. n_blocks : int, default=23 Number of Residual-in-Residual Dense Blocks (RRDBs) stacked in the network body. growth_channels : int, default=32 Number of intermediate growth channels used within each RRDB. res_scale : float, default=0.2 Residual scaling factor applied to stabilize deep residual learning. scale : int, default=4 Upscaling factor. Must be a power of two (1, 2, 4, 8, ...).

Attributes

conv_first : nn.Conv2d Initial 3×3 convolutional layer that extracts shallow features from the LR input. body : nn.Sequential Sequential container of RRDB blocks performing deep feature extraction. conv_body : nn.Conv2d 3×3 convolution applied after the RRDB stack to merge body features. upsampler : nn.Module PixelShuffle-based upsampling module (from make_upsampler) for the configured scale. If scale == 1, replaced by an identity mapping. conv_hr : nn.Conv2d 3×3 convolution used for high-resolution refinement. activation : nn.LeakyReLU LeakyReLU activation (slope=0.2) applied after conv_hr. conv_last : nn.Conv2d Final 3×3 projection layer mapping features back to output space. scale : int Model’s upscaling factor. n_blocks : int Number of RRDB blocks in the body. n_features : int Number of feature maps in the feature extraction layers. growth_channels : int Growth channels per RRDB.

Raises

ValueError If scale is not a power of two or if n_blocks < 1.

Examples

from opensr_srgan.model.generators.esrgan_generator import ESRGANGenerator model = ESRGANGenerator(in_channels=3, scale=4) x = torch.randn(1, 3, 64, 64) y = model(x) y.shape torch.Size([1, 3, 256, 256])

References

  • Wang et al., ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks,
Source code in opensr_srgan/model/generators/esrgan.py
class ESRGANGenerator(nn.Module):
    """
    ESRGAN generator network for single-image super-resolution.

    This implementation follows the design proposed in
    *"ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks"*
    (Wang et al., 2018). It replaces traditional residual blocks with
    Residual-in-Residual Dense Blocks (RRDBs), omits batch normalization,
    and applies residual scaling for improved stability and perceptual quality.

    The architecture can be summarized as:
        Input → Conv(3×3) → [RRDB × N] → Conv(3×3)
        → (PixelShuffle upsampling × scale) → Conv(3×3) → LeakyReLU → Conv(3×3) → Output

    Parameters
    ----------
    in_channels : int, default=3
        Number of input channels (e.g., 3 for RGB or 4 for RGB-NIR).
    out_channels : int or None, default=None
        Number of output channels. If ``None``, defaults to ``in_channels``.
    n_features : int, default=64
        Number of feature maps in the base convolutional layers.
    n_blocks : int, default=23
        Number of Residual-in-Residual Dense Blocks (RRDBs) stacked in the network body.
    growth_channels : int, default=32
        Number of intermediate growth channels used within each RRDB.
    res_scale : float, default=0.2
        Residual scaling factor applied to stabilize deep residual learning.
    scale : int, default=4
        Upscaling factor. Must be a power of two (1, 2, 4, 8, ...).

    Attributes
    ----------
    conv_first : nn.Conv2d
        Initial 3×3 convolutional layer that extracts shallow features from the LR input.
    body : nn.Sequential
        Sequential container of ``RRDB`` blocks performing deep feature extraction.
    conv_body : nn.Conv2d
        3×3 convolution applied after the RRDB stack to merge body features.
    upsampler : nn.Module
        PixelShuffle-based upsampling module (from ``make_upsampler``) for the configured scale.
        If ``scale == 1``, replaced by an identity mapping.
    conv_hr : nn.Conv2d
        3×3 convolution used for high-resolution refinement.
    activation : nn.LeakyReLU
        LeakyReLU activation (slope=0.2) applied after ``conv_hr``.
    conv_last : nn.Conv2d
        Final 3×3 projection layer mapping features back to output space.
    scale : int
        Model’s upscaling factor.
    n_blocks : int
        Number of RRDB blocks in the body.
    n_features : int
        Number of feature maps in the feature extraction layers.
    growth_channels : int
        Growth channels per RRDB.

    Raises
    ------
    ValueError
        If ``scale`` is not a power of two or if ``n_blocks < 1``.

    Examples
    --------
    >>> from opensr_srgan.model.generators.esrgan_generator import ESRGANGenerator
    >>> model = ESRGANGenerator(in_channels=3, scale=4)
    >>> x = torch.randn(1, 3, 64, 64)
    >>> y = model(x)
    >>> y.shape
    torch.Size([1, 3, 256, 256])

    References
    ----------
    - Wang et al., *ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks*,
    """

    def __init__(
        self,
        *,
        in_channels: int = 3,
        out_channels: int | None = None,
        n_features: int = 64,
        n_blocks: int = 23,
        growth_channels: int = 32,
        res_scale: float = 0.2,
        scale: int = 4,
    ) -> None:
        super().__init__()

        if scale < 1 or scale & (scale - 1) != 0:
            raise ValueError("ESRGANGenerator only supports power-of-two scales (1, 2, 4, 8, ...).")

        if n_blocks < 1:
            raise ValueError("ESRGANGenerator requires at least one RRDB block.")

        if out_channels is None:
            out_channels = in_channels

        self.scale = scale
        self.n_blocks = n_blocks
        self.n_features = n_features
        self.growth_channels = growth_channels

        body_blocks = [RRDB(n_features, growth_channels, res_scale=res_scale) for _ in range(n_blocks)]

        self.conv_first = nn.Conv2d(in_channels, n_features, 3, padding=1)
        self.body = nn.Sequential(*body_blocks)
        self.conv_body = nn.Conv2d(n_features, n_features, 3, padding=1)
        self.upsampler = nn.Identity() if scale == 1 else make_upsampler(n_features, scale)
        self.conv_hr = nn.Conv2d(n_features, n_features, 3, padding=1)
        self.activation = nn.LeakyReLU(0.2, inplace=True)
        self.conv_last = nn.Conv2d(n_features, out_channels, 3, padding=1)

    def forward(self, x: Tensor) -> Tensor:
        """
        Forward pass of the ESRGAN generator.

        Parameters
        ----------
        x : torch.Tensor
            Low-resolution input image tensor of shape (B, C, H, W).

        Returns
        -------
        torch.Tensor
            Super-resolved output tensor of shape (B, C, sH, sW),
            where `s` is the upscaling factor defined by ``self.scale``.

        Notes
        -----
        - The generator first encodes input features, processes them through
          a sequence of Residual-in-Residual Dense Blocks (RRDBs), and then
          performs upsampling via sub-pixel convolutions.
        - A long skip connection adds the shallow features from the input
          stem to the deep body output before upsampling.
        - The final activation is a LeakyReLU followed by a 3×3 convolution
          that projects features back to image space.
        - When ``scale == 1``, the upsampling block is replaced by an identity.
        """
        first = self.conv_first(x)
        trunk = self.body(first)
        body_out = self.conv_body(trunk)
        feat = first + body_out
        feat = self.upsampler(feat)
        feat = self.activation(self.conv_hr(feat))
        return self.conv_last(feat)

forward(x)

Forward pass of the ESRGAN generator.

Parameters

x : torch.Tensor Low-resolution input image tensor of shape (B, C, H, W).

Returns

torch.Tensor Super-resolved output tensor of shape (B, C, sH, sW), where s is the upscaling factor defined by self.scale.

Notes
  • The generator first encodes input features, processes them through a sequence of Residual-in-Residual Dense Blocks (RRDBs), and then performs upsampling via sub-pixel convolutions.
  • A long skip connection adds the shallow features from the input stem to the deep body output before upsampling.
  • The final activation is a LeakyReLU followed by a 3×3 convolution that projects features back to image space.
  • When scale == 1, the upsampling block is replaced by an identity.
Source code in opensr_srgan/model/generators/esrgan.py
def forward(self, x: Tensor) -> Tensor:
    """
    Forward pass of the ESRGAN generator.

    Parameters
    ----------
    x : torch.Tensor
        Low-resolution input image tensor of shape (B, C, H, W).

    Returns
    -------
    torch.Tensor
        Super-resolved output tensor of shape (B, C, sH, sW),
        where `s` is the upscaling factor defined by ``self.scale``.

    Notes
    -----
    - The generator first encodes input features, processes them through
      a sequence of Residual-in-Residual Dense Blocks (RRDBs), and then
      performs upsampling via sub-pixel convolutions.
    - A long skip connection adds the shallow features from the input
      stem to the deep body output before upsampling.
    - The final activation is a LeakyReLU followed by a 3×3 convolution
      that projects features back to image space.
    - When ``scale == 1``, the upsampling block is replaced by an identity.
    """
    first = self.conv_first(x)
    trunk = self.body(first)
    body_out = self.conv_body(trunk)
    feat = first + body_out
    feat = self.upsampler(feat)
    feat = self.activation(self.conv_hr(feat))
    return self.conv_last(feat)

FlexibleGenerator

Bases: Module

Modular super-resolution generator with pluggable residual blocks.

Provides a single, drop-in generator backbone that can be instantiated with different residual block families—res, rcab, rrdb, or lka—all built from a shared interface. The network follows a head → body → tail design with learnable upsampling:

Head (large-kernel conv) → N × Block(type=block_type) → Body tail conv
→ Skip add → Upsampler (×2/×4/×8) → Output conv

Use this when you want to compare architectural choices or sweep hyper- parameters without changing call sites.

Parameters:

Name Type Description Default
in_channels int

Number of input channels (e.g., RGB=3, RGB-NIR=4/6).

6
n_channels int

Base feature width used throughout the backbone.

96
n_blocks int

Number of residual blocks in the body.

32
small_kernel int

Kernel size for body/ tail convolutions.

3
large_kernel int

Kernel size for head/ output convolutions.

9
scale int

Upscaling factor; one of {2, 4, 8}.

8
block_type str

Residual block family in {"res","rcab","rrdb","lka"}.

'rcab'

Attributes:

Name Type Description
head Sequential

Large-receptive-field stem conv + activation.

body Sequential

Sequence of residual blocks of the selected type.

body_tail Conv2d

Fusion conv after the residual stack.

upsampler Module

PixelShuffle-style learnable upsampling to scale.

tail Conv2d

Final projection to in_channels.

Raises:

Type Description
ValueError

If scale is not in {2, 4, 8} or block_type is unknown.

Example

g = FlexibleGenerator(in_channels=3, block_type="rcab", scale=4) x = torch.randn(1, 3, 64, 64) y = g(x) # (1, 3, 256, 256)

Source code in opensr_srgan/model/generators/flexible_generator.py
class FlexibleGenerator(nn.Module):
    """Modular super-resolution generator with pluggable residual blocks.

    Provides a single, drop-in generator backbone that can be instantiated with
    different residual block families—**res**, **rcab**, **rrdb**, or **lka**—all
    built from a shared interface. The network follows a head → body → tail
    design with learnable upsampling:

        Head (large-kernel conv) → N × Block(type=block_type) → Body tail conv
        → Skip add → Upsampler (×2/×4/×8) → Output conv

    Use this when you want to compare architectural choices or sweep hyper-
    parameters without changing call sites.

    Args:
        in_channels (int): Number of input channels (e.g., RGB=3, RGB-NIR=4/6).
        n_channels (int): Base feature width used throughout the backbone.
        n_blocks (int): Number of residual blocks in the body.
        small_kernel (int): Kernel size for body/ tail convolutions.
        large_kernel (int): Kernel size for head/ output convolutions.
        scale (int): Upscaling factor; one of {2, 4, 8}.
        block_type (str): Residual block family in {"res","rcab","rrdb","lka"}.

    Attributes:
        head (nn.Sequential): Large-receptive-field stem conv + activation.
        body (nn.Sequential): Sequence of residual blocks of the selected type.
        body_tail (nn.Conv2d): Fusion conv after the residual stack.
        upsampler (nn.Module): PixelShuffle-style learnable upsampling to `scale`.
        tail (nn.Conv2d): Final projection to `in_channels`.

    Raises:
        ValueError: If `scale` is not in {2, 4, 8} or `block_type` is unknown.

    Example:
        >>> g = FlexibleGenerator(in_channels=3, block_type="rcab", scale=4)
        >>> x = torch.randn(1, 3, 64, 64)
        >>> y = g(x)  # (1, 3, 256, 256)
    """

    def __init__(
        self,
        in_channels: int = 6,
        n_channels: int = 96,
        n_blocks: int = 32,
        small_kernel: int = 3,
        large_kernel: int = 9,
        scale: int = 8,
        block_type: str = "rcab",
    ) -> None:
        super().__init__()

        if scale not in {2, 4, 8}:
            raise ValueError("scale must be one of {2, 4, 8}")

        block_key = block_type.lower()
        if block_key not in _BLOCK_REGISTRY:
            raise ValueError("block_type must be one of {'res', 'rcab', 'rrdb', 'lka'}")

        self.scale = scale

        padding_large = large_kernel // 2
        self.head = nn.Sequential(
            nn.Conv2d(in_channels, n_channels, large_kernel, padding=padding_large),
            nn.PReLU(),
        )

        block_factory = _BLOCK_REGISTRY[block_key]
        self.body = nn.Sequential(
            *[block_factory(n_channels, small_kernel) for _ in range(n_blocks)]
        )
        self.body_tail = nn.Conv2d(
            n_channels, n_channels, small_kernel, padding=small_kernel // 2
        )
        self.upsampler = make_upsampler(n_channels, scale)
        self.tail = nn.Conv2d(
            n_channels, in_channels, large_kernel, padding=padding_large
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the flexible SR generator.

        Parameters
        ----------
        x : torch.Tensor
            Low-resolution input tensor of shape (B, C, H, W).

        Returns
        -------
        torch.Tensor
            Super-resolved output tensor of shape (B, C, sH, sW),
            where `s` ∈ {2, 4, 8} is the configured upscaling factor.

        Workflow
        --------
        1) Extract shallow features with the head conv.
        2) Transform via a stack of residual blocks (type = `block_type`).
        3) Fuse with a 3×3 body tail conv and add a long skip.
        4) Upsample via PixelShuffle-based `make_upsampler`.
        5) Project back to image space with the output conv.
        """
        feat = self.head(x)
        res = self.body(feat)
        res = self.body_tail(res)
        feat = feat + res
        feat = self.upsampler(feat)
        return self.tail(feat)

forward(x)

Forward pass of the flexible SR generator.

Parameters

x : torch.Tensor Low-resolution input tensor of shape (B, C, H, W).

Returns

torch.Tensor Super-resolved output tensor of shape (B, C, sH, sW), where s ∈ {2, 4, 8} is the configured upscaling factor.

Workflow

1) Extract shallow features with the head conv. 2) Transform via a stack of residual blocks (type = block_type). 3) Fuse with a 3×3 body tail conv and add a long skip. 4) Upsample via PixelShuffle-based make_upsampler. 5) Project back to image space with the output conv.

Source code in opensr_srgan/model/generators/flexible_generator.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Forward pass of the flexible SR generator.

    Parameters
    ----------
    x : torch.Tensor
        Low-resolution input tensor of shape (B, C, H, W).

    Returns
    -------
    torch.Tensor
        Super-resolved output tensor of shape (B, C, sH, sW),
        where `s` ∈ {2, 4, 8} is the configured upscaling factor.

    Workflow
    --------
    1) Extract shallow features with the head conv.
    2) Transform via a stack of residual blocks (type = `block_type`).
    3) Fuse with a 3×3 body tail conv and add a long skip.
    4) Upsample via PixelShuffle-based `make_upsampler`.
    5) Project back to image space with the output conv.
    """
    feat = self.head(x)
    res = self.body(feat)
    res = self.body_tail(res)
    feat = feat + res
    feat = self.upsampler(feat)
    return self.tail(feat)

ConvolutionalBlock

Bases: Module

A convolutional block comprised of Conv → (BN) → (Activation).

Source code in opensr_srgan/model/model_blocks/__init__.py
class ConvolutionalBlock(nn.Module):
    """A convolutional block comprised of Conv → (BN) → (Activation)."""

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int = 1,
        batch_norm: bool = False,
        activation: str | None = None,
    ) -> None:
        super().__init__()

        act = activation.lower() if activation is not None else None
        if act is not None:
            if act not in {"prelu", "leakyrelu", "tanh"}:
                raise AssertionError("activation must be one of {'prelu', 'leakyrelu', 'tanh'}")

        layers: list[nn.Module] = [
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=kernel_size // 2,
            )
        ]

        if batch_norm:
            layers.append(nn.BatchNorm2d(num_features=out_channels))

        if act == "prelu":
            layers.append(nn.PReLU())
        elif act == "leakyrelu":
            layers.append(nn.LeakyReLU(0.2))
        elif act == "tanh":
            layers.append(nn.Tanh())

        self.conv_block = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.conv_block(x)

SubPixelConvolutionalBlock

Bases: Module

Conv → PixelShuffle → PReLU upsampling block.

Source code in opensr_srgan/model/model_blocks/__init__.py
class SubPixelConvolutionalBlock(nn.Module):
    """Conv → PixelShuffle → PReLU upsampling block."""

    def __init__(
        self,
        kernel_size: int = 3,
        n_channels: int = 64,
        scaling_factor: int = 2,
    ) -> None:
        super().__init__()

        self.conv = nn.Conv2d(
            in_channels=n_channels,
            out_channels=n_channels * (scaling_factor**2),
            kernel_size=kernel_size,
            padding=kernel_size // 2,
        )
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor=scaling_factor)
        self.prelu = nn.PReLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        x = self.prelu(x)
        return x

ResidualBlock

Bases: Module

BN-enabled residual block used in the original SRResNet.

Source code in opensr_srgan/model/model_blocks/__init__.py
class ResidualBlock(nn.Module):
    """BN-enabled residual block used in the original SRResNet."""

    def __init__(self, kernel_size: int = 3, n_channels: int = 64) -> None:
        super().__init__()
        self.conv_block1 = ConvolutionalBlock(
            in_channels=n_channels,
            out_channels=n_channels,
            kernel_size=kernel_size,
            batch_norm=True,
            activation="PReLu",
        )
        self.conv_block2 = ConvolutionalBlock(
            in_channels=n_channels,
            out_channels=n_channels,
            kernel_size=kernel_size,
            batch_norm=True,
            activation=None,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = x
        x = self.conv_block1(x)
        x = self.conv_block2(x)
        return x + residual

ResidualBlockNoBN

Bases: Module

Residual block variant without batch norm and with residual scaling.

Source code in opensr_srgan/model/model_blocks/__init__.py
class ResidualBlockNoBN(nn.Module):
    """Residual block variant without batch norm and with residual scaling."""

    def __init__(
        self,
        n_channels: int = 64,
        kernel_size: int = 3,
        res_scale: float = 0.2,
    ) -> None:
        super().__init__()
        padding = kernel_size // 2
        self.body = nn.Sequential(
            nn.Conv2d(n_channels, n_channels, kernel_size, padding=padding),
            nn.PReLU(),
            nn.Conv2d(n_channels, n_channels, kernel_size, padding=padding),
        )
        self.res_scale = res_scale

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.res_scale * self.body(x)

RCAB

Bases: Module

Residual Channel Attention Block (no BN).

Source code in opensr_srgan/model/model_blocks/__init__.py
class RCAB(nn.Module):
    """Residual Channel Attention Block (no BN)."""

    def __init__(
        self,
        n_channels: int = 64,
        kernel_size: int = 3,
        reduction: int = 16,
        res_scale: float = 0.2,
    ) -> None:
        super().__init__()
        padding = kernel_size // 2
        self.conv1 = nn.Conv2d(n_channels, n_channels, kernel_size, padding=padding)
        self.act = nn.PReLU()
        self.conv2 = nn.Conv2d(n_channels, n_channels, kernel_size, padding=padding)
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(n_channels, n_channels // reduction, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(n_channels // reduction, n_channels, 1),
            nn.Sigmoid(),
        )
        self.res_scale = res_scale

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = self.conv1(x)
        y = self.act(y)
        y = self.conv2(y)
        w = self.se(y)
        return x + self.res_scale * (y * w)

DenseBlock5

Bases: Module

ESRGAN-style dense block with five convolutions.

Source code in opensr_srgan/model/model_blocks/__init__.py
class DenseBlock5(nn.Module):
    """ESRGAN-style dense block with five convolutions."""

    def __init__(self, n_features: int = 64, growth_channels: int = 32, kernel_size: int = 3) -> None:
        super().__init__()
        padding = kernel_size // 2
        self.act = nn.LeakyReLU(0.2, inplace=True)
        self.c1 = nn.Conv2d(n_features, growth_channels, kernel_size, padding=padding)
        self.c2 = nn.Conv2d(n_features + growth_channels, growth_channels, kernel_size, padding=padding)
        self.c3 = nn.Conv2d(n_features + 2 * growth_channels, growth_channels, kernel_size, padding=padding)
        self.c4 = nn.Conv2d(n_features + 3 * growth_channels, growth_channels, kernel_size, padding=padding)
        self.c5 = nn.Conv2d(n_features + 4 * growth_channels, n_features, kernel_size, padding=padding)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x1 = self.act(self.c1(x))
        x2 = self.act(self.c2(torch.cat([x, x1], dim=1)))
        x3 = self.act(self.c3(torch.cat([x, x1, x2], dim=1)))
        x4 = self.act(self.c4(torch.cat([x, x1, x2, x3], dim=1)))
        x5 = self.c5(torch.cat([x, x1, x2, x3, x4], dim=1))
        return x5

RRDB

Bases: Module

Residual-in-Residual Dense Block.

Source code in opensr_srgan/model/model_blocks/__init__.py
class RRDB(nn.Module):
    """Residual-in-Residual Dense Block."""

    def __init__(self, n_features: int = 64, growth_channels: int = 32, res_scale: float = 0.2) -> None:
        super().__init__()
        self.db1 = DenseBlock5(n_features, growth_channels)
        self.db2 = DenseBlock5(n_features, growth_channels)
        self.db3 = DenseBlock5(n_features, growth_channels)
        self.res_scale = res_scale

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = self.db1(x)
        y = self.db2(y)
        y = self.db3(y)
        return x + self.res_scale * y

LKA

Bases: Module

Lightweight Large-Kernel Attention module.

Source code in opensr_srgan/model/model_blocks/__init__.py
class LKA(nn.Module):
    """Lightweight Large-Kernel Attention module."""

    def __init__(self, n_channels: int = 64) -> None:
        super().__init__()
        self.dw5 = nn.Conv2d(n_channels, n_channels, 5, padding=2, groups=n_channels)
        self.dw7d = nn.Conv2d(n_channels, n_channels, 7, padding=9, dilation=3, groups=n_channels)
        self.pw = nn.Conv2d(n_channels, n_channels, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        attn = self.dw5(x)
        attn = self.dw7d(attn)
        attn = self.pw(attn)
        return x * torch.sigmoid(attn)

LKAResBlock

Bases: Module

Residual block incorporating Large-Kernel Attention.

Source code in opensr_srgan/model/model_blocks/__init__.py
class LKAResBlock(nn.Module):
    """Residual block incorporating Large-Kernel Attention."""

    def __init__(self, n_channels: int = 64, kernel_size: int = 3, res_scale: float = 0.2) -> None:
        super().__init__()
        padding = kernel_size // 2
        self.conv1 = nn.Conv2d(n_channels, n_channels, kernel_size, padding=padding)
        self.act = nn.PReLU()
        self.lka = LKA(n_channels)
        self.conv2 = nn.Conv2d(n_channels, n_channels, kernel_size, padding=padding)
        self.res_scale = res_scale

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = self.conv1(x)
        y = self.act(y)
        y = self.lka(y)
        y = self.conv2(y)
        return x + self.res_scale * y

build_generator(config)

Instantiate a generator module from the provided configuration.

Parameters

config : Any Resolved configuration object with at least: - config.Model.in_bands - config.Generator.model_type (or SRResNet block alias) - Additional generator-specific fields (e.g., n_blocks, n_channels, scaling_factor).

Returns

torch.nn.Module A fully constructed generator (SRResNet/Flexible, ESRGAN, or StochasticGenerator).

Raises

ValueError If the model type or block variant cannot be resolved to a known implementation.

Notes

  • Accepts legacy configs where Generator.model_type was a block alias (e.g., "rcab").
  • Non-applicable options are reported (not fatal) for clarity.
Source code in opensr_srgan/model/generators/factory.py
def build_generator(config: Any) -> nn.Module:
    """
    Instantiate a generator module from the provided configuration.

    Parameters
    ----------
    config : Any
        Resolved configuration object with at least:
          - `config.Model.in_bands`
          - `config.Generator.model_type` (or SRResNet block alias)
          - Additional generator-specific fields (e.g., n_blocks, n_channels, scaling_factor).

    Returns
    -------
    torch.nn.Module
        A fully constructed generator (SRResNet/Flexible, ESRGAN, or StochasticGenerator).

    Raises
    ------
    ValueError
        If the model type or block variant cannot be resolved to a known implementation.

    Notes
    -----
    - Accepts legacy configs where `Generator.model_type` was a block alias (e.g., "rcab").
    - Non-applicable options are reported (not fatal) for clarity.
    """

    generator_cfg = config.Generator
    model_cfg = config.Model

    raw_model_type = str(getattr(generator_cfg, "model_type", "srresnet"))
    model_type = _match_alias(raw_model_type, _MODEL_TYPE_ALIASES)

    # Legacy support: configs that directly specified the block variant inside
    # ``model_type`` (e.g., "rcab") should still build the flexible generator.
    if model_type is None:
        srresnet_block = _match_alias(raw_model_type, _SRRESNET_BLOCK_ALIASES)
        if srresnet_block is not None:
            model_type = "srresnet"
        else:
            raise ValueError(
                "Unknown generator model type '{model_type}'. Expected one of: {options}.".format(
                    model_type=raw_model_type,
                    options=", ".join(sorted(_MODEL_TYPE_ALIASES)),
                )
            )
    else:
        srresnet_block = None

    in_channels = int(getattr(model_cfg, "in_bands"))
    scale = int(getattr(generator_cfg, "scaling_factor", 4))

    if model_type == "srresnet":
        large_kernel = int(getattr(generator_cfg, "large_kernel_size", 9))
        small_kernel = int(getattr(generator_cfg, "small_kernel_size", 3))
        n_channels = int(getattr(generator_cfg, "n_channels", 64))
        n_blocks = int(getattr(generator_cfg, "n_blocks", 16))

        block_variant = _resolve_srresnet_block(generator_cfg, srresnet_block)

        if block_variant == "standard":
            return SRResNetGenerator(
                in_channels=in_channels,
                large_kernel_size=large_kernel,
                small_kernel_size=small_kernel,
                n_channels=n_channels,
                n_blocks=n_blocks,
                scaling_factor=scale,
            )

        return FlexibleGenerator(
            in_channels=in_channels,
            n_channels=n_channels,
            n_blocks=n_blocks,
            small_kernel=small_kernel,
            large_kernel=large_kernel,
            scale=scale,
            block_type=block_variant,
        )

    if model_type == "stochastic_gan":
        large_kernel = int(getattr(generator_cfg, "large_kernel_size", 9))
        small_kernel = int(getattr(generator_cfg, "small_kernel_size", 3))
        n_channels = int(getattr(generator_cfg, "n_channels", 64))
        n_blocks = int(getattr(generator_cfg, "n_blocks", 16))
        noise_dim = int(getattr(generator_cfg, "noise_dim", 128))
        res_scale = float(getattr(generator_cfg, "res_scale", 0.2))

        _warn_overridden_options(
            "Generator",
            "stochastic_gan",
            _collect_overridden(generator_cfg, "block_type"),
        )

        return StochasticGenerator(
            in_channels=in_channels,
            n_channels=n_channels,
            n_blocks=n_blocks,
            small_kernel=small_kernel,
            large_kernel=large_kernel,
            scale=scale,
            noise_dim=noise_dim,
            res_scale=res_scale,
        )

    if model_type == "esrgan":
        n_channels = int(getattr(generator_cfg, "n_channels", 64))
        n_rrdb = int(getattr(generator_cfg, "n_blocks", 23))
        growth_channels = int(getattr(generator_cfg, "growth_channels", 32))
        res_scale = float(getattr(generator_cfg, "res_scale", 0.2))
        out_channels = int(getattr(generator_cfg, "out_channels", in_channels))

        _warn_overridden_options(
            "Generator",
            "esrgan",
            _collect_overridden(
                generator_cfg,
                "block_type",
                "large_kernel_size",
                "small_kernel_size",
                "noise_dim",
            ),
        )

        return ESRGANGenerator(
            in_channels=in_channels,
            out_channels=out_channels,
            n_features=n_channels,
            n_blocks=n_rrdb,
            growth_channels=growth_channels,
            res_scale=res_scale,
            scale=scale,
        )

    raise ValueError(
        "Unhandled generator model type '{model_type}'. Expected one of: {options}.".format(
            model_type=raw_model_type,
            options=", ".join(sorted(_MODEL_TYPE_ALIASES)),
        )
    )

make_upsampler(n_channels, scale)

Create a pixel-shuffle upsampler matching the flexible generator implementation.

Source code in opensr_srgan/model/model_blocks/__init__.py
def make_upsampler(n_channels: int, scale: int) -> nn.Sequential:
    """Create a pixel-shuffle upsampler matching the flexible generator implementation."""

    stages: list[nn.Module] = []
    for _ in range(int(math.log2(scale))):
        stages.extend(
            [
                nn.Conv2d(n_channels, n_channels * 4, 3, padding=1),
                nn.PixelShuffle(2),
                nn.PReLU(),
            ]
        )
    return nn.Sequential(*stages)

Discriminator architectures

Discriminator architectures for SRGAN.

Discriminator

Bases: Module

Standard SRGAN discriminator as defined in the original paper.

The network alternates between stride-1 and stride-2 convolutional blocks, gradually reducing spatial resolution while increasing channel depth. The resulting features are globally pooled and classified via two fully connected layers to produce a single realism score.

Parameters

in_channels : int, default=3 Number of input channels (e.g., 3 for RGB, 4 for multispectral SR). n_blocks : int, default=8 Number of convolutional blocks to stack. Must be >= 1. use_spectral_norm : bool, default=True Whether to apply spectral normalization to convolutional and linear layers for improved Lipschitz control and training stability.

Attributes

conv_blocks : nn.Sequential Sequential stack of convolutional feature extraction blocks. adaptive_pool : nn.AdaptiveAvgPool2d Global pooling layer reducing spatial dimensions to 6×6. fc1 : nn.Linear Fully connected layer mapping pooled features to hidden representation. fc2 : nn.Linear Output layer producing a single realism score. leaky_relu : nn.LeakyReLU Activation used in fully connected layers. base_channels : int Number of channels in the first convolutional layer (default 64). kernel_size : int Kernel size used in all convolutional blocks (default 3). fc_size : int Hidden dimension of the fully connected layer (default 1024). n_blocks : int Total number of convolutional blocks. use_spectral_norm : bool Indicates whether spectral normalization wraps the convolutional and linear layers.

Raises

ValueError If n_blocks < 1.

Examples

disc = Discriminator(in_channels=3, n_blocks=8) x = torch.randn(4, 3, 96, 96) y = disc(x) y.shape torch.Size([4, 1])

Source code in opensr_srgan/model/discriminators/srgan_discriminator.py
class Discriminator(nn.Module):
    """
    Standard SRGAN discriminator as defined in the original paper.

    The network alternates between stride-1 and stride-2 convolutional
    blocks, gradually reducing spatial resolution while increasing
    channel depth. The resulting features are globally pooled and
    classified via two fully connected layers to produce a single
    realism score.

    Parameters
    ----------
    in_channels : int, default=3
        Number of input channels (e.g., 3 for RGB, 4 for multispectral SR).
    n_blocks : int, default=8
        Number of convolutional blocks to stack. Must be >= 1.
    use_spectral_norm : bool, default=True
        Whether to apply spectral normalization to convolutional and linear layers
        for improved Lipschitz control and training stability.

    Attributes
    ----------
    conv_blocks : nn.Sequential
        Sequential stack of convolutional feature extraction blocks.
    adaptive_pool : nn.AdaptiveAvgPool2d
        Global pooling layer reducing spatial dimensions to 6×6.
    fc1 : nn.Linear
        Fully connected layer mapping pooled features to hidden representation.
    fc2 : nn.Linear
        Output layer producing a single realism score.
    leaky_relu : nn.LeakyReLU
        Activation used in fully connected layers.
    base_channels : int
        Number of channels in the first convolutional layer (default 64).
    kernel_size : int
        Kernel size used in all convolutional blocks (default 3).
    fc_size : int
        Hidden dimension of the fully connected layer (default 1024).
    n_blocks : int
        Total number of convolutional blocks.
    use_spectral_norm : bool
        Indicates whether spectral normalization wraps the convolutional and linear layers.

    Raises
    ------
    ValueError
        If `n_blocks` < 1.

    Examples
    --------
    >>> disc = Discriminator(in_channels=3, n_blocks=8)
    >>> x = torch.randn(4, 3, 96, 96)
    >>> y = disc(x)
    >>> y.shape
    torch.Size([4, 1])
    """

    def __init__(
        self,
        in_channels: int = 3,
        n_blocks: int = 8,
        use_spectral_norm: bool = True,
    ) -> None:
        super().__init__()

        if n_blocks < 1:
            raise ValueError("The SRGAN discriminator requires at least one block.")

        kernel_size = 3
        base_channels = 64
        fc_size = 1024

        conv_blocks: list[nn.Module] = []
        current_in = in_channels
        out_channels = base_channels
        for i in range(n_blocks):
            if i == 0:
                out_channels = base_channels
            elif i % 2 == 0:
                out_channels = current_in * 2
            else:
                out_channels = current_in

            block = ConvolutionalBlock(
                in_channels=current_in,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=1 if i % 2 == 0 else 2,
                batch_norm=i != 0,
                activation="LeakyReLu",
            )
            if use_spectral_norm:
                self._apply_spectral_norm(block)

            conv_blocks.append(block)
            current_in = out_channels

        self.conv_blocks = nn.Sequential(*conv_blocks)

        self.adaptive_pool = nn.AdaptiveAvgPool2d((6, 6))
        self.fc1 = nn.Linear(out_channels * 6 * 6, fc_size)
        self.leaky_relu = nn.LeakyReLU(0.2)
        self.fc2 = nn.Linear(fc_size, 1)

        if use_spectral_norm:
            self.fc1 = spectral_norm(self.fc1)
            self.fc2 = spectral_norm(self.fc2)

        self.base_channels = base_channels
        self.kernel_size = kernel_size
        self.fc_size = fc_size
        self.n_blocks = n_blocks
        self.use_spectral_norm = use_spectral_norm

    def forward(self, imgs: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the SRGAN discriminator.

        Parameters
        ----------
        imgs : torch.Tensor
            Input image tensor of shape (B, C, H, W).

        Returns
        -------
        torch.Tensor
            Realism logits of shape (B, 1), where higher values
            indicate greater likelihood of being a real image.
        """
        batch_size = imgs.size(0)
        feats = self.conv_blocks(imgs)
        pooled = self.adaptive_pool(feats)
        flat = pooled.view(batch_size, -1)
        hidden = self.leaky_relu(self.fc1(flat))
        return self.fc2(hidden)

    @staticmethod
    def _apply_spectral_norm(module: nn.Module) -> None:
        for submodule in module.modules():
            if isinstance(submodule, (nn.Conv2d, nn.Linear)):
                spectral_norm(submodule)

forward(imgs)

Forward pass through the SRGAN discriminator.

Parameters

imgs : torch.Tensor Input image tensor of shape (B, C, H, W).

Returns

torch.Tensor Realism logits of shape (B, 1), where higher values indicate greater likelihood of being a real image.

Source code in opensr_srgan/model/discriminators/srgan_discriminator.py
def forward(self, imgs: torch.Tensor) -> torch.Tensor:
    """
    Forward pass through the SRGAN discriminator.

    Parameters
    ----------
    imgs : torch.Tensor
        Input image tensor of shape (B, C, H, W).

    Returns
    -------
    torch.Tensor
        Realism logits of shape (B, 1), where higher values
        indicate greater likelihood of being a real image.
    """
    batch_size = imgs.size(0)
    feats = self.conv_blocks(imgs)
    pooled = self.adaptive_pool(feats)
    flat = pooled.view(batch_size, -1)
    hidden = self.leaky_relu(self.fc1(flat))
    return self.fc2(hidden)

PatchGANDiscriminator

Bases: Module

High-level convenience wrapper for the N-layer PatchGAN discriminator.

Parameters

input_nc : int Number of input channels. n_layers : int, default=3 Number of convolutional layers. norm_type : {"batch", "instance", "none"}, default="instance" Normalization strategy.

Attributes

model : NLayerDiscriminator Underlying PatchGAN model. base_channels : int Number of base feature channels (default 64). kernel_size : int Kernel size for convolutions (default 4). n_layers : int Number of downsampling layers.

Raises

ValueError If n_layers < 1.

Examples

disc = PatchGANDiscriminator(input_nc=3, n_layers=3) x = torch.randn(4, 3, 128, 128) y = disc(x) y.shape torch.Size([4, 1, 14, 14])

Source code in opensr_srgan/model/discriminators/patchgan.py
class PatchGANDiscriminator(nn.Module):
    """
    High-level convenience wrapper for the N-layer PatchGAN discriminator.

    Parameters
    ----------
    input_nc : int
        Number of input channels.
    n_layers : int, default=3
        Number of convolutional layers.
    norm_type : {"batch", "instance", "none"}, default="instance"
        Normalization strategy.

    Attributes
    ----------
    model : NLayerDiscriminator
        Underlying PatchGAN model.
    base_channels : int
        Number of base feature channels (default 64).
    kernel_size : int
        Kernel size for convolutions (default 4).
    n_layers : int
        Number of downsampling layers.

    Raises
    ------
    ValueError
        If `n_layers` < 1.

    Examples
    --------
    >>> disc = PatchGANDiscriminator(input_nc=3, n_layers=3)
    >>> x = torch.randn(4, 3, 128, 128)
    >>> y = disc(x)
    >>> y.shape
    torch.Size([4, 1, 14, 14])
    """
    def __init__(
        self,
        input_nc: int,
        n_layers: int = 3,
        norm_type: str = "instance",
    ) -> None:
        super().__init__()

        if n_layers < 1:
            raise ValueError("PatchGAN discriminator requires at least one layer.")

        ndf = 64
        norm_layer = get_norm_layer(norm_type)
        self.model = NLayerDiscriminator(
            input_nc, ndf=ndf, n_layers=n_layers, norm_layer=norm_layer
        )

        self.base_channels = ndf
        self.kernel_size = 4
        self.n_layers = n_layers

    def forward(self, input):  # type: ignore[override]
        """
        Forward pass through the PatchGAN discriminator.

        Parameters
        ----------
        x : torch.Tensor
            Input image tensor of shape (B, C, H, W).

        Returns
        -------
        torch.Tensor
            Patch-level realism scores, shape (B, 1, H/2ⁿ, W/2ⁿ).
        """
        return self.model(input)

forward(input)

Forward pass through the PatchGAN discriminator.

Parameters

x : torch.Tensor Input image tensor of shape (B, C, H, W).

Returns

torch.Tensor Patch-level realism scores, shape (B, 1, H/2ⁿ, W/2ⁿ).

Source code in opensr_srgan/model/discriminators/patchgan.py
def forward(self, input):  # type: ignore[override]
    """
    Forward pass through the PatchGAN discriminator.

    Parameters
    ----------
    x : torch.Tensor
        Input image tensor of shape (B, C, H, W).

    Returns
    -------
    torch.Tensor
        Patch-level realism scores, shape (B, 1, H/2ⁿ, W/2ⁿ).
    """
    return self.model(input)

ESRGANDiscriminator

Bases: Module

VGG-style discriminator network used in ESRGAN.

This discriminator processes high-resolution image patches and predicts a scalar realism score. It follows the VGG-style design with progressively increasing channel depth and downsampling via strided convolutions.

Parameters

in_channels : int, default=3 Number of input channels (e.g., 3 for RGB, 4 for RGB-NIR). base_channels : int, default=64 Number of channels in the first convolutional layer; subsequent layers scale as powers of two. linear_size : int, default=1024 Size of the intermediate fully-connected layer before the output scalar.

Attributes

features : nn.Sequential Convolutional feature extractor backbone. pool : nn.AdaptiveAvgPool2d Global pooling layer to aggregate spatial features. classifier : nn.Sequential Fully connected layers producing a single output value. n_layers : int Total number of convolutional blocks (for metadata/reference only).

Raises

ValueError If base_channels or linear_size is not a positive integer.

Examples

disc = ESRGANDiscriminator(in_channels=3) x = torch.randn(8, 3, 128, 128) y = disc(x) y.shape torch.Size([8, 1])

Source code in opensr_srgan/model/discriminators/esrgan.py
class ESRGANDiscriminator(nn.Module):
    """
    VGG-style discriminator network used in ESRGAN.

    This discriminator processes high-resolution image patches and predicts
    a scalar realism score. It follows the VGG-style design with progressively
    increasing channel depth and downsampling via strided convolutions.

    Parameters
    ----------
    in_channels : int, default=3
        Number of input channels (e.g., 3 for RGB, 4 for RGB-NIR).
    base_channels : int, default=64
        Number of channels in the first convolutional layer; subsequent layers
        scale as powers of two.
    linear_size : int, default=1024
        Size of the intermediate fully-connected layer before the output scalar.

    Attributes
    ----------
    features : nn.Sequential
        Convolutional feature extractor backbone.
    pool : nn.AdaptiveAvgPool2d
        Global pooling layer to aggregate spatial features.
    classifier : nn.Sequential
        Fully connected layers producing a single output value.
    n_layers : int
        Total number of convolutional blocks (for metadata/reference only).

    Raises
    ------
    ValueError
        If `base_channels` or `linear_size` is not a positive integer.

    Examples
    --------
    >>> disc = ESRGANDiscriminator(in_channels=3)
    >>> x = torch.randn(8, 3, 128, 128)
    >>> y = disc(x)
    >>> y.shape
    torch.Size([8, 1])
    """
    def __init__(
        self,
        *,
        in_channels: int = 3,
        base_channels: int = 64,
        linear_size: int = 1024,
    ) -> None:
        super().__init__()

        if base_channels <= 0:
            raise ValueError("base_channels must be a positive integer.")

        if linear_size <= 0:
            raise ValueError("linear_size must be a positive integer.")

        features: list[nn.Module] = [
            nn.Conv2d(in_channels, base_channels, 3, 1, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            _conv_block(base_channels, base_channels, kernel_size=4, stride=2),
            _conv_block(base_channels, base_channels * 2, kernel_size=3, stride=1),
            _conv_block(base_channels * 2, base_channels * 2, kernel_size=4, stride=2),
            _conv_block(base_channels * 2, base_channels * 4, kernel_size=3, stride=1),
            _conv_block(base_channels * 4, base_channels * 4, kernel_size=4, stride=2),
            _conv_block(base_channels * 4, base_channels * 8, kernel_size=3, stride=1),
            _conv_block(base_channels * 8, base_channels * 8, kernel_size=4, stride=2),
            _conv_block(base_channels * 8, base_channels * 16, kernel_size=3, stride=1),
            _conv_block(base_channels * 16, base_channels * 16, kernel_size=4, stride=2),
        ]

        self.features = nn.Sequential(*features)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Sequential(
            nn.Linear(base_channels * 16, linear_size),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(linear_size, 1),
        )

        self.base_channels = base_channels
        self.linear_size = linear_size
        self.n_layers = 1 + 10  # first conv + stacked blocks

    def forward(self, x: Tensor) -> Tensor:
        """
        Forward pass through the ESRGAN discriminator.

        Parameters
        ----------
        x : torch.Tensor
            Input image tensor of shape (B, C, H, W).

        Returns
        -------
        torch.Tensor
            Discriminator logits of shape (B, 1), where higher values
            indicate more realistic images.
        """
        feats = self.features(x)
        pooled = self.pool(feats).view(x.size(0), -1)
        return self.classifier(pooled)

forward(x)

Forward pass through the ESRGAN discriminator.

Parameters

x : torch.Tensor Input image tensor of shape (B, C, H, W).

Returns

torch.Tensor Discriminator logits of shape (B, 1), where higher values indicate more realistic images.

Source code in opensr_srgan/model/discriminators/esrgan.py
def forward(self, x: Tensor) -> Tensor:
    """
    Forward pass through the ESRGAN discriminator.

    Parameters
    ----------
    x : torch.Tensor
        Input image tensor of shape (B, C, H, W).

    Returns
    -------
    torch.Tensor
        Discriminator logits of shape (B, 1), where higher values
        indicate more realistic images.
    """
    feats = self.features(x)
    pooled = self.pool(feats).view(x.size(0), -1)
    return self.classifier(pooled)

Loss components

GeneratorContentLoss

Bases: Module

Composite generator content loss with perceptual metric selection.

Combines multiple terms to form the generator's content objective: total = l1_w * L1 + sam_w * SAM + perc_w * Perceptual + tv_w * TV. Also computes auxiliary quality metrics (PSNR/SSIM) for logging/evaluation.

Loss weights, max value, window sizes, band selection settings, and the perceptual backend (VGG or LPIPS) are read from the provided config.

Parameters:

Name Type Description Default
cfg

Configuration object (OmegaConf/dict-like) with fields under: - Training.Losses.{l1_weight,sam_weight,perceptual_weight,tv_weight} - Training.Losses.{max_val,ssim_win,randomize_bands,fixed_idx} - Training.Losses.perceptual_metric in {"vgg", "lpips"} - TruncatedVGG.{i,j} (if VGG is used)

required
testing bool

If True, do not load pretrained VGG weights (avoids downloads in CI/tests). Defaults to False.

False

Attributes:

Name Type Description
l1_w, sam_w, perc_w, tv_w (float

Loss term weights.

max_val float

Dynamic range for PSNR/SSIM.

ssim_win int

Window size for SSIM.

randomize_bands bool

Whether to sample random bands for perceptual loss.

fixed_idx LongTensor | None

Fixed 3-channel indices if not randomized.

perc_metric str

Selected perceptual backend ("vgg" or "lpips").

perceptual_model Module

Back-end feature network/metric.

normalizer Normalizer

Shared normalizer for evaluation metrics.

Source code in opensr_srgan/model/loss/loss.py
class GeneratorContentLoss(nn.Module):
    """Composite generator content loss with perceptual metric selection.

    Combines multiple terms to form the generator's content objective:
    ``total = l1_w * L1 + sam_w * SAM + perc_w * Perceptual + tv_w * TV``.
    Also computes auxiliary quality metrics (PSNR/SSIM) for logging/evaluation.

    Loss weights, max value, window sizes, band selection settings, and the
    perceptual backend (VGG or LPIPS) are read from the provided config.

    Args:
        cfg: Configuration object (OmegaConf/dict-like) with fields under:
            - ``Training.Losses.{l1_weight,sam_weight,perceptual_weight,tv_weight}``
            - ``Training.Losses.{max_val,ssim_win,randomize_bands,fixed_idx}``
            - ``Training.Losses.perceptual_metric`` in {``"vgg"``, ``"lpips"``}
            - ``TruncatedVGG.{i,j}`` (if VGG is used)
        testing (bool, optional): If True, do not load pretrained VGG weights
            (avoids downloads in CI/tests). Defaults to False.

    Attributes:
        l1_w, sam_w, perc_w, tv_w (float): Loss term weights.
        max_val (float): Dynamic range for PSNR/SSIM.
        ssim_win (int): Window size for SSIM.
        randomize_bands (bool): Whether to sample random bands for perceptual loss.
        fixed_idx (torch.LongTensor|None): Fixed 3-channel indices if not randomized.
        perc_metric (str): Selected perceptual backend (``"vgg"`` or ``"lpips"``).
        perceptual_model (nn.Module): Back-end feature network/metric.
        normalizer (Normalizer): Shared normalizer for evaluation metrics.
    """

    def __init__(self, cfg, testing=False):
        super().__init__()
        self.cfg = cfg

        # --- weights & settings from config ---
        # (fallback to deprecated Training.perceptual_loss_weight if needed)
        self.l1_w = float(_cfg_get(cfg, ["Training", "Losses", "l1_weight"], 1.0))
        self.sam_w = float(_cfg_get(cfg, ["Training", "Losses", "sam_weight"], 0.05))
        perc_w_cfg = _cfg_get(
            cfg,
            ["Training", "Losses", "perceptual_weight"],
            _cfg_get(cfg, ["Training", "perceptual_loss_weight"], 0.1),
        )
        self.perc_w = float(perc_w_cfg)
        self.tv_w = float(_cfg_get(cfg, ["Training", "Losses", "tv_weight"], 0.0))

        self.max_val = float(_cfg_get(cfg, ["Training", "Losses", "max_val"], 1.0))
        self.ssim_win = int(_cfg_get(cfg, ["Training", "Losses", "ssim_win"], 11))

        fixed_idx = _cfg_get(cfg, ["Training", "Losses", "fixed_idx"], None)
        if fixed_idx is not None:
            fixed_idx = torch.as_tensor(fixed_idx, dtype=torch.long)
            assert fixed_idx.numel() == 3, "fixed_idx must have length 3"
        self.register_buffer(
            "fixed_idx", fixed_idx if fixed_idx is not None else None, persistent=False
        )

        # Only init model if perceptual weight > 0
        if self.perc_w != 0.0:
            # --- configure perceptual metric ---
            self.perc_metric = str(
                _cfg_get(cfg, ["Training", "Losses", "perceptual_metric"], "vgg")
            ).lower()

            if self.perc_metric == "vgg":
                from .vgg import TruncatedVGG19

                i = int(_cfg_get(cfg, ["TruncatedVGG", "i"], 5))
                j = int(_cfg_get(cfg, ["TruncatedVGG", "j"], 4))
                self.perceptual_model = TruncatedVGG19(i=i, j=j, weights=not testing)
            elif self.perc_metric == "lpips":
                import lpips

                self.perceptual_model = lpips.LPIPS(net="alex")
            else:
                raise ValueError(f"Unsupported perceptual metric: {self.perc_metric}")

            # Set to eval and freeze params
            for p in self.perceptual_model.parameters():
                p.requires_grad = False
            self.perceptual_model.eval()
        else:  # Set to None when no perc. wanted
            self.perceptual_model = None
            self.perc_metric = None

        # Shared normalizer for computing evaluation metrics
        self.normalizer = Normalizer(cfg)

    # ---------- public API ----------
    def return_loss(
        self, sr: torch.Tensor, hr: torch.Tensor
    ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
        """Compute the weighted content loss and return raw component metrics.

        Builds the autograd graph for terms with non-zero weights and returns the
        scalar total along with a dict of unweighted component values.

        Args:
            sr (torch.Tensor): Super-resolved prediction, shape ``(B, C, H, W)``.
            hr (torch.Tensor): High-resolution target, shape ``(B, C, H, W)``.

        Returns:
            Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
                - Total loss tensor (requires grad).
                - Dict with component tensors: ``{"l1","sam","perceptual","tv", "psnr","ssim"}``
                (component values detached except the ones used in the graph).
        """
        comps = self._compute_components(sr, hr, build_graph=True)
        loss = (
            self.l1_w * comps["l1"]
            + self.sam_w * comps["sam"]
            + self.perc_w * comps["perceptual"]
            + self.tv_w * comps["tv"]
        )
        loss = _ensure_finite(loss)
        metrics = {k: v.detach() for k, v in comps.items()}
        return loss, metrics

    @torch.no_grad()
    def return_metrics(
        self, sr: torch.Tensor, hr: torch.Tensor, prefix: str = ""
    ) -> dict[str, torch.Tensor]:
        """
        Compute all unweighted metric components and (optionally) prefix their keys.

        Args:
            sr, hr: tensors in the same range as the generator output/HR targets.
            prefix: key prefix like 'train/' or 'val'. If non-empty and doesn't end
                    with '/', a '/' is added automatically.

        Returns:
            dict mapping metric names -> tensors (detached), e.g. {'train/l1': ...}.
            Includes raw PSNR/SSIM metrics computed on stretched/clipped inputs.
        """
        comps = self._compute_components(sr, hr, build_graph=False)
        p = (prefix + "/") if prefix and not prefix.endswith("/") else prefix
        return {f"{p}{k}": v.detach() for k, v in comps.items()}

    # ---------- internals ----------
    @staticmethod
    def _tv_loss(x: torch.Tensor) -> torch.Tensor:
        """Total variation (TV) regularizer.

        Computes the L1 norm of first-order finite differences along height and width.

        Args:
            x (torch.Tensor): Input tensor, shape ``(B, C, H, W)``.

        Returns:
            torch.Tensor: Scalar TV loss (mean over batch/channels/pixels).
        """
        dh = (x[:, :, 1:, :] - x[:, :, :-1, :]).abs().mean()
        dw = (x[:, :, :, 1:] - x[:, :, :, :-1]).abs().mean()
        return _ensure_finite(dh + dw)

    @staticmethod
    def _sam_loss(
        sr: torch.Tensor, hr: torch.Tensor, eps: float = 1e-8
    ) -> torch.Tensor:
        """Spectral Angle Mapper (SAM) in radians.

        Flattens spatial dims and computes the mean angle between spectral vectors
        of SR and HR across all pixels.

        Args:
            sr (torch.Tensor): Super-resolved tensor, shape ``(B, C, H, W)``.
            hr (torch.Tensor): Target tensor, shape ``(B, C, H, W)``.
            eps (float, optional): Numerical stability epsilon. Defaults to ``1e-8``.

        Returns:
            torch.Tensor: Scalar mean SAM (radians).
        """
        B, C, H, W = sr.shape
        sr_f = sr.view(B, C, -1)
        hr_f = hr.view(B, C, -1)
        dot = (sr_f * hr_f).sum(dim=1)
        if torch.is_floating_point(sr):
            dtype_eps = torch.finfo(sr.dtype).eps
        else:
            dtype_eps = torch.finfo(torch.float32).eps
        eps = max(eps, dtype_eps)
        sr_n = sr_f.norm(dim=1).clamp_min(eps)
        hr_n = hr_f.norm(dim=1).clamp_min(eps)
        denom = torch.clamp(sr_n * hr_n, min=eps)
        cos = torch.clamp(dot / denom, -1 + 1e-7, 1 - 1e-7)
        ang = torch.acos(cos)
        return _ensure_finite(ang.mean())

    def _prepare_perceptual_input(
        self, sr: torch.Tensor, hr: torch.Tensor
    ) -> torch.Tensor:
        """Select three channels for perceptual computation.

        If the input has exactly 3 channels, returns them unchanged. Otherwise,
        selects either random unique indices  or
        the fixed indices stored in ``self.fixed_idx``.

        Args:
            x (torch.Tensor): Input tensor, shape ``(B, C, H, W)``.

        Returns:
            torch.Tensor: Tensor with three channels, shape ``(B, 3, H, W)``.
        """
        B, C, H, W = sr.shape
        if C == 1:
            # repeat single channel 3 times
            sr = sr.repeat(1, 3, 1, 1)
            hr = hr.repeat(1, 3, 1, 1)
            return sr, hr
        elif C == 2:
            # repeat first channel to make 3
            sr = torch.cat([sr, sr[:, :1, :, :]], dim=1)
            hr = torch.cat([hr, hr[:, :1, :, :]], dim=1)
            return sr, hr
        elif C == 3:
            # already 3 channels, return as is
            return sr, hr
        else:
            # when over 3 channels, randomly select 3 unique indices
            idx = torch.randperm(C, device=sr.device)[:3]
            sr = sr[:, idx, :, :]
            hr = hr[:, idx, :, :]
            return sr, hr

    def _perceptual_distance(
        self, sr_3: torch.Tensor, hr_3: torch.Tensor, *, build_graph: bool
    ) -> torch.Tensor:
        """Compute perceptual distance between SR and HR (3-channel inputs).

        Uses the configured backend:
            - ``"vgg"``: MSE between intermediate VGG features.
            - ``"lpips"``: Learned LPIPS distance (expects inputs in [-1, 1]).

        The computation detaches HR features and optionally detaches SR path if
        ``build_graph`` is False or the perceptual weight is zero.

        Args:
            sr_3 (torch.Tensor): SR slice with 3 channels, shape ``(B, 3, H, W)``, values in [0, 1].
            hr_3 (torch.Tensor): HR slice with 3 channels, shape ``(B, 3, H, W)``, values in [0, 1].
            build_graph (bool): Whether to keep gradients for SR.

        Returns:
            torch.Tensor: Scalar perceptual distance (mean over batch/spatial).
        """
        requires_grad = build_graph and self.perc_w != 0.0

        if self.perc_metric == "vgg":
            if requires_grad:
                sr_features = self.perceptual_model(sr_3)
            else:
                with torch.no_grad():
                    sr_features = self.perceptual_model(sr_3)
            with torch.no_grad():
                hr_features = self.perceptual_model(hr_3)
            distance = F.mse_loss(sr_features, hr_features)
        elif self.perc_metric == "lpips":
            sr_norm = sr_3.mul(2.0).sub(1.0)
            hr_norm = hr_3.mul(2.0).sub(1.0).detach()
            if requires_grad:
                distance = self.perceptual_model(sr_norm, hr_norm)
            else:
                with torch.no_grad():
                    distance = self.perceptual_model(sr_norm, hr_norm)
            distance = distance.mean()
        else:
            raise RuntimeError(f"Unhandled perceptual metric: {self.perc_metric}")

        if not requires_grad:
            distance = distance.detach()
        return _ensure_finite(distance)

    def _compute_components(
        self, sr: torch.Tensor, hr: torch.Tensor, *, build_graph: bool
    ) -> dict[str, torch.Tensor]:
        """Compute individual content components and auxiliary quality metrics.

        Produces a dictionary with: L1, SAM, Perceptual, TV (optionally with grads),
        and PSNR/SSIM (always without grads). Per-component autograd is enabled only
        if ``build_graph`` is True and the corresponding weight is non-zero.

        Args:
            sr (torch.Tensor): Super-resolved prediction, shape ``(B, C, H, W)``.
            hr (torch.Tensor): High-resolution target, shape ``(B, C, H, W)``.
            build_graph (bool): Whether to allow gradients for weighted components.

        Returns:
            Dict[str, torch.Tensor]: Keys ``{"l1","sam","perceptual","tv","psnr","ssim"}``.
            Component tensors are scalar means; PSNR/SSIM are detached.
        """
        comps: dict[str, torch.Tensor] = {}

        def _compute(weight: float, fn) -> torch.Tensor:
            requires_grad = build_graph and weight != 0.0
            if requires_grad:
                return fn()
            with torch.no_grad():
                return fn().detach()

        # Core reconstruction metrics (always unweighted)
        comps["l1"] = _compute(self.l1_w, lambda: _ensure_finite(F.l1_loss(sr, hr)))
        comps["sam"] = _compute(self.sam_w, lambda: self._sam_loss(sr, hr))

        # Perceptual distance on 3 selected bands
        if self.perceptual_model != None and self.perc_w != 0.0:
            sr_3, hr_3 = self._prepare_perceptual_input(sr=sr, hr=hr)
            comps["perceptual"] = self._perceptual_distance(
                sr_3, hr_3, build_graph=build_graph
            )
        else:
            comps["perceptual"] = torch.tensor(0.0, device=sr.device)

        # Total variation
        comps["tv"] = _compute(self.tv_w, lambda: self._tv_loss(sr))

        # --- Quality metrics ---
        with torch.no_grad():
            safe_max_val = max(self.max_val, LOSS_EPS)
            sr_metric = torch.clamp(sr, 0.0, safe_max_val)
            hr_metric = torch.clamp(hr, 0.0, safe_max_val)
            psnr = km.psnr(sr_metric, hr_metric, max_val=safe_max_val)
            ssim = km.ssim(
                sr_metric,
                hr_metric,
                window_size=self.ssim_win,
                max_val=safe_max_val,
            )

            if psnr.dim() > 0:
                psnr = psnr.mean()
            if ssim.dim() > 0:
                ssim = ssim.mean()

            comps["psnr"] = _ensure_finite(psnr).detach()
            comps["ssim"] = _ensure_finite(ssim).detach()

        return comps

return_loss(sr, hr)

Compute the weighted content loss and return raw component metrics.

Builds the autograd graph for terms with non-zero weights and returns the scalar total along with a dict of unweighted component values.

Parameters:

Name Type Description Default
sr Tensor

Super-resolved prediction, shape (B, C, H, W).

required
hr Tensor

High-resolution target, shape (B, C, H, W).

required

Returns:

Type Description
tuple[Tensor, dict[str, Tensor]]

Tuple[torch.Tensor, Dict[str, torch.Tensor]]: - Total loss tensor (requires grad). - Dict with component tensors: {"l1","sam","perceptual","tv", "psnr","ssim"} (component values detached except the ones used in the graph).

Source code in opensr_srgan/model/loss/loss.py
def return_loss(
    self, sr: torch.Tensor, hr: torch.Tensor
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
    """Compute the weighted content loss and return raw component metrics.

    Builds the autograd graph for terms with non-zero weights and returns the
    scalar total along with a dict of unweighted component values.

    Args:
        sr (torch.Tensor): Super-resolved prediction, shape ``(B, C, H, W)``.
        hr (torch.Tensor): High-resolution target, shape ``(B, C, H, W)``.

    Returns:
        Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
            - Total loss tensor (requires grad).
            - Dict with component tensors: ``{"l1","sam","perceptual","tv", "psnr","ssim"}``
            (component values detached except the ones used in the graph).
    """
    comps = self._compute_components(sr, hr, build_graph=True)
    loss = (
        self.l1_w * comps["l1"]
        + self.sam_w * comps["sam"]
        + self.perc_w * comps["perceptual"]
        + self.tv_w * comps["tv"]
    )
    loss = _ensure_finite(loss)
    metrics = {k: v.detach() for k, v in comps.items()}
    return loss, metrics

return_metrics(sr, hr, prefix='')

Compute all unweighted metric components and (optionally) prefix their keys.

Parameters:

Name Type Description Default
sr, hr

tensors in the same range as the generator output/HR targets.

required
prefix str

key prefix like 'train/' or 'val'. If non-empty and doesn't end with '/', a '/' is added automatically.

''

Returns:

Type Description
dict[str, Tensor]

dict mapping metric names -> tensors (detached), e.g. {'train/l1': ...}.

dict[str, Tensor]

Includes raw PSNR/SSIM metrics computed on stretched/clipped inputs.

Source code in opensr_srgan/model/loss/loss.py
@torch.no_grad()
def return_metrics(
    self, sr: torch.Tensor, hr: torch.Tensor, prefix: str = ""
) -> dict[str, torch.Tensor]:
    """
    Compute all unweighted metric components and (optionally) prefix their keys.

    Args:
        sr, hr: tensors in the same range as the generator output/HR targets.
        prefix: key prefix like 'train/' or 'val'. If non-empty and doesn't end
                with '/', a '/' is added automatically.

    Returns:
        dict mapping metric names -> tensors (detached), e.g. {'train/l1': ...}.
        Includes raw PSNR/SSIM metrics computed on stretched/clipped inputs.
    """
    comps = self._compute_components(sr, hr, build_graph=False)
    p = (prefix + "/") if prefix and not prefix.endswith("/") else prefix
    return {f"{p}{k}": v.detach() for k, v in comps.items()}

VGG based perceptual feature extractor utilities.

TruncatedVGG19

Bases: Module

A truncated VGG-19 feature extractor for perceptual loss computation.

This class wraps a pretrained VGG-19 network from torchvision.models and truncates it at a specific convolutional layer (i, j) within the feature hierarchy, following the convention in perceptual and style-transfer literature (e.g., layer relu{i}_{j}).

The truncated model outputs intermediate feature maps that capture perceptual similarity more effectively than raw pixel losses. These feature activations are typically used in content or perceptual loss terms such as:

L_perceptual = || Φ_j(x_sr) − Φ_j(x_hr) ||_1
where Φ_j denotes the truncated VGG feature extractor.

Parameters:

Name Type Description Default
i int

The convolutional block index (1-5) at which to truncate. Each block corresponds to a region between max-pooling layers. Defaults to 5.

5
j int

The convolution layer index within the chosen block. Defaults to 4.

4
weights bool

Whether to load pretrained ImageNet weights. If False, the model is initialized without downloading weights (useful for testing environments). Defaults to True.

True

Attributes:

Name Type Description
truncated_vgg19 Sequential

Sequential container of layers up to

Raises:

Type Description
AssertionError

If the provided (i, j) combination does not match a valid convolutional layer in VGG-19.

Example

vgg = TruncatedVGG19(i=5, j=4) feats = vgg(img_batch) # [B, C, H, W] feature map

Source code in opensr_srgan/model/loss/vgg.py
class TruncatedVGG19(nn.Module):
    """A truncated VGG-19 feature extractor for perceptual loss computation.

    This class wraps a pretrained VGG-19 network from ``torchvision.models``
    and truncates it at a specific convolutional layer ``(i, j)`` within
    the feature hierarchy, following the convention in perceptual and
    style-transfer literature (e.g., layer *relu{i}_{j}*).

    The truncated model outputs intermediate feature maps that capture
    perceptual similarity more effectively than raw pixel losses. These
    feature activations are typically used in content or perceptual loss
    terms such as:
    ```
    L_perceptual = || Φ_j(x_sr) − Φ_j(x_hr) ||_1
    ```
    where ``Φ_j`` denotes the truncated VGG feature extractor.

    Args:
        i (int, optional): The convolutional block index (1-5) at which to
            truncate. Each block corresponds to a region between max-pooling
            layers. Defaults to ``5``.
        j (int, optional): The convolution layer index within the chosen block.
            Defaults to ``4``.
        weights (bool, optional): Whether to load pretrained ImageNet weights.
            If ``False``, the model is initialized without downloading weights
            (useful for testing environments). Defaults to ``True``.

    Attributes:
        truncated_vgg19 (nn.Sequential): Sequential container of layers up to
        the specified truncation point.

    Raises:
        AssertionError: If the provided ``(i, j)`` combination does not match a
            valid convolutional layer in VGG-19.

    Example:
        >>> vgg = TruncatedVGG19(i=5, j=4)
        >>> feats = vgg(img_batch)  # [B, C, H, W] feature map
    """
    def __init__(self, i: int = 5, j: int = 4, weights=True) -> None:
        super().__init__()

        if weights: # omit downloading for tests
            vgg19 = torchvision.models.vgg19(weights=torchvision.models.VGG19_Weights.DEFAULT)
        else:
            vgg19 = torchvision.models.vgg19(weights=None)

        maxpool_counter = 0
        conv_counter = 0
        truncate_at = 0
        for layer in vgg19.features.children():
            truncate_at += 1

            if isinstance(layer, nn.Conv2d):
                conv_counter += 1
            if isinstance(layer, nn.MaxPool2d):
                maxpool_counter += 1
                conv_counter = 0

            if maxpool_counter == i - 1 and conv_counter == j:
                break

        if not (maxpool_counter == i - 1 and conv_counter == j):
            raise AssertionError(
                f"One or both of i={i} and j={j} are not valid choices for the VGG19!"
            )

        self.truncated_vgg19 = nn.Sequential(*list(vgg19.features.children())[: truncate_at + 1])

    def forward(self, input):  # type: ignore[override]
        """Compute VGG-19 features up to the configured truncation layer.

        Args:
            input (torch.Tensor): Input tensor of shape ``(B, 3, H, W)``
                with values normalized to ImageNet statistics (mean/std).

        Returns:
            torch.Tensor: The feature map extracted from the specified
            intermediate layer of VGG-19.
        """
        return self.truncated_vgg19(input)

forward(input)

Compute VGG-19 features up to the configured truncation layer.

Parameters:

Name Type Description Default
input Tensor

Input tensor of shape (B, 3, H, W) with values normalized to ImageNet statistics (mean/std).

required

Returns:

Type Description

torch.Tensor: The feature map extracted from the specified

intermediate layer of VGG-19.

Source code in opensr_srgan/model/loss/vgg.py
def forward(self, input):  # type: ignore[override]
    """Compute VGG-19 features up to the configured truncation layer.

    Args:
        input (torch.Tensor): Input tensor of shape ``(B, 3, H, W)``
            with values normalized to ImageNet statistics (mean/std).

    Returns:
        torch.Tensor: The feature map extracted from the specified
        intermediate layer of VGG-19.
    """
    return self.truncated_vgg19(input)

Training step helpers

training_step_PL1(self, batch, batch_idx, optimizer_idx)

One training step for PL < 2.0 using automatic optimization and multi-optimizers.

Implements GAN training with two optimizers (D first, then G) and a pretraining gate. During the pretraining phase, only the generator (optimizer_idx == 1) is optimized with content loss; the discriminator branch returns a dummy loss and logs zeros. During adversarial training, the discriminator minimizes BCE on real HR vs. fake SR logits, and the generator minimizes content loss plus a ramped adversarial loss.

Parameters:

Name Type Description Default
batch Tuple[Tensor, Tensor]

(lr_imgs, hr_imgs) with shape (B, C, H, W).

required
batch_idx int

Global batch index for the current epoch.

required
optimizer_idx int

Active optimizer index provided by Lightning: - 0: Discriminator step. - 1: Generator step.

required

Returns:

Type Description

torch.Tensor: - Pretraining: - optimizer_idx == 1: content loss tensor for the generator. - optimizer_idx == 0: dummy scalar tensor with requires_grad=True. - Adversarial training: - optimizer_idx == 0: discriminator BCE loss (real + fake). - optimizer_idx == 1: generator total loss = content + λ_adv · BCE(G).

Logged Metrics (selection): - "training/pretrain_phase": 1.0 during pretraining (logged on G step). - "train_metrics/*": content metrics from the content loss criterion. - "generator/content_loss", "generator/adversarial_loss", "generator/total_loss". - "discriminator/adversarial_loss", "discriminator/D(y)_prob", "discriminator/D(G(x))_prob". - "training/adv_loss_weight": current λ_adv from the ramp scheduler.

Notes
  • Discriminator step uses sr_imgs.detach() to prevent G gradients.
  • Adversarial loss weight λ_adv ramps from 0 → adv_loss_beta per configured schedule.
  • Assumes optimizers are ordered as [D, G] in configure_optimizers().
Source code in opensr_srgan/model/training_step_PL.py
def training_step_PL1(self, batch, batch_idx, optimizer_idx):
    """One training step for PL < 2.0 using automatic optimization and multi-optimizers.

    Implements GAN training with two optimizers (D first, then G) and a
    pretraining gate. During the **pretraining phase**, only the generator
    (optimizer_idx == 1) is optimized with content loss; the discriminator
    branch returns a dummy loss and logs zeros. During **adversarial training**,
    the discriminator minimizes BCE on real HR vs. fake SR logits, and the
    generator minimizes content loss plus a ramped adversarial loss.

    Args:
        batch (Tuple[torch.Tensor, torch.Tensor]): `(lr_imgs, hr_imgs)` with shape `(B, C, H, W)`.
        batch_idx (int): Global batch index for the current epoch.
        optimizer_idx (int): Active optimizer index provided by Lightning:
            - `0`: Discriminator step.
            - `1`: Generator step.

    Returns:
        torch.Tensor:
            - **Pretraining**:
              - `optimizer_idx == 1`: content loss tensor for the generator.
              - `optimizer_idx == 0`: dummy scalar tensor with `requires_grad=True`.
            - **Adversarial training**:
              - `optimizer_idx == 0`: discriminator BCE loss (real + fake).
              - `optimizer_idx == 1`: generator total loss = content + λ_adv · BCE(G).

    Logged Metrics (selection):
        - `"training/pretrain_phase"`: 1.0 during pretraining (logged on G step).
        - `"train_metrics/*"`: content metrics from the content loss criterion.
        - `"generator/content_loss"`, `"generator/adversarial_loss"`, `"generator/total_loss"`.
        - `"discriminator/adversarial_loss"`, `"discriminator/D(y)_prob"`,
          `"discriminator/D(G(x))_prob"`.
        - `"training/adv_loss_weight"`: current λ_adv from the ramp scheduler.

    Notes:
        - Discriminator step uses `sr_imgs.detach()` to prevent G gradients.
        - Adversarial loss weight λ_adv ramps from 0 → `adv_loss_beta` per configured schedule.
        - Assumes optimizers are ordered as `[D, G]` in `configure_optimizers()`.
    """

    # -------- CREATE SR DATA --------
    lr_imgs, hr_imgs = batch  # unpack LR/HR tensors from dataloader batch
    sr_imgs = self.forward(
        lr_imgs
    )  # forward pass of the generator to produce SR from LR

    # Default to standard GAN loss if adv_loss_type is not defined (e.g., lightweight
    # harnesses in tests). The real model sets this attribute during init.
    use_wasserstein = getattr(self, "adv_loss_type", "gan") == "wasserstein"

    # ======================================================================
    # SECTION: Pretraining phase gate
    # Purpose: decide if we are in the content-only pretrain stage.
    # ======================================================================

    # -------- DETERMINE PRETRAINING --------
    pretrain_phase = (
        self._pretrain_check()
    )  # check schedule: True => content-only pretraining
    if optimizer_idx == 1:  # log whether pretraining is active or not
        self.log(
            "training/pretrain_phase",
            float(pretrain_phase),
            prog_bar=False,
            sync_dist=True,
        )  # log once per G step to track phase state

    # ======================================================================
    # SECTION: Pretraining branch (delegated)
    # Purpose: during pretrain, only content loss for G and dummy logging for D.
    # ======================================================================

    # -------- IF PRETRAIN: delegate --------
    if pretrain_phase:
        # run pretrain step separately and return loss here
        if optimizer_idx == 1:
            content_loss, metrics = self.content_loss_criterion.return_loss(
                sr_imgs, hr_imgs
            )  # compute perceptual/content loss (e.g., VGG or L1)
            self._log_generator_content_loss(
                content_loss
            )  # log content loss for G (consistent args)
            for key, value in metrics.items():
                self.log(
                    f"train_metrics/{key}", value, sync_dist=True
                )  # reuse computed metrics for logging

            # Ensure adversarial weight is logged even when not used during pretraining
            adv_weight = self._compute_adv_loss_weight()
            self._log_adv_loss_weight(adv_weight)
            return content_loss  # return loss for optimizer step (G only)

        # ======================================================================
        # SECTION: Discriminator (D) pretraining step
        # Purpose: no real training — just log zeros and return dummy loss to satisfy closure.
        # ======================================================================
        elif optimizer_idx == 0:
            device, dtype = (
                hr_imgs.device,
                hr_imgs.dtype,
            )  # get tensor device and dtype for consistency
            zero = torch.tensor(
                0.0, device=device, dtype=dtype
            )  # define reusable zero tensor

            # --- Log dummy discriminator "opinions" (always zero during pretrain) ---
            self.log(
                "discriminator/D(y)_prob", zero, prog_bar=True, sync_dist=True
            )  # fake real-prob (always 0)
            self.log(
                "discriminator/D(G(x))_prob", zero, prog_bar=True, sync_dist=True
            )  # fake fake-prob (always 0)

            # --- Create dummy scalar loss (ensures PL closure runs) ---
            dummy = torch.zeros(
                (), device=device, dtype=dtype, requires_grad=True
            )  # dummy value with grad for optimizer compatibility
            self.log(
                "discriminator/adversarial_loss", dummy, sync_dist=True
            )  # log dummy adversarial loss (always 0)
            return dummy
    # -------- END PRETRAIN --------

    # ======================================================================
    # SECTION: Adversarial training — Discriminator step
    # Purpose: update D to distinguish HR (real) vs SR (fake).
    # ======================================================================

    # -------- Normal Train: Discriminator Step  --------
    if optimizer_idx == 0:
        r1_gamma = getattr(self, "r1_gamma", 0.0)  # default to 0 for
        hr_imgs.requires_grad_(r1_gamma > 0)  # enable grad for R1 penalty if needed

        # run discriminator and get loss between pred labels and true labels
        hr_discriminated = self.discriminator(hr_imgs)  # D(real): logits for HR images
        sr_discriminated = self.discriminator(
            sr_imgs.detach()
        )  # detach so G doesn’t get gradients from D’s step

        # Check for WS GAN loss
        if use_wasserstein:  # Wasserstein GAN loss
            loss_real = -hr_discriminated.mean()
            loss_fake = sr_discriminated.mean()
        else:  # Standard GAN loss (BCE)
            real_target = torch.full_like(
                hr_discriminated, self.adv_target
            )  # get labels/fuzzy labels
            fake_target = torch.zeros_like(
                sr_discriminated
            )  # zeros, since generative prediction

            # Binary Cross-Entropy loss
            loss_real = self.adversarial_loss_criterion(
                hr_discriminated, real_target
            )  # BCEWithLogitsLoss for D(G(x))
            loss_fake = self.adversarial_loss_criterion(
                sr_discriminated, fake_target
            )  # BCEWithLogitsLoss for D(y)

        # R1 Gradient Penalty (if enabled)
        r1_penalty = torch.zeros((), device=hr_imgs.device, dtype=hr_imgs.dtype)
        if r1_gamma > 0:
            grad_real = torch.autograd.grad(
                outputs=hr_discriminated.sum(),
                inputs=hr_imgs,
                create_graph=True,
                retain_graph=True,
            )[0]
            grad_penalty = grad_real.pow(2).reshape(grad_real.size(0), -1).sum(dim=1)
            r1_penalty = 0.5 * r1_gamma * grad_penalty.mean()

        # Sum up losses
        adversarial_loss = (
            loss_real + loss_fake + r1_penalty
        )  # add 0s for R1 if disabled
        self.log(
            "discriminator/adversarial_loss", adversarial_loss, sync_dist=True
        )  # log weighted loss
        self.log(
            "discriminator/r1_penalty", r1_penalty.detach(), sync_dist=True
        )  # log R1 penalty regarless, is 0 when turned off

        # [LOG-B] Always log D opinions: real probs in normal training
        with torch.no_grad():
            d_real_prob = torch.sigmoid(
                hr_discriminated
            ).mean()  # estimate mean real probability
            d_fake_prob = torch.sigmoid(
                sr_discriminated
            ).mean()  # estimate mean fake probability
        self.log(
            "discriminator/D(y)_prob", d_real_prob, prog_bar=True, sync_dist=True
        )  # log D(real) confidence
        self.log(
            "discriminator/D(G(x))_prob", d_fake_prob, prog_bar=True, sync_dist=True
        )  # log D(fake) confidence

        # return weighted discriminator loss
        return adversarial_loss  # PL will use this to step the D optimizer

    # ======================================================================
    # SECTION: Adversarial training — Generator step
    # Purpose: update G to minimize content loss + (weighted) adversarial loss.
    # ======================================================================

    # -------- Normal Train: Generator Step  --------
    if optimizer_idx == 1:

        """1. Get VGG space loss"""
        # encode images
        content_loss, metrics = self.content_loss_criterion.return_loss(
            sr_imgs, hr_imgs
        )  # perceptual/content criterion (e.g., VGG)
        self._log_generator_content_loss(
            content_loss
        )  # log content loss for G (consistent args)
        for key, value in metrics.items():
            self.log(
                f"train_metrics/{key}", value, sync_dist=True
            )  # log detailed metrics without extra forward passes

        """ 2. Get Discriminator Opinion and loss """
        # run discriminator and get loss between pred labels and true labels
        sr_discriminated = self.discriminator(
            sr_imgs
        )  # D(SR): logits for generator outputs
        if use_wasserstein:  # Wasserstein GAN loss
            adversarial_loss = -sr_discriminated.mean()
        else:  # Standard GAN loss (BCE)
            adversarial_loss = self.adversarial_loss_criterion(
                sr_discriminated, torch.ones_like(sr_discriminated)
            )  # keep taargets 1.0 for G loss
        self.log(
            "generator/adversarial_loss", adversarial_loss, sync_dist=True
        )  # log unweighted adversarial loss

        """ 3. Weight the losses"""
        adv_weight = (
            self._adv_loss_weight()
        )  # get adversarial weight based on current step
        adversarial_loss_weighted = (
            adversarial_loss * adv_weight
        )  # weight adversarial loss
        total_loss = content_loss + adversarial_loss_weighted  # total content loss
        self.log(
            "generator/total_loss", total_loss, sync_dist=True
        )  # log combined objective (content + λ_adv * adv)

        # return Generator loss
        return total_loss

training_step_PL2(self, batch, batch_idx)

Manual-optimization training step for PyTorch Lightning ≥ 2.0.

Mirrors the PL1.x logic with explicit optimizer control: - Pretraining phase: Discriminator logs dummies; Generator is optimized with content loss only (no adversarial term), and EMA optionally updates. - Adversarial phase: Performs a Discriminator step (real vs. fake BCE), followed by a Generator step (content + λ_adv · BCE against ones). Uses the same log keys and ordering as the PL1.x path.

Assumptions
  • self.automatic_optimization is False (manual opt).
  • configure_optimizers() returns optimizers in order [opt_d, opt_g].
  • EMA updates occur after self._ema_update_after_step.

Parameters:

Name Type Description Default
batch Tuple[Tensor, Tensor]

(lr_imgs, hr_imgs) tensors with shape (B, C, H, W).

required
batch_idx int

Index of the current batch.

required

Returns:

Type Description

torch.Tensor: - Pretraining: content loss (Generator-only step). - Adversarial: total generator loss = content + λ_adv · BCE(G).

Logged metrics (selection): - "training/pretrain_phase" (0/1) - "train_metrics/*" (from content criterion) - "generator/content_loss", "generator/adversarial_loss", "generator/total_loss" - "discriminator/adversarial_loss", "discriminator/D(y)_prob", "discriminator/D(G(x))_prob" - "training/adv_loss_weight" (λ_adv from ramp schedule)

Raises:

Type Description
AssertionError

If PL version < 2.0 or automatic_optimization is True.

Source code in opensr_srgan/model/training_step_PL.py
def training_step_PL2(self, batch, batch_idx):
    """Manual-optimization training step for PyTorch Lightning ≥ 2.0.

    Mirrors the PL1.x logic with explicit optimizer control:
    - **Pretraining phase**: Discriminator logs dummies; Generator is optimized with
      content loss only (no adversarial term), and EMA optionally updates.
    - **Adversarial phase**: Performs a Discriminator step (real vs. fake BCE),
      followed by a Generator step (content + λ_adv · BCE against ones). Uses the
      same log keys and ordering as the PL1.x path.

    Assumptions:
        - `self.automatic_optimization` is `False` (manual opt).
        - `configure_optimizers()` returns optimizers in order `[opt_d, opt_g]`.
        - EMA updates occur after `self._ema_update_after_step`.

    Args:
        batch (Tuple[torch.Tensor, torch.Tensor]): `(lr_imgs, hr_imgs)` tensors with shape `(B, C, H, W)`.
        batch_idx (int): Index of the current batch.

    Returns:
        torch.Tensor:
            - **Pretraining**: content loss (Generator-only step).
            - **Adversarial**: total generator loss = content + λ_adv · BCE(G).

    Logged metrics (selection):
        - `"training/pretrain_phase"` (0/1)
        - `"train_metrics/*"` (from content criterion)
        - `"generator/content_loss"`, `"generator/adversarial_loss"`, `"generator/total_loss"`
        - `"discriminator/adversarial_loss"`, `"discriminator/D(y)_prob"`, `"discriminator/D(G(x))_prob"`
        - `"training/adv_loss_weight"` (λ_adv from ramp schedule)

    Raises:
        AssertionError: If PL version < 2.0 or `automatic_optimization` is True.
    """
    assert self.pl_version >= (
        2,
        0,
        0,
    ), "training_step_PL2 requires PyTorch Lightning >= 2.x."
    assert (
        self.automatic_optimization is False
    ), "training_step_PL2 requires manual_optimization."

    # -------- CREATE SR DATA --------
    lr_imgs, hr_imgs = batch
    sr_imgs = self.forward(lr_imgs)
    use_wasserstein = getattr(self, "adv_loss_type", "gan") == "wasserstein"

    # --- helper to resolve adv-weight function name mismatches ---
    def _adv_weight():
        if hasattr(self, "_adv_loss_weight"):
            return self._adv_loss_weight()
        return self._compute_adv_loss_weight()

    # fetch optimizers (expects two)
    opt_d, opt_g = self.optimizers()

    # optional gradient clipping support (norm-based)
    try:
        gradient_clip_val = self.config.Schedulers.gradient_clip_val
    except AttributeError:
        gradient_clip_val = 0.0

    def _maybe_clip_gradients(module, optimizer=None):
        if gradient_clip_val > 0.0 and module is not None:
            precision_plugin = getattr(self.trainer, "precision_plugin", None)
            if (
                optimizer is not None
                and precision_plugin is not None
                and hasattr(precision_plugin, "unscale_optimizer")
            ):
                precision_plugin.unscale_optimizer(optimizer)
            torch.nn.utils.clip_grad_norm_(module.parameters(), gradient_clip_val)

    # ======================================================================
    # SECTION: Pretraining phase gate
    # ======================================================================
    pretrain_phase = self._pretrain_check()
    # in PL1.x you logged this only on G-step; here we log once per batch
    self.log(
        "training/pretrain_phase", float(pretrain_phase), prog_bar=False, sync_dist=True
    )

    # ======================================================================
    # SECTION: Pretraining branch (content-only on G; D logs dummies)
    # ======================================================================
    if pretrain_phase:
        # --- D dummy logs (no step) to mimic your optimizer_idx==0 branch ---
        with torch.no_grad():
            zero = torch.tensor(0.0, device=hr_imgs.device, dtype=hr_imgs.dtype)
            self.log("discriminator/D(y)_prob", zero, prog_bar=True, sync_dist=True)
            self.log("discriminator/D(G(x))_prob", zero, prog_bar=True, sync_dist=True)
            self.log("discriminator/adversarial_loss", zero, sync_dist=True)

        # --- G step: content loss only (identical to your optimizer_idx==1 pretrain) ---
        content_loss, metrics = self.content_loss_criterion.return_loss(
            sr_imgs, hr_imgs
        )
        self._log_generator_content_loss(content_loss)
        for key, value in metrics.items():
            self.log(f"train_metrics/{key}", value, sync_dist=True)

        # ensure adv-weight is still logged like in pretrain
        self._log_adv_loss_weight(_adv_weight())

        # manual optimize G
        if hasattr(self, "toggle_optimizer"):
            self.toggle_optimizer(opt_g)
        opt_g.zero_grad()
        self.manual_backward(content_loss)
        _maybe_clip_gradients(self.generator, opt_g)
        opt_g.step()
        if hasattr(self, "untoggle_optimizer"):
            self.untoggle_optimizer(opt_g)

        # EMA in PL2 manual mode
        if self.ema is not None and self.global_step >= self._ema_update_after_step:
            self.ema.update(self.generator)

        # return same scalar you’d have returned in PL1.x (content loss)
        return content_loss

    # ======================================================================
    # SECTION: Adversarial training — Discriminator step
    # ======================================================================
    if hasattr(self, "toggle_optimizer"):
        self.toggle_optimizer(opt_d)
    opt_d.zero_grad()

    # Get R1 Gamma
    r1_gamma = getattr(self, "r1_gamma", 0.0)
    hr_imgs.requires_grad_(r1_gamma > 0)  # enable grad for R1 penalty if needed


    hr_discriminated = self.discriminator(hr_imgs)  # D(y)
    sr_discriminated = self.discriminator(sr_imgs.detach())  # D(G(x)) w/o grad to G

    if use_wasserstein:  # Wasserstein GAN loss
        loss_real = -hr_discriminated.mean()
        loss_fake = sr_discriminated.mean()
    else:  # Standard GAN loss (BCE)
        real_target = torch.full_like(hr_discriminated, self.adv_target)
        fake_target = torch.zeros_like(sr_discriminated)

        loss_real = self.adversarial_loss_criterion(hr_discriminated, real_target)
        loss_fake = self.adversarial_loss_criterion(sr_discriminated, fake_target)

    # R1 Gradient Penalty
    r1_penalty = torch.zeros((), device=hr_imgs.device, dtype=hr_imgs.dtype)
    if r1_gamma > 0:
        grad_real = torch.autograd.grad(
            outputs=hr_discriminated.sum(),
            inputs=hr_imgs,
            create_graph=True,
            retain_graph=True,
        )[0]
        grad_penalty = grad_real.pow(2).reshape(grad_real.size(0), -1).sum(dim=1)
        r1_penalty = 0.5 * r1_gamma * grad_penalty.mean()

    adversarial_loss = (
        loss_real + loss_fake + r1_penalty
    )  # sum up loss with R1 (0 when turned off)
    self.log("discriminator/adversarial_loss", adversarial_loss, sync_dist=True)
    self.log(
        "discriminator/r1_penalty", r1_penalty.detach(), sync_dist=True
    )  # log R1 penalty regardless, is 0 when turned off

    with torch.no_grad():
        d_real_prob = torch.sigmoid(hr_discriminated).mean()
        d_fake_prob = torch.sigmoid(sr_discriminated).mean()
    self.log("discriminator/D(y)_prob", d_real_prob, prog_bar=True, sync_dist=True)
    self.log("discriminator/D(G(x))_prob", d_fake_prob, prog_bar=True, sync_dist=True)

    self.manual_backward(adversarial_loss)
    _maybe_clip_gradients(self.discriminator, opt_d)
    opt_d.step()
    if hasattr(self, "untoggle_optimizer"):
        self.untoggle_optimizer(opt_d)

    # ======================================================================
    # SECTION: Adversarial training — Generator step
    # ======================================================================
    if hasattr(self, "toggle_optimizer"):
        self.toggle_optimizer(opt_g)
    opt_g.zero_grad()

    # 1) content loss (identical to original)
    content_loss, metrics = self.content_loss_criterion.return_loss(sr_imgs, hr_imgs)
    self._log_generator_content_loss(content_loss)
    for key, value in metrics.items():
        self.log(f"train_metrics/{key}", value, sync_dist=True)

    # 2) adversarial loss against ones
    sr_discriminated_for_g = self.discriminator(sr_imgs)
    if use_wasserstein:  # Wasserstein GAN loss
        g_adv = -sr_discriminated_for_g.mean()
    else:  # Standard GAN loss (BCE)
        g_adv = self.adversarial_loss_criterion(
            sr_discriminated_for_g, torch.ones_like(sr_discriminated_for_g)
        )
    self.log("generator/adversarial_loss", g_adv, sync_dist=True)

    # 3) weighted total
    adv_weight = _adv_weight()
    total_loss = content_loss + (g_adv * adv_weight)
    self.log("generator/total_loss", total_loss, sync_dist=True)

    self.manual_backward(total_loss)
    _maybe_clip_gradients(self.generator, opt_g)
    opt_g.step()
    if hasattr(self, "untoggle_optimizer"):
        self.untoggle_optimizer(opt_g)

    # EMA in PL2 manual mode
    if self.ema is not None and self.global_step >= self._ema_update_after_step:
        self.ema.update(self.generator)

    # return same scalar you return in PL1.x G path
    return total_loss