Skip to content

Utility Helpers

Reusable helpers for logging, radiometric preprocessing, tensor conversion, and distributed-safe side effects.

Package exports

Utility helpers for the OpenSR SRGAN project.

plot_tensors(lr, sr, hr, title='Train')

Render LR–SR–HR triplets from a batch as a single PIL image.

Performs percentile-based min–max stretching and clamping to [0, 1] on each tensor, converts channel-first images to numpy for plotting, and arranges up to two samples (rows) with three columns (LR, SR, HR) into a matplotlib figure. The figure is then rasterized and returned as a PIL.Image.

Parameters:

Name Type Description Default
lr Tensor

Low-resolution batch tensor of shape (B, C, H, W), values expected in an arbitrary range (stretched internally).

required
sr Tensor

Super-resolved batch tensor of shape (B, C, H, W).

required
hr Tensor

High-resolution (target) batch tensor of shape (B, C, H, W).

required
title str

Figure title placed above the grid. Defaults to "Train".

'Train'

Returns:

Type Description

PIL.Image.Image: RGB image containing a grid with columns [LR | SR | HR]

and up to two rows (first two items of the batch).

Notes
  • Only the first two items of the batch are visualized to avoid large figures.
  • Supports grayscale, RGB, RGBA, and generic multi-channel inputs (first 3 channels shown).
  • This function is side-effect free for tensors (uses .detach() and plots copies), and closes the matplotlib figure after rendering.
Source code in opensr_srgan/utils/logging_helpers.py
def plot_tensors(lr, sr, hr, title="Train"):
    """Render LR–SR–HR triplets from a batch as a single PIL image.

    Performs percentile-based min–max stretching and clamping to [0, 1] on each
    tensor, converts channel-first images to numpy for plotting, and arranges up
    to two samples (rows) with three columns (LR, SR, HR) into a matplotlib
    figure. The figure is then rasterized and returned as a `PIL.Image`.

    Args:
        lr (torch.Tensor): Low-resolution batch tensor of shape `(B, C, H, W)`,
            values expected in an arbitrary range (stretched internally).
        sr (torch.Tensor): Super-resolved batch tensor of shape `(B, C, H, W)`.
        hr (torch.Tensor): High-resolution (target) batch tensor of shape `(B, C, H, W)`.
        title (str, optional): Figure title placed above the grid. Defaults to `"Train"`.

    Returns:
        PIL.Image.Image: RGB image containing a grid with columns `[LR | SR | HR]`
        and up to two rows (first two items of the batch).

    Notes:
        - Only the first two items of the batch are visualized to avoid large figures.
        - Supports grayscale, RGB, RGBA, and generic multi-channel inputs
          (first 3 channels shown).
        - This function is side-effect free for tensors (uses `.detach()` and
          plots copies), and closes the matplotlib figure after rendering.
    """
    # --- denorm(?) + stretch  ---
    lr = minmax_percentile(lr)
    sr = minmax_percentile(sr)
    hr = minmax_percentile(hr)

    # clamp in-place-friendly
    lr, sr, hr = lr.clamp(0, 1), sr.clamp(0, 1), hr.clamp(0, 1)

    # shapes
    B, C, H, W = lr.shape  # (B,C,H,W)

    # Determine colormap for grayscale images
    if C == 1:
        cmap = "bone"
    else:
        cmap = None

    # limit to max_n (images to plot ontop of each other)
    max_n = 2
    if B > max_n:
        lr = lr[:max_n]
        sr = sr[:max_n]
        hr = hr[:max_n]
        B = max_n

    # figure/axes: always 2D array even for B==1
    fixed_width = 15
    variable_height = (15 / 3) * B
    fig, axes = plt.subplots(
        B, 3, figsize=(fixed_width, variable_height), squeeze=False
    )

    # loop over batch
    with torch.no_grad():
        for i in range(B):
            img_lr = _to_numpy_img(lr[i].detach().cpu())
            img_sr = _to_numpy_img(sr[i].detach().cpu())
            img_hr = _to_numpy_img(hr[i].detach().cpu())

            axes[i, 0].imshow(img_lr)
            axes[i, 0].set_title("LR")
            axes[i, 0].axis("off")

            axes[i, 1].imshow(img_sr)
            axes[i, 1].set_title("SR")
            axes[i, 1].axis("off")

            axes[i, 2].imshow(img_hr)
            axes[i, 2].set_title("HR")
            axes[i, 2].axis("off")

    fig.suptitle(title)
    fig.tight_layout()

    # render to PIL
    buf = io.BytesIO()
    fig.savefig(buf, format="png", bbox_inches="tight")
    buf.seek(0)
    pil_image = Image.open(buf).convert("RGB").copy()
    buf.close()
    plt.close(fig)
    return pil_image

print_model_summary(self)

Prints a detailed and visually structured summary of the SRGAN configuration. Includes architecture info, resolution scale, training parameters, loss weights, and model sizes.

