Architecture¶
This document outlines how ESA OpenSR organises its super-resolution GAN, the major components that make up the model, and how each piece interacts during training and inference.
Vackground¶
OpenSR-SRGAN follows the single-image super-resolution (SISR) formulation in which the generator learns a mapping from a low-resolution observation $x$ to a plausible high-resolution reconstruction $x'$. The generator head widens the receptive field, a configurable trunk of $N$ residual-style blocks extracts features, and an upsampling tail increases spatial resolution. The residual fusion keeps skip connections active so the network focuses on high-frequency corrections rather than relearning the full signal: $$ x' = \mathrm{Upsample}!\left( \mathrm{Conv}{\text{tail}}!\left(\mathrm{Body}(x\right)! \right). $$ Because every generator variant (residual, RCAB, RRDB, large-kernel attention, ESRGAN, or stochastic conditional) shares this template, you can swap block implementations without altering the training pipeline or configuration schema.}}) + x_{\text{head}
SRGAN Lightning module¶
opensr_srgan/model/SRGAN.py defines SRGAN_model, a pytorch_lightning.LightningModule that encapsulates the full adversarial workflow. The module is initialised from a YAML configuration file and provides the following responsibilities:
- Configuration ingestion. Uses OmegaConf to load hyperparameters, dataset choices, and logging options. Convenience helpers
such as
_pretrain_check()and_compute_adv_loss_weight()translate config values into runtime behaviour. - Model factory.
get_models()builds the generator and discriminator at runtime via the generator factory usingGenerator.model_type/block_typeandDiscriminator.model_type. Unsupported combinations fail fast with clear error messages. - Loss construction.
GeneratorContentLoss(fromopensr_srgan.model.loss) provides L1, spectral angle mapper (SAM), perceptual, and total-variation terms. Adversarial supervision usestorch.nn.BCEWithLogitsLosswith optional label smoothing. - Optimiser scheduling.
configure_optimizers()returns paired Adam optimisers (generator + discriminator) withReduceLROnPlateauschedulers that monitor a configurable validation metric. - Training orchestration.
training_step()alternates discriminator (optimizer_idx == 0) and generator (optimizer_idx == 1) updates. During the warm-up period configured byTraining.pretrain_g_only, discriminator weights are frozen viaon_train_batch_start()and a dedicatedpretraining_training_step()computes purely content-driven updates. - Validation and logging.
validation_step()computes the same content metrics, logs discriminator diagnostics, and pushes qualitative image panels to Weights & Biases according toLogging.num_val_images. - Inference pipeline.
predict_step()automatically normalises Sentinel-2 style 0–10000 inputs, runs the generator, histogram matches the result to the low-resolution source, and denormalises if necessary.
Key helper methods¶
| Method | Purpose |
|---|---|
_pretrain_check() |
Determines whether the generator-only warm-up is active. |
_compute_adv_loss_weight() |
Produces the ramped adversarial weight using linear or cosine schedules. |
_log_generator_content_loss() and _log_adv_loss_weight() |
Centralise logging so metrics remain consistent across phases. |
on_fit_start() |
Prints informative status messages when training begins. |
Generator options¶
The generator zoo lives under opensr_srgan/model/generators/ and can be selected via Generator.model_type in the configuration.
SRResNet(srresnet.py). Classic residual blocks with pixel shuffle upsampling. Ideal for baseline experiments or when a lightweight architecture is required.- Flexible residual families (
flexible_generator.py). Parameterised factory that instantiates residual, RCAB, RRDB, or large-kernel attention blocks while reusing the same interface. Channel counts, block depth, kernel sizes, and scaling factor are all read from the YAML file. - Stochastic GAN generator (
cgan_generator.py). Extends the flexible generator with conditioning inputs and latent noise, enabling experiments where auxiliary metadata influences the super-resolution output. - ESRGAN generator (
esrgan.py). Implements the RRDBNet trunk introduced with ESRGAN, exposingn_blocks,growth_channels, andres_scaleso you can dial in deeper receptive fields and sharper textures. - Advanced variants (
SRGAN_advanced.py). Provides additional block implementations and compatibility aliases exposed in__init__.pyfor backwards compatibility.
Common traits across generators include configurable input channel counts (Model.in_bands), support for upscaling factors from 2× to 8×, and residual scaling to stabilise deeper networks.
Discriminator options¶
opensr_srgan/model/discriminators/ exposes three complementary discriminators:
- Standard SRGAN discriminator (
srgan_discriminator.py). Deep convolutional stack tailored for multispectral imagery. The number of convolutional blocks is configurable throughDiscriminator.n_blocks. - PatchGAN discriminator (
patchgan.py). Operates on local patches, which can improve high-frequency fidelity when training with large images. The depth is controlled byn_blocksand defaults to three layers. - ESRGAN discriminator (
esrgan.py). Deep VGG-style stack with configurablebase_channelsandlinear_size; pairs well with RRDB generators when perceptual sharpness is the priority.
Both discriminators use LeakyReLU activations and strided convolutions to progressively downsample the input until a real/fake logit map is produced.
Loss suite and metrics¶
opensr_srgan/model/loss contains the perceptual and pixel-based criteria applied to the generator outputs. The primary entry point is GeneratorContentLoss, which supports:
- L1 reconstruction over all spectral bands.
- Spectral Angle Mapper (SAM) to preserve spectral signatures.
- Perceptual similarity via VGG or LPIPS feature spaces, depending on
Training.Losses.perceptual_metric. - Total variation regularisation for smoothing when
tv_weightis non-zero.
The same module exposes return_metrics() so validation can log PSNR/SSIM-style diagnostics without recomputing forward passes.
Data flow and normalisation¶
The Lightning module expects batches of (lr_imgs, hr_imgs) tensors supplied by the LightningDataModule returned from
opensr_srgan/data/dataset_selector.py. predict_step() and the validation hooks rely on two utilities from opensr_srgan.utils.spectral_helpers:
normalise_10k: Converts Sentinel-2 style reflectance values between[0, 10000]and[0, 1].histogram: Matches the SR histogram to the LR reference to minimise domain gaps during inference.
These helpers allow the generator to operate in a normalised space while still reporting outputs in physical units when needed.
Putting it together¶
opensr_srgan/train.pyloads the YAML configuration and instantiatesSRGAN_model.- The model initialises the selected generator/discriminator, prepares losses, and prints a summary via
opensr_srgan.utils.model_descriptions.print_model_summary. - During each training batch, the discriminator receives real HR crops and fake SR predictions, while the generator combines content loss and a ramped adversarial term.
- Validation reuses the same modules to compute quantitative metrics and log qualitative examples.
- When exported,
predict_step()can be called directly or wrapped in a LightningTrainer.predict()loop for large-scale inference.
This modular design keeps the research workflow flexible: swap components with configuration changes, extend the factories with new architectures, or plug in custom losses without touching the training loop itself.