Model Components¶
Classes and building blocks used to assemble SRGAN training and inference graphs.
Lightning module¶
SRGAN_model
¶
Bases: LightningModule
SRGAN_model¶
A flexible, PyTorch Lightning–based SRGAN implementation for single-image super-resolution in remote sensing and general imaging. The model supports multiple generator backbones (SRResNet, RCAB, RRDB, LKA, and a flexible registry-driven variant) and discriminator types (standard, PatchGAN), optional generator-only pretraining, adversarial loss ramp-up, and an Exponential Moving Average (EMA) of generator weights for more stable evaluation.
Key features¶
- Backbone flexibility: Select generator/discriminator architectures via config.
- Training modes: Generator pretraining, adversarial training, and LR warm-up.
- PL compatibility: Automatic optimization for PL < 2.0; manual optimization for PL ≥ 2.0.
- EMA support: Optional EMA tracking with delayed activation and device placement.
- Metrics & logging: Content/perceptual metrics, LR logging, and optional W&B image logs.
- Inference helpers: Normalization/denormalization for 0–10000 reflectance and histogram matching.
Args¶
config : str | pathlib.Path | dict | omegaconf.DictConfig, optional
Path to a YAML file, an in-memory dict, or an OmegaConf config with sections
defining Generator/Discriminator, Training, Optimizers, Schedulers, and Logging.
Defaults to "config.yaml".
mode : {"train", "eval"}, optional
Build both G and D in "train" mode; only G in "eval". Defaults to "train".
Configuration overview (minimal)¶
- Model:
in_bands(int)- Generator:
model_type("SRResNet","stochastic_gan","esrgan"), optionalblock_typefor SRResNet variants ("standard","res","rcab","rrdb","lka"),n_channels,n_blocks,large_kernel_size,small_kernel_size,scaling_factor, plus ESRGAN-specific knobs (growth_channels,res_scale,out_channels).
- Generator:
- Discriminator:
model_type("standard","patchgan","esrgan"),n_blocks(optional), ESRGAN extras (base_channels,linear_size). - Training:
pretrain_g_only(bool),g_pretrain_steps(int)adv_loss_ramp_steps(int),label_smoothing(bool)Losses.adv_loss_beta(float),Losses.adv_loss_schedule("linear"|"cosine")EMA.enabled(bool),EMA.decay(float),EMA.use_num_updates(bool),EMA.update_after_step(int),EMA.device(str|None)- Optimizers:
optim_g_lr(float),optim_d_lr(float) - Schedulers:
factor_g,factor_d,patience_g,patience_d,metric, optionalg_warmup_steps,g_warmup_type("linear"|"cosine") - Logging:
wandb.enabled(bool),num_val_images(int)
Behavior & versioning¶
- PL ≥ 2.0: Manual optimization (
automatic_optimization = False). The boundtraining_step_PL2performs explicitzero_grad/stepcalls and handles EMA updates. - PL < 2.0: Automatic optimization. The legacy
training_step_PL1is used, andoptimizer_stepcoordinates stepping and EMA after generator updates.
Created attributes (non-exhaustive)¶
generator : torch.nn.Module
Super-resolution generator.
discriminator : torch.nn.Module | None
Only present in "train" mode.
ema : ExponentialMovingAverage | None
EMA tracker for the generator (if enabled).
content_loss_criterion : torch.nn.Module
Perceptual/pixel content loss wrapper used in training/validation.
adversarial_loss_criterion : torch.nn.Module
BCEWithLogits loss for adversarial training (D/G).
Input/Output conventions¶
- Forward:
lr_imgsof shape(B, C, H, W)→ SR output with spatial scale set byGenerator.scaling_factor. - Predict: Applies optional normalization (0–10000 → 0–1), EMA averaging, histogram matching, and denormalization back to the input range; returns CPU tensors.
Example¶
model = SRGAN_model(config="config.yaml", mode="train")
Trainer handles fit; inference via .predict_step() or forward():¶
model.eval() with torch.no_grad(): ... sr = model(lr_imgs)
Notes¶
- Discriminator is frozen during generator pretraining (
pretrain_g_only=Trueandglobal_step < g_pretrain_steps). - The adversarial loss contribution is ramped from 0 to
adv_loss_betaoveradv_loss_ramp_stepswith linear or cosine schedule. - Learning rate warm-up for the generator is supported via a per-step LambdaLR.
Source code in opensr_srgan/model/SRGAN.py
class SRGAN_model(pl.LightningModule):
"""
SRGAN_model
===========
A flexible, PyTorch Lightning–based SRGAN implementation for single-image super-resolution
in remote sensing and general imaging. The model supports multiple generator backbones
(SRResNet, RCAB, RRDB, LKA, and a flexible registry-driven variant) and discriminator
types (standard, PatchGAN), optional generator-only pretraining, adversarial loss ramp-up,
and an Exponential Moving Average (EMA) of generator weights for more stable evaluation.
Key features
------------
- **Backbone flexibility:** Select generator/discriminator architectures via config.
- **Training modes:** Generator pretraining, adversarial training, and LR warm-up.
- **PL compatibility:** Automatic optimization for PL < 2.0; manual optimization for PL ≥ 2.0.
- **EMA support:** Optional EMA tracking with delayed activation and device placement.
- **Metrics & logging:** Content/perceptual metrics, LR logging, and optional W&B image logs.
- **Inference helpers:** Normalization/denormalization for 0–10000 reflectance and histogram matching.
Args
----
config : str | pathlib.Path | dict | omegaconf.DictConfig, optional
Path to a YAML file, an in-memory dict, or an OmegaConf config with sections
defining Generator/Discriminator, Training, Optimizers, Schedulers, and Logging.
Defaults to `"config.yaml"`.
mode : {"train", "eval"}, optional
Build both G and D in `"train"` mode; only G in `"eval"`. Defaults to `"train"`.
Configuration overview (minimal)
--------------------------------
- **Model**: `in_bands` (int)
- **Generator**: `model_type` (`"SRResNet"`, `"stochastic_gan"`, `"esrgan"`),
optional `block_type` for SRResNet variants (`"standard"`, `"res"`, `"rcab"`,
`"rrdb"`, `"lka"`), `n_channels`, `n_blocks`, `large_kernel_size`,
`small_kernel_size`, `scaling_factor`, plus ESRGAN-specific knobs
(`growth_channels`, `res_scale`, `out_channels`).
- **Discriminator**: `model_type` (`"standard"`, `"patchgan"`, `"esrgan"`), `n_blocks`
(optional), ESRGAN extras (`base_channels`, `linear_size`).
- **Training**:
- `pretrain_g_only` (bool), `g_pretrain_steps` (int)
- `adv_loss_ramp_steps` (int), `label_smoothing` (bool)
- `Losses.adv_loss_beta` (float), `Losses.adv_loss_schedule` (`"linear"`|`"cosine"`)
- `EMA.enabled` (bool), `EMA.decay` (float), `EMA.use_num_updates` (bool),
`EMA.update_after_step` (int), `EMA.device` (str|None)
- **Optimizers**: `optim_g_lr` (float), `optim_d_lr` (float)
- **Schedulers**: `factor_g`, `factor_d`, `patience_g`, `patience_d`, `metric`,
optional `g_warmup_steps`, `g_warmup_type` (`"linear"`|`"cosine"`)
- **Logging**: `wandb.enabled` (bool), `num_val_images` (int)
Behavior & versioning
---------------------
- **PL ≥ 2.0**: Manual optimization (`automatic_optimization = False`). The bound
`training_step_PL2` performs explicit `zero_grad/step` calls and handles EMA updates.
- **PL < 2.0**: Automatic optimization. The legacy `training_step_PL1` is used, and
`optimizer_step` coordinates stepping and EMA after generator updates.
Created attributes (non-exhaustive)
-----------------------------------
generator : torch.nn.Module
Super-resolution generator.
discriminator : torch.nn.Module | None
Only present in `"train"` mode.
ema : ExponentialMovingAverage | None
EMA tracker for the generator (if enabled).
content_loss_criterion : torch.nn.Module
Perceptual/pixel content loss wrapper used in training/validation.
adversarial_loss_criterion : torch.nn.Module
BCEWithLogits loss for adversarial training (D/G).
Input/Output conventions
------------------------
- **Forward**: `lr_imgs` of shape `(B, C, H, W)` → SR output with spatial scale set by
`Generator.scaling_factor`.
- **Predict**: Applies optional normalization (0–10000 → 0–1), EMA averaging, histogram
matching, and denormalization back to the input range; returns CPU tensors.
Example
-------
>>> model = SRGAN_model(config="config.yaml", mode="train")
>>> # Trainer handles fit; inference via .predict_step() or forward():
>>> model.eval()
>>> with torch.no_grad():
... sr = model(lr_imgs)
Notes
-----
- Discriminator is frozen during generator pretraining (`pretrain_g_only=True` and
`global_step < g_pretrain_steps`).
- The adversarial loss contribution is ramped from 0 to `adv_loss_beta` over
`adv_loss_ramp_steps` with linear or cosine schedule.
- Learning rate warm-up for the generator is supported via a per-step LambdaLR.
"""
def __init__(self, config="config.yaml", mode="train"):
super(SRGAN_model, self).__init__()
# ======================================================================
# SECTION: Load Configuration
# Purpose: Load and parse model/training hyperparameters from YAML file.
# ======================================================================
if isinstance(config, str) or isinstance(config, Path):
config = OmegaConf.load(config)
elif isinstance(config, dict):
config = OmegaConf.create(config)
elif OmegaConf.is_config(config):
pass
else:
raise TypeError(
"Config must be a filepath (str or Path), dict, or OmegaConf object."
)
assert mode in {
"train",
"eval",
}, "Mode must be 'train' or 'eval'" # validate mode
# ======================================================================
# SECTION: Set Variables
# Purpose: Set config and mode variables model-wide, including PL version.
# ======================================================================
self.config = config
self.mode = mode
self.pl_version = tuple(int(x) for x in pl.__version__.split("."))
self.normalizer = Normalizer(self.config)
# ======================================================================
# SECTION: Get Training settings
# Purpose: Define model variables to enable training strategies.
# ======================================================================
self.pretrain_g_only = bool(
getattr(self.config.Training, "pretrain_g_only", False)
) # pretrain generator only (default False)
self.g_pretrain_steps = int(
getattr(self.config.Training, "g_pretrain_steps", 0)
) # number of steps for G pretraining
self.adv_loss_ramp_steps = int(
getattr(self.config.Training, "adv_loss_ramp_steps", 20000)
) # linear ramp-up steps for adversarial loss
self.adv_target = (
0.9 if getattr(self.config.Training, "label_smoothing", False) else 1.0
) # use 0.9 if label smoothing enabled, else 1.0
self.adv_loss_type = str(
getattr(self.config.Training.Losses, "adv_loss_type", "bce")
).lower()
if self.adv_loss_type not in {"bce", "wasserstein"}:
raise ValueError(
"Training.Losses.adv_loss_type must be either 'bce' or 'wasserstein'"
)
self.r1_gamma = float(
getattr(self.config.Training.Losses, "r1_gamma", 0.0)
) # R1 gradient penalty strength (0 disables)
# ======================================================================
# SECTION: Set up Training Strategy
# Purpose: Depending on PL version, set up optimizers, schedulers, etc.
# ======================================================================
self.setup_lightning() # dynamically builds and attaches generator + discriminator
# ======================================================================
# SECTION: Initialize Generator
# Purpose: Build generator network depending on selected architecture.
# ======================================================================
self.get_models(
mode=self.mode
) # dynamically builds and attaches generator + discriminator
# ======================================================================
# SECTION: Initialize EMA
# Purpose: Optional exponential moving average (EMA) tracking for generator weights
# ======================================================================
self.initialize_ema()
# ======================================================================
# SECTION: Define Loss Functions
# Purpose: Configure generator content loss and discriminator adversarial loss.
# ======================================================================
if self.mode == "train":
from opensr_srgan.model.loss import GeneratorContentLoss
self.content_loss_criterion = GeneratorContentLoss(
self.config
) # perceptual loss (VGG + pixel)
if self.adv_loss_type == "bce": # check for WS GAN or BCE
self.adversarial_loss_criterion = torch.nn.BCEWithLogitsLoss()
else:
self.adversarial_loss_criterion = None
def get_models(self, mode):
"""Initialize and attach the Generator and (optionally) Discriminator models.
This method builds the generator and discriminator architectures based on
the configuration provided in `self.config`. It supports multiple generator
backbones (e.g., SRResNet, RCAB, RRDB, LKA) and discriminator types
(standard, PatchGAN). The discriminator is only initialized when the mode
is set to `"train"`.
Args:
mode (str): Operational mode of the model. Must be one of:
- `"train"`: Initializes both generator and discriminator.
- Any other value: Initializes only the generator.
Raises:
ValueError: If an unknown generator or discriminator type is specified
in the configuration.
Attributes:
generator (nn.Module): The initialized generator network instance.
discriminator (nn.Module, optional): The initialized discriminator
network instance (only present if `mode == "train"`).
"""
# ======================================================================
# SECTION: Initialize Generator
# Purpose: Build generator network depending on selected architecture.
# ======================================================================
self.generator = build_generator(self.config)
if mode == "train": # only get discriminator in training mode
# ======================================================================
# SECTION: Initialize Discriminator
# Purpose: Build discriminator network for adversarial training.
# ======================================================================
raw_discriminator_type = getattr(
self.config.Discriminator, "model_type", "standard"
)
discriminator_type = str(raw_discriminator_type).strip().lower()
n_blocks = getattr(self.config.Discriminator, "n_blocks", None)
if discriminator_type == "standard":
from opensr_srgan.model.discriminators.srgan_discriminator import (
Discriminator,
)
discriminator_kwargs = {
"in_channels": self.config.Model.in_bands,
}
if n_blocks is not None:
discriminator_kwargs["n_blocks"] = n_blocks
# pass spectral norm option
use_spectral_norm = getattr(
self.config.Discriminator, "use_spectral_norm", True
)
discriminator_kwargs["use_spectral_norm"] = bool(use_spectral_norm)
self.discriminator = Discriminator(**discriminator_kwargs)
elif discriminator_type == "patchgan":
from opensr_srgan.model.discriminators.patchgan import (
PatchGANDiscriminator,
)
patchgan_layers = n_blocks if n_blocks is not None else 3
self.discriminator = PatchGANDiscriminator(
input_nc=self.config.Model.in_bands,
n_layers=patchgan_layers,
)
elif discriminator_type == "esrgan":
from opensr_srgan.model.discriminators.esrgan import (
ESRGANDiscriminator,
)
ignored_options = []
if n_blocks is not None:
ignored_options.append("n_blocks")
if ignored_options:
ignored_joined = ", ".join(sorted(ignored_options))
print(
f"[Discriminator:esrgan] Ignoring unsupported configuration options: {ignored_joined}."
)
base_channels = getattr(self.config.Discriminator, "base_channels", 64)
linear_size = getattr(self.config.Discriminator, "linear_size", 1024)
self.discriminator = ESRGANDiscriminator(
in_channels=self.config.Model.in_bands,
base_channels=int(base_channels),
linear_size=int(linear_size),
)
else:
raise ValueError(
f"Unknown discriminator model type: {raw_discriminator_type}"
)
def setup_lightning(self):
"""Configure PyTorch Lightning behavior based on the detected version.
This method ensures compatibility between different versions of
PyTorch Lightning (PL) by setting appropriate optimization modes
and binding the correct training step implementation.
- For PL ≥ 2.0: Enables **manual optimization**, required for GAN training.
- For PL < 2.0: Uses **automatic optimization** and the legacy training step.
The selected training step function (`training_step_PL1` or `training_step_PL2`)
is dynamically attached to the model as `_training_step_implementation`.
Raises:
AssertionError: If `automatic_optimization` is incorrectly set for PL < 2.0.
RuntimeError: If the detected PyTorch Lightning version is unsupported.
Attributes:
automatic_optimization (bool): Indicates whether Lightning manages
optimizer steps automatically.
_training_step_implementation (Callable): Bound training step function
corresponding to the active PL version.
"""
# Check for PL version - Define PL Hooks accordingly
if self.pl_version >= (2, 0, 0):
self.automatic_optimization = False # manual optimization for PL 2.x
# Set up Training Step
from opensr_srgan.model.training_step_PL import training_step_PL2
self._training_step_implementation = MethodType(training_step_PL2, self)
elif self.pl_version < (2, 0, 0):
assert (
self.automatic_optimization is True
), "For PL <2.0, automatic_optimization must be True."
# Set up Training Step
from opensr_srgan.model.training_step_PL import training_step_PL1
self._training_step_implementation = MethodType(training_step_PL1, self)
else:
raise RuntimeError(
f"Unsupported PyTorch Lightning version: {pl.__version__}"
)
def initialize_ema(self):
"""Initialize the Exponential Moving Average (EMA) mechanism for the generator.
This method sets up an EMA shadow copy of the generator parameters to
stabilize training and improve the quality of generated outputs. EMA is
enabled only if specified in the training configuration.
The EMA model tracks the moving average of generator weights with a
configurable decay factor and update schedule.
Configuration fields under `config.Training.EMA`:
- `enabled` (bool): Whether to enable EMA tracking.
- `decay` (float): Exponential decay factor for weight averaging (default: 0.999).
- `device` (str | None): Device to store the EMA weights on.
- `use_num_updates` (bool): Whether to use step-based update counting.
- `update_after_step` (int): Number of steps to wait before starting updates.
Attributes:
ema (ExponentialMovingAverage | None): EMA object tracking generator parameters.
_ema_update_after_step (int): Step count threshold before EMA updates begin.
_ema_applied (bool): Indicates whether EMA weights are currently applied to the generator.
"""
ema_cfg = getattr(self.config.Training, "EMA", None)
self.ema: ExponentialMovingAverage | None = None
self._ema_update_after_step = 0
self._ema_applied = False
if ema_cfg is not None and getattr(ema_cfg, "enabled", False):
ema_decay = float(getattr(ema_cfg, "decay", 0.999))
ema_device = getattr(ema_cfg, "device", None)
use_num_updates = bool(getattr(ema_cfg, "use_num_updates", True))
self.ema = ExponentialMovingAverage(
self.generator,
decay=ema_decay,
use_num_updates=use_num_updates,
)
self._ema_update_after_step = int(getattr(ema_cfg, "update_after_step", 0))
def forward(self, lr_imgs):
"""Forward pass through the generator network.
Takes a batch of low-resolution (LR) input images and produces
their corresponding super-resolved (SR) outputs using the generator model.
Args:
lr_imgs (torch.Tensor): Batch of input low-resolution images
with shape `(B, C, H, W)` where:
- `B`: batch size
- `C`: number of channels
- `H`, `W`: spatial dimensions.
Returns:
torch.Tensor: Super-resolved output images with increased spatial resolution,
typically scaled by the model's configured upsampling factor.
"""
sr_imgs = self.generator(lr_imgs) # pass LR input through generator network
return sr_imgs # return super-resolved output
@torch.no_grad()
def predict_step(self, lr_imgs):
"""Run a single super-resolution inference step.
Performs forward inference using the generator (optionally under EMA weights)
to produce super-resolved (SR) outputs from low-resolution (LR) inputs.
The method normalizes input values using the configured strategy (e.g., raw
Sentinel-2 reflectance via ``normalise_10k``), applies histogram matching, and
denormalizes the outputs back to their original scale.
Args:
lr_imgs (torch.Tensor): Batch of input low-resolution images
with shape `(B, C, H, W)`. Pixel value ranges may vary depending
on preprocessing (e.g., 0–10000 for Sentinel-2 reflectance).
Returns:
torch.Tensor: Super-resolved output images with matched histograms
and restored value range, detached from the computation graph and
placed on CPU memory.
Raises:
AssertionError: If the generator is not in evaluation mode (`.eval()`).
"""
assert (
self.generator.training is False
), "Generator must be in eval mode for prediction." # ensure eval mode
lr_imgs = lr_imgs.to(self.device) # move to device (GPU or CPU)
# --- Normalize inputs according to configuration ---
normalized_lr = self.normalizer.normalize(lr_imgs)
# --- Perform super-resolution (optionally using EMA weights) ---
context = (
self.ema.average_parameters(self.generator)
if self.ema is not None
else nullcontext()
)
with context:
sr_imgs = self.generator(normalized_lr) # forward pass (SR prediction)
# --- Histogram match SR to LR ---
sr_imgs = histogram_match(normalized_lr, sr_imgs) # match distributions
# --- Denormalize output back to original range ---
sr_imgs = self.normalizer.denormalize(sr_imgs)
# --- Move to CPU and return ---
sr_imgs = sr_imgs.cpu().detach() # detach from graph for inference output
return sr_imgs
def training_step(
self, batch, batch_idx, optimizer_idx: Optional[int] = None, *args
):
"""Dispatch the correct training step implementation based on PyTorch Lightning version.
This method acts as a compatibility layer between different PyTorch Lightning
versions that handle multi-optimizer GAN training differently.
- For PL ≥ 2.0: Manual optimization is used, and the optimizer index is not passed.
- For PL < 2.0: Automatic optimization is used, and the optimizer index is passed
to handle generator/discriminator updates separately.
Args:
batch (Any): A batch of training data (input tensors and targets as defined by the DataModule).
batch_idx (int): Index of the current batch within the epoch.
optimizer_idx (int | None, optional): Index of the active optimizer (0 for generator,
1 for discriminator) when using PL < 2.0.
*args: Additional arguments that may be passed by older Lightning versions.
Returns:
Any: The output of the active training step implementation, loss value.
"""
# Depending on PL version, and depending on the manual optimization
if self.pl_version >= (2, 0, 0):
# In PL2.x, optimizer_idx is not passed, manual optimization is performed
return self._training_step_implementation(batch, batch_idx) # no optim_idx
else:
# In Pl1.x, optimizer_idx arrives twice and is passed on
return self._training_step_implementation(
batch, batch_idx, optimizer_idx
) # pass optim_idx
def optimizer_step(
self,
epoch,
batch_idx,
optimizer,
optimizer_idx=None,
optimizer_closure=None,
**kwargs, # absorbs on_tpu/using_lbfgs/etc across PL versions
):
"""Custom optimizer step handling for PL 1.x automatic optimization.
This method ensures correct behavior across different PyTorch Lightning
versions and training modes. It is invoked automatically during training
in PL < 2.0 when `automatic_optimization=True`. For PL ≥ 2.0, where manual
optimization is used, this function is effectively bypassed.
- In **PL ≥ 2.0 (manual optimization)**: The optimizer step is explicitly
called within `training_step_PL2()`, including EMA updates.
- In **PL < 2.0 (automatic optimization)**: This function manages optimizer
stepping, gradient zeroing, and optional EMA updates after generator steps.
Args:
epoch (int): Current training epoch.
batch_idx (int): Index of the current batch.
optimizer (torch.optim.Optimizer): The active optimizer instance.
optimizer_idx (int, optional): Index of the optimizer being stepped
(e.g., 0 for discriminator, 1 for generator).
optimizer_closure (Callable, optional): Closure for re-evaluating the
model and loss before optimizer step (used with some optimizers).
**kwargs: Additional arguments passed by PL depending on backend
(e.g., TPU flags, LBFGS options).
Notes:
- EMA updates are performed only after generator steps (optimizer_idx == 1).
- The update starts after `self._ema_update_after_step` global steps.
"""
# If we're in manual optimization (PL >=2 path), do nothing special.
if not self.automatic_optimization:
# In manual mode we call opt.step()/zero_grad() in training_step_PL2.
# In manual mode, we update EMA weights manually in training step too.
return super().optimizer_step(
epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, **kwargs
)
# ---- PL 1.x auto-optimization path ----
if optimizer_closure is not None:
optimizer.step(closure=optimizer_closure)
else:
optimizer.step()
optimizer.zero_grad()
# EMA after the generator step (assumes G is optimizer_idx == 1)
if (
self.ema is not None
and optimizer_idx == 1
and self.global_step >= self._ema_update_after_step
):
self.ema.update(self.generator)
@torch.no_grad()
def validation_step(self, batch, batch_idx):
"""Run the validation loop for a single batch.
This method performs super-resolution inference on validation data,
computes image quality metrics (e.g., PSNR, SSIM), logs them, and
optionally visualizes SR–HR–LR triplets. It also evaluates the
discriminator’s adversarial response if applicable.
Workflow:
1. Forward pass (LR → SR) through the generator.
2. Compute content-based validation metrics.
3. Optionally log visual examples to the logger (e.g., Weights & Biases).
4. Compute and log discriminator metrics, unless in pretraining mode.
Args:
batch (Tuple[torch.Tensor, torch.Tensor]): A tuple `(lr_imgs, hr_imgs)` of
low-resolution and high-resolution tensors with shape `(B, C, H, W)`.
batch_idx (int): Index of the current validation batch.
Returns:
None: Metrics and images are logged via Lightning’s logger interface.
Raises:
AssertionError: If an unexpected number of bands or invalid visualization
configuration is encountered.
Notes:
- Validation is executed without gradient tracking.
- Only the first `config.Logging.num_val_images` batches are visualized.
- If EMA is enabled, the generator predictions reflect the current EMA state.
"""
# ======================================================================
# SECTION: Forward pass — Generate SR prediction from LR input
# Purpose: Run model inference on validation batch without gradient tracking.
# ======================================================================
""" 1. Extract and Predict """
lr_imgs, hr_imgs = batch # unpack LR and HR tensors
sr_imgs = self.forward(lr_imgs) # run generator to produce SR prediction
# ======================================================================
# SECTION: Compute and log validation metrics
# Purpose: measure content-based metrics (PSNR/SSIM/etc.) on SR vs HR.
# ======================================================================
""" 2. Log Generator Metrics """
metrics_hr_img = torch.clone(
hr_imgs
) # clone to avoid in-place ops on autograd graph
metrics_sr_img = torch.clone(sr_imgs) # same for SR
# metrics = calculate_metrics(metrics_sr_img, metrics_hr_img, phase="val_metrics")
metrics = self.content_loss_criterion.return_metrics(
metrics_sr_img, metrics_hr_img, prefix="val_metrics/"
) # compute metrics using loss criterion helper
del metrics_hr_img, metrics_sr_img # free cloned tensors from GPU memory
for key, value in metrics.items(): # iterate over metrics dict
self.log(
f"{key}", value, sync_dist=True
) # log each metric to logger (e.g., W&B, TensorBoard)
# ======================================================================
# SECTION: Optional visualization — Log example SR/HR/LR images
# Purpose: visually track qualitative progress of the model.
# ======================================================================
# only perform image logging for first N batches to avoid logging all 200 images
if batch_idx < self.config.Logging.num_val_images:
base_lr = lr_imgs # use original LR for visualization
# --- Select visualization bands (if multispectral) ---
if self.config.Model.in_bands < 3:
# show only first band
lr_vis = base_lr[:, :1, :, :] # e.g., single-band input
hr_vis = hr_imgs[:, :1, :, :] # subset HR
sr_vis = sr_imgs[:, :1, :, :] # subset SR
elif self.config.Model.in_bands == 3:
# we can show normally
pass
elif self.config.Model.in_bands == 4:
# assume its RGB-NIR, show RGB
lr_vis = base_lr[:, :3, :, :] # e.g., Sentinel-2 RGB
hr_vis = hr_imgs[:, :3, :, :] # subset HR
sr_vis = sr_imgs[:, :3, :, :] # subset SR
elif self.config.Model.in_bands > 4: # e.g., Sentinel-2 with >3 channels
# random selection of bands
idx = np.random.choice(
sr_imgs.shape[1], 3, replace=False
) # randomly select 3 bands
lr_vis = base_lr[:, idx, :, :] # subset LR
hr_vis = hr_imgs[:, idx, :, :] # subset HR
sr_vis = sr_imgs[:, idx, :, :] # subset SR
else:
# should not happen
pass
# --- Clone tensors for plotting to avoid affecting main tensors ---
plot_lr_img = lr_vis.clone()
plot_hr_img = hr_vis.clone()
plot_sr_img = sr_vis.clone()
# --- Generate matplotlib visualization (LR, SR, HR side-by-side) ---
val_img = plot_tensors(plot_lr_img, plot_sr_img, plot_hr_img, title="Val")
# --- Cleanup ---
del plot_lr_img, plot_hr_img, plot_sr_img # free memory after plotting
# --- Log image to WandB (or compatible logger), if wanted ---
if self.config.Logging.wandb.enabled:
self.logger.experiment.log(
{"Val SR": wandb.Image(val_img)}
) # upload to dashboard
""" 3. Log Discriminator metrics """
# If in pretraining, discard D metrics
if self._pretrain_check(): # check if we'e in pretrain phase
self.log(
"discriminator/adversarial_loss",
torch.zeros(1, device=lr_imgs.device),
prog_bar=False,
sync_dist=True,
)
else:
# run discriminator and get loss between pred labels and true labels
hr_discriminated = self.discriminator(hr_imgs)
sr_discriminated = self.discriminator(sr_imgs)
# Run loss depending on type
if self.adv_loss_type == "wasserstein":
adversarial_loss = sr_discriminated.mean() - hr_discriminated.mean()
else:
adversarial_loss = self.adversarial_loss_criterion(
sr_discriminated, torch.zeros_like(sr_discriminated)
) + self.adversarial_loss_criterion(
hr_discriminated, torch.ones_like(hr_discriminated)
)
# Log image
self.log(
"validation/DISC_adversarial_loss", adversarial_loss, sync_dist=True
)
def on_validation_epoch_start(self):
"""Hook executed at the start of each validation epoch.
Applies the Exponential Moving Average (EMA) weights to the generator
before running validation to ensure evaluation uses the smoothed model
parameters.
Notes:
- Calls the parent hook via `super().on_validation_epoch_start()`.
- Restores original weights at the end of validation.
"""
super().on_validation_epoch_start()
self._apply_generator_ema_weights()
def on_validation_epoch_end(self):
"""Hook executed at the end of each validation epoch.
Restores the generator’s original (non-EMA) weights after validation.
Ensures subsequent training or testing uses up-to-date parameters.
Notes:
- Calls the parent hook via `super().on_validation_epoch_end()`.
"""
self._restore_generator_weights()
super().on_validation_epoch_end()
def on_test_epoch_start(self):
"""Hook executed at the start of each testing epoch.
Applies the Exponential Moving Average (EMA) weights to the generator
before running tests to ensure consistent evaluation with the
smoothed model parameters.
Notes:
- Calls the parent hook via `super().on_test_epoch_start()`.
- Restores original weights at the end of testing.
"""
super().on_test_epoch_start()
self._apply_generator_ema_weights()
def on_test_epoch_end(self):
"""Hook executed at the end of each testing epoch.
Restores the generator’s original (non-EMA) weights after testing.
Ensures the model is reset to its latest training state.
Notes:
- Calls the parent hook via `super().on_test_epoch_end()`.
"""
self._restore_generator_weights()
super().on_test_epoch_end()
def configure_optimizers(self):
"""
Robust optimizers & schedulers for GANs (PL1 & PL2 compatible).
- TTUR by default (D lr <= G lr)
- Adam with GAN-friendly betas/eps
- Exclude norm/affine/bias params from weight decay
- Separate Plateau schedulers for G and D (with cooldown/min_lr)
- Optional step-wise warmup for G (linear/cosine), no LR jump at handoff
- Returned order is [D, G] to match your training_step expectations
"""
import math
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
cfg_opt = self.config.Optimizers
cfg_sch = self.config.Schedulers
# ---------- helpers ----------
def _split_wd_params(model):
"""Return two lists: params_with_wd, params_without_wd."""
wd, no_wd = [], []
norm_like = (
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d,
nn.GroupNorm,
nn.LayerNorm,
nn.InstanceNorm1d,
nn.InstanceNorm2d,
nn.InstanceNorm3d,
)
for m in model.modules():
for n, p in m.named_parameters(recurse=False):
if not p.requires_grad:
continue
if n.endswith("bias") or isinstance(m, norm_like):
no_wd.append(p)
else:
wd.append(p)
# catch any top-level params not in modules (rare)
seen = set(map(id, wd + no_wd))
for n, p in model.named_parameters():
if p.requires_grad and id(p) not in seen:
(no_wd if n.endswith("bias") else wd).append(p)
return wd, no_wd
def _adam(params, lr):
# GAN-friendly defaults; tune in config if needed
betas = getattr(cfg_opt, "betas", (0.0, 0.99))
eps = getattr(cfg_opt, "eps", 1e-7)
return Adam(params, lr=lr, betas=betas, eps=eps)
# ---------- LRs (TTUR) ----------
lr_g = float(getattr(cfg_opt, "optim_g_lr", 1e-4))
lr_d = float(
getattr(cfg_opt, "optim_d_lr", max(lr_g * 0.5, 1e-6))
) # default: D slower than G
# weight decay (only on non-norm, non-bias)
wd_g = float(getattr(cfg_opt, "weight_decay_g", 0.0))
wd_d = float(getattr(cfg_opt, "weight_decay_d", 0.0))
# ---------- build optimizers with clean param groups ----------
g_wd, g_no = _split_wd_params(self.generator)
d_wd, d_no = _split_wd_params(self.discriminator)
optimizer_g = _adam(
[
{"params": g_wd, "weight_decay": wd_g},
{"params": g_no, "weight_decay": 0.0},
],
lr=lr_g,
)
optimizer_d = _adam(
[
{"params": d_wd, "weight_decay": wd_d},
{"params": d_no, "weight_decay": 0.0},
],
lr=lr_d,
)
# ---------- schedulers ----------
# Use distinct monitors for clarity (recommend: log these in validation)
monitor_g = getattr(
cfg_sch, "metric_g", getattr(cfg_sch, "metric", "val_g_loss")
)
monitor_d = getattr(
cfg_sch, "metric_d", getattr(cfg_sch, "metric", "val_d_loss")
)
sched_kwargs = dict(
mode="min",
factor=float(getattr(cfg_sch, "factor_g", 0.5)),
patience=int(getattr(cfg_sch, "patience_g", 5)),
threshold=float(getattr(cfg_sch, "threshold", 1e-3)),
threshold_mode="rel",
cooldown=int(getattr(cfg_sch, "cooldown", 0)),
min_lr=float(getattr(cfg_sch, "min_lr", 1e-7)),
verbose=bool(getattr(cfg_sch, "verbose", False)),
)
# D can have its own factor/patience; fall back to G’s if not set
sched_kwargs_d = dict(sched_kwargs)
sched_kwargs_d["factor"] = float(
getattr(cfg_sch, "factor_d", sched_kwargs["factor"])
)
sched_kwargs_d["patience"] = int(
getattr(cfg_sch, "patience_d", sched_kwargs["patience"])
)
scheduler_g = ReduceLROnPlateau(optimizer_g, **sched_kwargs)
scheduler_d = ReduceLROnPlateau(optimizer_d, **sched_kwargs_d)
sch_configs = [
{
"scheduler": scheduler_d,
"monitor": monitor_d,
"reduce_on_plateau": True,
"interval": "epoch",
"frequency": 1,
"name": "plateau_d",
},
{
"scheduler": scheduler_g,
"monitor": monitor_g,
"reduce_on_plateau": True,
"interval": "epoch",
"frequency": 1,
"name": "plateau_g",
},
]
# ---------- optional warmup for G (step-wise, multiplicative) ----------
warmup_steps = int(getattr(cfg_sch, "g_warmup_steps", 0))
warmup_type = str(getattr(cfg_sch, "g_warmup_type", "none")).lower()
if warmup_steps > 0 and warmup_type in {"linear", "cosine"}:
def _g_warmup_lambda(step: int) -> float:
if step >= warmup_steps:
return 1.0
t = (step + 1) / max(1, warmup_steps)
return (
t
if warmup_type == "linear"
else 0.5 * (1.0 - math.cos(math.pi * t))
)
warmup_g = torch.optim.lr_scheduler.LambdaLR(
optimizer_g, lr_lambda=_g_warmup_lambda
)
# Runs every step; multiplies base LR so there is no jump at the end
sch_configs.append(
{
"scheduler": warmup_g,
"interval": "step",
"frequency": 1,
"name": "warmup_g",
}
)
# Return order [D, G] to match your training_step
return [optimizer_d, optimizer_g], sch_configs
def on_train_batch_start(
self, batch, batch_idx
): # called before each training batch
"""Hook executed before each training batch.
Freezes or unfreezes discriminator parameters depending on the
current training phase. During pretraining, the discriminator is
frozen to allow the generator to learn reconstruction without
adversarial pressure.
Args:
batch (Any): The current batch of training data.
batch_idx (int): Index of the current batch in the epoch.
"""
pre = self._pretrain_check() # check if currently in pretraining phase
for p in self.discriminator.parameters(): # loop over all discriminator params
p.requires_grad = not pre # freeze D during pretrain, unfreeze otherwise
def on_train_batch_end(self, outputs, batch, batch_idx):
"""Hook executed after each training batch.
Logs the current learning rates for all active optimizers to
the logger for monitoring and debugging purposes.
Args:
outputs (Any): Outputs returned by `training_step`.
batch (Any): The batch of data processed.
batch_idx (int): Index of the current batch in the epoch.
"""
self._log_lrs() # log LR's on each batch end
def on_fit_start(self): # called once at the start of training
"""Hook executed once at the beginning of model fitting.
Performs setup tasks that must occur before training starts:
- Moves EMA weights to the correct device.
- Prints a model summary (only from global rank 0 in DDP setups).
Notes:
- Calls `super().on_fit_start()` to preserve Lightning’s default behavior.
- The model summary is only printed by the global zero process
to avoid duplicated output in distributed training.
"""
super().on_fit_start()
if self.ema is not None and self.ema.device is None: # move ema weights
self.ema.to(self.device)
from opensr_srgan.utils.gpu_rank import _is_global_zero
if _is_global_zero():
print_model_summary(self) # print model summary to console
def _log_generator_content_loss(self, content_loss: torch.Tensor) -> None:
"""Helper to consistently log the generator content loss across training phases."""
self.log(
"generator/content_loss",
content_loss,
prog_bar=True,
sync_dist=True,
)
def _log_ema_setup_metrics(self) -> None:
"""Log static Exponential Moving Average (EMA) configuration parameters.
Records whether EMA is enabled, along with its core hyperparameters
(decay rate, activation delay, update mode). This information is
logged once when training begins to help track model configuration.
Notes:
- If EMA is disabled, logs `"EMA/enabled" = 0.0`.
- Called after the trainer is initialized to ensure logging context.
"""
if getattr(self, "trainer", None) is None:
return
if self.ema is None:
self.log(
"EMA/enabled",
0.0,
on_step=False,
on_epoch=True,
prog_bar=False,
sync_dist=True,
)
return
self.log(
"EMA/enabled",
1.0,
on_step=False,
on_epoch=True,
prog_bar=False,
sync_dist=True,
)
self.log(
"EMA/decay",
float(self.ema.decay),
on_step=False,
on_epoch=True,
prog_bar=False,
sync_dist=True,
)
self.log(
"EMA/update_after_step",
float(self._ema_update_after_step),
on_step=False,
on_epoch=True,
prog_bar=False,
sync_dist=True,
)
self.log(
"EMA/use_num_updates",
1.0 if self.ema.num_updates is not None else 0.0,
on_step=False,
on_epoch=True,
prog_bar=False,
sync_dist=True,
)
def _log_ema_step_metrics(self, *, updated: bool) -> None:
"""Log dynamic EMA statistics during training.
Tracks per-step EMA state, including whether an update occurred,
how many steps remain until activation, and the most recent decay value.
These metrics provide insight into EMA behavior over time.
Args:
updated (bool): Whether EMA weights were updated in the current step.
Notes:
- If EMA is disabled, this function exits without logging.
- Logs include:
- `"EMA/is_active"`: Indicates if EMA is currently updating.
- `"EMA/steps_until_activation"`: Steps remaining before EMA starts updating.
- `"EMA/last_decay"`: Latest applied decay value.
- `"EMA/num_updates"`: Total number of EMA updates performed.
"""
if self.ema is None:
return
self.log(
"EMA/is_active",
1.0 if updated else 0.0,
on_step=True,
on_epoch=False,
prog_bar=False,
sync_dist=True,
)
steps_until_active = max(0, self._ema_update_after_step - self.global_step)
self.log(
"EMA/steps_until_activation",
float(steps_until_active),
on_step=True,
on_epoch=False,
prog_bar=False,
sync_dist=True,
)
if not updated:
return
self.log(
"EMA/last_decay",
float(self.ema.last_decay),
on_step=True,
on_epoch=False,
prog_bar=False,
sync_dist=True,
)
if self.ema.num_updates is not None:
self.log(
"EMA/num_updates",
float(self.ema.num_updates),
on_step=True,
on_epoch=False,
prog_bar=False,
sync_dist=True,
)
def _pretrain_check(self) -> bool:
"""Check whether the model is still in the generator pretraining phase.
Returns:
bool: True if the generator-only pretraining phase is active
(i.e., `global_step` < `g_pretrain_steps`), otherwise False.
Notes:
- During pretraining, the discriminator is frozen and only the
generator is updated.
"""
if (
self.pretrain_g_only and self.global_step < self.g_pretrain_steps
): # true if pretraining active
return True
else:
return False # false once pretrain steps are exceeded
def _compute_adv_loss_weight(self) -> float:
"""Compute the current adversarial loss weighting factor.
Determines how strongly the adversarial loss contributes to the total
generator loss, following a configurable ramp-up schedule. This helps
stabilize early training by gradually increasing the influence of the
discriminator.
Returns:
float: The current adversarial loss weight for the active step.
Configuration Fields:
config.Training.Losses:
- `adv_loss_beta` (float): Maximum scaling factor for adversarial loss.
- `adv_loss_schedule` (str): Type of ramp schedule (`"linear"` or `"cosine"`).
Notes:
- Returns `0.0` during pretraining steps (`global_step < g_pretrain_steps`).
- After the ramp-up phase, the weight saturates at `beta`.
- Cosine schedule provides a smoother ramp-up than linear.
Raises:
ValueError: If an unknown schedule type is provided in the configuration.
"""
"""Compute the current adversarial loss weight using the configured ramp schedule."""
beta = float(self.config.Training.Losses.adv_loss_beta)
schedule = getattr(
self.config.Training.Losses,
"adv_loss_schedule",
"cosine",
).lower()
# Handle pretraining and edge cases early
if self.global_step < self.g_pretrain_steps:
return 0.0
if (
self.adv_loss_ramp_steps <= 0
or self.global_step >= self.g_pretrain_steps + self.adv_loss_ramp_steps
):
return beta
# Normalize progress to [0, 1]
progress = (self.global_step - self.g_pretrain_steps) / self.adv_loss_ramp_steps
progress = max(0.0, min(progress, 1.0))
if schedule == "linear":
return progress * beta
if schedule == "cosine":
# Cosine ramp to match the generator warmup behaviour
return 0.5 * (1.0 - math.cos(math.pi * progress)) * beta
raise ValueError(
f"Unknown adversarial loss schedule '{schedule}'. Expected 'linear' or 'cosine'."
)
def _log_adv_loss_weight(self, adv_weight: float) -> None:
"""Log the current adversarial loss weight.
Args:
adv_weight (float): Scalar multiplier applied to the adversarial loss term.
"""
self.log("training/adv_loss_weight", adv_weight, sync_dist=True)
def _adv_loss_weight(self) -> float:
"""Compute and log the current adversarial loss weight.
Calls the internal scheduler/heuristic to obtain the adversarial loss weight,
logs it, and returns the value.
Returns:
float: The computed adversarial loss weight for the current step/epoch.
"""
adv_weight = self._compute_adv_loss_weight()
self._log_adv_loss_weight(adv_weight)
return adv_weight
def _apply_generator_ema_weights(self) -> None:
"""Swap the generator's parameters to their EMA-smoothed counterparts.
Applies EMA weights to the generator for evaluation (e.g., val/test). A no-op if
EMA is disabled or already applied. Moves EMA to the correct device if needed.
Notes:
- Sets an internal flag to avoid double application during the same phase.
"""
if self.ema is None or self._ema_applied:
return
if self.ema.device is None:
self.ema.to(self.device)
self.ema.apply_to(self.generator)
self._ema_applied = True
def _restore_generator_weights(self) -> None:
"""Restore the generator's original (non-EMA) parameters.
Reverts the parameter swap performed by `_apply_generator_ema_weights()`.
A no-op if EMA is disabled or not currently applied.
Notes:
- Clears the internal "applied" flag to enable future swaps.
"""
if self.ema is None or not self._ema_applied:
return
self.ema.restore(self.generator)
self._ema_applied = False
def on_save_checkpoint(self, checkpoint: dict) -> None:
"""Augment the checkpoint with EMA state, if available.
Adds the EMA buffer/metadata to the checkpoint so that EMA can be restored upon load.
Args:
checkpoint (dict): Mutable checkpoint dictionary provided by Lightning.
"""
super().on_save_checkpoint(checkpoint)
if self.ema is not None:
checkpoint["ema_state"] = self.ema.state_dict()
def on_load_checkpoint(self, checkpoint: dict) -> None:
"""Restore EMA state from a checkpoint, if present.
Args:
checkpoint (dict): Checkpoint dictionary provided by Lightning containing
model state and optional `"ema_state"` entry.
"""
super().on_load_checkpoint(checkpoint)
if self.ema is not None and "ema_state" in checkpoint:
self.ema.load_state_dict(checkpoint["ema_state"])
def _log_lrs(self) -> None:
"""Log learning rates for discriminator and generator optimizers.
Notes:
- Assumes optimizers are ordered as `[optimizer_d, optimizer_g]` in the trainer.
- Logs both on-step and on-epoch for easier tracking.
"""
# order matches your return: [optimizer_d, optimizer_g]
opt_d = self.trainer.optimizers[0]
opt_g = self.trainer.optimizers[1]
self.log(
"lr_discriminator",
opt_d.param_groups[0]["lr"],
on_step=True,
on_epoch=True,
prog_bar=False,
logger=True,
sync_dist=True,
)
self.log(
"lr_generator",
opt_g.param_groups[0]["lr"],
on_step=True,
on_epoch=True,
prog_bar=False,
logger=True,
sync_dist=True,
)
def load_from_checkpoint(self, ckpt_path) -> None:
"""Load model weights from a PyTorch Lightning checkpoint file.
Loads the `state_dict` from the given checkpoint and maps it to the current device.
Args:
ckpt_path (str | pathlib.Path): Path to the `.ckpt` file saved by Lightning.
Raises:
FileNotFoundError: If the checkpoint path does not exist.
KeyError: If the checkpoint does not contain a `'state_dict'` entry.
"""
# load ckpt
ckpt = torch.load(ckpt_path, map_location=self.device)
self.load_state_dict(ckpt["state_dict"])
print(f"Loaded checkpoint from {ckpt_path}")
get_models(mode)
¶
Initialize and attach the Generator and (optionally) Discriminator models.
This method builds the generator and discriminator architectures based on
the configuration provided in self.config. It supports multiple generator
backbones (e.g., SRResNet, RCAB, RRDB, LKA) and discriminator types
(standard, PatchGAN). The discriminator is only initialized when the mode
is set to "train".
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mode
|
str
|
Operational mode of the model. Must be one of:
- |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If an unknown generator or discriminator type is specified in the configuration. |
Attributes:
| Name | Type | Description |
|---|---|---|
generator |
Module
|
The initialized generator network instance. |
discriminator |
Module
|
The initialized discriminator
network instance (only present if |
Source code in opensr_srgan/model/SRGAN.py
def get_models(self, mode):
"""Initialize and attach the Generator and (optionally) Discriminator models.
This method builds the generator and discriminator architectures based on
the configuration provided in `self.config`. It supports multiple generator
backbones (e.g., SRResNet, RCAB, RRDB, LKA) and discriminator types
(standard, PatchGAN). The discriminator is only initialized when the mode
is set to `"train"`.
Args:
mode (str): Operational mode of the model. Must be one of:
- `"train"`: Initializes both generator and discriminator.
- Any other value: Initializes only the generator.
Raises:
ValueError: If an unknown generator or discriminator type is specified
in the configuration.
Attributes:
generator (nn.Module): The initialized generator network instance.
discriminator (nn.Module, optional): The initialized discriminator
network instance (only present if `mode == "train"`).
"""
# ======================================================================
# SECTION: Initialize Generator
# Purpose: Build generator network depending on selected architecture.
# ======================================================================
self.generator = build_generator(self.config)
if mode == "train": # only get discriminator in training mode
# ======================================================================
# SECTION: Initialize Discriminator
# Purpose: Build discriminator network for adversarial training.
# ======================================================================
raw_discriminator_type = getattr(
self.config.Discriminator, "model_type", "standard"
)
discriminator_type = str(raw_discriminator_type).strip().lower()
n_blocks = getattr(self.config.Discriminator, "n_blocks", None)
if discriminator_type == "standard":
from opensr_srgan.model.discriminators.srgan_discriminator import (
Discriminator,
)
discriminator_kwargs = {
"in_channels": self.config.Model.in_bands,
}
if n_blocks is not None:
discriminator_kwargs["n_blocks"] = n_blocks
# pass spectral norm option
use_spectral_norm = getattr(
self.config.Discriminator, "use_spectral_norm", True
)
discriminator_kwargs["use_spectral_norm"] = bool(use_spectral_norm)
self.discriminator = Discriminator(**discriminator_kwargs)
elif discriminator_type == "patchgan":
from opensr_srgan.model.discriminators.patchgan import (
PatchGANDiscriminator,
)
patchgan_layers = n_blocks if n_blocks is not None else 3
self.discriminator = PatchGANDiscriminator(
input_nc=self.config.Model.in_bands,
n_layers=patchgan_layers,
)
elif discriminator_type == "esrgan":
from opensr_srgan.model.discriminators.esrgan import (
ESRGANDiscriminator,
)
ignored_options = []
if n_blocks is not None:
ignored_options.append("n_blocks")
if ignored_options:
ignored_joined = ", ".join(sorted(ignored_options))
print(
f"[Discriminator:esrgan] Ignoring unsupported configuration options: {ignored_joined}."
)
base_channels = getattr(self.config.Discriminator, "base_channels", 64)
linear_size = getattr(self.config.Discriminator, "linear_size", 1024)
self.discriminator = ESRGANDiscriminator(
in_channels=self.config.Model.in_bands,
base_channels=int(base_channels),
linear_size=int(linear_size),
)
else:
raise ValueError(
f"Unknown discriminator model type: {raw_discriminator_type}"
)
setup_lightning()
¶
Configure PyTorch Lightning behavior based on the detected version.
This method ensures compatibility between different versions of PyTorch Lightning (PL) by setting appropriate optimization modes and binding the correct training step implementation.
- For PL ≥ 2.0: Enables manual optimization, required for GAN training.
- For PL < 2.0: Uses automatic optimization and the legacy training step.
The selected training step function (training_step_PL1 or training_step_PL2)
is dynamically attached to the model as _training_step_implementation.
Raises:
| Type | Description |
|---|---|
AssertionError
|
If |
RuntimeError
|
If the detected PyTorch Lightning version is unsupported. |
Attributes:
| Name | Type | Description |
|---|---|---|
automatic_optimization |
bool
|
Indicates whether Lightning manages optimizer steps automatically. |
_training_step_implementation |
Callable
|
Bound training step function corresponding to the active PL version. |
Source code in opensr_srgan/model/SRGAN.py
def setup_lightning(self):
"""Configure PyTorch Lightning behavior based on the detected version.
This method ensures compatibility between different versions of
PyTorch Lightning (PL) by setting appropriate optimization modes
and binding the correct training step implementation.
- For PL ≥ 2.0: Enables **manual optimization**, required for GAN training.
- For PL < 2.0: Uses **automatic optimization** and the legacy training step.
The selected training step function (`training_step_PL1` or `training_step_PL2`)
is dynamically attached to the model as `_training_step_implementation`.
Raises:
AssertionError: If `automatic_optimization` is incorrectly set for PL < 2.0.
RuntimeError: If the detected PyTorch Lightning version is unsupported.
Attributes:
automatic_optimization (bool): Indicates whether Lightning manages
optimizer steps automatically.
_training_step_implementation (Callable): Bound training step function
corresponding to the active PL version.
"""
# Check for PL version - Define PL Hooks accordingly
if self.pl_version >= (2, 0, 0):
self.automatic_optimization = False # manual optimization for PL 2.x
# Set up Training Step
from opensr_srgan.model.training_step_PL import training_step_PL2
self._training_step_implementation = MethodType(training_step_PL2, self)
elif self.pl_version < (2, 0, 0):
assert (
self.automatic_optimization is True
), "For PL <2.0, automatic_optimization must be True."
# Set up Training Step
from opensr_srgan.model.training_step_PL import training_step_PL1
self._training_step_implementation = MethodType(training_step_PL1, self)
else:
raise RuntimeError(
f"Unsupported PyTorch Lightning version: {pl.__version__}"
)
initialize_ema()
¶
Initialize the Exponential Moving Average (EMA) mechanism for the generator.
This method sets up an EMA shadow copy of the generator parameters to stabilize training and improve the quality of generated outputs. EMA is enabled only if specified in the training configuration.
The EMA model tracks the moving average of generator weights with a configurable decay factor and update schedule.
Configuration fields under config.Training.EMA:
- enabled (bool): Whether to enable EMA tracking.
- decay (float): Exponential decay factor for weight averaging (default: 0.999).
- device (str | None): Device to store the EMA weights on.
- use_num_updates (bool): Whether to use step-based update counting.
- update_after_step (int): Number of steps to wait before starting updates.
Attributes:
| Name | Type | Description |
|---|---|---|
ema |
ExponentialMovingAverage | None
|
EMA object tracking generator parameters. |
_ema_update_after_step |
int
|
Step count threshold before EMA updates begin. |
_ema_applied |
bool
|
Indicates whether EMA weights are currently applied to the generator. |
Source code in opensr_srgan/model/SRGAN.py
def initialize_ema(self):
"""Initialize the Exponential Moving Average (EMA) mechanism for the generator.
This method sets up an EMA shadow copy of the generator parameters to
stabilize training and improve the quality of generated outputs. EMA is
enabled only if specified in the training configuration.
The EMA model tracks the moving average of generator weights with a
configurable decay factor and update schedule.
Configuration fields under `config.Training.EMA`:
- `enabled` (bool): Whether to enable EMA tracking.
- `decay` (float): Exponential decay factor for weight averaging (default: 0.999).
- `device` (str | None): Device to store the EMA weights on.
- `use_num_updates` (bool): Whether to use step-based update counting.
- `update_after_step` (int): Number of steps to wait before starting updates.
Attributes:
ema (ExponentialMovingAverage | None): EMA object tracking generator parameters.
_ema_update_after_step (int): Step count threshold before EMA updates begin.
_ema_applied (bool): Indicates whether EMA weights are currently applied to the generator.
"""
ema_cfg = getattr(self.config.Training, "EMA", None)
self.ema: ExponentialMovingAverage | None = None
self._ema_update_after_step = 0
self._ema_applied = False
if ema_cfg is not None and getattr(ema_cfg, "enabled", False):
ema_decay = float(getattr(ema_cfg, "decay", 0.999))
ema_device = getattr(ema_cfg, "device", None)
use_num_updates = bool(getattr(ema_cfg, "use_num_updates", True))
self.ema = ExponentialMovingAverage(
self.generator,
decay=ema_decay,
use_num_updates=use_num_updates,
)
self._ema_update_after_step = int(getattr(ema_cfg, "update_after_step", 0))
forward(lr_imgs)
¶
Forward pass through the generator network.
Takes a batch of low-resolution (LR) input images and produces their corresponding super-resolved (SR) outputs using the generator model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
lr_imgs
|
Tensor
|
Batch of input low-resolution images
with shape |
required |
Returns:
| Type | Description |
|---|---|
|
torch.Tensor: Super-resolved output images with increased spatial resolution, |
|
|
typically scaled by the model's configured upsampling factor. |
Source code in opensr_srgan/model/SRGAN.py
def forward(self, lr_imgs):
"""Forward pass through the generator network.
Takes a batch of low-resolution (LR) input images and produces
their corresponding super-resolved (SR) outputs using the generator model.
Args:
lr_imgs (torch.Tensor): Batch of input low-resolution images
with shape `(B, C, H, W)` where:
- `B`: batch size
- `C`: number of channels
- `H`, `W`: spatial dimensions.
Returns:
torch.Tensor: Super-resolved output images with increased spatial resolution,
typically scaled by the model's configured upsampling factor.
"""
sr_imgs = self.generator(lr_imgs) # pass LR input through generator network
return sr_imgs # return super-resolved output
predict_step(lr_imgs)
¶
Run a single super-resolution inference step.
Performs forward inference using the generator (optionally under EMA weights)
to produce super-resolved (SR) outputs from low-resolution (LR) inputs.
The method normalizes input values using the configured strategy (e.g., raw
Sentinel-2 reflectance via normalise_10k), applies histogram matching, and
denormalizes the outputs back to their original scale.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
lr_imgs
|
Tensor
|
Batch of input low-resolution images
with shape |
required |
Returns:
| Type | Description |
|---|---|
|
torch.Tensor: Super-resolved output images with matched histograms |
|
|
and restored value range, detached from the computation graph and |
|
|
placed on CPU memory. |
Raises:
| Type | Description |
|---|---|
AssertionError
|
If the generator is not in evaluation mode ( |
Source code in opensr_srgan/model/SRGAN.py
@torch.no_grad()
def predict_step(self, lr_imgs):
"""Run a single super-resolution inference step.
Performs forward inference using the generator (optionally under EMA weights)
to produce super-resolved (SR) outputs from low-resolution (LR) inputs.
The method normalizes input values using the configured strategy (e.g., raw
Sentinel-2 reflectance via ``normalise_10k``), applies histogram matching, and
denormalizes the outputs back to their original scale.
Args:
lr_imgs (torch.Tensor): Batch of input low-resolution images
with shape `(B, C, H, W)`. Pixel value ranges may vary depending
on preprocessing (e.g., 0–10000 for Sentinel-2 reflectance).
Returns:
torch.Tensor: Super-resolved output images with matched histograms
and restored value range, detached from the computation graph and
placed on CPU memory.
Raises:
AssertionError: If the generator is not in evaluation mode (`.eval()`).
"""
assert (
self.generator.training is False
), "Generator must be in eval mode for prediction." # ensure eval mode
lr_imgs = lr_imgs.to(self.device) # move to device (GPU or CPU)
# --- Normalize inputs according to configuration ---
normalized_lr = self.normalizer.normalize(lr_imgs)
# --- Perform super-resolution (optionally using EMA weights) ---
context = (
self.ema.average_parameters(self.generator)
if self.ema is not None
else nullcontext()
)
with context:
sr_imgs = self.generator(normalized_lr) # forward pass (SR prediction)
# --- Histogram match SR to LR ---
sr_imgs = histogram_match(normalized_lr, sr_imgs) # match distributions
# --- Denormalize output back to original range ---
sr_imgs = self.normalizer.denormalize(sr_imgs)
# --- Move to CPU and return ---
sr_imgs = sr_imgs.cpu().detach() # detach from graph for inference output
return sr_imgs
training_step(batch, batch_idx, optimizer_idx=None, *args)
¶
Dispatch the correct training step implementation based on PyTorch Lightning version.
This method acts as a compatibility layer between different PyTorch Lightning versions that handle multi-optimizer GAN training differently.
- For PL ≥ 2.0: Manual optimization is used, and the optimizer index is not passed.
- For PL < 2.0: Automatic optimization is used, and the optimizer index is passed to handle generator/discriminator updates separately.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Any
|
A batch of training data (input tensors and targets as defined by the DataModule). |
required |
batch_idx
|
int
|
Index of the current batch within the epoch. |
required |
optimizer_idx
|
int | None
|
Index of the active optimizer (0 for generator, 1 for discriminator) when using PL < 2.0. |
None
|
*args
|
Additional arguments that may be passed by older Lightning versions. |
()
|
Returns:
| Name | Type | Description |
|---|---|---|
Any |
The output of the active training step implementation, loss value. |
Source code in opensr_srgan/model/SRGAN.py
def training_step(
self, batch, batch_idx, optimizer_idx: Optional[int] = None, *args
):
"""Dispatch the correct training step implementation based on PyTorch Lightning version.
This method acts as a compatibility layer between different PyTorch Lightning
versions that handle multi-optimizer GAN training differently.
- For PL ≥ 2.0: Manual optimization is used, and the optimizer index is not passed.
- For PL < 2.0: Automatic optimization is used, and the optimizer index is passed
to handle generator/discriminator updates separately.
Args:
batch (Any): A batch of training data (input tensors and targets as defined by the DataModule).
batch_idx (int): Index of the current batch within the epoch.
optimizer_idx (int | None, optional): Index of the active optimizer (0 for generator,
1 for discriminator) when using PL < 2.0.
*args: Additional arguments that may be passed by older Lightning versions.
Returns:
Any: The output of the active training step implementation, loss value.
"""
# Depending on PL version, and depending on the manual optimization
if self.pl_version >= (2, 0, 0):
# In PL2.x, optimizer_idx is not passed, manual optimization is performed
return self._training_step_implementation(batch, batch_idx) # no optim_idx
else:
# In Pl1.x, optimizer_idx arrives twice and is passed on
return self._training_step_implementation(
batch, batch_idx, optimizer_idx
) # pass optim_idx
optimizer_step(epoch, batch_idx, optimizer, optimizer_idx=None, optimizer_closure=None, **kwargs)
¶
Custom optimizer step handling for PL 1.x automatic optimization.
This method ensures correct behavior across different PyTorch Lightning
versions and training modes. It is invoked automatically during training
in PL < 2.0 when automatic_optimization=True. For PL ≥ 2.0, where manual
optimization is used, this function is effectively bypassed.
- In PL ≥ 2.0 (manual optimization): The optimizer step is explicitly
called within
training_step_PL2(), including EMA updates. - In PL < 2.0 (automatic optimization): This function manages optimizer stepping, gradient zeroing, and optional EMA updates after generator steps.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
epoch
|
int
|
Current training epoch. |
required |
batch_idx
|
int
|
Index of the current batch. |
required |
optimizer
|
Optimizer
|
The active optimizer instance. |
required |
optimizer_idx
|
int
|
Index of the optimizer being stepped (e.g., 0 for discriminator, 1 for generator). |
None
|
optimizer_closure
|
Callable
|
Closure for re-evaluating the model and loss before optimizer step (used with some optimizers). |
None
|
**kwargs
|
Additional arguments passed by PL depending on backend (e.g., TPU flags, LBFGS options). |
{}
|
Notes
- EMA updates are performed only after generator steps (optimizer_idx == 1).
- The update starts after
self._ema_update_after_stepglobal steps.
Source code in opensr_srgan/model/SRGAN.py
def optimizer_step(
self,
epoch,
batch_idx,
optimizer,
optimizer_idx=None,
optimizer_closure=None,
**kwargs, # absorbs on_tpu/using_lbfgs/etc across PL versions
):
"""Custom optimizer step handling for PL 1.x automatic optimization.
This method ensures correct behavior across different PyTorch Lightning
versions and training modes. It is invoked automatically during training
in PL < 2.0 when `automatic_optimization=True`. For PL ≥ 2.0, where manual
optimization is used, this function is effectively bypassed.
- In **PL ≥ 2.0 (manual optimization)**: The optimizer step is explicitly
called within `training_step_PL2()`, including EMA updates.
- In **PL < 2.0 (automatic optimization)**: This function manages optimizer
stepping, gradient zeroing, and optional EMA updates after generator steps.
Args:
epoch (int): Current training epoch.
batch_idx (int): Index of the current batch.
optimizer (torch.optim.Optimizer): The active optimizer instance.
optimizer_idx (int, optional): Index of the optimizer being stepped
(e.g., 0 for discriminator, 1 for generator).
optimizer_closure (Callable, optional): Closure for re-evaluating the
model and loss before optimizer step (used with some optimizers).
**kwargs: Additional arguments passed by PL depending on backend
(e.g., TPU flags, LBFGS options).
Notes:
- EMA updates are performed only after generator steps (optimizer_idx == 1).
- The update starts after `self._ema_update_after_step` global steps.
"""
# If we're in manual optimization (PL >=2 path), do nothing special.
if not self.automatic_optimization:
# In manual mode we call opt.step()/zero_grad() in training_step_PL2.
# In manual mode, we update EMA weights manually in training step too.
return super().optimizer_step(
epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, **kwargs
)
# ---- PL 1.x auto-optimization path ----
if optimizer_closure is not None:
optimizer.step(closure=optimizer_closure)
else:
optimizer.step()
optimizer.zero_grad()
# EMA after the generator step (assumes G is optimizer_idx == 1)
if (
self.ema is not None
and optimizer_idx == 1
and self.global_step >= self._ema_update_after_step
):
self.ema.update(self.generator)
validation_step(batch, batch_idx)
¶
Run the validation loop for a single batch.
This method performs super-resolution inference on validation data, computes image quality metrics (e.g., PSNR, SSIM), logs them, and optionally visualizes SR–HR–LR triplets. It also evaluates the discriminator’s adversarial response if applicable.
Workflow
- Forward pass (LR → SR) through the generator.
- Compute content-based validation metrics.
- Optionally log visual examples to the logger (e.g., Weights & Biases).
- Compute and log discriminator metrics, unless in pretraining mode.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Tuple[Tensor, Tensor]
|
A tuple |
required |
batch_idx
|
int
|
Index of the current validation batch. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
None |
Metrics and images are logged via Lightning’s logger interface. |
Raises:
| Type | Description |
|---|---|
AssertionError
|
If an unexpected number of bands or invalid visualization configuration is encountered. |
Notes
- Validation is executed without gradient tracking.
- Only the first
config.Logging.num_val_imagesbatches are visualized. - If EMA is enabled, the generator predictions reflect the current EMA state.
Source code in opensr_srgan/model/SRGAN.py
@torch.no_grad()
def validation_step(self, batch, batch_idx):
"""Run the validation loop for a single batch.
This method performs super-resolution inference on validation data,
computes image quality metrics (e.g., PSNR, SSIM), logs them, and
optionally visualizes SR–HR–LR triplets. It also evaluates the
discriminator’s adversarial response if applicable.
Workflow:
1. Forward pass (LR → SR) through the generator.
2. Compute content-based validation metrics.
3. Optionally log visual examples to the logger (e.g., Weights & Biases).
4. Compute and log discriminator metrics, unless in pretraining mode.
Args:
batch (Tuple[torch.Tensor, torch.Tensor]): A tuple `(lr_imgs, hr_imgs)` of
low-resolution and high-resolution tensors with shape `(B, C, H, W)`.
batch_idx (int): Index of the current validation batch.
Returns:
None: Metrics and images are logged via Lightning’s logger interface.
Raises:
AssertionError: If an unexpected number of bands or invalid visualization
configuration is encountered.
Notes:
- Validation is executed without gradient tracking.
- Only the first `config.Logging.num_val_images` batches are visualized.
- If EMA is enabled, the generator predictions reflect the current EMA state.
"""
# ======================================================================
# SECTION: Forward pass — Generate SR prediction from LR input
# Purpose: Run model inference on validation batch without gradient tracking.
# ======================================================================
""" 1. Extract and Predict """
lr_imgs, hr_imgs = batch # unpack LR and HR tensors
sr_imgs = self.forward(lr_imgs) # run generator to produce SR prediction
# ======================================================================
# SECTION: Compute and log validation metrics
# Purpose: measure content-based metrics (PSNR/SSIM/etc.) on SR vs HR.
# ======================================================================
""" 2. Log Generator Metrics """
metrics_hr_img = torch.clone(
hr_imgs
) # clone to avoid in-place ops on autograd graph
metrics_sr_img = torch.clone(sr_imgs) # same for SR
# metrics = calculate_metrics(metrics_sr_img, metrics_hr_img, phase="val_metrics")
metrics = self.content_loss_criterion.return_metrics(
metrics_sr_img, metrics_hr_img, prefix="val_metrics/"
) # compute metrics using loss criterion helper
del metrics_hr_img, metrics_sr_img # free cloned tensors from GPU memory
for key, value in metrics.items(): # iterate over metrics dict
self.log(
f"{key}", value, sync_dist=True
) # log each metric to logger (e.g., W&B, TensorBoard)
# ======================================================================
# SECTION: Optional visualization — Log example SR/HR/LR images
# Purpose: visually track qualitative progress of the model.
# ======================================================================
# only perform image logging for first N batches to avoid logging all 200 images
if batch_idx < self.config.Logging.num_val_images:
base_lr = lr_imgs # use original LR for visualization
# --- Select visualization bands (if multispectral) ---
if self.config.Model.in_bands < 3:
# show only first band
lr_vis = base_lr[:, :1, :, :] # e.g., single-band input
hr_vis = hr_imgs[:, :1, :, :] # subset HR
sr_vis = sr_imgs[:, :1, :, :] # subset SR
elif self.config.Model.in_bands == 3:
# we can show normally
pass
elif self.config.Model.in_bands == 4:
# assume its RGB-NIR, show RGB
lr_vis = base_lr[:, :3, :, :] # e.g., Sentinel-2 RGB
hr_vis = hr_imgs[:, :3, :, :] # subset HR
sr_vis = sr_imgs[:, :3, :, :] # subset SR
elif self.config.Model.in_bands > 4: # e.g., Sentinel-2 with >3 channels
# random selection of bands
idx = np.random.choice(
sr_imgs.shape[1], 3, replace=False
) # randomly select 3 bands
lr_vis = base_lr[:, idx, :, :] # subset LR
hr_vis = hr_imgs[:, idx, :, :] # subset HR
sr_vis = sr_imgs[:, idx, :, :] # subset SR
else:
# should not happen
pass
# --- Clone tensors for plotting to avoid affecting main tensors ---
plot_lr_img = lr_vis.clone()
plot_hr_img = hr_vis.clone()
plot_sr_img = sr_vis.clone()
# --- Generate matplotlib visualization (LR, SR, HR side-by-side) ---
val_img = plot_tensors(plot_lr_img, plot_sr_img, plot_hr_img, title="Val")
# --- Cleanup ---
del plot_lr_img, plot_hr_img, plot_sr_img # free memory after plotting
# --- Log image to WandB (or compatible logger), if wanted ---
if self.config.Logging.wandb.enabled:
self.logger.experiment.log(
{"Val SR": wandb.Image(val_img)}
) # upload to dashboard
""" 3. Log Discriminator metrics """
# If in pretraining, discard D metrics
if self._pretrain_check(): # check if we'e in pretrain phase
self.log(
"discriminator/adversarial_loss",
torch.zeros(1, device=lr_imgs.device),
prog_bar=False,
sync_dist=True,
)
else:
# run discriminator and get loss between pred labels and true labels
hr_discriminated = self.discriminator(hr_imgs)
sr_discriminated = self.discriminator(sr_imgs)
# Run loss depending on type
if self.adv_loss_type == "wasserstein":
adversarial_loss = sr_discriminated.mean() - hr_discriminated.mean()
else:
adversarial_loss = self.adversarial_loss_criterion(
sr_discriminated, torch.zeros_like(sr_discriminated)
) + self.adversarial_loss_criterion(
hr_discriminated, torch.ones_like(hr_discriminated)
)
# Log image
self.log(
"validation/DISC_adversarial_loss", adversarial_loss, sync_dist=True
)
on_validation_epoch_start()
¶
Hook executed at the start of each validation epoch.
Applies the Exponential Moving Average (EMA) weights to the generator before running validation to ensure evaluation uses the smoothed model parameters.
Notes
- Calls the parent hook via
super().on_validation_epoch_start(). - Restores original weights at the end of validation.
Source code in opensr_srgan/model/SRGAN.py
def on_validation_epoch_start(self):
"""Hook executed at the start of each validation epoch.
Applies the Exponential Moving Average (EMA) weights to the generator
before running validation to ensure evaluation uses the smoothed model
parameters.
Notes:
- Calls the parent hook via `super().on_validation_epoch_start()`.
- Restores original weights at the end of validation.
"""
super().on_validation_epoch_start()
self._apply_generator_ema_weights()
on_validation_epoch_end()
¶
Hook executed at the end of each validation epoch.
Restores the generator’s original (non-EMA) weights after validation. Ensures subsequent training or testing uses up-to-date parameters.
Notes
- Calls the parent hook via
super().on_validation_epoch_end().
Source code in opensr_srgan/model/SRGAN.py
def on_validation_epoch_end(self):
"""Hook executed at the end of each validation epoch.
Restores the generator’s original (non-EMA) weights after validation.
Ensures subsequent training or testing uses up-to-date parameters.
Notes:
- Calls the parent hook via `super().on_validation_epoch_end()`.
"""
self._restore_generator_weights()
super().on_validation_epoch_end()
on_test_epoch_start()
¶
Hook executed at the start of each testing epoch.
Applies the Exponential Moving Average (EMA) weights to the generator before running tests to ensure consistent evaluation with the smoothed model parameters.
Notes
- Calls the parent hook via
super().on_test_epoch_start(). - Restores original weights at the end of testing.
Source code in opensr_srgan/model/SRGAN.py
def on_test_epoch_start(self):
"""Hook executed at the start of each testing epoch.
Applies the Exponential Moving Average (EMA) weights to the generator
before running tests to ensure consistent evaluation with the
smoothed model parameters.
Notes:
- Calls the parent hook via `super().on_test_epoch_start()`.
- Restores original weights at the end of testing.
"""
super().on_test_epoch_start()
self._apply_generator_ema_weights()
on_test_epoch_end()
¶
Hook executed at the end of each testing epoch.
Restores the generator’s original (non-EMA) weights after testing. Ensures the model is reset to its latest training state.
Notes
- Calls the parent hook via
super().on_test_epoch_end().
Source code in opensr_srgan/model/SRGAN.py
def on_test_epoch_end(self):
"""Hook executed at the end of each testing epoch.
Restores the generator’s original (non-EMA) weights after testing.
Ensures the model is reset to its latest training state.
Notes:
- Calls the parent hook via `super().on_test_epoch_end()`.
"""
self._restore_generator_weights()
super().on_test_epoch_end()
configure_optimizers()
¶
Robust optimizers & schedulers for GANs (PL1 & PL2 compatible).
- TTUR by default (D lr <= G lr)
- Adam with GAN-friendly betas/eps
- Exclude norm/affine/bias params from weight decay
- Separate Plateau schedulers for G and D (with cooldown/min_lr)
- Optional step-wise warmup for G (linear/cosine), no LR jump at handoff
- Returned order is [D, G] to match your training_step expectations
Source code in opensr_srgan/model/SRGAN.py
def configure_optimizers(self):
"""
Robust optimizers & schedulers for GANs (PL1 & PL2 compatible).
- TTUR by default (D lr <= G lr)
- Adam with GAN-friendly betas/eps
- Exclude norm/affine/bias params from weight decay
- Separate Plateau schedulers for G and D (with cooldown/min_lr)
- Optional step-wise warmup for G (linear/cosine), no LR jump at handoff
- Returned order is [D, G] to match your training_step expectations
"""
import math
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
cfg_opt = self.config.Optimizers
cfg_sch = self.config.Schedulers
# ---------- helpers ----------
def _split_wd_params(model):
"""Return two lists: params_with_wd, params_without_wd."""
wd, no_wd = [], []
norm_like = (
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d,
nn.GroupNorm,
nn.LayerNorm,
nn.InstanceNorm1d,
nn.InstanceNorm2d,
nn.InstanceNorm3d,
)
for m in model.modules():
for n, p in m.named_parameters(recurse=False):
if not p.requires_grad:
continue
if n.endswith("bias") or isinstance(m, norm_like):
no_wd.append(p)
else:
wd.append(p)
# catch any top-level params not in modules (rare)
seen = set(map(id, wd + no_wd))
for n, p in model.named_parameters():
if p.requires_grad and id(p) not in seen:
(no_wd if n.endswith("bias") else wd).append(p)
return wd, no_wd
def _adam(params, lr):
# GAN-friendly defaults; tune in config if needed
betas = getattr(cfg_opt, "betas", (0.0, 0.99))
eps = getattr(cfg_opt, "eps", 1e-7)
return Adam(params, lr=lr, betas=betas, eps=eps)
# ---------- LRs (TTUR) ----------
lr_g = float(getattr(cfg_opt, "optim_g_lr", 1e-4))
lr_d = float(
getattr(cfg_opt, "optim_d_lr", max(lr_g * 0.5, 1e-6))
) # default: D slower than G
# weight decay (only on non-norm, non-bias)
wd_g = float(getattr(cfg_opt, "weight_decay_g", 0.0))
wd_d = float(getattr(cfg_opt, "weight_decay_d", 0.0))
# ---------- build optimizers with clean param groups ----------
g_wd, g_no = _split_wd_params(self.generator)
d_wd, d_no = _split_wd_params(self.discriminator)
optimizer_g = _adam(
[
{"params": g_wd, "weight_decay": wd_g},
{"params": g_no, "weight_decay": 0.0},
],
lr=lr_g,
)
optimizer_d = _adam(
[
{"params": d_wd, "weight_decay": wd_d},
{"params": d_no, "weight_decay": 0.0},
],
lr=lr_d,
)
# ---------- schedulers ----------
# Use distinct monitors for clarity (recommend: log these in validation)
monitor_g = getattr(
cfg_sch, "metric_g", getattr(cfg_sch, "metric", "val_g_loss")
)
monitor_d = getattr(
cfg_sch, "metric_d", getattr(cfg_sch, "metric", "val_d_loss")
)
sched_kwargs = dict(
mode="min",
factor=float(getattr(cfg_sch, "factor_g", 0.5)),
patience=int(getattr(cfg_sch, "patience_g", 5)),
threshold=float(getattr(cfg_sch, "threshold", 1e-3)),
threshold_mode="rel",
cooldown=int(getattr(cfg_sch, "cooldown", 0)),
min_lr=float(getattr(cfg_sch, "min_lr", 1e-7)),
verbose=bool(getattr(cfg_sch, "verbose", False)),
)
# D can have its own factor/patience; fall back to G’s if not set
sched_kwargs_d = dict(sched_kwargs)
sched_kwargs_d["factor"] = float(
getattr(cfg_sch, "factor_d", sched_kwargs["factor"])
)
sched_kwargs_d["patience"] = int(
getattr(cfg_sch, "patience_d", sched_kwargs["patience"])
)
scheduler_g = ReduceLROnPlateau(optimizer_g, **sched_kwargs)
scheduler_d = ReduceLROnPlateau(optimizer_d, **sched_kwargs_d)
sch_configs = [
{
"scheduler": scheduler_d,
"monitor": monitor_d,
"reduce_on_plateau": True,
"interval": "epoch",
"frequency": 1,
"name": "plateau_d",
},
{
"scheduler": scheduler_g,
"monitor": monitor_g,
"reduce_on_plateau": True,
"interval": "epoch",
"frequency": 1,
"name": "plateau_g",
},
]
# ---------- optional warmup for G (step-wise, multiplicative) ----------
warmup_steps = int(getattr(cfg_sch, "g_warmup_steps", 0))
warmup_type = str(getattr(cfg_sch, "g_warmup_type", "none")).lower()
if warmup_steps > 0 and warmup_type in {"linear", "cosine"}:
def _g_warmup_lambda(step: int) -> float:
if step >= warmup_steps:
return 1.0
t = (step + 1) / max(1, warmup_steps)
return (
t
if warmup_type == "linear"
else 0.5 * (1.0 - math.cos(math.pi * t))
)
warmup_g = torch.optim.lr_scheduler.LambdaLR(
optimizer_g, lr_lambda=_g_warmup_lambda
)
# Runs every step; multiplies base LR so there is no jump at the end
sch_configs.append(
{
"scheduler": warmup_g,
"interval": "step",
"frequency": 1,
"name": "warmup_g",
}
)
# Return order [D, G] to match your training_step
return [optimizer_d, optimizer_g], sch_configs
on_train_batch_start(batch, batch_idx)
¶
Hook executed before each training batch.
Freezes or unfreezes discriminator parameters depending on the current training phase. During pretraining, the discriminator is frozen to allow the generator to learn reconstruction without adversarial pressure.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Any
|
The current batch of training data. |
required |
batch_idx
|
int
|
Index of the current batch in the epoch. |
required |
Source code in opensr_srgan/model/SRGAN.py
def on_train_batch_start(
self, batch, batch_idx
): # called before each training batch
"""Hook executed before each training batch.
Freezes or unfreezes discriminator parameters depending on the
current training phase. During pretraining, the discriminator is
frozen to allow the generator to learn reconstruction without
adversarial pressure.
Args:
batch (Any): The current batch of training data.
batch_idx (int): Index of the current batch in the epoch.
"""
pre = self._pretrain_check() # check if currently in pretraining phase
for p in self.discriminator.parameters(): # loop over all discriminator params
p.requires_grad = not pre # freeze D during pretrain, unfreeze otherwise
on_train_batch_end(outputs, batch, batch_idx)
¶
Hook executed after each training batch.
Logs the current learning rates for all active optimizers to the logger for monitoring and debugging purposes.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
outputs
|
Any
|
Outputs returned by |
required |
batch
|
Any
|
The batch of data processed. |
required |
batch_idx
|
int
|
Index of the current batch in the epoch. |
required |
Source code in opensr_srgan/model/SRGAN.py
def on_train_batch_end(self, outputs, batch, batch_idx):
"""Hook executed after each training batch.
Logs the current learning rates for all active optimizers to
the logger for monitoring and debugging purposes.
Args:
outputs (Any): Outputs returned by `training_step`.
batch (Any): The batch of data processed.
batch_idx (int): Index of the current batch in the epoch.
"""
self._log_lrs() # log LR's on each batch end
on_fit_start()
¶
Hook executed once at the beginning of model fitting.
Performs setup tasks that must occur before training starts: - Moves EMA weights to the correct device. - Prints a model summary (only from global rank 0 in DDP setups).
Notes
- Calls
super().on_fit_start()to preserve Lightning’s default behavior. - The model summary is only printed by the global zero process to avoid duplicated output in distributed training.
Source code in opensr_srgan/model/SRGAN.py
def on_fit_start(self): # called once at the start of training
"""Hook executed once at the beginning of model fitting.
Performs setup tasks that must occur before training starts:
- Moves EMA weights to the correct device.
- Prints a model summary (only from global rank 0 in DDP setups).
Notes:
- Calls `super().on_fit_start()` to preserve Lightning’s default behavior.
- The model summary is only printed by the global zero process
to avoid duplicated output in distributed training.
"""
super().on_fit_start()
if self.ema is not None and self.ema.device is None: # move ema weights
self.ema.to(self.device)
from opensr_srgan.utils.gpu_rank import _is_global_zero
if _is_global_zero():
print_model_summary(self) # print model summary to console
on_save_checkpoint(checkpoint)
¶
Augment the checkpoint with EMA state, if available.
Adds the EMA buffer/metadata to the checkpoint so that EMA can be restored upon load.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
checkpoint
|
dict
|
Mutable checkpoint dictionary provided by Lightning. |
required |
Source code in opensr_srgan/model/SRGAN.py
def on_save_checkpoint(self, checkpoint: dict) -> None:
"""Augment the checkpoint with EMA state, if available.
Adds the EMA buffer/metadata to the checkpoint so that EMA can be restored upon load.
Args:
checkpoint (dict): Mutable checkpoint dictionary provided by Lightning.
"""
super().on_save_checkpoint(checkpoint)
if self.ema is not None:
checkpoint["ema_state"] = self.ema.state_dict()
on_load_checkpoint(checkpoint)
¶
Restore EMA state from a checkpoint, if present.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
checkpoint
|
dict
|
Checkpoint dictionary provided by Lightning containing
model state and optional |
required |
Source code in opensr_srgan/model/SRGAN.py
def on_load_checkpoint(self, checkpoint: dict) -> None:
"""Restore EMA state from a checkpoint, if present.
Args:
checkpoint (dict): Checkpoint dictionary provided by Lightning containing
model state and optional `"ema_state"` entry.
"""
super().on_load_checkpoint(checkpoint)
if self.ema is not None and "ema_state" in checkpoint:
self.ema.load_state_dict(checkpoint["ema_state"])
load_from_checkpoint(ckpt_path)
¶
Load model weights from a PyTorch Lightning checkpoint file.
Loads the state_dict from the given checkpoint and maps it to the current device.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ckpt_path
|
str | Path
|
Path to the |
required |
Raises:
| Type | Description |
|---|---|
FileNotFoundError
|
If the checkpoint path does not exist. |
KeyError
|
If the checkpoint does not contain a |
Source code in opensr_srgan/model/SRGAN.py
def load_from_checkpoint(self, ckpt_path) -> None:
"""Load model weights from a PyTorch Lightning checkpoint file.
Loads the `state_dict` from the given checkpoint and maps it to the current device.
Args:
ckpt_path (str | pathlib.Path): Path to the `.ckpt` file saved by Lightning.
Raises:
FileNotFoundError: If the checkpoint path does not exist.
KeyError: If the checkpoint does not contain a `'state_dict'` entry.
"""
# load ckpt
ckpt = torch.load(ckpt_path, map_location=self.device)
self.load_state_dict(ckpt["state_dict"])
print(f"Loaded checkpoint from {ckpt_path}")
Building blocks¶
Reusable building blocks for SRGAN family models.
ConvolutionalBlock
¶
Bases: Module
A convolutional block comprised of Conv → (BN) → (Activation).
Source code in opensr_srgan/model/model_blocks/__init__.py
class ConvolutionalBlock(nn.Module):
"""A convolutional block comprised of Conv → (BN) → (Activation)."""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
batch_norm: bool = False,
activation: str | None = None,
) -> None:
super().__init__()
act = activation.lower() if activation is not None else None
if act is not None:
if act not in {"prelu", "leakyrelu", "tanh"}:
raise AssertionError("activation must be one of {'prelu', 'leakyrelu', 'tanh'}")
layers: list[nn.Module] = [
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=kernel_size // 2,
)
]
if batch_norm:
layers.append(nn.BatchNorm2d(num_features=out_channels))
if act == "prelu":
layers.append(nn.PReLU())
elif act == "leakyrelu":
layers.append(nn.LeakyReLU(0.2))
elif act == "tanh":
layers.append(nn.Tanh())
self.conv_block = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv_block(x)
SubPixelConvolutionalBlock
¶
Bases: Module
Conv → PixelShuffle → PReLU upsampling block.
Source code in opensr_srgan/model/model_blocks/__init__.py
class SubPixelConvolutionalBlock(nn.Module):
"""Conv → PixelShuffle → PReLU upsampling block."""
def __init__(
self,
kernel_size: int = 3,
n_channels: int = 64,
scaling_factor: int = 2,
) -> None:
super().__init__()
self.conv = nn.Conv2d(
in_channels=n_channels,
out_channels=n_channels * (scaling_factor**2),
kernel_size=kernel_size,
padding=kernel_size // 2,
)
self.pixel_shuffle = nn.PixelShuffle(upscale_factor=scaling_factor)
self.prelu = nn.PReLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv(x)
x = self.pixel_shuffle(x)
x = self.prelu(x)
return x
ResidualBlock
¶
Bases: Module
BN-enabled residual block used in the original SRResNet.
Source code in opensr_srgan/model/model_blocks/__init__.py
class ResidualBlock(nn.Module):
"""BN-enabled residual block used in the original SRResNet."""
def __init__(self, kernel_size: int = 3, n_channels: int = 64) -> None:
super().__init__()
self.conv_block1 = ConvolutionalBlock(
in_channels=n_channels,
out_channels=n_channels,
kernel_size=kernel_size,
batch_norm=True,
activation="PReLu",
)
self.conv_block2 = ConvolutionalBlock(
in_channels=n_channels,
out_channels=n_channels,
kernel_size=kernel_size,
batch_norm=True,
activation=None,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = self.conv_block1(x)
x = self.conv_block2(x)
return x + residual
ResidualBlockNoBN
¶
Bases: Module
Residual block variant without batch norm and with residual scaling.
Source code in opensr_srgan/model/model_blocks/__init__.py
class ResidualBlockNoBN(nn.Module):
"""Residual block variant without batch norm and with residual scaling."""
def __init__(
self,
n_channels: int = 64,
kernel_size: int = 3,
res_scale: float = 0.2,
) -> None:
super().__init__()
padding = kernel_size // 2
self.body = nn.Sequential(
nn.Conv2d(n_channels, n_channels, kernel_size, padding=padding),
nn.PReLU(),
nn.Conv2d(n_channels, n_channels, kernel_size, padding=padding),
)
self.res_scale = res_scale
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.res_scale * self.body(x)
RCAB
¶
Bases: Module
Residual Channel Attention Block (no BN).
Source code in opensr_srgan/model/model_blocks/__init__.py
class RCAB(nn.Module):
"""Residual Channel Attention Block (no BN)."""
def __init__(
self,
n_channels: int = 64,
kernel_size: int = 3,
reduction: int = 16,
res_scale: float = 0.2,
) -> None:
super().__init__()
padding = kernel_size // 2
self.conv1 = nn.Conv2d(n_channels, n_channels, kernel_size, padding=padding)
self.act = nn.PReLU()
self.conv2 = nn.Conv2d(n_channels, n_channels, kernel_size, padding=padding)
self.se = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(n_channels, n_channels // reduction, 1),
nn.ReLU(inplace=True),
nn.Conv2d(n_channels // reduction, n_channels, 1),
nn.Sigmoid(),
)
self.res_scale = res_scale
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = self.conv1(x)
y = self.act(y)
y = self.conv2(y)
w = self.se(y)
return x + self.res_scale * (y * w)
DenseBlock5
¶
Bases: Module
ESRGAN-style dense block with five convolutions.
Source code in opensr_srgan/model/model_blocks/__init__.py
class DenseBlock5(nn.Module):
"""ESRGAN-style dense block with five convolutions."""
def __init__(self, n_features: int = 64, growth_channels: int = 32, kernel_size: int = 3) -> None:
super().__init__()
padding = kernel_size // 2
self.act = nn.LeakyReLU(0.2, inplace=True)
self.c1 = nn.Conv2d(n_features, growth_channels, kernel_size, padding=padding)
self.c2 = nn.Conv2d(n_features + growth_channels, growth_channels, kernel_size, padding=padding)
self.c3 = nn.Conv2d(n_features + 2 * growth_channels, growth_channels, kernel_size, padding=padding)
self.c4 = nn.Conv2d(n_features + 3 * growth_channels, growth_channels, kernel_size, padding=padding)
self.c5 = nn.Conv2d(n_features + 4 * growth_channels, n_features, kernel_size, padding=padding)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x1 = self.act(self.c1(x))
x2 = self.act(self.c2(torch.cat([x, x1], dim=1)))
x3 = self.act(self.c3(torch.cat([x, x1, x2], dim=1)))
x4 = self.act(self.c4(torch.cat([x, x1, x2, x3], dim=1)))
x5 = self.c5(torch.cat([x, x1, x2, x3, x4], dim=1))
return x5
RRDB
¶
Bases: Module
Residual-in-Residual Dense Block.
Source code in opensr_srgan/model/model_blocks/__init__.py
class RRDB(nn.Module):
"""Residual-in-Residual Dense Block."""
def __init__(self, n_features: int = 64, growth_channels: int = 32, res_scale: float = 0.2) -> None:
super().__init__()
self.db1 = DenseBlock5(n_features, growth_channels)
self.db2 = DenseBlock5(n_features, growth_channels)
self.db3 = DenseBlock5(n_features, growth_channels)
self.res_scale = res_scale
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = self.db1(x)
y = self.db2(y)
y = self.db3(y)
return x + self.res_scale * y
LKA
¶
Bases: Module
Lightweight Large-Kernel Attention module.
Source code in opensr_srgan/model/model_blocks/__init__.py
class LKA(nn.Module):
"""Lightweight Large-Kernel Attention module."""
def __init__(self, n_channels: int = 64) -> None:
super().__init__()
self.dw5 = nn.Conv2d(n_channels, n_channels, 5, padding=2, groups=n_channels)
self.dw7d = nn.Conv2d(n_channels, n_channels, 7, padding=9, dilation=3, groups=n_channels)
self.pw = nn.Conv2d(n_channels, n_channels, 1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
attn = self.dw5(x)
attn = self.dw7d(attn)
attn = self.pw(attn)
return x * torch.sigmoid(attn)
LKAResBlock
¶
Bases: Module
Residual block incorporating Large-Kernel Attention.
Source code in opensr_srgan/model/model_blocks/__init__.py
class LKAResBlock(nn.Module):
"""Residual block incorporating Large-Kernel Attention."""
def __init__(self, n_channels: int = 64, kernel_size: int = 3, res_scale: float = 0.2) -> None:
super().__init__()
padding = kernel_size // 2
self.conv1 = nn.Conv2d(n_channels, n_channels, kernel_size, padding=padding)
self.act = nn.PReLU()
self.lka = LKA(n_channels)
self.conv2 = nn.Conv2d(n_channels, n_channels, kernel_size, padding=padding)
self.res_scale = res_scale
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = self.conv1(x)
y = self.act(y)
y = self.lka(y)
y = self.conv2(y)
return x + self.res_scale * y
ExponentialMovingAverage
¶
Maintain an exponential moving average (EMA) of a model’s parameters and buffers.
This class provides a self-contained implementation of parameter smoothing via EMA, commonly used to stabilize training and improve generalization in deep generative models. It tracks both model parameters and registered buffers (e.g., batch norm statistics), maintains a decayed running average, and allows temporary swapping of model weights for evaluation or checkpointing.
EMA is updated with each training step:
wheredecay is typically close to 1.0 (e.g., 0.999–0.9999).
The class includes
- On-the-fly registration of parameters/buffers from an existing model.
- Safe apply/restore methods to temporarily replace model weights.
- Device management for multi-GPU and CPU environments.
- Full checkpoint serialization support.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
The model whose parameters are to be tracked. |
required |
decay
|
float
|
Smoothing coefficient (0 ≤ decay ≤ 1). Higher values make EMA updates slower. Default is 0.999. |
0.999
|
use_num_updates
|
bool
|
Whether to adapt decay during early updates (useful for warm-up). Default is True. |
True
|
device
|
str | device | None
|
Optional target device for storing EMA parameters (e.g., "cpu" for offloading). Default is None. |
None
|
Attributes:
| Name | Type | Description |
|---|---|---|
decay |
float
|
EMA smoothing coefficient. |
num_updates |
int | None
|
Counter of EMA updates, used to adapt decay. |
device |
device | None
|
Device where EMA tensors are stored. |
shadow_params |
Dict[str, Tensor]
|
Smoothed parameter tensors. |
shadow_buffers |
Dict[str, Tensor]
|
Smoothed buffer tensors. |
collected_params |
Dict[str, Tensor]
|
Temporary cache for original parameters during apply/restore operations. |
collected_buffers |
Dict[str, Tensor]
|
Temporary cache for original buffers during apply/restore operations. |
Source code in opensr_srgan/model/model_blocks/EMA.py
class ExponentialMovingAverage:
"""Maintain an exponential moving average (EMA) of a model’s parameters and buffers.
This class provides a self-contained implementation of parameter smoothing
via EMA, commonly used to stabilize training and improve generalization in
deep generative models. It tracks both model parameters and registered
buffers (e.g., batch norm statistics), maintains a decayed running average,
and allows temporary swapping of model weights for evaluation or checkpointing.
EMA is updated with each training step:
```
shadow = decay * shadow + (1 - decay) * parameter
```
where ``decay`` is typically close to 1.0 (e.g., 0.999–0.9999).
The class includes:
- On-the-fly registration of parameters/buffers from an existing model.
- Safe apply/restore methods to temporarily replace model weights.
- Device management for multi-GPU and CPU environments.
- Full checkpoint serialization support.
Args:
model (nn.Module): The model whose parameters are to be tracked.
decay (float, optional): Smoothing coefficient (0 ≤ decay ≤ 1).
Higher values make EMA updates slower. Default is 0.999.
use_num_updates (bool, optional): Whether to adapt decay during early
updates (useful for warm-up). Default is True.
device (str | torch.device | None, optional): Optional target device for
storing EMA parameters (e.g., "cpu" for offloading). Default is None.
Attributes:
decay (float): EMA smoothing coefficient.
num_updates (int | None): Counter of EMA updates, used to adapt decay.
device (torch.device | None): Device where EMA tensors are stored.
shadow_params (Dict[str, torch.Tensor]): Smoothed parameter tensors.
shadow_buffers (Dict[str, torch.Tensor]): Smoothed buffer tensors.
collected_params (Dict[str, torch.Tensor]): Temporary cache for original
parameters during apply/restore operations.
collected_buffers (Dict[str, torch.Tensor]): Temporary cache for original
buffers during apply/restore operations.
"""
def __init__(
self,
model: nn.Module,
decay: float = 0.999,
*,
use_num_updates: bool = True,
device: str | torch.device | None = None,
) -> None:
if not 0.0 <= decay <= 1.0:
raise ValueError("decay must be between 0 and 1 (inclusive)")
self.decay = float(decay)
self.num_updates = 0 if use_num_updates else None
self.device = torch.device(device) if device is not None else None
self.shadow_params: Dict[str, torch.Tensor] = {}
self.shadow_buffers: Dict[str, torch.Tensor] = {}
self.collected_params: Dict[str, torch.Tensor] = {}
self.collected_buffers: Dict[str, torch.Tensor] = {}
self._register(model)
def _register(self, model: nn.Module) -> None:
for name, param in model.named_parameters():
if not param.requires_grad:
continue
shadow = param.detach().clone()
if self.device is not None:
shadow = shadow.to(self.device)
self.shadow_params[name] = shadow
for name, buffer in model.named_buffers():
shadow = buffer.detach().clone()
if self.device is not None:
shadow = shadow.to(self.device)
self.shadow_buffers[name] = shadow
def update(self, model: nn.Module) -> None:
"""Update the EMA weights using the latest parameters from ``model``.
Performs an in-place exponential moving average update on all
trainable parameters and buffers tracked in ``shadow_params`` and
``shadow_buffers``. If ``use_num_updates=True``, adapts the decay
coefficient during early steps for smoother warm-up.
Args:
model (nn.Module): Model whose parameters and buffers are used to
update the EMA state.
Notes:
- Dynamically adds new parameters or buffers if they were not
present during initialization.
- Operates in ``torch.no_grad()`` context to avoid gradient tracking.
"""
if self.num_updates is not None:
self.num_updates += 1
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
else:
decay = self.decay
one_minus_decay = 1.0 - decay
with torch.no_grad():
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if name not in self.shadow_params:
# Parameter may have been added dynamically (rare, but safeguard)
shadow = param.detach().clone()
if self.device is not None:
shadow = shadow.to(self.device)
self.shadow_params[name] = shadow
shadow_param = self.shadow_params[name]
param_data = param.detach()
if param_data.device != shadow_param.device:
param_data = param_data.to(shadow_param.device)
shadow_param.lerp_(param_data, one_minus_decay)
for name, buffer in model.named_buffers():
if name not in self.shadow_buffers:
shadow = buffer.detach().clone()
if self.device is not None:
shadow = shadow.to(self.device)
self.shadow_buffers[name] = shadow
shadow_buffer = self.shadow_buffers[name]
buffer_data = buffer.detach()
if buffer_data.device != shadow_buffer.device:
buffer_data = buffer_data.to(shadow_buffer.device)
shadow_buffer.copy_(buffer_data)
def apply_to(self, model: nn.Module) -> None:
"""Replace model parameters with EMA-smoothed versions (in-place).
Temporarily swaps the current model parameters and buffers with their
EMA counterparts for evaluation or checkpoint export. The original
tensors are cached internally and can be restored later with
:meth:`restore`.
Args:
model (nn.Module): Model whose parameters will be replaced.
Raises:
RuntimeError: If EMA weights are already applied and not yet restored.
"""
if self.collected_params or self.collected_buffers:
raise RuntimeError("EMA weights already applied; call restore() before reapplying.")
for name, param in model.named_parameters():
if not param.requires_grad or name not in self.shadow_params:
continue
self.collected_params[name] = param.detach().clone()
param.data.copy_(self.shadow_params[name].to(param.device))
for name, buffer in model.named_buffers():
if name not in self.shadow_buffers:
continue
self.collected_buffers[name] = buffer.detach().clone()
buffer.data.copy_(self.shadow_buffers[name].to(buffer.device))
def restore(self, model: nn.Module) -> None:
"""Restore the model’s original parameters after an EMA swap.
Reverts the parameter and buffer changes made by :meth:`apply_to`
by restoring the cached tensors. This is a no-op if EMA weights
were never applied.
Args:
model (nn.Module): Model whose parameters will be restored.
"""
for name, param in model.named_parameters():
cached = self.collected_params.pop(name, None)
if cached is None:
continue
param.data.copy_(cached.to(param.device))
for name, buffer in model.named_buffers():
cached = self.collected_buffers.pop(name, None)
if cached is None:
continue
buffer.data.copy_(cached.to(buffer.device))
@contextmanager
def average_parameters(self, model: nn.Module) -> Iterator[None]:
"""Context manager to temporarily apply EMA weights to ``model``.
This convenience wrapper allows for automatic restoration after use.
Example:
```python
with ema.average_parameters(model):
validate(model)
```
Args:
model (nn.Module): The model to temporarily replace parameters for.
Yields:
None: Executes the body of the context with EMA weights applied.
"""
self.apply_to(model)
try:
yield
finally:
self.restore(model)
def to(self, device: str | torch.device) -> None:
"""Move EMA-tracked tensors to a target device.
Transfers all shadow parameters and buffers to the specified device,
updating the internal ``self.device`` reference.
Args:
device (str | torch.device): Target device (e.g., "cuda", "cpu").
"""
target_device = torch.device(device)
for name, tensor in list(self.shadow_params.items()):
self.shadow_params[name] = tensor.to(target_device)
for name, tensor in list(self.shadow_buffers.items()):
self.shadow_buffers[name] = tensor.to(target_device)
self.device = target_device
def state_dict(self) -> Dict[str, object]:
"""Return a serializable state dictionary for checkpointing.
Packages all relevant EMA state into a plain dictionary, compatible
with PyTorch’s standard checkpoint format. Converts all tensors to CPU
for safe serialization.
Returns:
Dict[str, object]: Dictionary containing EMA decay, update count,
device info, and copies of shadow parameters/buffers.
"""
return {
"decay": self.decay,
"num_updates": self.num_updates,
"device": str(self.device) if self.device is not None else None,
"shadow_params": {k: v.detach().cpu() for k, v in self.shadow_params.items()},
"shadow_buffers": {k: v.detach().cpu() for k, v in self.shadow_buffers.items()},
}
def load_state_dict(self, state_dict: Dict[str, object]) -> None:
"""Load EMA state from a previously saved checkpoint.
Reconstructs the EMA tracking state from a saved dictionary, restoring
all tracked parameters, buffers, and metadata such as decay, device,
and update count.
Args:
state_dict (Dict[str, object]): Dictionary as produced by
:meth:`state_dict`.
Notes:
- Tensors are moved to the current or saved device automatically.
- Clears existing collected (applied) caches to avoid stale state.
"""
self.decay = float(state_dict["decay"])
self.num_updates = state_dict["num_updates"]
device_str = state_dict.get("device", None)
self.device = torch.device(device_str) if device_str is not None else None
self.shadow_params = {
name: tensor.clone().to(self.device) if self.device is not None else tensor.clone()
for name, tensor in state_dict.get("shadow_params", {}).items()
}
self.shadow_buffers = {
name: tensor.clone().to(self.device) if self.device is not None else tensor.clone()
for name, tensor in state_dict.get("shadow_buffers", {}).items()
}
self.collected_params = {}
self.collected_buffers = {}
update(model)
¶
Update the EMA weights using the latest parameters from model.
Performs an in-place exponential moving average update on all
trainable parameters and buffers tracked in shadow_params and
shadow_buffers. If use_num_updates=True, adapts the decay
coefficient during early steps for smoother warm-up.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
Model whose parameters and buffers are used to update the EMA state. |
required |
Notes
- Dynamically adds new parameters or buffers if they were not present during initialization.
- Operates in
torch.no_grad()context to avoid gradient tracking.
Source code in opensr_srgan/model/model_blocks/EMA.py
def update(self, model: nn.Module) -> None:
"""Update the EMA weights using the latest parameters from ``model``.
Performs an in-place exponential moving average update on all
trainable parameters and buffers tracked in ``shadow_params`` and
``shadow_buffers``. If ``use_num_updates=True``, adapts the decay
coefficient during early steps for smoother warm-up.
Args:
model (nn.Module): Model whose parameters and buffers are used to
update the EMA state.
Notes:
- Dynamically adds new parameters or buffers if they were not
present during initialization.
- Operates in ``torch.no_grad()`` context to avoid gradient tracking.
"""
if self.num_updates is not None:
self.num_updates += 1
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
else:
decay = self.decay
one_minus_decay = 1.0 - decay
with torch.no_grad():
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if name not in self.shadow_params:
# Parameter may have been added dynamically (rare, but safeguard)
shadow = param.detach().clone()
if self.device is not None:
shadow = shadow.to(self.device)
self.shadow_params[name] = shadow
shadow_param = self.shadow_params[name]
param_data = param.detach()
if param_data.device != shadow_param.device:
param_data = param_data.to(shadow_param.device)
shadow_param.lerp_(param_data, one_minus_decay)
for name, buffer in model.named_buffers():
if name not in self.shadow_buffers:
shadow = buffer.detach().clone()
if self.device is not None:
shadow = shadow.to(self.device)
self.shadow_buffers[name] = shadow
shadow_buffer = self.shadow_buffers[name]
buffer_data = buffer.detach()
if buffer_data.device != shadow_buffer.device:
buffer_data = buffer_data.to(shadow_buffer.device)
shadow_buffer.copy_(buffer_data)
apply_to(model)
¶
Replace model parameters with EMA-smoothed versions (in-place).
Temporarily swaps the current model parameters and buffers with their
EMA counterparts for evaluation or checkpoint export. The original
tensors are cached internally and can be restored later with
:meth:restore.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
Model whose parameters will be replaced. |
required |
Raises:
| Type | Description |
|---|---|
RuntimeError
|
If EMA weights are already applied and not yet restored. |
Source code in opensr_srgan/model/model_blocks/EMA.py
def apply_to(self, model: nn.Module) -> None:
"""Replace model parameters with EMA-smoothed versions (in-place).
Temporarily swaps the current model parameters and buffers with their
EMA counterparts for evaluation or checkpoint export. The original
tensors are cached internally and can be restored later with
:meth:`restore`.
Args:
model (nn.Module): Model whose parameters will be replaced.
Raises:
RuntimeError: If EMA weights are already applied and not yet restored.
"""
if self.collected_params or self.collected_buffers:
raise RuntimeError("EMA weights already applied; call restore() before reapplying.")
for name, param in model.named_parameters():
if not param.requires_grad or name not in self.shadow_params:
continue
self.collected_params[name] = param.detach().clone()
param.data.copy_(self.shadow_params[name].to(param.device))
for name, buffer in model.named_buffers():
if name not in self.shadow_buffers:
continue
self.collected_buffers[name] = buffer.detach().clone()
buffer.data.copy_(self.shadow_buffers[name].to(buffer.device))
restore(model)
¶
Restore the model’s original parameters after an EMA swap.
Reverts the parameter and buffer changes made by :meth:apply_to
by restoring the cached tensors. This is a no-op if EMA weights
were never applied.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
Model whose parameters will be restored. |
required |
Source code in opensr_srgan/model/model_blocks/EMA.py
def restore(self, model: nn.Module) -> None:
"""Restore the model’s original parameters after an EMA swap.
Reverts the parameter and buffer changes made by :meth:`apply_to`
by restoring the cached tensors. This is a no-op if EMA weights
were never applied.
Args:
model (nn.Module): Model whose parameters will be restored.
"""
for name, param in model.named_parameters():
cached = self.collected_params.pop(name, None)
if cached is None:
continue
param.data.copy_(cached.to(param.device))
for name, buffer in model.named_buffers():
cached = self.collected_buffers.pop(name, None)
if cached is None:
continue
buffer.data.copy_(cached.to(buffer.device))
average_parameters(model)
¶
Context manager to temporarily apply EMA weights to model.
This convenience wrapper allows for automatic restoration after use. Example:
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
The model to temporarily replace parameters for. |
required |
Yields:
| Name | Type | Description |
|---|---|---|
None |
None
|
Executes the body of the context with EMA weights applied. |
Source code in opensr_srgan/model/model_blocks/EMA.py
@contextmanager
def average_parameters(self, model: nn.Module) -> Iterator[None]:
"""Context manager to temporarily apply EMA weights to ``model``.
This convenience wrapper allows for automatic restoration after use.
Example:
```python
with ema.average_parameters(model):
validate(model)
```
Args:
model (nn.Module): The model to temporarily replace parameters for.
Yields:
None: Executes the body of the context with EMA weights applied.
"""
self.apply_to(model)
try:
yield
finally:
self.restore(model)
to(device)
¶
Move EMA-tracked tensors to a target device.
Transfers all shadow parameters and buffers to the specified device,
updating the internal self.device reference.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
device
|
str | device
|
Target device (e.g., "cuda", "cpu"). |
required |
Source code in opensr_srgan/model/model_blocks/EMA.py
def to(self, device: str | torch.device) -> None:
"""Move EMA-tracked tensors to a target device.
Transfers all shadow parameters and buffers to the specified device,
updating the internal ``self.device`` reference.
Args:
device (str | torch.device): Target device (e.g., "cuda", "cpu").
"""
target_device = torch.device(device)
for name, tensor in list(self.shadow_params.items()):
self.shadow_params[name] = tensor.to(target_device)
for name, tensor in list(self.shadow_buffers.items()):
self.shadow_buffers[name] = tensor.to(target_device)
self.device = target_device
state_dict()
¶
Return a serializable state dictionary for checkpointing.
Packages all relevant EMA state into a plain dictionary, compatible with PyTorch’s standard checkpoint format. Converts all tensors to CPU for safe serialization.
Returns:
| Type | Description |
|---|---|
Dict[str, object]
|
Dict[str, object]: Dictionary containing EMA decay, update count, |
Dict[str, object]
|
device info, and copies of shadow parameters/buffers. |
Source code in opensr_srgan/model/model_blocks/EMA.py
def state_dict(self) -> Dict[str, object]:
"""Return a serializable state dictionary for checkpointing.
Packages all relevant EMA state into a plain dictionary, compatible
with PyTorch’s standard checkpoint format. Converts all tensors to CPU
for safe serialization.
Returns:
Dict[str, object]: Dictionary containing EMA decay, update count,
device info, and copies of shadow parameters/buffers.
"""
return {
"decay": self.decay,
"num_updates": self.num_updates,
"device": str(self.device) if self.device is not None else None,
"shadow_params": {k: v.detach().cpu() for k, v in self.shadow_params.items()},
"shadow_buffers": {k: v.detach().cpu() for k, v in self.shadow_buffers.items()},
}
load_state_dict(state_dict)
¶
Load EMA state from a previously saved checkpoint.
Reconstructs the EMA tracking state from a saved dictionary, restoring all tracked parameters, buffers, and metadata such as decay, device, and update count.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state_dict
|
Dict[str, object]
|
Dictionary as produced by
:meth: |
required |
Notes
- Tensors are moved to the current or saved device automatically.
- Clears existing collected (applied) caches to avoid stale state.
Source code in opensr_srgan/model/model_blocks/EMA.py
def load_state_dict(self, state_dict: Dict[str, object]) -> None:
"""Load EMA state from a previously saved checkpoint.
Reconstructs the EMA tracking state from a saved dictionary, restoring
all tracked parameters, buffers, and metadata such as decay, device,
and update count.
Args:
state_dict (Dict[str, object]): Dictionary as produced by
:meth:`state_dict`.
Notes:
- Tensors are moved to the current or saved device automatically.
- Clears existing collected (applied) caches to avoid stale state.
"""
self.decay = float(state_dict["decay"])
self.num_updates = state_dict["num_updates"]
device_str = state_dict.get("device", None)
self.device = torch.device(device_str) if device_str is not None else None
self.shadow_params = {
name: tensor.clone().to(self.device) if self.device is not None else tensor.clone()
for name, tensor in state_dict.get("shadow_params", {}).items()
}
self.shadow_buffers = {
name: tensor.clone().to(self.device) if self.device is not None else tensor.clone()
for name, tensor in state_dict.get("shadow_buffers", {}).items()
}
self.collected_params = {}
self.collected_buffers = {}
make_upsampler(n_channels, scale)
¶
Create a pixel-shuffle upsampler matching the flexible generator implementation.
Source code in opensr_srgan/model/model_blocks/__init__.py
def make_upsampler(n_channels: int, scale: int) -> nn.Sequential:
"""Create a pixel-shuffle upsampler matching the flexible generator implementation."""
stages: list[nn.Module] = []
for _ in range(int(math.log2(scale))):
stages.extend(
[
nn.Conv2d(n_channels, n_channels * 4, 3, padding=1),
nn.PixelShuffle(2),
nn.PReLU(),
]
)
return nn.Sequential(*stages)
Generator architectures¶
Generator architectures for SRGAN.
SRResNet
¶
Bases: Module
Canonical SRResNet generator for single-image super-resolution.
Implements the SRResNet backbone with
1) Large-kernel stem conv (+PReLU) 2) N × residual blocks (small-kernel, no upsampling) 3) Conv for global residual fusion 4) Log2(scaling_factor) × SubPixelConvolutionalBlock (×2 each) 5) Large-kernel output conv (+Tanh)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_channels
|
int
|
Input channels (e.g., 3 for RGB). |
3
|
large_kernel_size
|
int
|
Kernel size for head/tail convolutions. |
9
|
small_kernel_size
|
int
|
Kernel size used inside residual/upsampling blocks. |
3
|
n_channels
|
int
|
Feature width across the network. |
64
|
n_blocks
|
int
|
Number of residual blocks in the trunk. |
16
|
scaling_factor
|
int
|
Upscale factor (must be one of {2, 4, 8}). |
4
|
Returns:
| Type | Description |
|---|---|
|
torch.Tensor: Super-resolved image of shape (B, in_channels, Hscale, Wscale). |
Notes
- The network uses a global skip connection around the residual stack.
- Upsampling is performed by PixelShuffle via sub-pixel convolution blocks.
Source code in opensr_srgan/model/generators/srresnet.py
class SRResNet(nn.Module):
"""Canonical SRResNet generator for single-image super-resolution.
Implements the SRResNet backbone with:
1) Large-kernel stem conv (+PReLU)
2) N × residual blocks (small-kernel, no upsampling)
3) Conv for global residual fusion
4) Log2(scaling_factor) × SubPixelConvolutionalBlock (×2 each)
5) Large-kernel output conv (+Tanh)
Args:
in_channels (int): Input channels (e.g., 3 for RGB).
large_kernel_size (int): Kernel size for head/tail convolutions.
small_kernel_size (int): Kernel size used inside residual/upsampling blocks.
n_channels (int): Feature width across the network.
n_blocks (int): Number of residual blocks in the trunk.
scaling_factor (int): Upscale factor (must be one of {2, 4, 8}).
Returns:
torch.Tensor: Super-resolved image of shape (B, in_channels, H*scale, W*scale).
Notes:
- The network uses a global skip connection around the residual stack.
- Upsampling is performed by PixelShuffle via sub-pixel convolution blocks.
"""
def __init__(
self,
in_channels: int = 3,
large_kernel_size: int = 9,
small_kernel_size: int = 3,
n_channels: int = 64,
n_blocks: int = 16,
scaling_factor: int = 4,
) -> None:
"""Build the canonical SRResNet generator network."""
super().__init__()
scaling_factor = int(scaling_factor)
if scaling_factor not in {2, 4, 8}:
raise AssertionError("The scaling factor must be 2, 4, or 8!")
self.conv_block1 = ConvolutionalBlock(
in_channels=in_channels,
out_channels=n_channels,
kernel_size=large_kernel_size,
batch_norm=False,
activation="PReLu",
)
self.residual_blocks = nn.Sequential(
*[
ResidualBlock(
kernel_size=small_kernel_size,
n_channels=n_channels,
)
for _ in range(n_blocks)
]
)
self.conv_block2 = ConvolutionalBlock(
in_channels=n_channels,
out_channels=n_channels,
kernel_size=small_kernel_size,
batch_norm=True,
activation=None,
)
n_subpixel_convolution_blocks = int(math.log2(scaling_factor))
self.subpixel_convolutional_blocks = nn.Sequential(
*[
SubPixelConvolutionalBlock(
kernel_size=small_kernel_size,
n_channels=n_channels,
scaling_factor=2,
)
for _ in range(n_subpixel_convolution_blocks)
]
)
self.conv_block3 = ConvolutionalBlock(
in_channels=n_channels,
out_channels=in_channels,
kernel_size=large_kernel_size,
batch_norm=False,
activation="Tanh",
)
def forward(self, lr_imgs: torch.Tensor) -> torch.Tensor:
"""Forward propagation through SRResNet."""
output = self.conv_block1(lr_imgs)
residual = output
output = self.residual_blocks(output)
output = self.conv_block2(output)
output = output + residual
output = self.subpixel_convolutional_blocks(output)
sr_imgs = self.conv_block3(output)
return sr_imgs
forward(lr_imgs)
¶
Forward propagation through SRResNet.
Source code in opensr_srgan/model/generators/srresnet.py
def forward(self, lr_imgs: torch.Tensor) -> torch.Tensor:
"""Forward propagation through SRResNet."""
output = self.conv_block1(lr_imgs)
residual = output
output = self.residual_blocks(output)
output = self.conv_block2(output)
output = output + residual
output = self.subpixel_convolutional_blocks(output)
sr_imgs = self.conv_block3(output)
return sr_imgs
Generator
¶
Bases: Module
SRGAN generator wrapper around :class:SRResNet.
Provides a thin adapter that
- Builds an internal :class:
SRResNetwith the given hyperparameters. - Optionally initializes weights from a pretrained SRResNet checkpoint.
- Exposes a unified forward for SRGAN pipelines.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_channels
|
int
|
Input channels (e.g., 3 for RGB). |
3
|
large_kernel_size
|
int
|
Kernel size for head/tail convolutions. |
9
|
small_kernel_size
|
int
|
Kernel size used inside residual/upsampling blocks. |
3
|
n_channels
|
int
|
Feature width across the network. |
64
|
n_blocks
|
int
|
Number of residual blocks in the trunk. |
16
|
scaling_factor
|
int
|
Upscale factor (must be one of {2, 4, 8}). |
4
|
Returns:
| Type | Description |
|---|---|
|
torch.Tensor: Super-resolved image produced by the wrapped SRResNet. |
Source code in opensr_srgan/model/generators/srresnet.py
class Generator(nn.Module):
"""SRGAN generator wrapper around :class:`SRResNet`.
Provides a thin adapter that:
- Builds an internal :class:`SRResNet` with the given hyperparameters.
- Optionally initializes weights from a pretrained SRResNet checkpoint.
- Exposes a unified forward for SRGAN pipelines.
Args:
in_channels (int): Input channels (e.g., 3 for RGB).
large_kernel_size (int): Kernel size for head/tail convolutions.
small_kernel_size (int): Kernel size used inside residual/upsampling blocks.
n_channels (int): Feature width across the network.
n_blocks (int): Number of residual blocks in the trunk.
scaling_factor (int): Upscale factor (must be one of {2, 4, 8}).
Returns:
torch.Tensor: Super-resolved image produced by the wrapped SRResNet.
"""
def __init__(
self,
in_channels: int = 3,
large_kernel_size: int = 9,
small_kernel_size: int = 3,
n_channels: int = 64,
n_blocks: int = 16,
scaling_factor: int = 4,
) -> None:
super().__init__()
self.net = SRResNet(
in_channels=in_channels,
large_kernel_size=large_kernel_size,
small_kernel_size=small_kernel_size,
n_channels=n_channels,
n_blocks=n_blocks,
scaling_factor=scaling_factor,
)
def initialize_with_srresnet(self, srresnet_checkpoint: str) -> None:
"""Initialize the generator weights from a pretrained SRResNet checkpoint."""
srresnet = torch.load(srresnet_checkpoint)["model"]
self.net.load_state_dict(srresnet.state_dict())
print("\nLoaded weights from pre-trained SRResNet.\n")
def forward(self, lr_imgs: torch.Tensor) -> torch.Tensor:
"""Forward propagation via the wrapped SRResNet."""
return self.net(lr_imgs)
initialize_with_srresnet(srresnet_checkpoint)
¶
Initialize the generator weights from a pretrained SRResNet checkpoint.
Source code in opensr_srgan/model/generators/srresnet.py
def initialize_with_srresnet(self, srresnet_checkpoint: str) -> None:
"""Initialize the generator weights from a pretrained SRResNet checkpoint."""
srresnet = torch.load(srresnet_checkpoint)["model"]
self.net.load_state_dict(srresnet.state_dict())
print("\nLoaded weights from pre-trained SRResNet.\n")
forward(lr_imgs)
¶
StochasticGenerator
¶
Bases: Module
Stochastic generator with latent noise modulation for super-resolution.
Extends a standard SR generator by injecting stochastic latent noise through
NoiseResBlocks, enabling diverse texture generation conditioned on the same
low-resolution (LR) input. When no latent vector is provided, one is sampled
internally from a standard normal distribution, allowing both deterministic
and stochastic inference modes.
The architecture follows a residual backbone with
- A wide receptive-field head convolution.
- Multiple noise-modulated residual blocks.
- A tail convolution with learnable upsampling.
- Configurable scaling factor (×2, ×4, or ×8).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_channels
|
int
|
Number of input channels (e.g., RGB+NIR = 4 or 6). |
6
|
n_channels
|
int
|
Base number of feature channels in the generator. |
96
|
n_blocks
|
int
|
Number of noise-modulated residual blocks. |
16
|
small_kernel
|
int
|
Kernel size for body convolutions. |
3
|
large_kernel
|
int
|
Kernel size for head/tail convolutions. |
9
|
scale
|
int
|
Upscaling factor (must be one of {2, 4, 8}). |
4
|
noise_dim
|
int
|
Dimensionality of the latent vector z. |
128
|
res_scale
|
float
|
Residual scaling factor for block stability. |
0.2
|
Attributes:
| Name | Type | Description |
|---|---|---|
noise_dim |
int
|
Dimensionality of latent vector z. |
head |
Sequential
|
Initial convolutional stem. |
body |
ModuleList
|
Sequence of |
upsampler |
Module
|
PixelShuffle-based upsampling module. |
tail |
Conv2d
|
Final convolution projecting to output space. |
Example
g = StochasticGenerator(in_channels=3, scale=4) lr = torch.randn(1, 3, 64, 64) sr, noise = g(lr, return_noise=True)
Source code in opensr_srgan/model/generators/cgan_generator.py
class StochasticGenerator(nn.Module):
"""Stochastic generator with latent noise modulation for super-resolution.
Extends a standard SR generator by injecting stochastic latent noise through
`NoiseResBlock`s, enabling diverse texture generation conditioned on the same
low-resolution (LR) input. When no latent vector is provided, one is sampled
internally from a standard normal distribution, allowing both deterministic
and stochastic inference modes.
The architecture follows a residual backbone with:
- A wide receptive-field head convolution.
- Multiple noise-modulated residual blocks.
- A tail convolution with learnable upsampling.
- Configurable scaling factor (×2, ×4, or ×8).
Args:
in_channels (int): Number of input channels (e.g., RGB+NIR = 4 or 6).
n_channels (int): Base number of feature channels in the generator.
n_blocks (int): Number of noise-modulated residual blocks.
small_kernel (int): Kernel size for body convolutions.
large_kernel (int): Kernel size for head/tail convolutions.
scale (int): Upscaling factor (must be one of {2, 4, 8}).
noise_dim (int): Dimensionality of the latent vector z.
res_scale (float): Residual scaling factor for block stability.
Attributes:
noise_dim (int): Dimensionality of latent vector z.
head (nn.Sequential): Initial convolutional stem.
body (nn.ModuleList): Sequence of `NoiseResBlock`s.
upsampler (nn.Module): PixelShuffle-based upsampling module.
tail (nn.Conv2d): Final convolution projecting to output space.
Example:
>>> g = StochasticGenerator(in_channels=3, scale=4)
>>> lr = torch.randn(1, 3, 64, 64)
>>> sr, noise = g(lr, return_noise=True)
"""
def __init__(
self,
in_channels: int = 6,
n_channels: int = 96,
n_blocks: int = 16,
small_kernel: int = 3,
large_kernel: int = 9,
scale: int = 4,
noise_dim: int = 128,
res_scale: float = 0.2,
) -> None:
super().__init__()
if scale not in {2, 4, 8}:
raise ValueError("scale must be one of {2, 4, 8}")
self.noise_dim = noise_dim
self.scale = scale
padding_large = large_kernel // 2
self.head = nn.Sequential(
nn.Conv2d(in_channels, n_channels, large_kernel, padding=padding_large),
nn.PReLU(),
)
self.body = nn.ModuleList(
[
NoiseResBlock(n_channels, small_kernel, noise_dim, res_scale)
for _ in range(n_blocks)
]
)
self.body_tail = nn.Conv2d(
n_channels,
n_channels,
small_kernel,
padding=small_kernel // 2,
)
self.upsampler = make_upsampler(n_channels, scale)
self.tail = nn.Conv2d(
n_channels, in_channels, large_kernel, padding=padding_large
)
def sample_noise(
self,
batch_size: int,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
"""
Sample a latent noise tensor consistent with the generator configuration.
Parameters
----------
batch_size : int
Number of noise vectors to generate.
device : torch.device, optional
Device on which to allocate the tensor. Defaults to the model's current device.
dtype : torch.dtype, optional
Tensor dtype. Defaults to the model's parameter dtype.
Returns
-------
torch.Tensor
Random latent tensor sampled from a standard normal distribution,
shape (B, D) where B = `batch_size` and D = `self.noise_dim`.
Notes
-----
- Used for stochastic generation when no latent vector is provided to ``forward()``.
- Ensures type and device consistency with the current model parameters.
- The resulting noise can be reused to reproduce identical stochastic outputs.
"""
if device is None:
device = next(self.parameters()).device
if dtype is None:
dtype = next(self.parameters()).dtype
return torch.randn(batch_size, self.noise_dim, device=device, dtype=dtype)
def forward(
self,
lr: torch.Tensor,
noise: Optional[torch.Tensor] = None,
return_noise: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass of the stochastic super-resolution generator.
Parameters
----------
lr : torch.Tensor
Low-resolution input tensor of shape (B, C, H, W).
noise : torch.Tensor, optional
Latent noise tensor of shape (B, D), where D = `noise_dim`.
If None, a random vector is sampled internally.
return_noise : bool, default=False
If True, returns both the super-resolved image and the
latent noise used for generation.
Returns
-------
torch.Tensor or (torch.Tensor, torch.Tensor)
- ``sr``: Super-resolved output image of shape (B, C, sH, sW),
where s is the upscaling factor.
- Optionally, ``noise``: The latent vector used (if `return_noise=True`).
Notes
-----
- The latent code is broadcast and applied to all residual blocks.
- Supports both deterministic (fixed noise) and stochastic (random noise)
inference modes.
- The upsampling is performed via sub-pixel convolution (PixelShuffle)
defined in `make_upsampler()`.
"""
if noise is None:
noise = torch.randn(
lr.size(0),
self.noise_dim,
device=lr.device,
dtype=lr.dtype,
)
features = self.head(lr)
residual = features
for block in self.body:
residual = block(residual, noise)
residual = self.body_tail(residual)
features = features + residual
features = self.upsampler(features)
sr = self.tail(features)
if return_noise:
return sr, noise
return sr
sample_noise(batch_size, device=None, dtype=None)
¶
Sample a latent noise tensor consistent with the generator configuration.
Parameters¶
batch_size : int Number of noise vectors to generate. device : torch.device, optional Device on which to allocate the tensor. Defaults to the model's current device. dtype : torch.dtype, optional Tensor dtype. Defaults to the model's parameter dtype.
Returns¶
torch.Tensor
Random latent tensor sampled from a standard normal distribution,
shape (B, D) where B = batch_size and D = self.noise_dim.
Notes¶
- Used for stochastic generation when no latent vector is provided to
forward(). - Ensures type and device consistency with the current model parameters.
- The resulting noise can be reused to reproduce identical stochastic outputs.
Source code in opensr_srgan/model/generators/cgan_generator.py
def sample_noise(
self,
batch_size: int,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
"""
Sample a latent noise tensor consistent with the generator configuration.
Parameters
----------
batch_size : int
Number of noise vectors to generate.
device : torch.device, optional
Device on which to allocate the tensor. Defaults to the model's current device.
dtype : torch.dtype, optional
Tensor dtype. Defaults to the model's parameter dtype.
Returns
-------
torch.Tensor
Random latent tensor sampled from a standard normal distribution,
shape (B, D) where B = `batch_size` and D = `self.noise_dim`.
Notes
-----
- Used for stochastic generation when no latent vector is provided to ``forward()``.
- Ensures type and device consistency with the current model parameters.
- The resulting noise can be reused to reproduce identical stochastic outputs.
"""
if device is None:
device = next(self.parameters()).device
if dtype is None:
dtype = next(self.parameters()).dtype
return torch.randn(batch_size, self.noise_dim, device=device, dtype=dtype)
forward(lr, noise=None, return_noise=False)
¶
Forward pass of the stochastic super-resolution generator.
Parameters¶
lr : torch.Tensor
Low-resolution input tensor of shape (B, C, H, W).
noise : torch.Tensor, optional
Latent noise tensor of shape (B, D), where D = noise_dim.
If None, a random vector is sampled internally.
return_noise : bool, default=False
If True, returns both the super-resolved image and the
latent noise used for generation.
Returns¶
torch.Tensor or (torch.Tensor, torch.Tensor)
- sr: Super-resolved output image of shape (B, C, sH, sW),
where s is the upscaling factor.
- Optionally, noise: The latent vector used (if return_noise=True).
Notes¶
- The latent code is broadcast and applied to all residual blocks.
- Supports both deterministic (fixed noise) and stochastic (random noise) inference modes.
- The upsampling is performed via sub-pixel convolution (PixelShuffle)
defined in
make_upsampler().
Source code in opensr_srgan/model/generators/cgan_generator.py
def forward(
self,
lr: torch.Tensor,
noise: Optional[torch.Tensor] = None,
return_noise: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass of the stochastic super-resolution generator.
Parameters
----------
lr : torch.Tensor
Low-resolution input tensor of shape (B, C, H, W).
noise : torch.Tensor, optional
Latent noise tensor of shape (B, D), where D = `noise_dim`.
If None, a random vector is sampled internally.
return_noise : bool, default=False
If True, returns both the super-resolved image and the
latent noise used for generation.
Returns
-------
torch.Tensor or (torch.Tensor, torch.Tensor)
- ``sr``: Super-resolved output image of shape (B, C, sH, sW),
where s is the upscaling factor.
- Optionally, ``noise``: The latent vector used (if `return_noise=True`).
Notes
-----
- The latent code is broadcast and applied to all residual blocks.
- Supports both deterministic (fixed noise) and stochastic (random noise)
inference modes.
- The upsampling is performed via sub-pixel convolution (PixelShuffle)
defined in `make_upsampler()`.
"""
if noise is None:
noise = torch.randn(
lr.size(0),
self.noise_dim,
device=lr.device,
dtype=lr.dtype,
)
features = self.head(lr)
residual = features
for block in self.body:
residual = block(residual, noise)
residual = self.body_tail(residual)
features = features + residual
features = self.upsampler(features)
sr = self.tail(features)
if return_noise:
return sr, noise
return sr
ESRGANGenerator
¶
Bases: Module
ESRGAN generator network for single-image super-resolution.
This implementation follows the design proposed in "ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks" (Wang et al., 2018). It replaces traditional residual blocks with Residual-in-Residual Dense Blocks (RRDBs), omits batch normalization, and applies residual scaling for improved stability and perceptual quality.
The architecture can be summarized as
Input → Conv(3×3) → [RRDB × N] → Conv(3×3) → (PixelShuffle upsampling × scale) → Conv(3×3) → LeakyReLU → Conv(3×3) → Output
Parameters¶
in_channels : int, default=3
Number of input channels (e.g., 3 for RGB or 4 for RGB-NIR).
out_channels : int or None, default=None
Number of output channels. If None, defaults to in_channels.
n_features : int, default=64
Number of feature maps in the base convolutional layers.
n_blocks : int, default=23
Number of Residual-in-Residual Dense Blocks (RRDBs) stacked in the network body.
growth_channels : int, default=32
Number of intermediate growth channels used within each RRDB.
res_scale : float, default=0.2
Residual scaling factor applied to stabilize deep residual learning.
scale : int, default=4
Upscaling factor. Must be a power of two (1, 2, 4, 8, ...).
Attributes¶
conv_first : nn.Conv2d
Initial 3×3 convolutional layer that extracts shallow features from the LR input.
body : nn.Sequential
Sequential container of RRDB blocks performing deep feature extraction.
conv_body : nn.Conv2d
3×3 convolution applied after the RRDB stack to merge body features.
upsampler : nn.Module
PixelShuffle-based upsampling module (from make_upsampler) for the configured scale.
If scale == 1, replaced by an identity mapping.
conv_hr : nn.Conv2d
3×3 convolution used for high-resolution refinement.
activation : nn.LeakyReLU
LeakyReLU activation (slope=0.2) applied after conv_hr.
conv_last : nn.Conv2d
Final 3×3 projection layer mapping features back to output space.
scale : int
Model’s upscaling factor.
n_blocks : int
Number of RRDB blocks in the body.
n_features : int
Number of feature maps in the feature extraction layers.
growth_channels : int
Growth channels per RRDB.
Raises¶
ValueError
If scale is not a power of two or if n_blocks < 1.
Examples¶
from opensr_srgan.model.generators.esrgan_generator import ESRGANGenerator model = ESRGANGenerator(in_channels=3, scale=4) x = torch.randn(1, 3, 64, 64) y = model(x) y.shape torch.Size([1, 3, 256, 256])
References¶
- Wang et al., ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks,
Source code in opensr_srgan/model/generators/esrgan.py
class ESRGANGenerator(nn.Module):
"""
ESRGAN generator network for single-image super-resolution.
This implementation follows the design proposed in
*"ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks"*
(Wang et al., 2018). It replaces traditional residual blocks with
Residual-in-Residual Dense Blocks (RRDBs), omits batch normalization,
and applies residual scaling for improved stability and perceptual quality.
The architecture can be summarized as:
Input → Conv(3×3) → [RRDB × N] → Conv(3×3)
→ (PixelShuffle upsampling × scale) → Conv(3×3) → LeakyReLU → Conv(3×3) → Output
Parameters
----------
in_channels : int, default=3
Number of input channels (e.g., 3 for RGB or 4 for RGB-NIR).
out_channels : int or None, default=None
Number of output channels. If ``None``, defaults to ``in_channels``.
n_features : int, default=64
Number of feature maps in the base convolutional layers.
n_blocks : int, default=23
Number of Residual-in-Residual Dense Blocks (RRDBs) stacked in the network body.
growth_channels : int, default=32
Number of intermediate growth channels used within each RRDB.
res_scale : float, default=0.2
Residual scaling factor applied to stabilize deep residual learning.
scale : int, default=4
Upscaling factor. Must be a power of two (1, 2, 4, 8, ...).
Attributes
----------
conv_first : nn.Conv2d
Initial 3×3 convolutional layer that extracts shallow features from the LR input.
body : nn.Sequential
Sequential container of ``RRDB`` blocks performing deep feature extraction.
conv_body : nn.Conv2d
3×3 convolution applied after the RRDB stack to merge body features.
upsampler : nn.Module
PixelShuffle-based upsampling module (from ``make_upsampler``) for the configured scale.
If ``scale == 1``, replaced by an identity mapping.
conv_hr : nn.Conv2d
3×3 convolution used for high-resolution refinement.
activation : nn.LeakyReLU
LeakyReLU activation (slope=0.2) applied after ``conv_hr``.
conv_last : nn.Conv2d
Final 3×3 projection layer mapping features back to output space.
scale : int
Model’s upscaling factor.
n_blocks : int
Number of RRDB blocks in the body.
n_features : int
Number of feature maps in the feature extraction layers.
growth_channels : int
Growth channels per RRDB.
Raises
------
ValueError
If ``scale`` is not a power of two or if ``n_blocks < 1``.
Examples
--------
>>> from opensr_srgan.model.generators.esrgan_generator import ESRGANGenerator
>>> model = ESRGANGenerator(in_channels=3, scale=4)
>>> x = torch.randn(1, 3, 64, 64)
>>> y = model(x)
>>> y.shape
torch.Size([1, 3, 256, 256])
References
----------
- Wang et al., *ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks*,
"""
def __init__(
self,
*,
in_channels: int = 3,
out_channels: int | None = None,
n_features: int = 64,
n_blocks: int = 23,
growth_channels: int = 32,
res_scale: float = 0.2,
scale: int = 4,
) -> None:
super().__init__()
if scale < 1 or scale & (scale - 1) != 0:
raise ValueError("ESRGANGenerator only supports power-of-two scales (1, 2, 4, 8, ...).")
if n_blocks < 1:
raise ValueError("ESRGANGenerator requires at least one RRDB block.")
if out_channels is None:
out_channels = in_channels
self.scale = scale
self.n_blocks = n_blocks
self.n_features = n_features
self.growth_channels = growth_channels
body_blocks = [RRDB(n_features, growth_channels, res_scale=res_scale) for _ in range(n_blocks)]
self.conv_first = nn.Conv2d(in_channels, n_features, 3, padding=1)
self.body = nn.Sequential(*body_blocks)
self.conv_body = nn.Conv2d(n_features, n_features, 3, padding=1)
self.upsampler = nn.Identity() if scale == 1 else make_upsampler(n_features, scale)
self.conv_hr = nn.Conv2d(n_features, n_features, 3, padding=1)
self.activation = nn.LeakyReLU(0.2, inplace=True)
self.conv_last = nn.Conv2d(n_features, out_channels, 3, padding=1)
def forward(self, x: Tensor) -> Tensor:
"""
Forward pass of the ESRGAN generator.
Parameters
----------
x : torch.Tensor
Low-resolution input image tensor of shape (B, C, H, W).
Returns
-------
torch.Tensor
Super-resolved output tensor of shape (B, C, sH, sW),
where `s` is the upscaling factor defined by ``self.scale``.
Notes
-----
- The generator first encodes input features, processes them through
a sequence of Residual-in-Residual Dense Blocks (RRDBs), and then
performs upsampling via sub-pixel convolutions.
- A long skip connection adds the shallow features from the input
stem to the deep body output before upsampling.
- The final activation is a LeakyReLU followed by a 3×3 convolution
that projects features back to image space.
- When ``scale == 1``, the upsampling block is replaced by an identity.
"""
first = self.conv_first(x)
trunk = self.body(first)
body_out = self.conv_body(trunk)
feat = first + body_out
feat = self.upsampler(feat)
feat = self.activation(self.conv_hr(feat))
return self.conv_last(feat)
forward(x)
¶
Forward pass of the ESRGAN generator.
Parameters¶
x : torch.Tensor Low-resolution input image tensor of shape (B, C, H, W).
Returns¶
torch.Tensor
Super-resolved output tensor of shape (B, C, sH, sW),
where s is the upscaling factor defined by self.scale.
Notes¶
- The generator first encodes input features, processes them through a sequence of Residual-in-Residual Dense Blocks (RRDBs), and then performs upsampling via sub-pixel convolutions.
- A long skip connection adds the shallow features from the input stem to the deep body output before upsampling.
- The final activation is a LeakyReLU followed by a 3×3 convolution that projects features back to image space.
- When
scale == 1, the upsampling block is replaced by an identity.
Source code in opensr_srgan/model/generators/esrgan.py
def forward(self, x: Tensor) -> Tensor:
"""
Forward pass of the ESRGAN generator.
Parameters
----------
x : torch.Tensor
Low-resolution input image tensor of shape (B, C, H, W).
Returns
-------
torch.Tensor
Super-resolved output tensor of shape (B, C, sH, sW),
where `s` is the upscaling factor defined by ``self.scale``.
Notes
-----
- The generator first encodes input features, processes them through
a sequence of Residual-in-Residual Dense Blocks (RRDBs), and then
performs upsampling via sub-pixel convolutions.
- A long skip connection adds the shallow features from the input
stem to the deep body output before upsampling.
- The final activation is a LeakyReLU followed by a 3×3 convolution
that projects features back to image space.
- When ``scale == 1``, the upsampling block is replaced by an identity.
"""
first = self.conv_first(x)
trunk = self.body(first)
body_out = self.conv_body(trunk)
feat = first + body_out
feat = self.upsampler(feat)
feat = self.activation(self.conv_hr(feat))
return self.conv_last(feat)
FlexibleGenerator
¶
Bases: Module
Modular super-resolution generator with pluggable residual blocks.
Provides a single, drop-in generator backbone that can be instantiated with different residual block families—res, rcab, rrdb, or lka—all built from a shared interface. The network follows a head → body → tail design with learnable upsampling:
Head (large-kernel conv) → N × Block(type=block_type) → Body tail conv
→ Skip add → Upsampler (×2/×4/×8) → Output conv
Use this when you want to compare architectural choices or sweep hyper- parameters without changing call sites.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_channels
|
int
|
Number of input channels (e.g., RGB=3, RGB-NIR=4/6). |
6
|
n_channels
|
int
|
Base feature width used throughout the backbone. |
96
|
n_blocks
|
int
|
Number of residual blocks in the body. |
32
|
small_kernel
|
int
|
Kernel size for body/ tail convolutions. |
3
|
large_kernel
|
int
|
Kernel size for head/ output convolutions. |
9
|
scale
|
int
|
Upscaling factor; one of {2, 4, 8}. |
8
|
block_type
|
str
|
Residual block family in {"res","rcab","rrdb","lka"}. |
'rcab'
|
Attributes:
| Name | Type | Description |
|---|---|---|
head |
Sequential
|
Large-receptive-field stem conv + activation. |
body |
Sequential
|
Sequence of residual blocks of the selected type. |
body_tail |
Conv2d
|
Fusion conv after the residual stack. |
upsampler |
Module
|
PixelShuffle-style learnable upsampling to |
tail |
Conv2d
|
Final projection to |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
Example
g = FlexibleGenerator(in_channels=3, block_type="rcab", scale=4) x = torch.randn(1, 3, 64, 64) y = g(x) # (1, 3, 256, 256)
Source code in opensr_srgan/model/generators/flexible_generator.py
class FlexibleGenerator(nn.Module):
"""Modular super-resolution generator with pluggable residual blocks.
Provides a single, drop-in generator backbone that can be instantiated with
different residual block families—**res**, **rcab**, **rrdb**, or **lka**—all
built from a shared interface. The network follows a head → body → tail
design with learnable upsampling:
Head (large-kernel conv) → N × Block(type=block_type) → Body tail conv
→ Skip add → Upsampler (×2/×4/×8) → Output conv
Use this when you want to compare architectural choices or sweep hyper-
parameters without changing call sites.
Args:
in_channels (int): Number of input channels (e.g., RGB=3, RGB-NIR=4/6).
n_channels (int): Base feature width used throughout the backbone.
n_blocks (int): Number of residual blocks in the body.
small_kernel (int): Kernel size for body/ tail convolutions.
large_kernel (int): Kernel size for head/ output convolutions.
scale (int): Upscaling factor; one of {2, 4, 8}.
block_type (str): Residual block family in {"res","rcab","rrdb","lka"}.
Attributes:
head (nn.Sequential): Large-receptive-field stem conv + activation.
body (nn.Sequential): Sequence of residual blocks of the selected type.
body_tail (nn.Conv2d): Fusion conv after the residual stack.
upsampler (nn.Module): PixelShuffle-style learnable upsampling to `scale`.
tail (nn.Conv2d): Final projection to `in_channels`.
Raises:
ValueError: If `scale` is not in {2, 4, 8} or `block_type` is unknown.
Example:
>>> g = FlexibleGenerator(in_channels=3, block_type="rcab", scale=4)
>>> x = torch.randn(1, 3, 64, 64)
>>> y = g(x) # (1, 3, 256, 256)
"""
def __init__(
self,
in_channels: int = 6,
n_channels: int = 96,
n_blocks: int = 32,
small_kernel: int = 3,
large_kernel: int = 9,
scale: int = 8,
block_type: str = "rcab",
) -> None:
super().__init__()
if scale not in {2, 4, 8}:
raise ValueError("scale must be one of {2, 4, 8}")
block_key = block_type.lower()
if block_key not in _BLOCK_REGISTRY:
raise ValueError("block_type must be one of {'res', 'rcab', 'rrdb', 'lka'}")
self.scale = scale
padding_large = large_kernel // 2
self.head = nn.Sequential(
nn.Conv2d(in_channels, n_channels, large_kernel, padding=padding_large),
nn.PReLU(),
)
block_factory = _BLOCK_REGISTRY[block_key]
self.body = nn.Sequential(
*[block_factory(n_channels, small_kernel) for _ in range(n_blocks)]
)
self.body_tail = nn.Conv2d(
n_channels, n_channels, small_kernel, padding=small_kernel // 2
)
self.upsampler = make_upsampler(n_channels, scale)
self.tail = nn.Conv2d(
n_channels, in_channels, large_kernel, padding=padding_large
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the flexible SR generator.
Parameters
----------
x : torch.Tensor
Low-resolution input tensor of shape (B, C, H, W).
Returns
-------
torch.Tensor
Super-resolved output tensor of shape (B, C, sH, sW),
where `s` ∈ {2, 4, 8} is the configured upscaling factor.
Workflow
--------
1) Extract shallow features with the head conv.
2) Transform via a stack of residual blocks (type = `block_type`).
3) Fuse with a 3×3 body tail conv and add a long skip.
4) Upsample via PixelShuffle-based `make_upsampler`.
5) Project back to image space with the output conv.
"""
feat = self.head(x)
res = self.body(feat)
res = self.body_tail(res)
feat = feat + res
feat = self.upsampler(feat)
return self.tail(feat)
forward(x)
¶
Forward pass of the flexible SR generator.
Parameters¶
x : torch.Tensor Low-resolution input tensor of shape (B, C, H, W).
Returns¶
torch.Tensor
Super-resolved output tensor of shape (B, C, sH, sW),
where s ∈ {2, 4, 8} is the configured upscaling factor.
Workflow¶
1) Extract shallow features with the head conv.
2) Transform via a stack of residual blocks (type = block_type).
3) Fuse with a 3×3 body tail conv and add a long skip.
4) Upsample via PixelShuffle-based make_upsampler.
5) Project back to image space with the output conv.
Source code in opensr_srgan/model/generators/flexible_generator.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the flexible SR generator.
Parameters
----------
x : torch.Tensor
Low-resolution input tensor of shape (B, C, H, W).
Returns
-------
torch.Tensor
Super-resolved output tensor of shape (B, C, sH, sW),
where `s` ∈ {2, 4, 8} is the configured upscaling factor.
Workflow
--------
1) Extract shallow features with the head conv.
2) Transform via a stack of residual blocks (type = `block_type`).
3) Fuse with a 3×3 body tail conv and add a long skip.
4) Upsample via PixelShuffle-based `make_upsampler`.
5) Project back to image space with the output conv.
"""
feat = self.head(x)
res = self.body(feat)
res = self.body_tail(res)
feat = feat + res
feat = self.upsampler(feat)
return self.tail(feat)
ConvolutionalBlock
¶
Bases: Module
A convolutional block comprised of Conv → (BN) → (Activation).
Source code in opensr_srgan/model/model_blocks/__init__.py
class ConvolutionalBlock(nn.Module):
"""A convolutional block comprised of Conv → (BN) → (Activation)."""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
batch_norm: bool = False,
activation: str | None = None,
) -> None:
super().__init__()
act = activation.lower() if activation is not None else None
if act is not None:
if act not in {"prelu", "leakyrelu", "tanh"}:
raise AssertionError("activation must be one of {'prelu', 'leakyrelu', 'tanh'}")
layers: list[nn.Module] = [
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=kernel_size // 2,
)
]
if batch_norm:
layers.append(nn.BatchNorm2d(num_features=out_channels))
if act == "prelu":
layers.append(nn.PReLU())
elif act == "leakyrelu":
layers.append(nn.LeakyReLU(0.2))
elif act == "tanh":
layers.append(nn.Tanh())
self.conv_block = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv_block(x)
SubPixelConvolutionalBlock
¶
Bases: Module
Conv → PixelShuffle → PReLU upsampling block.
Source code in opensr_srgan/model/model_blocks/__init__.py
class SubPixelConvolutionalBlock(nn.Module):
"""Conv → PixelShuffle → PReLU upsampling block."""
def __init__(
self,
kernel_size: int = 3,
n_channels: int = 64,
scaling_factor: int = 2,
) -> None:
super().__init__()
self.conv = nn.Conv2d(
in_channels=n_channels,
out_channels=n_channels * (scaling_factor**2),
kernel_size=kernel_size,
padding=kernel_size // 2,
)
self.pixel_shuffle = nn.PixelShuffle(upscale_factor=scaling_factor)
self.prelu = nn.PReLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv(x)
x = self.pixel_shuffle(x)
x = self.prelu(x)
return x
ResidualBlock
¶
Bases: Module
BN-enabled residual block used in the original SRResNet.
Source code in opensr_srgan/model/model_blocks/__init__.py
class ResidualBlock(nn.Module):
"""BN-enabled residual block used in the original SRResNet."""
def __init__(self, kernel_size: int = 3, n_channels: int = 64) -> None:
super().__init__()
self.conv_block1 = ConvolutionalBlock(
in_channels=n_channels,
out_channels=n_channels,
kernel_size=kernel_size,
batch_norm=True,
activation="PReLu",
)
self.conv_block2 = ConvolutionalBlock(
in_channels=n_channels,
out_channels=n_channels,
kernel_size=kernel_size,
batch_norm=True,
activation=None,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = self.conv_block1(x)
x = self.conv_block2(x)
return x + residual
ResidualBlockNoBN
¶
Bases: Module
Residual block variant without batch norm and with residual scaling.
Source code in opensr_srgan/model/model_blocks/__init__.py
class ResidualBlockNoBN(nn.Module):
"""Residual block variant without batch norm and with residual scaling."""
def __init__(
self,
n_channels: int = 64,
kernel_size: int = 3,
res_scale: float = 0.2,
) -> None:
super().__init__()
padding = kernel_size // 2
self.body = nn.Sequential(
nn.Conv2d(n_channels, n_channels, kernel_size, padding=padding),
nn.PReLU(),
nn.Conv2d(n_channels, n_channels, kernel_size, padding=padding),
)
self.res_scale = res_scale
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.res_scale * self.body(x)
RCAB
¶
Bases: Module
Residual Channel Attention Block (no BN).
Source code in opensr_srgan/model/model_blocks/__init__.py
class RCAB(nn.Module):
"""Residual Channel Attention Block (no BN)."""
def __init__(
self,
n_channels: int = 64,
kernel_size: int = 3,
reduction: int = 16,
res_scale: float = 0.2,
) -> None:
super().__init__()
padding = kernel_size // 2
self.conv1 = nn.Conv2d(n_channels, n_channels, kernel_size, padding=padding)
self.act = nn.PReLU()
self.conv2 = nn.Conv2d(n_channels, n_channels, kernel_size, padding=padding)
self.se = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(n_channels, n_channels // reduction, 1),
nn.ReLU(inplace=True),
nn.Conv2d(n_channels // reduction, n_channels, 1),
nn.Sigmoid(),
)
self.res_scale = res_scale
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = self.conv1(x)
y = self.act(y)
y = self.conv2(y)
w = self.se(y)
return x + self.res_scale * (y * w)
DenseBlock5
¶
Bases: Module
ESRGAN-style dense block with five convolutions.
Source code in opensr_srgan/model/model_blocks/__init__.py
class DenseBlock5(nn.Module):
"""ESRGAN-style dense block with five convolutions."""
def __init__(self, n_features: int = 64, growth_channels: int = 32, kernel_size: int = 3) -> None:
super().__init__()
padding = kernel_size // 2
self.act = nn.LeakyReLU(0.2, inplace=True)
self.c1 = nn.Conv2d(n_features, growth_channels, kernel_size, padding=padding)
self.c2 = nn.Conv2d(n_features + growth_channels, growth_channels, kernel_size, padding=padding)
self.c3 = nn.Conv2d(n_features + 2 * growth_channels, growth_channels, kernel_size, padding=padding)
self.c4 = nn.Conv2d(n_features + 3 * growth_channels, growth_channels, kernel_size, padding=padding)
self.c5 = nn.Conv2d(n_features + 4 * growth_channels, n_features, kernel_size, padding=padding)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x1 = self.act(self.c1(x))
x2 = self.act(self.c2(torch.cat([x, x1], dim=1)))
x3 = self.act(self.c3(torch.cat([x, x1, x2], dim=1)))
x4 = self.act(self.c4(torch.cat([x, x1, x2, x3], dim=1)))
x5 = self.c5(torch.cat([x, x1, x2, x3, x4], dim=1))
return x5
RRDB
¶
Bases: Module
Residual-in-Residual Dense Block.
Source code in opensr_srgan/model/model_blocks/__init__.py
class RRDB(nn.Module):
"""Residual-in-Residual Dense Block."""
def __init__(self, n_features: int = 64, growth_channels: int = 32, res_scale: float = 0.2) -> None:
super().__init__()
self.db1 = DenseBlock5(n_features, growth_channels)
self.db2 = DenseBlock5(n_features, growth_channels)
self.db3 = DenseBlock5(n_features, growth_channels)
self.res_scale = res_scale
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = self.db1(x)
y = self.db2(y)
y = self.db3(y)
return x + self.res_scale * y
LKA
¶
Bases: Module
Lightweight Large-Kernel Attention module.
Source code in opensr_srgan/model/model_blocks/__init__.py
class LKA(nn.Module):
"""Lightweight Large-Kernel Attention module."""
def __init__(self, n_channels: int = 64) -> None:
super().__init__()
self.dw5 = nn.Conv2d(n_channels, n_channels, 5, padding=2, groups=n_channels)
self.dw7d = nn.Conv2d(n_channels, n_channels, 7, padding=9, dilation=3, groups=n_channels)
self.pw = nn.Conv2d(n_channels, n_channels, 1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
attn = self.dw5(x)
attn = self.dw7d(attn)
attn = self.pw(attn)
return x * torch.sigmoid(attn)
LKAResBlock
¶
Bases: Module
Residual block incorporating Large-Kernel Attention.
Source code in opensr_srgan/model/model_blocks/__init__.py
class LKAResBlock(nn.Module):
"""Residual block incorporating Large-Kernel Attention."""
def __init__(self, n_channels: int = 64, kernel_size: int = 3, res_scale: float = 0.2) -> None:
super().__init__()
padding = kernel_size // 2
self.conv1 = nn.Conv2d(n_channels, n_channels, kernel_size, padding=padding)
self.act = nn.PReLU()
self.lka = LKA(n_channels)
self.conv2 = nn.Conv2d(n_channels, n_channels, kernel_size, padding=padding)
self.res_scale = res_scale
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = self.conv1(x)
y = self.act(y)
y = self.lka(y)
y = self.conv2(y)
return x + self.res_scale * y
build_generator(config)
¶
Instantiate a generator module from the provided configuration.
Parameters¶
config : Any
Resolved configuration object with at least:
- config.Model.in_bands
- config.Generator.model_type (or SRResNet block alias)
- Additional generator-specific fields (e.g., n_blocks, n_channels, scaling_factor).
Returns¶
torch.nn.Module A fully constructed generator (SRResNet/Flexible, ESRGAN, or StochasticGenerator).
Raises¶
ValueError If the model type or block variant cannot be resolved to a known implementation.
Notes¶
- Accepts legacy configs where
Generator.model_typewas a block alias (e.g., "rcab"). - Non-applicable options are reported (not fatal) for clarity.
Source code in opensr_srgan/model/generators/factory.py
def build_generator(config: Any) -> nn.Module:
"""
Instantiate a generator module from the provided configuration.
Parameters
----------
config : Any
Resolved configuration object with at least:
- `config.Model.in_bands`
- `config.Generator.model_type` (or SRResNet block alias)
- Additional generator-specific fields (e.g., n_blocks, n_channels, scaling_factor).
Returns
-------
torch.nn.Module
A fully constructed generator (SRResNet/Flexible, ESRGAN, or StochasticGenerator).
Raises
------
ValueError
If the model type or block variant cannot be resolved to a known implementation.
Notes
-----
- Accepts legacy configs where `Generator.model_type` was a block alias (e.g., "rcab").
- Non-applicable options are reported (not fatal) for clarity.
"""
generator_cfg = config.Generator
model_cfg = config.Model
raw_model_type = str(getattr(generator_cfg, "model_type", "srresnet"))
model_type = _match_alias(raw_model_type, _MODEL_TYPE_ALIASES)
# Legacy support: configs that directly specified the block variant inside
# ``model_type`` (e.g., "rcab") should still build the flexible generator.
if model_type is None:
srresnet_block = _match_alias(raw_model_type, _SRRESNET_BLOCK_ALIASES)
if srresnet_block is not None:
model_type = "srresnet"
else:
raise ValueError(
"Unknown generator model type '{model_type}'. Expected one of: {options}.".format(
model_type=raw_model_type,
options=", ".join(sorted(_MODEL_TYPE_ALIASES)),
)
)
else:
srresnet_block = None
in_channels = int(getattr(model_cfg, "in_bands"))
scale = int(getattr(generator_cfg, "scaling_factor", 4))
if model_type == "srresnet":
large_kernel = int(getattr(generator_cfg, "large_kernel_size", 9))
small_kernel = int(getattr(generator_cfg, "small_kernel_size", 3))
n_channels = int(getattr(generator_cfg, "n_channels", 64))
n_blocks = int(getattr(generator_cfg, "n_blocks", 16))
block_variant = _resolve_srresnet_block(generator_cfg, srresnet_block)
if block_variant == "standard":
return SRResNetGenerator(
in_channels=in_channels,
large_kernel_size=large_kernel,
small_kernel_size=small_kernel,
n_channels=n_channels,
n_blocks=n_blocks,
scaling_factor=scale,
)
return FlexibleGenerator(
in_channels=in_channels,
n_channels=n_channels,
n_blocks=n_blocks,
small_kernel=small_kernel,
large_kernel=large_kernel,
scale=scale,
block_type=block_variant,
)
if model_type == "stochastic_gan":
large_kernel = int(getattr(generator_cfg, "large_kernel_size", 9))
small_kernel = int(getattr(generator_cfg, "small_kernel_size", 3))
n_channels = int(getattr(generator_cfg, "n_channels", 64))
n_blocks = int(getattr(generator_cfg, "n_blocks", 16))
noise_dim = int(getattr(generator_cfg, "noise_dim", 128))
res_scale = float(getattr(generator_cfg, "res_scale", 0.2))
_warn_overridden_options(
"Generator",
"stochastic_gan",
_collect_overridden(generator_cfg, "block_type"),
)
return StochasticGenerator(
in_channels=in_channels,
n_channels=n_channels,
n_blocks=n_blocks,
small_kernel=small_kernel,
large_kernel=large_kernel,
scale=scale,
noise_dim=noise_dim,
res_scale=res_scale,
)
if model_type == "esrgan":
n_channels = int(getattr(generator_cfg, "n_channels", 64))
n_rrdb = int(getattr(generator_cfg, "n_blocks", 23))
growth_channels = int(getattr(generator_cfg, "growth_channels", 32))
res_scale = float(getattr(generator_cfg, "res_scale", 0.2))
out_channels = int(getattr(generator_cfg, "out_channels", in_channels))
_warn_overridden_options(
"Generator",
"esrgan",
_collect_overridden(
generator_cfg,
"block_type",
"large_kernel_size",
"small_kernel_size",
"noise_dim",
),
)
return ESRGANGenerator(
in_channels=in_channels,
out_channels=out_channels,
n_features=n_channels,
n_blocks=n_rrdb,
growth_channels=growth_channels,
res_scale=res_scale,
scale=scale,
)
raise ValueError(
"Unhandled generator model type '{model_type}'. Expected one of: {options}.".format(
model_type=raw_model_type,
options=", ".join(sorted(_MODEL_TYPE_ALIASES)),
)
)
make_upsampler(n_channels, scale)
¶
Create a pixel-shuffle upsampler matching the flexible generator implementation.
Source code in opensr_srgan/model/model_blocks/__init__.py
def make_upsampler(n_channels: int, scale: int) -> nn.Sequential:
"""Create a pixel-shuffle upsampler matching the flexible generator implementation."""
stages: list[nn.Module] = []
for _ in range(int(math.log2(scale))):
stages.extend(
[
nn.Conv2d(n_channels, n_channels * 4, 3, padding=1),
nn.PixelShuffle(2),
nn.PReLU(),
]
)
return nn.Sequential(*stages)
Discriminator architectures¶
Discriminator architectures for SRGAN.
Discriminator
¶
Bases: Module
Standard SRGAN discriminator as defined in the original paper.
The network alternates between stride-1 and stride-2 convolutional blocks, gradually reducing spatial resolution while increasing channel depth. The resulting features are globally pooled and classified via two fully connected layers to produce a single realism score.
Parameters¶
in_channels : int, default=3 Number of input channels (e.g., 3 for RGB, 4 for multispectral SR). n_blocks : int, default=8 Number of convolutional blocks to stack. Must be >= 1. use_spectral_norm : bool, default=True Whether to apply spectral normalization to convolutional and linear layers for improved Lipschitz control and training stability.
Attributes¶
conv_blocks : nn.Sequential Sequential stack of convolutional feature extraction blocks. adaptive_pool : nn.AdaptiveAvgPool2d Global pooling layer reducing spatial dimensions to 6×6. fc1 : nn.Linear Fully connected layer mapping pooled features to hidden representation. fc2 : nn.Linear Output layer producing a single realism score. leaky_relu : nn.LeakyReLU Activation used in fully connected layers. base_channels : int Number of channels in the first convolutional layer (default 64). kernel_size : int Kernel size used in all convolutional blocks (default 3). fc_size : int Hidden dimension of the fully connected layer (default 1024). n_blocks : int Total number of convolutional blocks. use_spectral_norm : bool Indicates whether spectral normalization wraps the convolutional and linear layers.
Raises¶
ValueError
If n_blocks < 1.
Examples¶
disc = Discriminator(in_channels=3, n_blocks=8) x = torch.randn(4, 3, 96, 96) y = disc(x) y.shape torch.Size([4, 1])
Source code in opensr_srgan/model/discriminators/srgan_discriminator.py
class Discriminator(nn.Module):
"""
Standard SRGAN discriminator as defined in the original paper.
The network alternates between stride-1 and stride-2 convolutional
blocks, gradually reducing spatial resolution while increasing
channel depth. The resulting features are globally pooled and
classified via two fully connected layers to produce a single
realism score.
Parameters
----------
in_channels : int, default=3
Number of input channels (e.g., 3 for RGB, 4 for multispectral SR).
n_blocks : int, default=8
Number of convolutional blocks to stack. Must be >= 1.
use_spectral_norm : bool, default=True
Whether to apply spectral normalization to convolutional and linear layers
for improved Lipschitz control and training stability.
Attributes
----------
conv_blocks : nn.Sequential
Sequential stack of convolutional feature extraction blocks.
adaptive_pool : nn.AdaptiveAvgPool2d
Global pooling layer reducing spatial dimensions to 6×6.
fc1 : nn.Linear
Fully connected layer mapping pooled features to hidden representation.
fc2 : nn.Linear
Output layer producing a single realism score.
leaky_relu : nn.LeakyReLU
Activation used in fully connected layers.
base_channels : int
Number of channels in the first convolutional layer (default 64).
kernel_size : int
Kernel size used in all convolutional blocks (default 3).
fc_size : int
Hidden dimension of the fully connected layer (default 1024).
n_blocks : int
Total number of convolutional blocks.
use_spectral_norm : bool
Indicates whether spectral normalization wraps the convolutional and linear layers.
Raises
------
ValueError
If `n_blocks` < 1.
Examples
--------
>>> disc = Discriminator(in_channels=3, n_blocks=8)
>>> x = torch.randn(4, 3, 96, 96)
>>> y = disc(x)
>>> y.shape
torch.Size([4, 1])
"""
def __init__(
self,
in_channels: int = 3,
n_blocks: int = 8,
use_spectral_norm: bool = True,
) -> None:
super().__init__()
if n_blocks < 1:
raise ValueError("The SRGAN discriminator requires at least one block.")
kernel_size = 3
base_channels = 64
fc_size = 1024
conv_blocks: list[nn.Module] = []
current_in = in_channels
out_channels = base_channels
for i in range(n_blocks):
if i == 0:
out_channels = base_channels
elif i % 2 == 0:
out_channels = current_in * 2
else:
out_channels = current_in
block = ConvolutionalBlock(
in_channels=current_in,
out_channels=out_channels,
kernel_size=kernel_size,
stride=1 if i % 2 == 0 else 2,
batch_norm=i != 0,
activation="LeakyReLu",
)
if use_spectral_norm:
self._apply_spectral_norm(block)
conv_blocks.append(block)
current_in = out_channels
self.conv_blocks = nn.Sequential(*conv_blocks)
self.adaptive_pool = nn.AdaptiveAvgPool2d((6, 6))
self.fc1 = nn.Linear(out_channels * 6 * 6, fc_size)
self.leaky_relu = nn.LeakyReLU(0.2)
self.fc2 = nn.Linear(fc_size, 1)
if use_spectral_norm:
self.fc1 = spectral_norm(self.fc1)
self.fc2 = spectral_norm(self.fc2)
self.base_channels = base_channels
self.kernel_size = kernel_size
self.fc_size = fc_size
self.n_blocks = n_blocks
self.use_spectral_norm = use_spectral_norm
def forward(self, imgs: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the SRGAN discriminator.
Parameters
----------
imgs : torch.Tensor
Input image tensor of shape (B, C, H, W).
Returns
-------
torch.Tensor
Realism logits of shape (B, 1), where higher values
indicate greater likelihood of being a real image.
"""
batch_size = imgs.size(0)
feats = self.conv_blocks(imgs)
pooled = self.adaptive_pool(feats)
flat = pooled.view(batch_size, -1)
hidden = self.leaky_relu(self.fc1(flat))
return self.fc2(hidden)
@staticmethod
def _apply_spectral_norm(module: nn.Module) -> None:
for submodule in module.modules():
if isinstance(submodule, (nn.Conv2d, nn.Linear)):
spectral_norm(submodule)
forward(imgs)
¶
Forward pass through the SRGAN discriminator.
Parameters¶
imgs : torch.Tensor Input image tensor of shape (B, C, H, W).
Returns¶
torch.Tensor Realism logits of shape (B, 1), where higher values indicate greater likelihood of being a real image.
Source code in opensr_srgan/model/discriminators/srgan_discriminator.py
def forward(self, imgs: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the SRGAN discriminator.
Parameters
----------
imgs : torch.Tensor
Input image tensor of shape (B, C, H, W).
Returns
-------
torch.Tensor
Realism logits of shape (B, 1), where higher values
indicate greater likelihood of being a real image.
"""
batch_size = imgs.size(0)
feats = self.conv_blocks(imgs)
pooled = self.adaptive_pool(feats)
flat = pooled.view(batch_size, -1)
hidden = self.leaky_relu(self.fc1(flat))
return self.fc2(hidden)
PatchGANDiscriminator
¶
Bases: Module
High-level convenience wrapper for the N-layer PatchGAN discriminator.
Parameters¶
input_nc : int Number of input channels. n_layers : int, default=3 Number of convolutional layers. norm_type : {"batch", "instance", "none"}, default="instance" Normalization strategy.
Attributes¶
model : NLayerDiscriminator Underlying PatchGAN model. base_channels : int Number of base feature channels (default 64). kernel_size : int Kernel size for convolutions (default 4). n_layers : int Number of downsampling layers.
Raises¶
ValueError
If n_layers < 1.
Examples¶
disc = PatchGANDiscriminator(input_nc=3, n_layers=3) x = torch.randn(4, 3, 128, 128) y = disc(x) y.shape torch.Size([4, 1, 14, 14])
Source code in opensr_srgan/model/discriminators/patchgan.py
class PatchGANDiscriminator(nn.Module):
"""
High-level convenience wrapper for the N-layer PatchGAN discriminator.
Parameters
----------
input_nc : int
Number of input channels.
n_layers : int, default=3
Number of convolutional layers.
norm_type : {"batch", "instance", "none"}, default="instance"
Normalization strategy.
Attributes
----------
model : NLayerDiscriminator
Underlying PatchGAN model.
base_channels : int
Number of base feature channels (default 64).
kernel_size : int
Kernel size for convolutions (default 4).
n_layers : int
Number of downsampling layers.
Raises
------
ValueError
If `n_layers` < 1.
Examples
--------
>>> disc = PatchGANDiscriminator(input_nc=3, n_layers=3)
>>> x = torch.randn(4, 3, 128, 128)
>>> y = disc(x)
>>> y.shape
torch.Size([4, 1, 14, 14])
"""
def __init__(
self,
input_nc: int,
n_layers: int = 3,
norm_type: str = "instance",
) -> None:
super().__init__()
if n_layers < 1:
raise ValueError("PatchGAN discriminator requires at least one layer.")
ndf = 64
norm_layer = get_norm_layer(norm_type)
self.model = NLayerDiscriminator(
input_nc, ndf=ndf, n_layers=n_layers, norm_layer=norm_layer
)
self.base_channels = ndf
self.kernel_size = 4
self.n_layers = n_layers
def forward(self, input): # type: ignore[override]
"""
Forward pass through the PatchGAN discriminator.
Parameters
----------
x : torch.Tensor
Input image tensor of shape (B, C, H, W).
Returns
-------
torch.Tensor
Patch-level realism scores, shape (B, 1, H/2ⁿ, W/2ⁿ).
"""
return self.model(input)
ESRGANDiscriminator
¶
Bases: Module
VGG-style discriminator network used in ESRGAN.
This discriminator processes high-resolution image patches and predicts a scalar realism score. It follows the VGG-style design with progressively increasing channel depth and downsampling via strided convolutions.
Parameters¶
in_channels : int, default=3 Number of input channels (e.g., 3 for RGB, 4 for RGB-NIR). base_channels : int, default=64 Number of channels in the first convolutional layer; subsequent layers scale as powers of two. linear_size : int, default=1024 Size of the intermediate fully-connected layer before the output scalar.
Attributes¶
features : nn.Sequential Convolutional feature extractor backbone. pool : nn.AdaptiveAvgPool2d Global pooling layer to aggregate spatial features. classifier : nn.Sequential Fully connected layers producing a single output value. n_layers : int Total number of convolutional blocks (for metadata/reference only).
Raises¶
ValueError
If base_channels or linear_size is not a positive integer.
Examples¶
disc = ESRGANDiscriminator(in_channels=3) x = torch.randn(8, 3, 128, 128) y = disc(x) y.shape torch.Size([8, 1])
Source code in opensr_srgan/model/discriminators/esrgan.py
class ESRGANDiscriminator(nn.Module):
"""
VGG-style discriminator network used in ESRGAN.
This discriminator processes high-resolution image patches and predicts
a scalar realism score. It follows the VGG-style design with progressively
increasing channel depth and downsampling via strided convolutions.
Parameters
----------
in_channels : int, default=3
Number of input channels (e.g., 3 for RGB, 4 for RGB-NIR).
base_channels : int, default=64
Number of channels in the first convolutional layer; subsequent layers
scale as powers of two.
linear_size : int, default=1024
Size of the intermediate fully-connected layer before the output scalar.
Attributes
----------
features : nn.Sequential
Convolutional feature extractor backbone.
pool : nn.AdaptiveAvgPool2d
Global pooling layer to aggregate spatial features.
classifier : nn.Sequential
Fully connected layers producing a single output value.
n_layers : int
Total number of convolutional blocks (for metadata/reference only).
Raises
------
ValueError
If `base_channels` or `linear_size` is not a positive integer.
Examples
--------
>>> disc = ESRGANDiscriminator(in_channels=3)
>>> x = torch.randn(8, 3, 128, 128)
>>> y = disc(x)
>>> y.shape
torch.Size([8, 1])
"""
def __init__(
self,
*,
in_channels: int = 3,
base_channels: int = 64,
linear_size: int = 1024,
) -> None:
super().__init__()
if base_channels <= 0:
raise ValueError("base_channels must be a positive integer.")
if linear_size <= 0:
raise ValueError("linear_size must be a positive integer.")
features: list[nn.Module] = [
nn.Conv2d(in_channels, base_channels, 3, 1, padding=1),
nn.LeakyReLU(0.2, inplace=True),
_conv_block(base_channels, base_channels, kernel_size=4, stride=2),
_conv_block(base_channels, base_channels * 2, kernel_size=3, stride=1),
_conv_block(base_channels * 2, base_channels * 2, kernel_size=4, stride=2),
_conv_block(base_channels * 2, base_channels * 4, kernel_size=3, stride=1),
_conv_block(base_channels * 4, base_channels * 4, kernel_size=4, stride=2),
_conv_block(base_channels * 4, base_channels * 8, kernel_size=3, stride=1),
_conv_block(base_channels * 8, base_channels * 8, kernel_size=4, stride=2),
_conv_block(base_channels * 8, base_channels * 16, kernel_size=3, stride=1),
_conv_block(base_channels * 16, base_channels * 16, kernel_size=4, stride=2),
]
self.features = nn.Sequential(*features)
self.pool = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Sequential(
nn.Linear(base_channels * 16, linear_size),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(linear_size, 1),
)
self.base_channels = base_channels
self.linear_size = linear_size
self.n_layers = 1 + 10 # first conv + stacked blocks
def forward(self, x: Tensor) -> Tensor:
"""
Forward pass through the ESRGAN discriminator.
Parameters
----------
x : torch.Tensor
Input image tensor of shape (B, C, H, W).
Returns
-------
torch.Tensor
Discriminator logits of shape (B, 1), where higher values
indicate more realistic images.
"""
feats = self.features(x)
pooled = self.pool(feats).view(x.size(0), -1)
return self.classifier(pooled)
forward(x)
¶
Forward pass through the ESRGAN discriminator.
Parameters¶
x : torch.Tensor Input image tensor of shape (B, C, H, W).
Returns¶
torch.Tensor Discriminator logits of shape (B, 1), where higher values indicate more realistic images.
Source code in opensr_srgan/model/discriminators/esrgan.py
def forward(self, x: Tensor) -> Tensor:
"""
Forward pass through the ESRGAN discriminator.
Parameters
----------
x : torch.Tensor
Input image tensor of shape (B, C, H, W).
Returns
-------
torch.Tensor
Discriminator logits of shape (B, 1), where higher values
indicate more realistic images.
"""
feats = self.features(x)
pooled = self.pool(feats).view(x.size(0), -1)
return self.classifier(pooled)
Loss components¶
GeneratorContentLoss
¶
Bases: Module
Composite generator content loss with perceptual metric selection.
Combines multiple terms to form the generator's content objective:
total = l1_w * L1 + sam_w * SAM + perc_w * Perceptual + tv_w * TV.
Also computes auxiliary quality metrics (PSNR/SSIM) for logging/evaluation.
Loss weights, max value, window sizes, band selection settings, and the perceptual backend (VGG or LPIPS) are read from the provided config.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
cfg
|
Configuration object (OmegaConf/dict-like) with fields under:
- |
required | |
testing
|
bool
|
If True, do not load pretrained VGG weights (avoids downloads in CI/tests). Defaults to False. |
False
|
Attributes:
| Name | Type | Description |
|---|---|---|
l1_w, |
sam_w, perc_w, tv_w (float
|
Loss term weights. |
max_val |
float
|
Dynamic range for PSNR/SSIM. |
ssim_win |
int
|
Window size for SSIM. |
randomize_bands |
bool
|
Whether to sample random bands for perceptual loss. |
fixed_idx |
LongTensor | None
|
Fixed 3-channel indices if not randomized. |
perc_metric |
str
|
Selected perceptual backend ( |
perceptual_model |
Module
|
Back-end feature network/metric. |
normalizer |
Normalizer
|
Shared normalizer for evaluation metrics. |
Source code in opensr_srgan/model/loss/loss.py
class GeneratorContentLoss(nn.Module):
"""Composite generator content loss with perceptual metric selection.
Combines multiple terms to form the generator's content objective:
``total = l1_w * L1 + sam_w * SAM + perc_w * Perceptual + tv_w * TV``.
Also computes auxiliary quality metrics (PSNR/SSIM) for logging/evaluation.
Loss weights, max value, window sizes, band selection settings, and the
perceptual backend (VGG or LPIPS) are read from the provided config.
Args:
cfg: Configuration object (OmegaConf/dict-like) with fields under:
- ``Training.Losses.{l1_weight,sam_weight,perceptual_weight,tv_weight}``
- ``Training.Losses.{max_val,ssim_win,randomize_bands,fixed_idx}``
- ``Training.Losses.perceptual_metric`` in {``"vgg"``, ``"lpips"``}
- ``TruncatedVGG.{i,j}`` (if VGG is used)
testing (bool, optional): If True, do not load pretrained VGG weights
(avoids downloads in CI/tests). Defaults to False.
Attributes:
l1_w, sam_w, perc_w, tv_w (float): Loss term weights.
max_val (float): Dynamic range for PSNR/SSIM.
ssim_win (int): Window size for SSIM.
randomize_bands (bool): Whether to sample random bands for perceptual loss.
fixed_idx (torch.LongTensor|None): Fixed 3-channel indices if not randomized.
perc_metric (str): Selected perceptual backend (``"vgg"`` or ``"lpips"``).
perceptual_model (nn.Module): Back-end feature network/metric.
normalizer (Normalizer): Shared normalizer for evaluation metrics.
"""
def __init__(self, cfg, testing=False):
super().__init__()
self.cfg = cfg
# --- weights & settings from config ---
# (fallback to deprecated Training.perceptual_loss_weight if needed)
self.l1_w = float(_cfg_get(cfg, ["Training", "Losses", "l1_weight"], 1.0))
self.sam_w = float(_cfg_get(cfg, ["Training", "Losses", "sam_weight"], 0.05))
perc_w_cfg = _cfg_get(
cfg,
["Training", "Losses", "perceptual_weight"],
_cfg_get(cfg, ["Training", "perceptual_loss_weight"], 0.1),
)
self.perc_w = float(perc_w_cfg)
self.tv_w = float(_cfg_get(cfg, ["Training", "Losses", "tv_weight"], 0.0))
self.max_val = float(_cfg_get(cfg, ["Training", "Losses", "max_val"], 1.0))
self.ssim_win = int(_cfg_get(cfg, ["Training", "Losses", "ssim_win"], 11))
fixed_idx = _cfg_get(cfg, ["Training", "Losses", "fixed_idx"], None)
if fixed_idx is not None:
fixed_idx = torch.as_tensor(fixed_idx, dtype=torch.long)
assert fixed_idx.numel() == 3, "fixed_idx must have length 3"
self.register_buffer(
"fixed_idx", fixed_idx if fixed_idx is not None else None, persistent=False
)
# Only init model if perceptual weight > 0
if self.perc_w != 0.0:
# --- configure perceptual metric ---
self.perc_metric = str(
_cfg_get(cfg, ["Training", "Losses", "perceptual_metric"], "vgg")
).lower()
if self.perc_metric == "vgg":
from .vgg import TruncatedVGG19
i = int(_cfg_get(cfg, ["TruncatedVGG", "i"], 5))
j = int(_cfg_get(cfg, ["TruncatedVGG", "j"], 4))
self.perceptual_model = TruncatedVGG19(i=i, j=j, weights=not testing)
elif self.perc_metric == "lpips":
import lpips
self.perceptual_model = lpips.LPIPS(net="alex")
else:
raise ValueError(f"Unsupported perceptual metric: {self.perc_metric}")
# Set to eval and freeze params
for p in self.perceptual_model.parameters():
p.requires_grad = False
self.perceptual_model.eval()
else: # Set to None when no perc. wanted
self.perceptual_model = None
self.perc_metric = None
# Shared normalizer for computing evaluation metrics
self.normalizer = Normalizer(cfg)
# ---------- public API ----------
def return_loss(
self, sr: torch.Tensor, hr: torch.Tensor
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
"""Compute the weighted content loss and return raw component metrics.
Builds the autograd graph for terms with non-zero weights and returns the
scalar total along with a dict of unweighted component values.
Args:
sr (torch.Tensor): Super-resolved prediction, shape ``(B, C, H, W)``.
hr (torch.Tensor): High-resolution target, shape ``(B, C, H, W)``.
Returns:
Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
- Total loss tensor (requires grad).
- Dict with component tensors: ``{"l1","sam","perceptual","tv", "psnr","ssim"}``
(component values detached except the ones used in the graph).
"""
comps = self._compute_components(sr, hr, build_graph=True)
loss = (
self.l1_w * comps["l1"]
+ self.sam_w * comps["sam"]
+ self.perc_w * comps["perceptual"]
+ self.tv_w * comps["tv"]
)
loss = _ensure_finite(loss)
metrics = {k: v.detach() for k, v in comps.items()}
return loss, metrics
@torch.no_grad()
def return_metrics(
self, sr: torch.Tensor, hr: torch.Tensor, prefix: str = ""
) -> dict[str, torch.Tensor]:
"""
Compute all unweighted metric components and (optionally) prefix their keys.
Args:
sr, hr: tensors in the same range as the generator output/HR targets.
prefix: key prefix like 'train/' or 'val'. If non-empty and doesn't end
with '/', a '/' is added automatically.
Returns:
dict mapping metric names -> tensors (detached), e.g. {'train/l1': ...}.
Includes raw PSNR/SSIM metrics computed on stretched/clipped inputs.
"""
comps = self._compute_components(sr, hr, build_graph=False)
p = (prefix + "/") if prefix and not prefix.endswith("/") else prefix
return {f"{p}{k}": v.detach() for k, v in comps.items()}
# ---------- internals ----------
@staticmethod
def _tv_loss(x: torch.Tensor) -> torch.Tensor:
"""Total variation (TV) regularizer.
Computes the L1 norm of first-order finite differences along height and width.
Args:
x (torch.Tensor): Input tensor, shape ``(B, C, H, W)``.
Returns:
torch.Tensor: Scalar TV loss (mean over batch/channels/pixels).
"""
dh = (x[:, :, 1:, :] - x[:, :, :-1, :]).abs().mean()
dw = (x[:, :, :, 1:] - x[:, :, :, :-1]).abs().mean()
return _ensure_finite(dh + dw)
@staticmethod
def _sam_loss(
sr: torch.Tensor, hr: torch.Tensor, eps: float = 1e-8
) -> torch.Tensor:
"""Spectral Angle Mapper (SAM) in radians.
Flattens spatial dims and computes the mean angle between spectral vectors
of SR and HR across all pixels.
Args:
sr (torch.Tensor): Super-resolved tensor, shape ``(B, C, H, W)``.
hr (torch.Tensor): Target tensor, shape ``(B, C, H, W)``.
eps (float, optional): Numerical stability epsilon. Defaults to ``1e-8``.
Returns:
torch.Tensor: Scalar mean SAM (radians).
"""
B, C, H, W = sr.shape
sr_f = sr.view(B, C, -1)
hr_f = hr.view(B, C, -1)
dot = (sr_f * hr_f).sum(dim=1)
if torch.is_floating_point(sr):
dtype_eps = torch.finfo(sr.dtype).eps
else:
dtype_eps = torch.finfo(torch.float32).eps
eps = max(eps, dtype_eps)
sr_n = sr_f.norm(dim=1).clamp_min(eps)
hr_n = hr_f.norm(dim=1).clamp_min(eps)
denom = torch.clamp(sr_n * hr_n, min=eps)
cos = torch.clamp(dot / denom, -1 + 1e-7, 1 - 1e-7)
ang = torch.acos(cos)
return _ensure_finite(ang.mean())
def _prepare_perceptual_input(
self, sr: torch.Tensor, hr: torch.Tensor
) -> torch.Tensor:
"""Select three channels for perceptual computation.
If the input has exactly 3 channels, returns them unchanged. Otherwise,
selects either random unique indices or
the fixed indices stored in ``self.fixed_idx``.
Args:
x (torch.Tensor): Input tensor, shape ``(B, C, H, W)``.
Returns:
torch.Tensor: Tensor with three channels, shape ``(B, 3, H, W)``.
"""
B, C, H, W = sr.shape
if C == 1:
# repeat single channel 3 times
sr = sr.repeat(1, 3, 1, 1)
hr = hr.repeat(1, 3, 1, 1)
return sr, hr
elif C == 2:
# repeat first channel to make 3
sr = torch.cat([sr, sr[:, :1, :, :]], dim=1)
hr = torch.cat([hr, hr[:, :1, :, :]], dim=1)
return sr, hr
elif C == 3:
# already 3 channels, return as is
return sr, hr
else:
# when over 3 channels, randomly select 3 unique indices
idx = torch.randperm(C, device=sr.device)[:3]
sr = sr[:, idx, :, :]
hr = hr[:, idx, :, :]
return sr, hr
def _perceptual_distance(
self, sr_3: torch.Tensor, hr_3: torch.Tensor, *, build_graph: bool
) -> torch.Tensor:
"""Compute perceptual distance between SR and HR (3-channel inputs).
Uses the configured backend:
- ``"vgg"``: MSE between intermediate VGG features.
- ``"lpips"``: Learned LPIPS distance (expects inputs in [-1, 1]).
The computation detaches HR features and optionally detaches SR path if
``build_graph`` is False or the perceptual weight is zero.
Args:
sr_3 (torch.Tensor): SR slice with 3 channels, shape ``(B, 3, H, W)``, values in [0, 1].
hr_3 (torch.Tensor): HR slice with 3 channels, shape ``(B, 3, H, W)``, values in [0, 1].
build_graph (bool): Whether to keep gradients for SR.
Returns:
torch.Tensor: Scalar perceptual distance (mean over batch/spatial).
"""
requires_grad = build_graph and self.perc_w != 0.0
if self.perc_metric == "vgg":
if requires_grad:
sr_features = self.perceptual_model(sr_3)
else:
with torch.no_grad():
sr_features = self.perceptual_model(sr_3)
with torch.no_grad():
hr_features = self.perceptual_model(hr_3)
distance = F.mse_loss(sr_features, hr_features)
elif self.perc_metric == "lpips":
sr_norm = sr_3.mul(2.0).sub(1.0)
hr_norm = hr_3.mul(2.0).sub(1.0).detach()
if requires_grad:
distance = self.perceptual_model(sr_norm, hr_norm)
else:
with torch.no_grad():
distance = self.perceptual_model(sr_norm, hr_norm)
distance = distance.mean()
else:
raise RuntimeError(f"Unhandled perceptual metric: {self.perc_metric}")
if not requires_grad:
distance = distance.detach()
return _ensure_finite(distance)
def _compute_components(
self, sr: torch.Tensor, hr: torch.Tensor, *, build_graph: bool
) -> dict[str, torch.Tensor]:
"""Compute individual content components and auxiliary quality metrics.
Produces a dictionary with: L1, SAM, Perceptual, TV (optionally with grads),
and PSNR/SSIM (always without grads). Per-component autograd is enabled only
if ``build_graph`` is True and the corresponding weight is non-zero.
Args:
sr (torch.Tensor): Super-resolved prediction, shape ``(B, C, H, W)``.
hr (torch.Tensor): High-resolution target, shape ``(B, C, H, W)``.
build_graph (bool): Whether to allow gradients for weighted components.
Returns:
Dict[str, torch.Tensor]: Keys ``{"l1","sam","perceptual","tv","psnr","ssim"}``.
Component tensors are scalar means; PSNR/SSIM are detached.
"""
comps: dict[str, torch.Tensor] = {}
def _compute(weight: float, fn) -> torch.Tensor:
requires_grad = build_graph and weight != 0.0
if requires_grad:
return fn()
with torch.no_grad():
return fn().detach()
# Core reconstruction metrics (always unweighted)
comps["l1"] = _compute(self.l1_w, lambda: _ensure_finite(F.l1_loss(sr, hr)))
comps["sam"] = _compute(self.sam_w, lambda: self._sam_loss(sr, hr))
# Perceptual distance on 3 selected bands
if self.perceptual_model != None and self.perc_w != 0.0:
sr_3, hr_3 = self._prepare_perceptual_input(sr=sr, hr=hr)
comps["perceptual"] = self._perceptual_distance(
sr_3, hr_3, build_graph=build_graph
)
else:
comps["perceptual"] = torch.tensor(0.0, device=sr.device)
# Total variation
comps["tv"] = _compute(self.tv_w, lambda: self._tv_loss(sr))
# --- Quality metrics ---
with torch.no_grad():
safe_max_val = max(self.max_val, LOSS_EPS)
sr_metric = torch.clamp(sr, 0.0, safe_max_val)
hr_metric = torch.clamp(hr, 0.0, safe_max_val)
psnr = km.psnr(sr_metric, hr_metric, max_val=safe_max_val)
ssim = km.ssim(
sr_metric,
hr_metric,
window_size=self.ssim_win,
max_val=safe_max_val,
)
if psnr.dim() > 0:
psnr = psnr.mean()
if ssim.dim() > 0:
ssim = ssim.mean()
comps["psnr"] = _ensure_finite(psnr).detach()
comps["ssim"] = _ensure_finite(ssim).detach()
return comps
return_loss(sr, hr)
¶
Compute the weighted content loss and return raw component metrics.
Builds the autograd graph for terms with non-zero weights and returns the scalar total along with a dict of unweighted component values.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sr
|
Tensor
|
Super-resolved prediction, shape |
required |
hr
|
Tensor
|
High-resolution target, shape |
required |
Returns:
| Type | Description |
|---|---|
tuple[Tensor, dict[str, Tensor]]
|
Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
- Total loss tensor (requires grad).
- Dict with component tensors: |
Source code in opensr_srgan/model/loss/loss.py
def return_loss(
self, sr: torch.Tensor, hr: torch.Tensor
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
"""Compute the weighted content loss and return raw component metrics.
Builds the autograd graph for terms with non-zero weights and returns the
scalar total along with a dict of unweighted component values.
Args:
sr (torch.Tensor): Super-resolved prediction, shape ``(B, C, H, W)``.
hr (torch.Tensor): High-resolution target, shape ``(B, C, H, W)``.
Returns:
Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
- Total loss tensor (requires grad).
- Dict with component tensors: ``{"l1","sam","perceptual","tv", "psnr","ssim"}``
(component values detached except the ones used in the graph).
"""
comps = self._compute_components(sr, hr, build_graph=True)
loss = (
self.l1_w * comps["l1"]
+ self.sam_w * comps["sam"]
+ self.perc_w * comps["perceptual"]
+ self.tv_w * comps["tv"]
)
loss = _ensure_finite(loss)
metrics = {k: v.detach() for k, v in comps.items()}
return loss, metrics
return_metrics(sr, hr, prefix='')
¶
Compute all unweighted metric components and (optionally) prefix their keys.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sr, hr
|
tensors in the same range as the generator output/HR targets. |
required | |
prefix
|
str
|
key prefix like 'train/' or 'val'. If non-empty and doesn't end with '/', a '/' is added automatically. |
''
|
Returns:
| Type | Description |
|---|---|
dict[str, Tensor]
|
dict mapping metric names -> tensors (detached), e.g. {'train/l1': ...}. |
dict[str, Tensor]
|
Includes raw PSNR/SSIM metrics computed on stretched/clipped inputs. |
Source code in opensr_srgan/model/loss/loss.py
@torch.no_grad()
def return_metrics(
self, sr: torch.Tensor, hr: torch.Tensor, prefix: str = ""
) -> dict[str, torch.Tensor]:
"""
Compute all unweighted metric components and (optionally) prefix their keys.
Args:
sr, hr: tensors in the same range as the generator output/HR targets.
prefix: key prefix like 'train/' or 'val'. If non-empty and doesn't end
with '/', a '/' is added automatically.
Returns:
dict mapping metric names -> tensors (detached), e.g. {'train/l1': ...}.
Includes raw PSNR/SSIM metrics computed on stretched/clipped inputs.
"""
comps = self._compute_components(sr, hr, build_graph=False)
p = (prefix + "/") if prefix and not prefix.endswith("/") else prefix
return {f"{p}{k}": v.detach() for k, v in comps.items()}
VGG based perceptual feature extractor utilities.
TruncatedVGG19
¶
Bases: Module
A truncated VGG-19 feature extractor for perceptual loss computation.
This class wraps a pretrained VGG-19 network from torchvision.models
and truncates it at a specific convolutional layer (i, j) within
the feature hierarchy, following the convention in perceptual and
style-transfer literature (e.g., layer relu{i}_{j}).
The truncated model outputs intermediate feature maps that capture perceptual similarity more effectively than raw pixel losses. These feature activations are typically used in content or perceptual loss terms such as:
whereΦ_j denotes the truncated VGG feature extractor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
i
|
int
|
The convolutional block index (1-5) at which to
truncate. Each block corresponds to a region between max-pooling
layers. Defaults to |
5
|
j
|
int
|
The convolution layer index within the chosen block.
Defaults to |
4
|
weights
|
bool
|
Whether to load pretrained ImageNet weights.
If |
True
|
Attributes:
| Name | Type | Description |
|---|---|---|
truncated_vgg19 |
Sequential
|
Sequential container of layers up to |
Raises:
| Type | Description |
|---|---|
AssertionError
|
If the provided |
Example
vgg = TruncatedVGG19(i=5, j=4) feats = vgg(img_batch) # [B, C, H, W] feature map
Source code in opensr_srgan/model/loss/vgg.py
class TruncatedVGG19(nn.Module):
"""A truncated VGG-19 feature extractor for perceptual loss computation.
This class wraps a pretrained VGG-19 network from ``torchvision.models``
and truncates it at a specific convolutional layer ``(i, j)`` within
the feature hierarchy, following the convention in perceptual and
style-transfer literature (e.g., layer *relu{i}_{j}*).
The truncated model outputs intermediate feature maps that capture
perceptual similarity more effectively than raw pixel losses. These
feature activations are typically used in content or perceptual loss
terms such as:
```
L_perceptual = || Φ_j(x_sr) − Φ_j(x_hr) ||_1
```
where ``Φ_j`` denotes the truncated VGG feature extractor.
Args:
i (int, optional): The convolutional block index (1-5) at which to
truncate. Each block corresponds to a region between max-pooling
layers. Defaults to ``5``.
j (int, optional): The convolution layer index within the chosen block.
Defaults to ``4``.
weights (bool, optional): Whether to load pretrained ImageNet weights.
If ``False``, the model is initialized without downloading weights
(useful for testing environments). Defaults to ``True``.
Attributes:
truncated_vgg19 (nn.Sequential): Sequential container of layers up to
the specified truncation point.
Raises:
AssertionError: If the provided ``(i, j)`` combination does not match a
valid convolutional layer in VGG-19.
Example:
>>> vgg = TruncatedVGG19(i=5, j=4)
>>> feats = vgg(img_batch) # [B, C, H, W] feature map
"""
def __init__(self, i: int = 5, j: int = 4, weights=True) -> None:
super().__init__()
if weights: # omit downloading for tests
vgg19 = torchvision.models.vgg19(weights=torchvision.models.VGG19_Weights.DEFAULT)
else:
vgg19 = torchvision.models.vgg19(weights=None)
maxpool_counter = 0
conv_counter = 0
truncate_at = 0
for layer in vgg19.features.children():
truncate_at += 1
if isinstance(layer, nn.Conv2d):
conv_counter += 1
if isinstance(layer, nn.MaxPool2d):
maxpool_counter += 1
conv_counter = 0
if maxpool_counter == i - 1 and conv_counter == j:
break
if not (maxpool_counter == i - 1 and conv_counter == j):
raise AssertionError(
f"One or both of i={i} and j={j} are not valid choices for the VGG19!"
)
self.truncated_vgg19 = nn.Sequential(*list(vgg19.features.children())[: truncate_at + 1])
def forward(self, input): # type: ignore[override]
"""Compute VGG-19 features up to the configured truncation layer.
Args:
input (torch.Tensor): Input tensor of shape ``(B, 3, H, W)``
with values normalized to ImageNet statistics (mean/std).
Returns:
torch.Tensor: The feature map extracted from the specified
intermediate layer of VGG-19.
"""
return self.truncated_vgg19(input)
forward(input)
¶
Compute VGG-19 features up to the configured truncation layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input
|
Tensor
|
Input tensor of shape |
required |
Returns:
| Type | Description |
|---|---|
|
torch.Tensor: The feature map extracted from the specified |
|
|
intermediate layer of VGG-19. |
Source code in opensr_srgan/model/loss/vgg.py
def forward(self, input): # type: ignore[override]
"""Compute VGG-19 features up to the configured truncation layer.
Args:
input (torch.Tensor): Input tensor of shape ``(B, 3, H, W)``
with values normalized to ImageNet statistics (mean/std).
Returns:
torch.Tensor: The feature map extracted from the specified
intermediate layer of VGG-19.
"""
return self.truncated_vgg19(input)
Training step helpers¶
training_step_PL1(self, batch, batch_idx, optimizer_idx)
¶
One training step for PL < 2.0 using automatic optimization and multi-optimizers.
Implements GAN training with two optimizers (D first, then G) and a pretraining gate. During the pretraining phase, only the generator (optimizer_idx == 1) is optimized with content loss; the discriminator branch returns a dummy loss and logs zeros. During adversarial training, the discriminator minimizes BCE on real HR vs. fake SR logits, and the generator minimizes content loss plus a ramped adversarial loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Tuple[Tensor, Tensor]
|
|
required |
batch_idx
|
int
|
Global batch index for the current epoch. |
required |
optimizer_idx
|
int
|
Active optimizer index provided by Lightning:
- |
required |
Returns:
| Type | Description |
|---|---|
|
torch.Tensor:
- Pretraining:
- |
Logged Metrics (selection):
- "training/pretrain_phase": 1.0 during pretraining (logged on G step).
- "train_metrics/*": content metrics from the content loss criterion.
- "generator/content_loss", "generator/adversarial_loss", "generator/total_loss".
- "discriminator/adversarial_loss", "discriminator/D(y)_prob",
"discriminator/D(G(x))_prob".
- "training/adv_loss_weight": current λ_adv from the ramp scheduler.
Notes
- Discriminator step uses
sr_imgs.detach()to prevent G gradients. - Adversarial loss weight λ_adv ramps from 0 →
adv_loss_betaper configured schedule. - Assumes optimizers are ordered as
[D, G]inconfigure_optimizers().
Source code in opensr_srgan/model/training_step_PL.py
def training_step_PL1(self, batch, batch_idx, optimizer_idx):
"""One training step for PL < 2.0 using automatic optimization and multi-optimizers.
Implements GAN training with two optimizers (D first, then G) and a
pretraining gate. During the **pretraining phase**, only the generator
(optimizer_idx == 1) is optimized with content loss; the discriminator
branch returns a dummy loss and logs zeros. During **adversarial training**,
the discriminator minimizes BCE on real HR vs. fake SR logits, and the
generator minimizes content loss plus a ramped adversarial loss.
Args:
batch (Tuple[torch.Tensor, torch.Tensor]): `(lr_imgs, hr_imgs)` with shape `(B, C, H, W)`.
batch_idx (int): Global batch index for the current epoch.
optimizer_idx (int): Active optimizer index provided by Lightning:
- `0`: Discriminator step.
- `1`: Generator step.
Returns:
torch.Tensor:
- **Pretraining**:
- `optimizer_idx == 1`: content loss tensor for the generator.
- `optimizer_idx == 0`: dummy scalar tensor with `requires_grad=True`.
- **Adversarial training**:
- `optimizer_idx == 0`: discriminator BCE loss (real + fake).
- `optimizer_idx == 1`: generator total loss = content + λ_adv · BCE(G).
Logged Metrics (selection):
- `"training/pretrain_phase"`: 1.0 during pretraining (logged on G step).
- `"train_metrics/*"`: content metrics from the content loss criterion.
- `"generator/content_loss"`, `"generator/adversarial_loss"`, `"generator/total_loss"`.
- `"discriminator/adversarial_loss"`, `"discriminator/D(y)_prob"`,
`"discriminator/D(G(x))_prob"`.
- `"training/adv_loss_weight"`: current λ_adv from the ramp scheduler.
Notes:
- Discriminator step uses `sr_imgs.detach()` to prevent G gradients.
- Adversarial loss weight λ_adv ramps from 0 → `adv_loss_beta` per configured schedule.
- Assumes optimizers are ordered as `[D, G]` in `configure_optimizers()`.
"""
# -------- CREATE SR DATA --------
lr_imgs, hr_imgs = batch # unpack LR/HR tensors from dataloader batch
sr_imgs = self.forward(
lr_imgs
) # forward pass of the generator to produce SR from LR
# Default to standard GAN loss if adv_loss_type is not defined (e.g., lightweight
# harnesses in tests). The real model sets this attribute during init.
use_wasserstein = getattr(self, "adv_loss_type", "gan") == "wasserstein"
# ======================================================================
# SECTION: Pretraining phase gate
# Purpose: decide if we are in the content-only pretrain stage.
# ======================================================================
# -------- DETERMINE PRETRAINING --------
pretrain_phase = (
self._pretrain_check()
) # check schedule: True => content-only pretraining
if optimizer_idx == 1: # log whether pretraining is active or not
self.log(
"training/pretrain_phase",
float(pretrain_phase),
prog_bar=False,
sync_dist=True,
) # log once per G step to track phase state
# ======================================================================
# SECTION: Pretraining branch (delegated)
# Purpose: during pretrain, only content loss for G and dummy logging for D.
# ======================================================================
# -------- IF PRETRAIN: delegate --------
if pretrain_phase:
# run pretrain step separately and return loss here
if optimizer_idx == 1:
content_loss, metrics = self.content_loss_criterion.return_loss(
sr_imgs, hr_imgs
) # compute perceptual/content loss (e.g., VGG or L1)
self._log_generator_content_loss(
content_loss
) # log content loss for G (consistent args)
for key, value in metrics.items():
self.log(
f"train_metrics/{key}", value, sync_dist=True
) # reuse computed metrics for logging
# Ensure adversarial weight is logged even when not used during pretraining
adv_weight = self._compute_adv_loss_weight()
self._log_adv_loss_weight(adv_weight)
return content_loss # return loss for optimizer step (G only)
# ======================================================================
# SECTION: Discriminator (D) pretraining step
# Purpose: no real training — just log zeros and return dummy loss to satisfy closure.
# ======================================================================
elif optimizer_idx == 0:
device, dtype = (
hr_imgs.device,
hr_imgs.dtype,
) # get tensor device and dtype for consistency
zero = torch.tensor(
0.0, device=device, dtype=dtype
) # define reusable zero tensor
# --- Log dummy discriminator "opinions" (always zero during pretrain) ---
self.log(
"discriminator/D(y)_prob", zero, prog_bar=True, sync_dist=True
) # fake real-prob (always 0)
self.log(
"discriminator/D(G(x))_prob", zero, prog_bar=True, sync_dist=True
) # fake fake-prob (always 0)
# --- Create dummy scalar loss (ensures PL closure runs) ---
dummy = torch.zeros(
(), device=device, dtype=dtype, requires_grad=True
) # dummy value with grad for optimizer compatibility
self.log(
"discriminator/adversarial_loss", dummy, sync_dist=True
) # log dummy adversarial loss (always 0)
return dummy
# -------- END PRETRAIN --------
# ======================================================================
# SECTION: Adversarial training — Discriminator step
# Purpose: update D to distinguish HR (real) vs SR (fake).
# ======================================================================
# -------- Normal Train: Discriminator Step --------
if optimizer_idx == 0:
r1_gamma = getattr(self, "r1_gamma", 0.0) # default to 0 for
hr_imgs.requires_grad_(r1_gamma > 0) # enable grad for R1 penalty if needed
# run discriminator and get loss between pred labels and true labels
hr_discriminated = self.discriminator(hr_imgs) # D(real): logits for HR images
sr_discriminated = self.discriminator(
sr_imgs.detach()
) # detach so G doesn’t get gradients from D’s step
# Check for WS GAN loss
if use_wasserstein: # Wasserstein GAN loss
loss_real = -hr_discriminated.mean()
loss_fake = sr_discriminated.mean()
else: # Standard GAN loss (BCE)
real_target = torch.full_like(
hr_discriminated, self.adv_target
) # get labels/fuzzy labels
fake_target = torch.zeros_like(
sr_discriminated
) # zeros, since generative prediction
# Binary Cross-Entropy loss
loss_real = self.adversarial_loss_criterion(
hr_discriminated, real_target
) # BCEWithLogitsLoss for D(G(x))
loss_fake = self.adversarial_loss_criterion(
sr_discriminated, fake_target
) # BCEWithLogitsLoss for D(y)
# R1 Gradient Penalty (if enabled)
r1_penalty = torch.zeros((), device=hr_imgs.device, dtype=hr_imgs.dtype)
if r1_gamma > 0:
grad_real = torch.autograd.grad(
outputs=hr_discriminated.sum(),
inputs=hr_imgs,
create_graph=True,
retain_graph=True,
)[0]
grad_penalty = grad_real.pow(2).reshape(grad_real.size(0), -1).sum(dim=1)
r1_penalty = 0.5 * r1_gamma * grad_penalty.mean()
# Sum up losses
adversarial_loss = (
loss_real + loss_fake + r1_penalty
) # add 0s for R1 if disabled
self.log(
"discriminator/adversarial_loss", adversarial_loss, sync_dist=True
) # log weighted loss
self.log(
"discriminator/r1_penalty", r1_penalty.detach(), sync_dist=True
) # log R1 penalty regarless, is 0 when turned off
# [LOG-B] Always log D opinions: real probs in normal training
with torch.no_grad():
d_real_prob = torch.sigmoid(
hr_discriminated
).mean() # estimate mean real probability
d_fake_prob = torch.sigmoid(
sr_discriminated
).mean() # estimate mean fake probability
self.log(
"discriminator/D(y)_prob", d_real_prob, prog_bar=True, sync_dist=True
) # log D(real) confidence
self.log(
"discriminator/D(G(x))_prob", d_fake_prob, prog_bar=True, sync_dist=True
) # log D(fake) confidence
# return weighted discriminator loss
return adversarial_loss # PL will use this to step the D optimizer
# ======================================================================
# SECTION: Adversarial training — Generator step
# Purpose: update G to minimize content loss + (weighted) adversarial loss.
# ======================================================================
# -------- Normal Train: Generator Step --------
if optimizer_idx == 1:
"""1. Get VGG space loss"""
# encode images
content_loss, metrics = self.content_loss_criterion.return_loss(
sr_imgs, hr_imgs
) # perceptual/content criterion (e.g., VGG)
self._log_generator_content_loss(
content_loss
) # log content loss for G (consistent args)
for key, value in metrics.items():
self.log(
f"train_metrics/{key}", value, sync_dist=True
) # log detailed metrics without extra forward passes
""" 2. Get Discriminator Opinion and loss """
# run discriminator and get loss between pred labels and true labels
sr_discriminated = self.discriminator(
sr_imgs
) # D(SR): logits for generator outputs
if use_wasserstein: # Wasserstein GAN loss
adversarial_loss = -sr_discriminated.mean()
else: # Standard GAN loss (BCE)
adversarial_loss = self.adversarial_loss_criterion(
sr_discriminated, torch.ones_like(sr_discriminated)
) # keep taargets 1.0 for G loss
self.log(
"generator/adversarial_loss", adversarial_loss, sync_dist=True
) # log unweighted adversarial loss
""" 3. Weight the losses"""
adv_weight = (
self._adv_loss_weight()
) # get adversarial weight based on current step
adversarial_loss_weighted = (
adversarial_loss * adv_weight
) # weight adversarial loss
total_loss = content_loss + adversarial_loss_weighted # total content loss
self.log(
"generator/total_loss", total_loss, sync_dist=True
) # log combined objective (content + λ_adv * adv)
# return Generator loss
return total_loss
training_step_PL2(self, batch, batch_idx)
¶
Manual-optimization training step for PyTorch Lightning ≥ 2.0.
Mirrors the PL1.x logic with explicit optimizer control: - Pretraining phase: Discriminator logs dummies; Generator is optimized with content loss only (no adversarial term), and EMA optionally updates. - Adversarial phase: Performs a Discriminator step (real vs. fake BCE), followed by a Generator step (content + λ_adv · BCE against ones). Uses the same log keys and ordering as the PL1.x path.
Assumptions
self.automatic_optimizationisFalse(manual opt).configure_optimizers()returns optimizers in order[opt_d, opt_g].- EMA updates occur after
self._ema_update_after_step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Tuple[Tensor, Tensor]
|
|
required |
batch_idx
|
int
|
Index of the current batch. |
required |
Returns:
| Type | Description |
|---|---|
|
torch.Tensor: - Pretraining: content loss (Generator-only step). - Adversarial: total generator loss = content + λ_adv · BCE(G). |
Logged metrics (selection):
- "training/pretrain_phase" (0/1)
- "train_metrics/*" (from content criterion)
- "generator/content_loss", "generator/adversarial_loss", "generator/total_loss"
- "discriminator/adversarial_loss", "discriminator/D(y)_prob", "discriminator/D(G(x))_prob"
- "training/adv_loss_weight" (λ_adv from ramp schedule)
Raises:
| Type | Description |
|---|---|
AssertionError
|
If PL version < 2.0 or |
Source code in opensr_srgan/model/training_step_PL.py
def training_step_PL2(self, batch, batch_idx):
"""Manual-optimization training step for PyTorch Lightning ≥ 2.0.
Mirrors the PL1.x logic with explicit optimizer control:
- **Pretraining phase**: Discriminator logs dummies; Generator is optimized with
content loss only (no adversarial term), and EMA optionally updates.
- **Adversarial phase**: Performs a Discriminator step (real vs. fake BCE),
followed by a Generator step (content + λ_adv · BCE against ones). Uses the
same log keys and ordering as the PL1.x path.
Assumptions:
- `self.automatic_optimization` is `False` (manual opt).
- `configure_optimizers()` returns optimizers in order `[opt_d, opt_g]`.
- EMA updates occur after `self._ema_update_after_step`.
Args:
batch (Tuple[torch.Tensor, torch.Tensor]): `(lr_imgs, hr_imgs)` tensors with shape `(B, C, H, W)`.
batch_idx (int): Index of the current batch.
Returns:
torch.Tensor:
- **Pretraining**: content loss (Generator-only step).
- **Adversarial**: total generator loss = content + λ_adv · BCE(G).
Logged metrics (selection):
- `"training/pretrain_phase"` (0/1)
- `"train_metrics/*"` (from content criterion)
- `"generator/content_loss"`, `"generator/adversarial_loss"`, `"generator/total_loss"`
- `"discriminator/adversarial_loss"`, `"discriminator/D(y)_prob"`, `"discriminator/D(G(x))_prob"`
- `"training/adv_loss_weight"` (λ_adv from ramp schedule)
Raises:
AssertionError: If PL version < 2.0 or `automatic_optimization` is True.
"""
assert self.pl_version >= (
2,
0,
0,
), "training_step_PL2 requires PyTorch Lightning >= 2.x."
assert (
self.automatic_optimization is False
), "training_step_PL2 requires manual_optimization."
# -------- CREATE SR DATA --------
lr_imgs, hr_imgs = batch
sr_imgs = self.forward(lr_imgs)
use_wasserstein = getattr(self, "adv_loss_type", "gan") == "wasserstein"
# --- helper to resolve adv-weight function name mismatches ---
def _adv_weight():
if hasattr(self, "_adv_loss_weight"):
return self._adv_loss_weight()
return self._compute_adv_loss_weight()
# fetch optimizers (expects two)
opt_d, opt_g = self.optimizers()
# optional gradient clipping support (norm-based)
try:
gradient_clip_val = self.config.Schedulers.gradient_clip_val
except AttributeError:
gradient_clip_val = 0.0
def _maybe_clip_gradients(module, optimizer=None):
if gradient_clip_val > 0.0 and module is not None:
precision_plugin = getattr(self.trainer, "precision_plugin", None)
if (
optimizer is not None
and precision_plugin is not None
and hasattr(precision_plugin, "unscale_optimizer")
):
precision_plugin.unscale_optimizer(optimizer)
torch.nn.utils.clip_grad_norm_(module.parameters(), gradient_clip_val)
# ======================================================================
# SECTION: Pretraining phase gate
# ======================================================================
pretrain_phase = self._pretrain_check()
# in PL1.x you logged this only on G-step; here we log once per batch
self.log(
"training/pretrain_phase", float(pretrain_phase), prog_bar=False, sync_dist=True
)
# ======================================================================
# SECTION: Pretraining branch (content-only on G; D logs dummies)
# ======================================================================
if pretrain_phase:
# --- D dummy logs (no step) to mimic your optimizer_idx==0 branch ---
with torch.no_grad():
zero = torch.tensor(0.0, device=hr_imgs.device, dtype=hr_imgs.dtype)
self.log("discriminator/D(y)_prob", zero, prog_bar=True, sync_dist=True)
self.log("discriminator/D(G(x))_prob", zero, prog_bar=True, sync_dist=True)
self.log("discriminator/adversarial_loss", zero, sync_dist=True)
# --- G step: content loss only (identical to your optimizer_idx==1 pretrain) ---
content_loss, metrics = self.content_loss_criterion.return_loss(
sr_imgs, hr_imgs
)
self._log_generator_content_loss(content_loss)
for key, value in metrics.items():
self.log(f"train_metrics/{key}", value, sync_dist=True)
# ensure adv-weight is still logged like in pretrain
self._log_adv_loss_weight(_adv_weight())
# manual optimize G
if hasattr(self, "toggle_optimizer"):
self.toggle_optimizer(opt_g)
opt_g.zero_grad()
self.manual_backward(content_loss)
_maybe_clip_gradients(self.generator, opt_g)
opt_g.step()
if hasattr(self, "untoggle_optimizer"):
self.untoggle_optimizer(opt_g)
# EMA in PL2 manual mode
if self.ema is not None and self.global_step >= self._ema_update_after_step:
self.ema.update(self.generator)
# return same scalar you’d have returned in PL1.x (content loss)
return content_loss
# ======================================================================
# SECTION: Adversarial training — Discriminator step
# ======================================================================
if hasattr(self, "toggle_optimizer"):
self.toggle_optimizer(opt_d)
opt_d.zero_grad()
# Get R1 Gamma
r1_gamma = getattr(self, "r1_gamma", 0.0)
hr_imgs.requires_grad_(r1_gamma > 0) # enable grad for R1 penalty if needed
hr_discriminated = self.discriminator(hr_imgs) # D(y)
sr_discriminated = self.discriminator(sr_imgs.detach()) # D(G(x)) w/o grad to G
if use_wasserstein: # Wasserstein GAN loss
loss_real = -hr_discriminated.mean()
loss_fake = sr_discriminated.mean()
else: # Standard GAN loss (BCE)
real_target = torch.full_like(hr_discriminated, self.adv_target)
fake_target = torch.zeros_like(sr_discriminated)
loss_real = self.adversarial_loss_criterion(hr_discriminated, real_target)
loss_fake = self.adversarial_loss_criterion(sr_discriminated, fake_target)
# R1 Gradient Penalty
r1_penalty = torch.zeros((), device=hr_imgs.device, dtype=hr_imgs.dtype)
if r1_gamma > 0:
grad_real = torch.autograd.grad(
outputs=hr_discriminated.sum(),
inputs=hr_imgs,
create_graph=True,
retain_graph=True,
)[0]
grad_penalty = grad_real.pow(2).reshape(grad_real.size(0), -1).sum(dim=1)
r1_penalty = 0.5 * r1_gamma * grad_penalty.mean()
adversarial_loss = (
loss_real + loss_fake + r1_penalty
) # sum up loss with R1 (0 when turned off)
self.log("discriminator/adversarial_loss", adversarial_loss, sync_dist=True)
self.log(
"discriminator/r1_penalty", r1_penalty.detach(), sync_dist=True
) # log R1 penalty regardless, is 0 when turned off
with torch.no_grad():
d_real_prob = torch.sigmoid(hr_discriminated).mean()
d_fake_prob = torch.sigmoid(sr_discriminated).mean()
self.log("discriminator/D(y)_prob", d_real_prob, prog_bar=True, sync_dist=True)
self.log("discriminator/D(G(x))_prob", d_fake_prob, prog_bar=True, sync_dist=True)
self.manual_backward(adversarial_loss)
_maybe_clip_gradients(self.discriminator, opt_d)
opt_d.step()
if hasattr(self, "untoggle_optimizer"):
self.untoggle_optimizer(opt_d)
# ======================================================================
# SECTION: Adversarial training — Generator step
# ======================================================================
if hasattr(self, "toggle_optimizer"):
self.toggle_optimizer(opt_g)
opt_g.zero_grad()
# 1) content loss (identical to original)
content_loss, metrics = self.content_loss_criterion.return_loss(sr_imgs, hr_imgs)
self._log_generator_content_loss(content_loss)
for key, value in metrics.items():
self.log(f"train_metrics/{key}", value, sync_dist=True)
# 2) adversarial loss against ones
sr_discriminated_for_g = self.discriminator(sr_imgs)
if use_wasserstein: # Wasserstein GAN loss
g_adv = -sr_discriminated_for_g.mean()
else: # Standard GAN loss (BCE)
g_adv = self.adversarial_loss_criterion(
sr_discriminated_for_g, torch.ones_like(sr_discriminated_for_g)
)
self.log("generator/adversarial_loss", g_adv, sync_dist=True)
# 3) weighted total
adv_weight = _adv_weight()
total_loss = content_loss + (g_adv * adv_weight)
self.log("generator/total_loss", total_loss, sync_dist=True)
self.manual_backward(total_loss)
_maybe_clip_gradients(self.generator, opt_g)
opt_g.step()
if hasattr(self, "untoggle_optimizer"):
self.untoggle_optimizer(opt_g)
# EMA in PL2 manual mode
if self.ema is not None and self.global_step >= self._ema_update_after_step:
self.ema.update(self.generator)
# return same scalar you return in PL1.x G path
return total_loss