Source code in opensr_srgan/utils/model_descriptions.py
def print_model_summary(self):
    """
    Prints a detailed and visually structured summary of the SRGAN configuration.
    Includes architecture info, resolution scale, training parameters, loss weights, and model sizes.
    """

    # --- helpers (millions) ---
    def count_trainable_params(m):
        return sum(p.numel() for p in m.parameters() if p.requires_grad) / 1e6

    def count_total_params(m):
        return sum(p.numel() for p in m.parameters()) / 1e6

    # --- counts ---
    g_trainable = count_trainable_params(self.generator)
    g_total = count_total_params(self.generator)
    d_trainable = count_trainable_params(self.discriminator)
    d_total = count_total_params(self.discriminator)

    # expose the exact names you want to use elsewhere
    g_params = g_trainable  # alias for “trainable G params (M)”
    d_params = d_trainable  # alias for “trainable D params (M)”

    total_trainable = g_trainable + d_trainable
    total_all = g_total + d_total

    # ------------------------------------------------------------------
    # Derive human-readable generator description
    # ------------------------------------------------------------------
    g_type = getattr(self.config.Generator, "model_type", "SRResNet")
    g_type_norm = str(g_type).lower().replace("-", "_")
    block_type = getattr(self.config.Generator, "block_type", None)

    if g_type_norm in {"res", "rcab", "rrdb", "lka"} and block_type is None:
        block_type = g_type_norm
        g_type_norm = "srresnet"

    block_desc_map = {
        "standard": "Residual Blocks with BatchNorm",
        "res": "Residual Blocks without BatchNorm",
        "rcab": "RCAB Blocks with Channel Attention",
        "rrdb": "RRDB Dense Residual Blocks",
        "lka": "LKA Large-Kernel Attention Blocks",
    }

    if g_type_norm in {"srresnet", "sr_resnet"}:
        block_key = "standard" if block_type is None else str(block_type).lower()
        desc = block_desc_map.get(block_key, f"Custom Block Variant: {block_key}")
        g_desc = f"SRResNet ({desc})"
    elif g_type_norm in {"stochastic_gan", "cgan", "conditional_cgan"}:
        g_desc = "Stochastic SRGAN (Latent-modulated Residual Blocks)"
    elif g_type_norm == "esrgan":
        g_desc = "ESRGAN (RRDB Residual-in-Residual Dense Network)"
    else:
        g_desc = f"Custom Generator Type: {g_type}"

    # ------------------------------------------------------------------
    # Resolution info (input → output)
    # ------------------------------------------------------------------
    scale_factor = getattr(
        self.config.Generator, "scaling_factor", getattr(self.generator, "scale", None)
    )
    if scale_factor is not None:
        res_str = f"Super-Resolution Factor: ×{scale_factor}"
    else:
        res_str = "Super-Resolution Factor: Unknown"

    # ------------------------------------------------------------------
    # Retrieve loss weights if available
    # ------------------------------------------------------------------
    loss_cfg = getattr(self.config.Training, "Losses", {})
    content_w = getattr(loss_cfg, "content_loss_weight", 1.0)
    adv_w = getattr(loss_cfg, "adv_loss_beta", 1.0)
    perceptual_w = getattr(loss_cfg, "perceptual_loss_weight", None)
    total_w_str = f"   • Content: {content_w} | Adversarial: {adv_w}" + (
        f" | Perceptual: {perceptual_w}" if perceptual_w is not None else ""
    )

    print("\n" + "=" * 90)
    print("🚀  SRGAN Model Summary")
    print("=" * 90)

    # ------------------------------------------------------------------
    # Generator Info
    # ------------------------------------------------------------------
    print(f"🧩 Generator")
    print(f"   • Architecture:      {g_desc}")
    print(f"   • Resolution:        {res_str}")
    print(f"   • Input Channels:    {self.config.Model.in_bands}")
    feature_channels = getattr(self.config.Generator, "n_channels", None)
    if feature_channels is not None:
        print(f"   • Feature Channels:  {feature_channels}")

    block_count = getattr(
        self.config.Generator,
        "n_blocks",
        getattr(self.generator, "n_blocks", None),
    )
    if block_count is not None:
        print(f"   • Residual Blocks:   {block_count}")

    small_kernel = getattr(self.config.Generator, "small_kernel_size", None)
    large_kernel = getattr(self.config.Generator, "large_kernel_size", None)
    if small_kernel is not None or large_kernel is not None:
        print(
            "   • Kernel Sizes:      small={small}, large={large}".format(
                small=small_kernel if small_kernel is not None else "N/A",
                large=large_kernel if large_kernel is not None else "N/A",
            )
        )
    print(f"   • Params:            {g_params:.2f} M\n")

    # ------------------------------------------------------------------
    # Discriminator Info
    # ------------------------------------------------------------------
    d_type = getattr(self.config.Discriminator, "model_type", "standard")
    d_blocks = getattr(self.config.Discriminator, "n_blocks", None)
    effective_blocks = getattr(
        self.discriminator,
        "n_blocks",
        getattr(self.discriminator, "n_layers", d_blocks),
    )
    base_channels = getattr(self.discriminator, "base_channels", "N/A")
    kernel_size = getattr(self.discriminator, "kernel_size", "N/A")
    fc_size = getattr(self.discriminator, "fc_size", None)

    if d_type == "patchgan":
        d_desc = "PatchGAN"
    elif d_type == "esrgan":
        d_desc = "ESRGAN"
    else:
        d_desc = "SRGAN"

    print(f"🧠 Discriminator")
    print(f"   • Architecture:     {d_desc}")
    if effective_blocks is not None:
        print(f"   • Blocks/Layers:    {effective_blocks}")
    print(f"   • Base Channels:    {base_channels}")
    print(f"   • Kernel Size:      {kernel_size}")
    if fc_size is not None:
        print(f"   • FC Layer Size:    {fc_size}")
    print(f"   • Params:            {d_params:.2f} M\n")

    # ------------------------------------------------------------------
    # Training Setup
    # ------------------------------------------------------------------
    print(f"⚙️  Training Configuration")
    print(f"   • Pretrain Generator: {self.pretrain_g_only}")
    print(f"   • Pretrain Steps:     {self.g_pretrain_steps}")
    print(f"   • Adv. Ramp Steps:    {self.adv_loss_ramp_steps}")
    print(f"   • Label Smoothing:    {self.adv_target < 1.0}")
    print(f"   • Adv. Target Label:  {self.adv_target}\n")

    # ------------------------------------------------------------------
    # EMA Configuration (if present)
    # ------------------------------------------------------------------
    ema_cfg = getattr(getattr(self.config.Training, "EMA", None), "__dict__", None)
    if ema_cfg or hasattr(self.config.Training, "EMA"):
        ema = self.config.Training.EMA
        enabled = getattr(ema, "enabled", False)
        decay = getattr(ema, "decay", None)
        update_after = getattr(ema, "update_after_step", None)
        use_num_updates = getattr(ema, "use_num_updates", None)

        print(f"🔁 Exponential Moving Average (EMA)")
        print(f"   • Enabled:           {enabled}")
        if enabled:
            print(f"   • Decay:             {decay}")
            print(f"   • Update After Step: {update_after}")
            print(f"   • Use Num Updates:   {use_num_updates}")
        print()  # newline for spacing

    # ------------------------------------------------------------------
    # Loss Functions
    # ------------------------------------------------------------------
    print(f"📉 Loss Functions")
    print(f"   • Content Loss:       {type(self.content_loss_criterion).__name__}")
    print(f"   • Adversarial Loss:   {type(self.adversarial_loss_criterion).__name__}")
    print(total_w_str + "\n")

    # ------------------------------------------------------------------
    # Summary
    # ------------------------------------------------------------------
    print(f"📊 Model Summary")
    print(f"   • Total Params:           {total_all:.2f} M")
    print(f"   • Total Trainable Params: {total_trainable:.2f} M")
    print(
        f"   • Device:                 {self.device if hasattr(self, 'device') else 'Not set'}"
    )
    print("=" * 90 + "\n")

histogram(reference, target)

Perform per-channel histogram matching of targetreference.

Each channel in the target image is adjusted so that its cumulative distribution function (CDF) matches that of the corresponding channel in the reference image. This preserves overall color/radiometric tone relationships, but aligns the pixel intensity distributions more precisely than simple moment matching.

Supports both single images and batched tensors.

Parameters

reference : torch.Tensor Reference image or batch, shape (C, H, W) or (B, C, H, W). Its histogram will be used as the target distribution. target : torch.Tensor Target image or batch to be adjusted, shape (C, H, W) or (B, C, H, W). Must have the same number of channels as reference.

Returns

torch.Tensor Histogram-matched version of the target, with the same shape and dtype as the input. If a single image is given, returns shape (C, H, W).

Source code in opensr_srgan/utils/radiometrics.py
def histogram(reference: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """
    Perform **per-channel histogram matching** of `target` → `reference`.

    Each channel in the target image is adjusted so that its cumulative
    distribution function (CDF) matches that of the corresponding channel
    in the reference image.
    This preserves overall color/radiometric tone relationships, but aligns
    the pixel intensity distributions more precisely than simple moment matching.

    Supports both single images and batched tensors.

    Parameters
    ----------
    reference : torch.Tensor
        Reference image or batch, shape (C, H, W) or (B, C, H, W).
        Its histogram will be used as the target distribution.
    target : torch.Tensor
        Target image or batch to be adjusted, shape (C, H, W) or (B, C, H, W).
        Must have the same number of channels as `reference`.

    Returns
    -------
    torch.Tensor
        Histogram-matched version of the target, with the same shape and dtype
        as the input. If a single image is given, returns shape (C, H, W).
    """

    # Ensure both inputs have correct dimensionality: either (C,H,W) or (B,C,H,W)
    assert target.ndim in (3, 4) and reference.ndim in (
        3,
        4,
    ), "Expected (C,H,W) or (B,C,H,W) for both reference and target"

    # Save device/dtype for conversion back later
    device, dtype = target.device, target.dtype

    # --- Normalize to batch form (always 4D: B,C,H,W) ---
    # If inputs are 3D (single images), temporarily add batch dimension
    ref = reference.unsqueeze(0) if reference.ndim == 3 else reference
    tgt = target.unsqueeze(0) if target.ndim == 3 else target

    # Extract shapes
    B_ref, C_ref, H_ref, W_ref = ref.shape
    B_tgt, C_tgt, H_tgt, W_tgt = tgt.shape

    # Channel sanity check
    assert C_ref == C_tgt, f"Channel mismatch: reference={C_ref}, target={C_tgt}"

    # --- Resize reference spatially to match target ---
    # Uses bilinear interpolation, no corner alignment, safe for float data
    if (H_ref, W_ref) != (H_tgt, W_tgt):
        ref = F.interpolate(
            ref.to(dtype=torch.float32),
            size=(H_tgt, W_tgt),
            mode="bilinear",
            align_corners=False,
        )

    # Convert to NumPy for histogram matching operations
    ref_np = tensor_to_numpy(ref)
    tgt_np = tensor_to_numpy(tgt)
    out_np = np.empty_like(tgt_np)  # preallocate output array

    # --- Loop over batches and channels ---
    for b in range(B_tgt):
        # If reference has only one batch (B_ref=1), broadcast it to all targets
        rb = b % B_ref

        for c in range(C_tgt):
            ref_ch = ref_np[rb, c]  # reference channel
            tgt_ch = tgt_np[b, c]  # target channel

            # Mask invalid pixels (NaN or Inf)
            mask = np.isfinite(tgt_ch) & np.isfinite(ref_ch)

            if mask.any():
                # Perform per-channel histogram matching
                matched = exposure.match_histograms(tgt_ch[mask], ref_ch[mask])
                out = tgt_ch.copy()
                out[mask] = matched
                out_np[b, c] = out
            else:
                # If no valid pixels, copy target as-is
                out_np[b, c] = tgt_ch

    # Convert back to torch tensor on the original device/dtype
    out = torch.from_numpy(out_np).to(device=device, dtype=dtype)

    # If the original input was 3D, remove the temporary batch dimension
    return out[0] if target.ndim == 3 else out

normalise_10k(im, stage='norm')

Normalize or denormalize Sentinel-2 data scaled in units of 10,000.

This is the most common scaling for Sentinel-2 L2A reflectance data, where reflectance = DN / 10000.

Parameters

im : torch.Tensor Input tensor (any shape), expected to contain DN values ~[0, 10000]. stage : {"norm", "denorm"} - "norm" → divide by 10,000 to map to [0, 1] - "denorm" → multiply by 10,000 to restore original scale

Returns

torch.Tensor Scaled tensor.

Source code in opensr_srgan/utils/radiometrics.py
def normalise_10k(im: torch.Tensor, stage: str = "norm") -> torch.Tensor:
    """
    Normalize or denormalize Sentinel-2 data scaled in units of 10,000.

    This is the most common scaling for Sentinel-2 L2A reflectance data,
    where reflectance = DN / 10000.

    Parameters
    ----------
    im : torch.Tensor
        Input tensor (any shape), expected to contain DN values ~[0, 10000].
    stage : {"norm", "denorm"}
        - "norm"   → divide by 10,000 to map to [0, 1]
        - "denorm" → multiply by 10,000 to restore original scale

    Returns
    -------
    torch.Tensor
        Scaled tensor.
    """
    assert stage in ["norm", "denorm"]

    if stage == "norm":
        im = im / 10000.0
        im = torch.clamp(im, 0, 1)
    else:  # "denorm"
        im = im * 10000.0
        im = torch.clamp(im, 0, 10000)

    return im

normalise_10k_signed(im, stage='norm')

Normalize Sentinel-2 DN values to the symmetric [-1, 1] range.

This helper chains :func:normalise_10k ([0, 10000][0, 1]) with :func:zero_one_signed ([0, 1][-1, 1]) so models that expect signed inputs receive the conventional [-1, 1] distribution.

Parameters

im : torch.Tensor Input tensor containing reflectance-like values expressed as [0, 10000]. stage : {"norm", "denorm"} Normalisation stage. "denorm" restores the original [0, 10000] range.

Returns

torch.Tensor Tensor in the requested range.

Source code in opensr_srgan/utils/radiometrics.py
def normalise_10k_signed(im: torch.Tensor, stage: str = "norm") -> torch.Tensor:
    """Normalize Sentinel-2 DN values to the symmetric ``[-1, 1]`` range.

    This helper chains :func:`normalise_10k` (``[0, 10000]`` → ``[0, 1]``) with
    :func:`zero_one_signed` (``[0, 1]`` → ``[-1, 1]``) so models that expect
    signed inputs receive the conventional ``[-1, 1]`` distribution.

    Parameters
    ----------
    im : torch.Tensor
        Input tensor containing reflectance-like values expressed as
        ``[0, 10000]``.
    stage : {"norm", "denorm"}
        Normalisation stage. ``"denorm"`` restores the original ``[0, 10000]``
        range.

    Returns
    -------
    torch.Tensor
        Tensor in the requested range.
    """

    assert stage in ["norm", "denorm"]

    if stage == "norm":
        scaled = normalise_10k(im, stage="norm")
        return zero_one_signed(scaled, stage="norm")
    scaled = zero_one_signed(im, stage="denorm")
    return normalise_10k(scaled, stage="denorm")

normalise_s2(im, stage='norm')

Normalize or denormalize Sentinel-2 image values.

This function applies a symmetric scaling to map reflectance-like values to the range [-1, 1] for model input, and reverses it for visualization or saving.

Parameters

im : torch.Tensor Input image tensor (any shape), typically reflectance-scaled. stage : {"norm", "denorm"} - "norm" → normalize image to [-1, 1] - "denorm" → reverse normalization back to [0, 1]

Returns

torch.Tensor The normalized or denormalized image tensor.

Source code in opensr_srgan/utils/radiometrics.py
def normalise_s2(im: torch.Tensor, stage: str = "norm") -> torch.Tensor:
    """
    Normalize or denormalize Sentinel-2 image values.

    This function applies a symmetric scaling to map reflectance-like values
    to the range [-1, 1] for model input, and reverses it for visualization
    or saving.

    Parameters
    ----------
    im : torch.Tensor
        Input image tensor (any shape), typically reflectance-scaled.
    stage : {"norm", "denorm"}
        - "norm"   → normalize image to [-1, 1]
        - "denorm" → reverse normalization back to [0, 1]

    Returns
    -------
    torch.Tensor
        The normalized or denormalized image tensor.
    """
    assert stage in ["norm", "denorm"]
    value = 3.0  # reference scaling factor

    if stage == "norm":
        # Scale roughly from [0, value/10] → [0, 1] → [-1, 1]
        im = im * (10.0 / value)
        im = (im * 2) - 1
        im = torch.clamp(im, -1, 1)
    else:  # stage == "denorm"
        # Reverse mapping: [-1, 1] → [0, 1] → [0, value/10]
        im = (im + 1) / 2
        im = im * (value / 10.0)
        im = torch.clamp(im, 0, 1)

    return im

zero_one_signed(im, stage='norm')

Convert values between the [0, 1] and [-1, 1] ranges.

Parameters

im : torch.Tensor Tensor containing values already scaled to [0, 1]. stage : {"norm", "denorm"} - "norm" → map [0, 1][-1, 1] using im * 2 - 1. - "denorm" → map [-1, 1][0, 1] using (im + 1) / 2.

Returns

torch.Tensor Range-adjusted tensor.

Source code in opensr_srgan/utils/radiometrics.py
def zero_one_signed(im: torch.Tensor, stage: str = "norm") -> torch.Tensor:
    """Convert values between the ``[0, 1]`` and ``[-1, 1]`` ranges.

    Parameters
    ----------
    im : torch.Tensor
        Tensor containing values already scaled to ``[0, 1]``.
    stage : {"norm", "denorm"}
        - ``"norm"``   → map ``[0, 1]`` → ``[-1, 1]`` using ``im * 2 - 1``.
        - ``"denorm"`` → map ``[-1, 1]`` → ``[0, 1]`` using ``(im + 1) / 2``.

    Returns
    -------
    torch.Tensor
        Range-adjusted tensor.
    """

    assert stage in ["norm", "denorm"]

    if stage == "norm":
        return torch.clamp((im * 2.0) - 1.0, -1.0, 1.0)
    return torch.clamp((im + 1.0) / 2.0, 0.0, 1.0)

Logging helpers

_tensor_to_plot_data(t)

Convert a CPU tensor to a NumPy array suitable for matplotlib visualization.

Ensures contiguous memory layout before conversion and detaches the tensor from the computation graph. Typically used for final image preparation in validation or inference visualizations.

Parameters:

Name Type Description Default
t Tensor

Input tensor to convert. Can have arbitrary shape.

required

Returns:

Type Description

numpy.ndarray: NumPy array representation of the tensor, compatible

with matplotlib plotting functions.

Source code in opensr_srgan/utils/logging_helpers.py
def _tensor_to_plot_data(t: torch.Tensor):
    """Convert a CPU tensor to a NumPy array suitable for matplotlib visualization.

    Ensures contiguous memory layout before conversion and detaches the tensor
    from the computation graph. Typically used for final image preparation
    in validation or inference visualizations.

    Args:
        t (torch.Tensor): Input tensor to convert. Can have arbitrary shape.

    Returns:
        numpy.ndarray: NumPy array representation of the tensor, compatible
        with matplotlib plotting functions.
    """
    return tensor_to_numpy(t.contiguous())

_to_numpy_img(t)

Convert a normalized image tensor in (C, H, W) format to a NumPy array.

Supports grayscale, RGB, and RGBA images, as well as arbitrary multichannel tensors (which are converted by permuting to H×W×C).

Parameters:

Name Type Description Default
t Tensor

Input image tensor in channel-first (C, H, W) format, with pixel values expected in the range [0, 1].

required

Returns:

Type Description

numpy.ndarray: NumPy array in channel-last format (H, W, C) suitable

for matplotlib visualization.

Raises:

Type Description
ValueError

If the input tensor does not have exactly three dimensions.

Source code in opensr_srgan/utils/logging_helpers.py
def _to_numpy_img(t: torch.Tensor):
    """Convert a normalized image tensor in (C, H, W) format to a NumPy array.

    Supports grayscale, RGB, and RGBA images, as well as arbitrary
    multichannel tensors (which are converted by permuting to H×W×C).

    Args:
        t (torch.Tensor): Input image tensor in channel-first (C, H, W) format,
            with pixel values expected in the range [0, 1].

    Returns:
        numpy.ndarray: NumPy array in channel-last format (H, W, C) suitable
        for matplotlib visualization.

    Raises:
        ValueError: If the input tensor does not have exactly three dimensions.
    """
    if t.dim() != 3:
        raise ValueError(f"Expected (C,H,W), got {tuple(t.shape)}")
    C, H, W = t.shape
    t = t.detach().clamp(0, 1)
    if C == 1:
        out = _tensor_to_plot_data(t[0])
        return out  # grayscale

    if C in (3, 4):
        rgb = t[:3]
        out = _tensor_to_plot_data(rgb.permute(1, 2, 0))
        return out

    return _tensor_to_plot_data(t.permute(1, 2, 0))

plot_tensors(lr, sr, hr, title='Train')

Render LR–SR–HR triplets from a batch as a single PIL image.

Performs percentile-based min–max stretching and clamping to [0, 1] on each tensor, converts channel-first images to numpy for plotting, and arranges up to two samples (rows) with three columns (LR, SR, HR) into a matplotlib figure. The figure is then rasterized and returned as a PIL.Image.

Parameters:

Name Type Description Default
lr Tensor

Low-resolution batch tensor of shape (B, C, H, W), values expected in an arbitrary range (stretched internally).

required
sr Tensor

Super-resolved batch tensor of shape (B, C, H, W).

required
hr Tensor

High-resolution (target) batch tensor of shape (B, C, H, W).

required
title str

Figure title placed above the grid. Defaults to "Train".

'Train'

Returns:

Type Description

PIL.Image.Image: RGB image containing a grid with columns [LR | SR | HR]

and up to two rows (first two items of the batch).

Notes
  • Only the first two items of the batch are visualized to avoid large figures.
  • Supports grayscale, RGB, RGBA, and generic multi-channel inputs (first 3 channels shown).
  • This function is side-effect free for tensors (uses .detach() and plots copies), and closes the matplotlib figure after rendering.
Source code in opensr_srgan/utils/logging_helpers.py
def plot_tensors(lr, sr, hr, title="Train"):
    """Render LR–SR–HR triplets from a batch as a single PIL image.

    Performs percentile-based min–max stretching and clamping to [0, 1] on each
    tensor, converts channel-first images to numpy for plotting, and arranges up
    to two samples (rows) with three columns (LR, SR, HR) into a matplotlib
    figure. The figure is then rasterized and returned as a `PIL.Image`.

    Args:
        lr (torch.Tensor): Low-resolution batch tensor of shape `(B, C, H, W)`,
            values expected in an arbitrary range (stretched internally).
        sr (torch.Tensor): Super-resolved batch tensor of shape `(B, C, H, W)`.
        hr (torch.Tensor): High-resolution (target) batch tensor of shape `(B, C, H, W)`.
        title (str, optional): Figure title placed above the grid. Defaults to `"Train"`.

    Returns:
        PIL.Image.Image: RGB image containing a grid with columns `[LR | SR | HR]`
        and up to two rows (first two items of the batch).

    Notes:
        - Only the first two items of the batch are visualized to avoid large figures.
        - Supports grayscale, RGB, RGBA, and generic multi-channel inputs
          (first 3 channels shown).
        - This function is side-effect free for tensors (uses `.detach()` and
          plots copies), and closes the matplotlib figure after rendering.
    """
    # --- denorm(?) + stretch  ---
    lr = minmax_percentile(lr)
    sr = minmax_percentile(sr)
    hr = minmax_percentile(hr)

    # clamp in-place-friendly
    lr, sr, hr = lr.clamp(0, 1), sr.clamp(0, 1), hr.clamp(0, 1)

    # shapes
    B, C, H, W = lr.shape  # (B,C,H,W)

    # Determine colormap for grayscale images
    if C == 1:
        cmap = "bone"
    else:
        cmap = None

    # limit to max_n (images to plot ontop of each other)
    max_n = 2
    if B > max_n:
        lr = lr[:max_n]
        sr = sr[:max_n]
        hr = hr[:max_n]
        B = max_n

    # figure/axes: always 2D array even for B==1
    fixed_width = 15
    variable_height = (15 / 3) * B
    fig, axes = plt.subplots(
        B, 3, figsize=(fixed_width, variable_height), squeeze=False
    )

    # loop over batch
    with torch.no_grad():
        for i in range(B):
            img_lr = _to_numpy_img(lr[i].detach().cpu())
            img_sr = _to_numpy_img(sr[i].detach().cpu())
            img_hr = _to_numpy_img(hr[i].detach().cpu())

            axes[i, 0].imshow(img_lr)
            axes[i, 0].set_title("LR")
            axes[i, 0].axis("off")

            axes[i, 1].imshow(img_sr)
            axes[i, 1].set_title("SR")
            axes[i, 1].axis("off")

            axes[i, 2].imshow(img_hr)
            axes[i, 2].set_title("HR")
            axes[i, 2].axis("off")

    fig.suptitle(title)
    fig.tight_layout()

    # render to PIL
    buf = io.BytesIO()
    fig.savefig(buf, format="png", bbox_inches="tight")
    buf.seek(0)
    pil_image = Image.open(buf).convert("RGB").copy()
    buf.close()
    plt.close(fig)
    return pil_image

Radiometric transforms

normalise_s2(im, stage='norm')

Normalize or denormalize Sentinel-2 image values.

This function applies a symmetric scaling to map reflectance-like values to the range [-1, 1] for model input, and reverses it for visualization or saving.

Parameters

im : torch.Tensor Input image tensor (any shape), typically reflectance-scaled. stage : {"norm", "denorm"} - "norm" → normalize image to [-1, 1] - "denorm" → reverse normalization back to [0, 1]

Returns

torch.Tensor The normalized or denormalized image tensor.

Source code in opensr_srgan/utils/radiometrics.py
def normalise_s2(im: torch.Tensor, stage: str = "norm") -> torch.Tensor:
    """
    Normalize or denormalize Sentinel-2 image values.

    This function applies a symmetric scaling to map reflectance-like values
    to the range [-1, 1] for model input, and reverses it for visualization
    or saving.

    Parameters
    ----------
    im : torch.Tensor
        Input image tensor (any shape), typically reflectance-scaled.
    stage : {"norm", "denorm"}
        - "norm"   → normalize image to [-1, 1]
        - "denorm" → reverse normalization back to [0, 1]

    Returns
    -------
    torch.Tensor
        The normalized or denormalized image tensor.
    """
    assert stage in ["norm", "denorm"]
    value = 3.0  # reference scaling factor

    if stage == "norm":
        # Scale roughly from [0, value/10] → [0, 1] → [-1, 1]
        im = im * (10.0 / value)
        im = (im * 2) - 1
        im = torch.clamp(im, -1, 1)
    else:  # stage == "denorm"
        # Reverse mapping: [-1, 1] → [0, 1] → [0, value/10]
        im = (im + 1) / 2
        im = im * (value / 10.0)
        im = torch.clamp(im, 0, 1)

    return im

normalise_10k(im, stage='norm')

Normalize or denormalize Sentinel-2 data scaled in units of 10,000.

This is the most common scaling for Sentinel-2 L2A reflectance data, where reflectance = DN / 10000.

Parameters

im : torch.Tensor Input tensor (any shape), expected to contain DN values ~[0, 10000]. stage : {"norm", "denorm"} - "norm" → divide by 10,000 to map to [0, 1] - "denorm" → multiply by 10,000 to restore original scale

Returns

torch.Tensor Scaled tensor.

Source code in opensr_srgan/utils/radiometrics.py
def normalise_10k(im: torch.Tensor, stage: str = "norm") -> torch.Tensor:
    """
    Normalize or denormalize Sentinel-2 data scaled in units of 10,000.

    This is the most common scaling for Sentinel-2 L2A reflectance data,
    where reflectance = DN / 10000.

    Parameters
    ----------
    im : torch.Tensor
        Input tensor (any shape), expected to contain DN values ~[0, 10000].
    stage : {"norm", "denorm"}
        - "norm"   → divide by 10,000 to map to [0, 1]
        - "denorm" → multiply by 10,000 to restore original scale

    Returns
    -------
    torch.Tensor
        Scaled tensor.
    """
    assert stage in ["norm", "denorm"]

    if stage == "norm":
        im = im / 10000.0
        im = torch.clamp(im, 0, 1)
    else:  # "denorm"
        im = im * 10000.0
        im = torch.clamp(im, 0, 10000)

    return im

zero_one_signed(im, stage='norm')

Convert values between the [0, 1] and [-1, 1] ranges.

Parameters

im : torch.Tensor Tensor containing values already scaled to [0, 1]. stage : {"norm", "denorm"} - "norm" → map [0, 1][-1, 1] using im * 2 - 1. - "denorm" → map [-1, 1][0, 1] using (im + 1) / 2.

Returns

torch.Tensor Range-adjusted tensor.

Source code in opensr_srgan/utils/radiometrics.py
def zero_one_signed(im: torch.Tensor, stage: str = "norm") -> torch.Tensor:
    """Convert values between the ``[0, 1]`` and ``[-1, 1]`` ranges.

    Parameters
    ----------
    im : torch.Tensor
        Tensor containing values already scaled to ``[0, 1]``.
    stage : {"norm", "denorm"}
        - ``"norm"``   → map ``[0, 1]`` → ``[-1, 1]`` using ``im * 2 - 1``.
        - ``"denorm"`` → map ``[-1, 1]`` → ``[0, 1]`` using ``(im + 1) / 2``.

    Returns
    -------
    torch.Tensor
        Range-adjusted tensor.
    """

    assert stage in ["norm", "denorm"]

    if stage == "norm":
        return torch.clamp((im * 2.0) - 1.0, -1.0, 1.0)
    return torch.clamp((im + 1.0) / 2.0, 0.0, 1.0)

normalise_10k_signed(im, stage='norm')

Normalize Sentinel-2 DN values to the symmetric [-1, 1] range.

This helper chains :func:normalise_10k ([0, 10000][0, 1]) with :func:zero_one_signed ([0, 1][-1, 1]) so models that expect signed inputs receive the conventional [-1, 1] distribution.

Parameters

im : torch.Tensor Input tensor containing reflectance-like values expressed as [0, 10000]. stage : {"norm", "denorm"} Normalisation stage. "denorm" restores the original [0, 10000] range.

Returns

torch.Tensor Tensor in the requested range.

Source code in opensr_srgan/utils/radiometrics.py
def normalise_10k_signed(im: torch.Tensor, stage: str = "norm") -> torch.Tensor:
    """Normalize Sentinel-2 DN values to the symmetric ``[-1, 1]`` range.

    This helper chains :func:`normalise_10k` (``[0, 10000]`` → ``[0, 1]``) with
    :func:`zero_one_signed` (``[0, 1]`` → ``[-1, 1]``) so models that expect
    signed inputs receive the conventional ``[-1, 1]`` distribution.

    Parameters
    ----------
    im : torch.Tensor
        Input tensor containing reflectance-like values expressed as
        ``[0, 10000]``.
    stage : {"norm", "denorm"}
        Normalisation stage. ``"denorm"`` restores the original ``[0, 10000]``
        range.

    Returns
    -------
    torch.Tensor
        Tensor in the requested range.
    """

    assert stage in ["norm", "denorm"]

    if stage == "norm":
        scaled = normalise_10k(im, stage="norm")
        return zero_one_signed(scaled, stage="norm")
    scaled = zero_one_signed(im, stage="denorm")
    return normalise_10k(scaled, stage="denorm")

sen2_stretch(im)

Apply a simple contrast stretch to Sentinel-2 data.

Multiplies reflectance values by (10/3) ≈ 3.33 to increase dynamic range for visualization or augmentation purposes.

Parameters

im : torch.Tensor Sentinel-2 tensor (any shape).

Returns

torch.Tensor Contrast-stretched image tensor.

Source code in opensr_srgan/utils/radiometrics.py
def sen2_stretch(im: torch.Tensor) -> torch.Tensor:
    """
    Apply a simple contrast stretch to Sentinel-2 data.

    Multiplies reflectance values by (10/3) ≈ 3.33 to increase dynamic range
    for visualization or augmentation purposes.

    Parameters
    ----------
    im : torch.Tensor
        Sentinel-2 tensor (any shape).

    Returns
    -------
    torch.Tensor
        Contrast-stretched image tensor.
    """
    stretched = im * (10 / 3.0)
    return torch.clamp(stretched, 0.0, 1.0)

minmax_percentile(tensor, pmin=2, pmax=98)

Perform percentile-based min-max normalization to [0, 1].

Uses quantiles instead of absolute min/max to reduce outlier influence.

Parameters

tensor : torch.Tensor Input tensor (any shape). pmin : float Lower percentile (default 2%). pmax : float Upper percentile (default 98%).

Returns

torch.Tensor Tensor scaled to [0, 1] based on percentile range.

Source code in opensr_srgan/utils/radiometrics.py
def minmax_percentile(
    tensor: torch.Tensor, pmin: float = 2, pmax: float = 98
) -> torch.Tensor:
    """
    Perform percentile-based min-max normalization to [0, 1].

    Uses quantiles instead of absolute min/max to reduce outlier influence.

    Parameters
    ----------
    tensor : torch.Tensor
        Input tensor (any shape).
    pmin : float
        Lower percentile (default 2%).
    pmax : float
        Upper percentile (default 98%).

    Returns
    -------
    torch.Tensor
        Tensor scaled to [0, 1] based on percentile range.
    """
    min_val = torch.quantile(tensor, pmin / 100.0)
    max_val = torch.quantile(tensor, pmax / 100.0)
    tensor = (tensor - min_val) / (max_val - min_val)
    return tensor

minmax(img)

Standard min-max normalization to [0, 1] over the entire tensor.

Parameters

img : torch.Tensor Input tensor (any shape).

Returns

torch.Tensor Min-max normalized tensor.

Source code in opensr_srgan/utils/radiometrics.py
def minmax(img: torch.Tensor) -> torch.Tensor:
    """
    Standard min-max normalization to [0, 1] over the entire tensor.

    Parameters
    ----------
    img : torch.Tensor
        Input tensor (any shape).

    Returns
    -------
    torch.Tensor
        Min-max normalized tensor.
    """
    min_val = torch.min(img)
    max_val = torch.max(img)
    normalized_img = (img - min_val) / (max_val - min_val)
    return normalized_img

histogram(reference, target)

Perform per-channel histogram matching of targetreference.

Each channel in the target image is adjusted so that its cumulative distribution function (CDF) matches that of the corresponding channel in the reference image. This preserves overall color/radiometric tone relationships, but aligns the pixel intensity distributions more precisely than simple moment matching.

Supports both single images and batched tensors.

Parameters

reference : torch.Tensor Reference image or batch, shape (C, H, W) or (B, C, H, W). Its histogram will be used as the target distribution. target : torch.Tensor Target image or batch to be adjusted, shape (C, H, W) or (B, C, H, W). Must have the same number of channels as reference.

Returns

torch.Tensor Histogram-matched version of the target, with the same shape and dtype as the input. If a single image is given, returns shape (C, H, W).

Source code in opensr_srgan/utils/radiometrics.py
def histogram(reference: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """
    Perform **per-channel histogram matching** of `target` → `reference`.

    Each channel in the target image is adjusted so that its cumulative
    distribution function (CDF) matches that of the corresponding channel
    in the reference image.
    This preserves overall color/radiometric tone relationships, but aligns
    the pixel intensity distributions more precisely than simple moment matching.

    Supports both single images and batched tensors.

    Parameters
    ----------
    reference : torch.Tensor
        Reference image or batch, shape (C, H, W) or (B, C, H, W).
        Its histogram will be used as the target distribution.
    target : torch.Tensor
        Target image or batch to be adjusted, shape (C, H, W) or (B, C, H, W).
        Must have the same number of channels as `reference`.

    Returns
    -------
    torch.Tensor
        Histogram-matched version of the target, with the same shape and dtype
        as the input. If a single image is given, returns shape (C, H, W).
    """

    # Ensure both inputs have correct dimensionality: either (C,H,W) or (B,C,H,W)
    assert target.ndim in (3, 4) and reference.ndim in (
        3,
        4,
    ), "Expected (C,H,W) or (B,C,H,W) for both reference and target"

    # Save device/dtype for conversion back later
    device, dtype = target.device, target.dtype

    # --- Normalize to batch form (always 4D: B,C,H,W) ---
    # If inputs are 3D (single images), temporarily add batch dimension
    ref = reference.unsqueeze(0) if reference.ndim == 3 else reference
    tgt = target.unsqueeze(0) if target.ndim == 3 else target

    # Extract shapes
    B_ref, C_ref, H_ref, W_ref = ref.shape
    B_tgt, C_tgt, H_tgt, W_tgt = tgt.shape

    # Channel sanity check
    assert C_ref == C_tgt, f"Channel mismatch: reference={C_ref}, target={C_tgt}"

    # --- Resize reference spatially to match target ---
    # Uses bilinear interpolation, no corner alignment, safe for float data
    if (H_ref, W_ref) != (H_tgt, W_tgt):
        ref = F.interpolate(
            ref.to(dtype=torch.float32),
            size=(H_tgt, W_tgt),
            mode="bilinear",
            align_corners=False,
        )

    # Convert to NumPy for histogram matching operations
    ref_np = tensor_to_numpy(ref)
    tgt_np = tensor_to_numpy(tgt)
    out_np = np.empty_like(tgt_np)  # preallocate output array

    # --- Loop over batches and channels ---
    for b in range(B_tgt):
        # If reference has only one batch (B_ref=1), broadcast it to all targets
        rb = b % B_ref

        for c in range(C_tgt):
            ref_ch = ref_np[rb, c]  # reference channel
            tgt_ch = tgt_np[b, c]  # target channel

            # Mask invalid pixels (NaN or Inf)
            mask = np.isfinite(tgt_ch) & np.isfinite(ref_ch)

            if mask.any():
                # Perform per-channel histogram matching
                matched = exposure.match_histograms(tgt_ch[mask], ref_ch[mask])
                out = tgt_ch.copy()
                out[mask] = matched
                out_np[b, c] = out
            else:
                # If no valid pixels, copy target as-is
                out_np[b, c] = tgt_ch

    # Convert back to torch tensor on the original device/dtype
    out = torch.from_numpy(out_np).to(device=device, dtype=dtype)

    # If the original input was 3D, remove the temporary batch dimension
    return out[0] if target.ndim == 3 else out

moment(reference, target)

Perform moment matching between two multispectral image tensors.

Each channel in the target image is rescaled to match the mean and standard deviation (first and second moments) of the corresponding channel in the reference image. This operation effectively transfers the global radiometric statistics (brightness and contrast) from reference to target.

Parameters

reference : torch.Tensor Reference image whose per-channel statistics will be matched (e.g. Sentinel-2), shape (C, H, W). target : torch.Tensor Target image to be adjusted (e.g. SPOT-6), shape (C, H, W). reference_amount : float, optional Currently unused. Can later be used to control blending strength between the original target and the moment-matched output (0–1).

Returns

torch.Tensor Moment-matched image with shape (C, H, W), where each target channel now has the same mean and standard deviation as the corresponding reference channel.

Source code in opensr_srgan/utils/radiometrics.py
def moment(reference: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """
    Perform **moment matching** between two multispectral image tensors.

    Each channel in the `target` image is rescaled to match the mean and
    standard deviation (first and second moments) of the corresponding channel
    in the `reference` image.
    This operation effectively transfers the global radiometric statistics
    (brightness and contrast) from `reference` to `target`.

    Parameters
    ----------
    reference : torch.Tensor
        Reference image whose per-channel statistics will be matched (e.g. Sentinel-2),
        shape (C, H, W).
    target : torch.Tensor
        Target image to be adjusted (e.g. SPOT-6),
        shape (C, H, W).
    reference_amount : float, optional
        Currently unused. Can later be used to control blending strength between
        the original target and the moment-matched output (0–1).

    Returns
    -------
    torch.Tensor
        Moment-matched image with shape (C, H, W),
        where each target channel now has the same mean and standard deviation
        as the corresponding reference channel.
    """

    device, dtype = target.device, target.dtype

    # Convert to NumPy arrays for easier numerical processing
    reference_np = tensor_to_numpy(reference)
    target_np = tensor_to_numpy(target)

    matched_channels = []

    # Iterate channel-wise through reference and target
    for ref_ch, tgt_ch in zip(reference_np, target_np):

        # --- Compute per-channel mean and std ---
        ref_mean = np.mean(ref_ch)
        tgt_mean = np.mean(tgt_ch)
        ref_std = np.std(ref_ch)
        tgt_std = np.std(tgt_ch)

        # --- Apply moment matching formula ---
        # Normalize target → scale by reference std → shift by reference mean
        matched_channel = (((tgt_ch - tgt_mean) / tgt_std) * ref_std) + ref_mean
        matched_channels.append(matched_channel)

    matched_np = np.stack(matched_channels, axis=0)

    # Convert back to PyTorch tensor with channel-first format (C, H, W)
    matched = torch.from_numpy(matched_np).to(device=device, dtype=dtype)

    return matched

Tensor conversions

Fallback helpers for converting Torch tensors to NumPy arrays.

tensor_to_numpy(tensor)

Convert a PyTorch tensor to a NumPy array safely across environments.

Provides a robust fallback for minimal or CPU-only PyTorch builds that lack NumPy bindings (where tensor.numpy() raises RuntimeError: Numpy is not available). Ensures tensors are detached, moved to CPU, and contiguous before conversion.

Parameters:

Name Type Description Default
tensor Tensor

Input tensor to convert.

required

Returns:

Type Description
ndarray

numpy.ndarray: NumPy array representation of the tensor.

ndarray

Falls back to tensor.tolist() conversion if direct bindings are unavailable.

Raises:

Type Description
RuntimeError

Re-raises conversion errors not related to missing NumPy bindings.

Notes
  • Keeps dtype consistency between PyTorch and NumPy via a lookup table.
  • Returns a view when possible, or a copy when conversion via list fallback is used.
Source code in opensr_srgan/utils/tensor_conversions.py
def tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray:
    """Convert a PyTorch tensor to a NumPy array safely across environments.

    Provides a robust fallback for minimal or CPU-only PyTorch builds that lack
    NumPy bindings (where ``tensor.numpy()`` raises ``RuntimeError: Numpy is not available``).
    Ensures tensors are detached, moved to CPU, and contiguous before conversion.

    Args:
        tensor (torch.Tensor): Input tensor to convert.

    Returns:
        numpy.ndarray: NumPy array representation of the tensor.
        Falls back to ``tensor.tolist()`` conversion if direct bindings are unavailable.

    Raises:
        RuntimeError: Re-raises conversion errors not related to missing NumPy bindings.

    Notes:
        - Keeps dtype consistency between PyTorch and NumPy via a lookup table.
        - Returns a view when possible, or a copy when conversion via list fallback is used.
    """

    tensor_cpu = tensor.detach().cpu()
    if not tensor_cpu.is_contiguous():
        tensor_cpu = tensor_cpu.contiguous()

    try:
        return tensor_cpu.numpy()
    except RuntimeError as exc:
        if "Numpy is not available" not in str(exc):
            raise
        dtype = _TORCH_TO_NUMPY_DTYPE.get(tensor_cpu.dtype)
        return np.asarray(tensor_cpu.tolist(), dtype=dtype)

Model summarisation

print_model_summary(self)

Prints a detailed and visually structured summary of the SRGAN configuration. Includes architecture info, resolution scale, training parameters, loss weights, and model sizes.

Source code in opensr_srgan/utils/model_descriptions.py
def print_model_summary(self):
    """
    Prints a detailed and visually structured summary of the SRGAN configuration.
    Includes architecture info, resolution scale, training parameters, loss weights, and model sizes.
    """

    # --- helpers (millions) ---
    def count_trainable_params(m):
        return sum(p.numel() for p in m.parameters() if p.requires_grad) / 1e6

    def count_total_params(m):
        return sum(p.numel() for p in m.parameters()) / 1e6

    # --- counts ---
    g_trainable = count_trainable_params(self.generator)
    g_total = count_total_params(self.generator)
    d_trainable = count_trainable_params(self.discriminator)
    d_total = count_total_params(self.discriminator)

    # expose the exact names you want to use elsewhere
    g_params = g_trainable  # alias for “trainable G params (M)”
    d_params = d_trainable  # alias for “trainable D params (M)”

    total_trainable = g_trainable + d_trainable
    total_all = g_total + d_total

    # ------------------------------------------------------------------
    # Derive human-readable generator description
    # ------------------------------------------------------------------
    g_type = getattr(self.config.Generator, "model_type", "SRResNet")
    g_type_norm = str(g_type).lower().replace("-", "_")
    block_type = getattr(self.config.Generator, "block_type", None)

    if g_type_norm in {"res", "rcab", "rrdb", "lka"} and block_type is None:
        block_type = g_type_norm
        g_type_norm = "srresnet"

    block_desc_map = {
        "standard": "Residual Blocks with BatchNorm",
        "res": "Residual Blocks without BatchNorm",
        "rcab": "RCAB Blocks with Channel Attention",
        "rrdb": "RRDB Dense Residual Blocks",
        "lka": "LKA Large-Kernel Attention Blocks",
    }

    if g_type_norm in {"srresnet", "sr_resnet"}:
        block_key = "standard" if block_type is None else str(block_type).lower()
        desc = block_desc_map.get(block_key, f"Custom Block Variant: {block_key}")
        g_desc = f"SRResNet ({desc})"
    elif g_type_norm in {"stochastic_gan", "cgan", "conditional_cgan"}:
        g_desc = "Stochastic SRGAN (Latent-modulated Residual Blocks)"
    elif g_type_norm == "esrgan":
        g_desc = "ESRGAN (RRDB Residual-in-Residual Dense Network)"
    else:
        g_desc = f"Custom Generator Type: {g_type}"

    # ------------------------------------------------------------------
    # Resolution info (input → output)
    # ------------------------------------------------------------------
    scale_factor = getattr(
        self.config.Generator, "scaling_factor", getattr(self.generator, "scale", None)
    )
    if scale_factor is not None:
        res_str = f"Super-Resolution Factor: ×{scale_factor}"
    else:
        res_str = "Super-Resolution Factor: Unknown"

    # ------------------------------------------------------------------
    # Retrieve loss weights if available
    # ------------------------------------------------------------------
    loss_cfg = getattr(self.config.Training, "Losses", {})
    content_w = getattr(loss_cfg, "content_loss_weight", 1.0)
    adv_w = getattr(loss_cfg, "adv_loss_beta", 1.0)
    perceptual_w = getattr(loss_cfg, "perceptual_loss_weight", None)
    total_w_str = f"   • Content: {content_w} | Adversarial: {adv_w}" + (
        f" | Perceptual: {perceptual_w}" if perceptual_w is not None else ""
    )

    print("\n" + "=" * 90)
    print("🚀  SRGAN Model Summary")
    print("=" * 90)

    # ------------------------------------------------------------------
    # Generator Info
    # ------------------------------------------------------------------
    print(f"🧩 Generator")
    print(f"   • Architecture:      {g_desc}")
    print(f"   • Resolution:        {res_str}")
    print(f"   • Input Channels:    {self.config.Model.in_bands}")
    feature_channels = getattr(self.config.Generator, "n_channels", None)
    if feature_channels is not None:
        print(f"   • Feature Channels:  {feature_channels}")

    block_count = getattr(
        self.config.Generator,
        "n_blocks",
        getattr(self.generator, "n_blocks", None),
    )
    if block_count is not None:
        print(f"   • Residual Blocks:   {block_count}")

    small_kernel = getattr(self.config.Generator, "small_kernel_size", None)
    large_kernel = getattr(self.config.Generator, "large_kernel_size", None)
    if small_kernel is not None or large_kernel is not None:
        print(
            "   • Kernel Sizes:      small={small}, large={large}".format(
                small=small_kernel if small_kernel is not None else "N/A",
                large=large_kernel if large_kernel is not None else "N/A",
            )
        )
    print(f"   • Params:            {g_params:.2f} M\n")

    # ------------------------------------------------------------------
    # Discriminator Info
    # ------------------------------------------------------------------
    d_type = getattr(self.config.Discriminator, "model_type", "standard")
    d_blocks = getattr(self.config.Discriminator, "n_blocks", None)
    effective_blocks = getattr(
        self.discriminator,
        "n_blocks",
        getattr(self.discriminator, "n_layers", d_blocks),
    )
    base_channels = getattr(self.discriminator, "base_channels", "N/A")
    kernel_size = getattr(self.discriminator, "kernel_size", "N/A")
    fc_size = getattr(self.discriminator, "fc_size", None)

    if d_type == "patchgan":
        d_desc = "PatchGAN"
    elif d_type == "esrgan":
        d_desc = "ESRGAN"
    else:
        d_desc = "SRGAN"

    print(f"🧠 Discriminator")
    print(f"   • Architecture:     {d_desc}")
    if effective_blocks is not None:
        print(f"   • Blocks/Layers:    {effective_blocks}")
    print(f"   • Base Channels:    {base_channels}")
    print(f"   • Kernel Size:      {kernel_size}")
    if fc_size is not None:
        print(f"   • FC Layer Size:    {fc_size}")
    print(f"   • Params:            {d_params:.2f} M\n")

    # ------------------------------------------------------------------
    # Training Setup
    # ------------------------------------------------------------------
    print(f"⚙️  Training Configuration")
    print(f"   • Pretrain Generator: {self.pretrain_g_only}")
    print(f"   • Pretrain Steps:     {self.g_pretrain_steps}")
    print(f"   • Adv. Ramp Steps:    {self.adv_loss_ramp_steps}")
    print(f"   • Label Smoothing:    {self.adv_target < 1.0}")
    print(f"   • Adv. Target Label:  {self.adv_target}\n")

    # ------------------------------------------------------------------
    # EMA Configuration (if present)
    # ------------------------------------------------------------------
    ema_cfg = getattr(getattr(self.config.Training, "EMA", None), "__dict__", None)
    if ema_cfg or hasattr(self.config.Training, "EMA"):
        ema = self.config.Training.EMA
        enabled = getattr(ema, "enabled", False)
        decay = getattr(ema, "decay", None)
        update_after = getattr(ema, "update_after_step", None)
        use_num_updates = getattr(ema, "use_num_updates", None)

        print(f"🔁 Exponential Moving Average (EMA)")
        print(f"   • Enabled:           {enabled}")
        if enabled:
            print(f"   • Decay:             {decay}")
            print(f"   • Update After Step: {update_after}")
            print(f"   • Use Num Updates:   {use_num_updates}")
        print()  # newline for spacing

    # ------------------------------------------------------------------
    # Loss Functions
    # ------------------------------------------------------------------
    print(f"📉 Loss Functions")
    print(f"   • Content Loss:       {type(self.content_loss_criterion).__name__}")
    print(f"   • Adversarial Loss:   {type(self.adversarial_loss_criterion).__name__}")
    print(total_w_str + "\n")

    # ------------------------------------------------------------------
    # Summary
    # ------------------------------------------------------------------
    print(f"📊 Model Summary")
    print(f"   • Total Params:           {total_all:.2f} M")
    print(f"   • Total Trainable Params: {total_trainable:.2f} M")
    print(
        f"   • Device:                 {self.device if hasattr(self, 'device') else 'Not set'}"
    )
    print("=" * 90 + "\n")

Distributed coordination

_is_global_zero()

Return True if this process is the global rank zero (primary) worker.

Used to gate singleton side effects (e.g., logging, file writes, directory creation) in both single- and multi-process environments. Supports plain CPU, single-GPU, torch.distributed (torchrun), and SLURM-style launches.

Detection order
  1. Use torch.distributed if available and initialized.
  2. Fall back to environment variables RANK and WORLD_SIZE.
  3. Default to True for non-distributed (single-process) runs.

Returns:

Name Type Description
bool bool

True if this process is rank zero or if distributed training

bool

is not active (single process). False otherwise.

Source code in opensr_srgan/utils/gpu_rank.py
def _is_global_zero() -> bool:
    """Return True if this process is the global rank zero (primary) worker.

    Used to gate singleton side effects (e.g., logging, file writes, directory
    creation) in both single- and multi-process environments. Supports plain CPU,
    single-GPU, torch.distributed (``torchrun``), and SLURM-style launches.

    Detection order:
        1. Use ``torch.distributed`` if available and initialized.
        2. Fall back to environment variables ``RANK`` and ``WORLD_SIZE``.
        3. Default to True for non-distributed (single-process) runs.

    Returns:
        bool: True if this process is rank zero or if distributed training
        is not active (single process). False otherwise.
    """
    # Prefer torch.distributed if available
    try:
        import torch.distributed as dist

        if dist.is_available() and dist.is_initialized():
            return dist.get_rank() == 0
    except Exception:
        pass

    # Fallback to env vars commonly set by torchrun/SLURM
    rank = int(os.environ.get("RANK", "0"))
    world_size = int(os.environ.get("WORLD_SIZE", "1"))
    return rank == 0 or world_size == 1