Utility Helpers¶
Reusable helpers for logging, radiometric preprocessing, tensor conversion, and distributed-safe side effects.
Package exports¶
Utility helpers for the OpenSR SRGAN project.
plot_tensors(lr, sr, hr, title='Train')
¶
Render LR–SR–HR triplets from a batch as a single PIL image.
Performs percentile-based min–max stretching and clamping to [0, 1] on each
tensor, converts channel-first images to numpy for plotting, and arranges up
to two samples (rows) with three columns (LR, SR, HR) into a matplotlib
figure. The figure is then rasterized and returned as a PIL.Image.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
lr
|
Tensor
|
Low-resolution batch tensor of shape |
required |
sr
|
Tensor
|
Super-resolved batch tensor of shape |
required |
hr
|
Tensor
|
High-resolution (target) batch tensor of shape |
required |
title
|
str
|
Figure title placed above the grid. Defaults to |
'Train'
|
Returns:
| Type | Description |
|---|---|
|
PIL.Image.Image: RGB image containing a grid with columns |
|
|
and up to two rows (first two items of the batch). |
Notes
- Only the first two items of the batch are visualized to avoid large figures.
- Supports grayscale, RGB, RGBA, and generic multi-channel inputs (first 3 channels shown).
- This function is side-effect free for tensors (uses
.detach()and plots copies), and closes the matplotlib figure after rendering.
Source code in opensr_srgan/utils/logging_helpers.py
def plot_tensors(lr, sr, hr, title="Train"):
"""Render LR–SR–HR triplets from a batch as a single PIL image.
Performs percentile-based min–max stretching and clamping to [0, 1] on each
tensor, converts channel-first images to numpy for plotting, and arranges up
to two samples (rows) with three columns (LR, SR, HR) into a matplotlib
figure. The figure is then rasterized and returned as a `PIL.Image`.
Args:
lr (torch.Tensor): Low-resolution batch tensor of shape `(B, C, H, W)`,
values expected in an arbitrary range (stretched internally).
sr (torch.Tensor): Super-resolved batch tensor of shape `(B, C, H, W)`.
hr (torch.Tensor): High-resolution (target) batch tensor of shape `(B, C, H, W)`.
title (str, optional): Figure title placed above the grid. Defaults to `"Train"`.
Returns:
PIL.Image.Image: RGB image containing a grid with columns `[LR | SR | HR]`
and up to two rows (first two items of the batch).
Notes:
- Only the first two items of the batch are visualized to avoid large figures.
- Supports grayscale, RGB, RGBA, and generic multi-channel inputs
(first 3 channels shown).
- This function is side-effect free for tensors (uses `.detach()` and
plots copies), and closes the matplotlib figure after rendering.
"""
# --- denorm(?) + stretch ---
lr = minmax_percentile(lr)
sr = minmax_percentile(sr)
hr = minmax_percentile(hr)
# clamp in-place-friendly
lr, sr, hr = lr.clamp(0, 1), sr.clamp(0, 1), hr.clamp(0, 1)
# shapes
B, C, H, W = lr.shape # (B,C,H,W)
# Determine colormap for grayscale images
if C == 1:
cmap = "bone"
else:
cmap = None
# limit to max_n (images to plot ontop of each other)
max_n = 2
if B > max_n:
lr = lr[:max_n]
sr = sr[:max_n]
hr = hr[:max_n]
B = max_n
# figure/axes: always 2D array even for B==1
fixed_width = 15
variable_height = (15 / 3) * B
fig, axes = plt.subplots(
B, 3, figsize=(fixed_width, variable_height), squeeze=False
)
# loop over batch
with torch.no_grad():
for i in range(B):
img_lr = _to_numpy_img(lr[i].detach().cpu())
img_sr = _to_numpy_img(sr[i].detach().cpu())
img_hr = _to_numpy_img(hr[i].detach().cpu())
axes[i, 0].imshow(img_lr)
axes[i, 0].set_title("LR")
axes[i, 0].axis("off")
axes[i, 1].imshow(img_sr)
axes[i, 1].set_title("SR")
axes[i, 1].axis("off")
axes[i, 2].imshow(img_hr)
axes[i, 2].set_title("HR")
axes[i, 2].axis("off")
fig.suptitle(title)
fig.tight_layout()
# render to PIL
buf = io.BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight")
buf.seek(0)
pil_image = Image.open(buf).convert("RGB").copy()
buf.close()
plt.close(fig)
return pil_image
print_model_summary(self)
¶
Prints a detailed and visually structured summary of the SRGAN configuration. Includes architecture info, resolution scale, training parameters, loss weights, and model sizes.
Source code in opensr_srgan/utils/model_descriptions.py
def print_model_summary(self):
"""
Prints a detailed and visually structured summary of the SRGAN configuration.
Includes architecture info, resolution scale, training parameters, loss weights, and model sizes.
"""
# --- helpers (millions) ---
def count_trainable_params(m):
return sum(p.numel() for p in m.parameters() if p.requires_grad) / 1e6
def count_total_params(m):
return sum(p.numel() for p in m.parameters()) / 1e6
# --- counts ---
g_trainable = count_trainable_params(self.generator)
g_total = count_total_params(self.generator)
d_trainable = count_trainable_params(self.discriminator)
d_total = count_total_params(self.discriminator)
# expose the exact names you want to use elsewhere
g_params = g_trainable # alias for “trainable G params (M)”
d_params = d_trainable # alias for “trainable D params (M)”
total_trainable = g_trainable + d_trainable
total_all = g_total + d_total
# ------------------------------------------------------------------
# Derive human-readable generator description
# ------------------------------------------------------------------
g_type = getattr(self.config.Generator, "model_type", "SRResNet")
g_type_norm = str(g_type).lower().replace("-", "_")
block_type = getattr(self.config.Generator, "block_type", None)
if g_type_norm in {"res", "rcab", "rrdb", "lka"} and block_type is None:
block_type = g_type_norm
g_type_norm = "srresnet"
block_desc_map = {
"standard": "Residual Blocks with BatchNorm",
"res": "Residual Blocks without BatchNorm",
"rcab": "RCAB Blocks with Channel Attention",
"rrdb": "RRDB Dense Residual Blocks",
"lka": "LKA Large-Kernel Attention Blocks",
}
if g_type_norm in {"srresnet", "sr_resnet"}:
block_key = "standard" if block_type is None else str(block_type).lower()
desc = block_desc_map.get(block_key, f"Custom Block Variant: {block_key}")
g_desc = f"SRResNet ({desc})"
elif g_type_norm in {"stochastic_gan", "cgan", "conditional_cgan"}:
g_desc = "Stochastic SRGAN (Latent-modulated Residual Blocks)"
elif g_type_norm == "esrgan":
g_desc = "ESRGAN (RRDB Residual-in-Residual Dense Network)"
else:
g_desc = f"Custom Generator Type: {g_type}"
# ------------------------------------------------------------------
# Resolution info (input → output)
# ------------------------------------------------------------------
scale_factor = getattr(
self.config.Generator, "scaling_factor", getattr(self.generator, "scale", None)
)
if scale_factor is not None:
res_str = f"Super-Resolution Factor: ×{scale_factor}"
else:
res_str = "Super-Resolution Factor: Unknown"
# ------------------------------------------------------------------
# Retrieve loss weights if available
# ------------------------------------------------------------------
loss_cfg = getattr(self.config.Training, "Losses", {})
content_w = getattr(loss_cfg, "content_loss_weight", 1.0)
adv_w = getattr(loss_cfg, "adv_loss_beta", 1.0)
perceptual_w = getattr(loss_cfg, "perceptual_loss_weight", None)
total_w_str = f" • Content: {content_w} | Adversarial: {adv_w}" + (
f" | Perceptual: {perceptual_w}" if perceptual_w is not None else ""
)
print("\n" + "=" * 90)
print("🚀 SRGAN Model Summary")
print("=" * 90)
# ------------------------------------------------------------------
# Generator Info
# ------------------------------------------------------------------
print(f"🧩 Generator")
print(f" • Architecture: {g_desc}")
print(f" • Resolution: {res_str}")
print(f" • Input Channels: {self.config.Model.in_bands}")
feature_channels = getattr(self.config.Generator, "n_channels", None)
if feature_channels is not None:
print(f" • Feature Channels: {feature_channels}")
block_count = getattr(
self.config.Generator,
"n_blocks",
getattr(self.generator, "n_blocks", None),
)
if block_count is not None:
print(f" • Residual Blocks: {block_count}")
small_kernel = getattr(self.config.Generator, "small_kernel_size", None)
large_kernel = getattr(self.config.Generator, "large_kernel_size", None)
if small_kernel is not None or large_kernel is not None:
print(
" • Kernel Sizes: small={small}, large={large}".format(
small=small_kernel if small_kernel is not None else "N/A",
large=large_kernel if large_kernel is not None else "N/A",
)
)
print(f" • Params: {g_params:.2f} M\n")
# ------------------------------------------------------------------
# Discriminator Info
# ------------------------------------------------------------------
d_type = getattr(self.config.Discriminator, "model_type", "standard")
d_blocks = getattr(self.config.Discriminator, "n_blocks", None)
effective_blocks = getattr(
self.discriminator,
"n_blocks",
getattr(self.discriminator, "n_layers", d_blocks),
)
base_channels = getattr(self.discriminator, "base_channels", "N/A")
kernel_size = getattr(self.discriminator, "kernel_size", "N/A")
fc_size = getattr(self.discriminator, "fc_size", None)
if d_type == "patchgan":
d_desc = "PatchGAN"
elif d_type == "esrgan":
d_desc = "ESRGAN"
else:
d_desc = "SRGAN"
print(f"🧠 Discriminator")
print(f" • Architecture: {d_desc}")
if effective_blocks is not None:
print(f" • Blocks/Layers: {effective_blocks}")
print(f" • Base Channels: {base_channels}")
print(f" • Kernel Size: {kernel_size}")
if fc_size is not None:
print(f" • FC Layer Size: {fc_size}")
print(f" • Params: {d_params:.2f} M\n")
# ------------------------------------------------------------------
# Training Setup
# ------------------------------------------------------------------
print(f"⚙️ Training Configuration")
print(f" • Pretrain Generator: {self.pretrain_g_only}")
print(f" • Pretrain Steps: {self.g_pretrain_steps}")
print(f" • Adv. Ramp Steps: {self.adv_loss_ramp_steps}")
print(f" • Label Smoothing: {self.adv_target < 1.0}")
print(f" • Adv. Target Label: {self.adv_target}\n")
# ------------------------------------------------------------------
# EMA Configuration (if present)
# ------------------------------------------------------------------
ema_cfg = getattr(getattr(self.config.Training, "EMA", None), "__dict__", None)
if ema_cfg or hasattr(self.config.Training, "EMA"):
ema = self.config.Training.EMA
enabled = getattr(ema, "enabled", False)
decay = getattr(ema, "decay", None)
update_after = getattr(ema, "update_after_step", None)
use_num_updates = getattr(ema, "use_num_updates", None)
print(f"🔁 Exponential Moving Average (EMA)")
print(f" • Enabled: {enabled}")
if enabled:
print(f" • Decay: {decay}")
print(f" • Update After Step: {update_after}")
print(f" • Use Num Updates: {use_num_updates}")
print() # newline for spacing
# ------------------------------------------------------------------
# Loss Functions
# ------------------------------------------------------------------
print(f"📉 Loss Functions")
print(f" • Content Loss: {type(self.content_loss_criterion).__name__}")
print(f" • Adversarial Loss: {type(self.adversarial_loss_criterion).__name__}")
print(total_w_str + "\n")
# ------------------------------------------------------------------
# Summary
# ------------------------------------------------------------------
print(f"📊 Model Summary")
print(f" • Total Params: {total_all:.2f} M")
print(f" • Total Trainable Params: {total_trainable:.2f} M")
print(
f" • Device: {self.device if hasattr(self, 'device') else 'Not set'}"
)
print("=" * 90 + "\n")
histogram(reference, target)
¶
Perform per-channel histogram matching of target → reference.
Each channel in the target image is adjusted so that its cumulative distribution function (CDF) matches that of the corresponding channel in the reference image. This preserves overall color/radiometric tone relationships, but aligns the pixel intensity distributions more precisely than simple moment matching.
Supports both single images and batched tensors.
Parameters¶
reference : torch.Tensor
Reference image or batch, shape (C, H, W) or (B, C, H, W).
Its histogram will be used as the target distribution.
target : torch.Tensor
Target image or batch to be adjusted, shape (C, H, W) or (B, C, H, W).
Must have the same number of channels as reference.
Returns¶
torch.Tensor Histogram-matched version of the target, with the same shape and dtype as the input. If a single image is given, returns shape (C, H, W).
Source code in opensr_srgan/utils/radiometrics.py
def histogram(reference: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Perform **per-channel histogram matching** of `target` → `reference`.
Each channel in the target image is adjusted so that its cumulative
distribution function (CDF) matches that of the corresponding channel
in the reference image.
This preserves overall color/radiometric tone relationships, but aligns
the pixel intensity distributions more precisely than simple moment matching.
Supports both single images and batched tensors.
Parameters
----------
reference : torch.Tensor
Reference image or batch, shape (C, H, W) or (B, C, H, W).
Its histogram will be used as the target distribution.
target : torch.Tensor
Target image or batch to be adjusted, shape (C, H, W) or (B, C, H, W).
Must have the same number of channels as `reference`.
Returns
-------
torch.Tensor
Histogram-matched version of the target, with the same shape and dtype
as the input. If a single image is given, returns shape (C, H, W).
"""
# Ensure both inputs have correct dimensionality: either (C,H,W) or (B,C,H,W)
assert target.ndim in (3, 4) and reference.ndim in (
3,
4,
), "Expected (C,H,W) or (B,C,H,W) for both reference and target"
# Save device/dtype for conversion back later
device, dtype = target.device, target.dtype
# --- Normalize to batch form (always 4D: B,C,H,W) ---
# If inputs are 3D (single images), temporarily add batch dimension
ref = reference.unsqueeze(0) if reference.ndim == 3 else reference
tgt = target.unsqueeze(0) if target.ndim == 3 else target
# Extract shapes
B_ref, C_ref, H_ref, W_ref = ref.shape
B_tgt, C_tgt, H_tgt, W_tgt = tgt.shape
# Channel sanity check
assert C_ref == C_tgt, f"Channel mismatch: reference={C_ref}, target={C_tgt}"
# --- Resize reference spatially to match target ---
# Uses bilinear interpolation, no corner alignment, safe for float data
if (H_ref, W_ref) != (H_tgt, W_tgt):
ref = F.interpolate(
ref.to(dtype=torch.float32),
size=(H_tgt, W_tgt),
mode="bilinear",
align_corners=False,
)
# Convert to NumPy for histogram matching operations
ref_np = tensor_to_numpy(ref)
tgt_np = tensor_to_numpy(tgt)
out_np = np.empty_like(tgt_np) # preallocate output array
# --- Loop over batches and channels ---
for b in range(B_tgt):
# If reference has only one batch (B_ref=1), broadcast it to all targets
rb = b % B_ref
for c in range(C_tgt):
ref_ch = ref_np[rb, c] # reference channel
tgt_ch = tgt_np[b, c] # target channel
# Mask invalid pixels (NaN or Inf)
mask = np.isfinite(tgt_ch) & np.isfinite(ref_ch)
if mask.any():
# Perform per-channel histogram matching
matched = exposure.match_histograms(tgt_ch[mask], ref_ch[mask])
out = tgt_ch.copy()
out[mask] = matched
out_np[b, c] = out
else:
# If no valid pixels, copy target as-is
out_np[b, c] = tgt_ch
# Convert back to torch tensor on the original device/dtype
out = torch.from_numpy(out_np).to(device=device, dtype=dtype)
# If the original input was 3D, remove the temporary batch dimension
return out[0] if target.ndim == 3 else out
normalise_10k(im, stage='norm')
¶
Normalize or denormalize Sentinel-2 data scaled in units of 10,000.
This is the most common scaling for Sentinel-2 L2A reflectance data, where reflectance = DN / 10000.
Parameters¶
im : torch.Tensor Input tensor (any shape), expected to contain DN values ~[0, 10000]. stage : {"norm", "denorm"} - "norm" → divide by 10,000 to map to [0, 1] - "denorm" → multiply by 10,000 to restore original scale
Returns¶
torch.Tensor Scaled tensor.
Source code in opensr_srgan/utils/radiometrics.py
def normalise_10k(im: torch.Tensor, stage: str = "norm") -> torch.Tensor:
"""
Normalize or denormalize Sentinel-2 data scaled in units of 10,000.
This is the most common scaling for Sentinel-2 L2A reflectance data,
where reflectance = DN / 10000.
Parameters
----------
im : torch.Tensor
Input tensor (any shape), expected to contain DN values ~[0, 10000].
stage : {"norm", "denorm"}
- "norm" → divide by 10,000 to map to [0, 1]
- "denorm" → multiply by 10,000 to restore original scale
Returns
-------
torch.Tensor
Scaled tensor.
"""
assert stage in ["norm", "denorm"]
if stage == "norm":
im = im / 10000.0
im = torch.clamp(im, 0, 1)
else: # "denorm"
im = im * 10000.0
im = torch.clamp(im, 0, 10000)
return im
normalise_10k_signed(im, stage='norm')
¶
Normalize Sentinel-2 DN values to the symmetric [-1, 1] range.
This helper chains :func:normalise_10k ([0, 10000] → [0, 1]) with
:func:zero_one_signed ([0, 1] → [-1, 1]) so models that expect
signed inputs receive the conventional [-1, 1] distribution.
Parameters¶
im : torch.Tensor
Input tensor containing reflectance-like values expressed as
[0, 10000].
stage : {"norm", "denorm"}
Normalisation stage. "denorm" restores the original [0, 10000]
range.
Returns¶
torch.Tensor Tensor in the requested range.
Source code in opensr_srgan/utils/radiometrics.py
def normalise_10k_signed(im: torch.Tensor, stage: str = "norm") -> torch.Tensor:
"""Normalize Sentinel-2 DN values to the symmetric ``[-1, 1]`` range.
This helper chains :func:`normalise_10k` (``[0, 10000]`` → ``[0, 1]``) with
:func:`zero_one_signed` (``[0, 1]`` → ``[-1, 1]``) so models that expect
signed inputs receive the conventional ``[-1, 1]`` distribution.
Parameters
----------
im : torch.Tensor
Input tensor containing reflectance-like values expressed as
``[0, 10000]``.
stage : {"norm", "denorm"}
Normalisation stage. ``"denorm"`` restores the original ``[0, 10000]``
range.
Returns
-------
torch.Tensor
Tensor in the requested range.
"""
assert stage in ["norm", "denorm"]
if stage == "norm":
scaled = normalise_10k(im, stage="norm")
return zero_one_signed(scaled, stage="norm")
scaled = zero_one_signed(im, stage="denorm")
return normalise_10k(scaled, stage="denorm")
normalise_s2(im, stage='norm')
¶
Normalize or denormalize Sentinel-2 image values.
This function applies a symmetric scaling to map reflectance-like values to the range [-1, 1] for model input, and reverses it for visualization or saving.
Parameters¶
im : torch.Tensor Input image tensor (any shape), typically reflectance-scaled. stage : {"norm", "denorm"} - "norm" → normalize image to [-1, 1] - "denorm" → reverse normalization back to [0, 1]
Returns¶
torch.Tensor The normalized or denormalized image tensor.
Source code in opensr_srgan/utils/radiometrics.py
def normalise_s2(im: torch.Tensor, stage: str = "norm") -> torch.Tensor:
"""
Normalize or denormalize Sentinel-2 image values.
This function applies a symmetric scaling to map reflectance-like values
to the range [-1, 1] for model input, and reverses it for visualization
or saving.
Parameters
----------
im : torch.Tensor
Input image tensor (any shape), typically reflectance-scaled.
stage : {"norm", "denorm"}
- "norm" → normalize image to [-1, 1]
- "denorm" → reverse normalization back to [0, 1]
Returns
-------
torch.Tensor
The normalized or denormalized image tensor.
"""
assert stage in ["norm", "denorm"]
value = 3.0 # reference scaling factor
if stage == "norm":
# Scale roughly from [0, value/10] → [0, 1] → [-1, 1]
im = im * (10.0 / value)
im = (im * 2) - 1
im = torch.clamp(im, -1, 1)
else: # stage == "denorm"
# Reverse mapping: [-1, 1] → [0, 1] → [0, value/10]
im = (im + 1) / 2
im = im * (value / 10.0)
im = torch.clamp(im, 0, 1)
return im
zero_one_signed(im, stage='norm')
¶
Convert values between the [0, 1] and [-1, 1] ranges.
Parameters¶
im : torch.Tensor
Tensor containing values already scaled to [0, 1].
stage : {"norm", "denorm"}
- "norm" → map [0, 1] → [-1, 1] using im * 2 - 1.
- "denorm" → map [-1, 1] → [0, 1] using (im + 1) / 2.
Returns¶
torch.Tensor Range-adjusted tensor.
Source code in opensr_srgan/utils/radiometrics.py
def zero_one_signed(im: torch.Tensor, stage: str = "norm") -> torch.Tensor:
"""Convert values between the ``[0, 1]`` and ``[-1, 1]`` ranges.
Parameters
----------
im : torch.Tensor
Tensor containing values already scaled to ``[0, 1]``.
stage : {"norm", "denorm"}
- ``"norm"`` → map ``[0, 1]`` → ``[-1, 1]`` using ``im * 2 - 1``.
- ``"denorm"`` → map ``[-1, 1]`` → ``[0, 1]`` using ``(im + 1) / 2``.
Returns
-------
torch.Tensor
Range-adjusted tensor.
"""
assert stage in ["norm", "denorm"]
if stage == "norm":
return torch.clamp((im * 2.0) - 1.0, -1.0, 1.0)
return torch.clamp((im + 1.0) / 2.0, 0.0, 1.0)
Logging helpers¶
_tensor_to_plot_data(t)
¶
Convert a CPU tensor to a NumPy array suitable for matplotlib visualization.
Ensures contiguous memory layout before conversion and detaches the tensor from the computation graph. Typically used for final image preparation in validation or inference visualizations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
t
|
Tensor
|
Input tensor to convert. Can have arbitrary shape. |
required |
Returns:
| Type | Description |
|---|---|
|
numpy.ndarray: NumPy array representation of the tensor, compatible |
|
|
with matplotlib plotting functions. |
Source code in opensr_srgan/utils/logging_helpers.py
def _tensor_to_plot_data(t: torch.Tensor):
"""Convert a CPU tensor to a NumPy array suitable for matplotlib visualization.
Ensures contiguous memory layout before conversion and detaches the tensor
from the computation graph. Typically used for final image preparation
in validation or inference visualizations.
Args:
t (torch.Tensor): Input tensor to convert. Can have arbitrary shape.
Returns:
numpy.ndarray: NumPy array representation of the tensor, compatible
with matplotlib plotting functions.
"""
return tensor_to_numpy(t.contiguous())
_to_numpy_img(t)
¶
Convert a normalized image tensor in (C, H, W) format to a NumPy array.
Supports grayscale, RGB, and RGBA images, as well as arbitrary multichannel tensors (which are converted by permuting to H×W×C).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
t
|
Tensor
|
Input image tensor in channel-first (C, H, W) format, with pixel values expected in the range [0, 1]. |
required |
Returns:
| Type | Description |
|---|---|
|
numpy.ndarray: NumPy array in channel-last format (H, W, C) suitable |
|
|
for matplotlib visualization. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the input tensor does not have exactly three dimensions. |
Source code in opensr_srgan/utils/logging_helpers.py
def _to_numpy_img(t: torch.Tensor):
"""Convert a normalized image tensor in (C, H, W) format to a NumPy array.
Supports grayscale, RGB, and RGBA images, as well as arbitrary
multichannel tensors (which are converted by permuting to H×W×C).
Args:
t (torch.Tensor): Input image tensor in channel-first (C, H, W) format,
with pixel values expected in the range [0, 1].
Returns:
numpy.ndarray: NumPy array in channel-last format (H, W, C) suitable
for matplotlib visualization.
Raises:
ValueError: If the input tensor does not have exactly three dimensions.
"""
if t.dim() != 3:
raise ValueError(f"Expected (C,H,W), got {tuple(t.shape)}")
C, H, W = t.shape
t = t.detach().clamp(0, 1)
if C == 1:
out = _tensor_to_plot_data(t[0])
return out # grayscale
if C in (3, 4):
rgb = t[:3]
out = _tensor_to_plot_data(rgb.permute(1, 2, 0))
return out
return _tensor_to_plot_data(t.permute(1, 2, 0))
plot_tensors(lr, sr, hr, title='Train')
¶
Render LR–SR–HR triplets from a batch as a single PIL image.
Performs percentile-based min–max stretching and clamping to [0, 1] on each
tensor, converts channel-first images to numpy for plotting, and arranges up
to two samples (rows) with three columns (LR, SR, HR) into a matplotlib
figure. The figure is then rasterized and returned as a PIL.Image.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
lr
|
Tensor
|
Low-resolution batch tensor of shape |
required |
sr
|
Tensor
|
Super-resolved batch tensor of shape |
required |
hr
|
Tensor
|
High-resolution (target) batch tensor of shape |
required |
title
|
str
|
Figure title placed above the grid. Defaults to |
'Train'
|
Returns:
| Type | Description |
|---|---|
|
PIL.Image.Image: RGB image containing a grid with columns |
|
|
and up to two rows (first two items of the batch). |
Notes
- Only the first two items of the batch are visualized to avoid large figures.
- Supports grayscale, RGB, RGBA, and generic multi-channel inputs (first 3 channels shown).
- This function is side-effect free for tensors (uses
.detach()and plots copies), and closes the matplotlib figure after rendering.
Source code in opensr_srgan/utils/logging_helpers.py
def plot_tensors(lr, sr, hr, title="Train"):
"""Render LR–SR–HR triplets from a batch as a single PIL image.
Performs percentile-based min–max stretching and clamping to [0, 1] on each
tensor, converts channel-first images to numpy for plotting, and arranges up
to two samples (rows) with three columns (LR, SR, HR) into a matplotlib
figure. The figure is then rasterized and returned as a `PIL.Image`.
Args:
lr (torch.Tensor): Low-resolution batch tensor of shape `(B, C, H, W)`,
values expected in an arbitrary range (stretched internally).
sr (torch.Tensor): Super-resolved batch tensor of shape `(B, C, H, W)`.
hr (torch.Tensor): High-resolution (target) batch tensor of shape `(B, C, H, W)`.
title (str, optional): Figure title placed above the grid. Defaults to `"Train"`.
Returns:
PIL.Image.Image: RGB image containing a grid with columns `[LR | SR | HR]`
and up to two rows (first two items of the batch).
Notes:
- Only the first two items of the batch are visualized to avoid large figures.
- Supports grayscale, RGB, RGBA, and generic multi-channel inputs
(first 3 channels shown).
- This function is side-effect free for tensors (uses `.detach()` and
plots copies), and closes the matplotlib figure after rendering.
"""
# --- denorm(?) + stretch ---
lr = minmax_percentile(lr)
sr = minmax_percentile(sr)
hr = minmax_percentile(hr)
# clamp in-place-friendly
lr, sr, hr = lr.clamp(0, 1), sr.clamp(0, 1), hr.clamp(0, 1)
# shapes
B, C, H, W = lr.shape # (B,C,H,W)
# Determine colormap for grayscale images
if C == 1:
cmap = "bone"
else:
cmap = None
# limit to max_n (images to plot ontop of each other)
max_n = 2
if B > max_n:
lr = lr[:max_n]
sr = sr[:max_n]
hr = hr[:max_n]
B = max_n
# figure/axes: always 2D array even for B==1
fixed_width = 15
variable_height = (15 / 3) * B
fig, axes = plt.subplots(
B, 3, figsize=(fixed_width, variable_height), squeeze=False
)
# loop over batch
with torch.no_grad():
for i in range(B):
img_lr = _to_numpy_img(lr[i].detach().cpu())
img_sr = _to_numpy_img(sr[i].detach().cpu())
img_hr = _to_numpy_img(hr[i].detach().cpu())
axes[i, 0].imshow(img_lr)
axes[i, 0].set_title("LR")
axes[i, 0].axis("off")
axes[i, 1].imshow(img_sr)
axes[i, 1].set_title("SR")
axes[i, 1].axis("off")
axes[i, 2].imshow(img_hr)
axes[i, 2].set_title("HR")
axes[i, 2].axis("off")
fig.suptitle(title)
fig.tight_layout()
# render to PIL
buf = io.BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight")
buf.seek(0)
pil_image = Image.open(buf).convert("RGB").copy()
buf.close()
plt.close(fig)
return pil_image
Radiometric transforms¶
normalise_s2(im, stage='norm')
¶
Normalize or denormalize Sentinel-2 image values.
This function applies a symmetric scaling to map reflectance-like values to the range [-1, 1] for model input, and reverses it for visualization or saving.
Parameters¶
im : torch.Tensor Input image tensor (any shape), typically reflectance-scaled. stage : {"norm", "denorm"} - "norm" → normalize image to [-1, 1] - "denorm" → reverse normalization back to [0, 1]
Returns¶
torch.Tensor The normalized or denormalized image tensor.
Source code in opensr_srgan/utils/radiometrics.py
def normalise_s2(im: torch.Tensor, stage: str = "norm") -> torch.Tensor:
"""
Normalize or denormalize Sentinel-2 image values.
This function applies a symmetric scaling to map reflectance-like values
to the range [-1, 1] for model input, and reverses it for visualization
or saving.
Parameters
----------
im : torch.Tensor
Input image tensor (any shape), typically reflectance-scaled.
stage : {"norm", "denorm"}
- "norm" → normalize image to [-1, 1]
- "denorm" → reverse normalization back to [0, 1]
Returns
-------
torch.Tensor
The normalized or denormalized image tensor.
"""
assert stage in ["norm", "denorm"]
value = 3.0 # reference scaling factor
if stage == "norm":
# Scale roughly from [0, value/10] → [0, 1] → [-1, 1]
im = im * (10.0 / value)
im = (im * 2) - 1
im = torch.clamp(im, -1, 1)
else: # stage == "denorm"
# Reverse mapping: [-1, 1] → [0, 1] → [0, value/10]
im = (im + 1) / 2
im = im * (value / 10.0)
im = torch.clamp(im, 0, 1)
return im
normalise_10k(im, stage='norm')
¶
Normalize or denormalize Sentinel-2 data scaled in units of 10,000.
This is the most common scaling for Sentinel-2 L2A reflectance data, where reflectance = DN / 10000.
Parameters¶
im : torch.Tensor Input tensor (any shape), expected to contain DN values ~[0, 10000]. stage : {"norm", "denorm"} - "norm" → divide by 10,000 to map to [0, 1] - "denorm" → multiply by 10,000 to restore original scale
Returns¶
torch.Tensor Scaled tensor.
Source code in opensr_srgan/utils/radiometrics.py
def normalise_10k(im: torch.Tensor, stage: str = "norm") -> torch.Tensor:
"""
Normalize or denormalize Sentinel-2 data scaled in units of 10,000.
This is the most common scaling for Sentinel-2 L2A reflectance data,
where reflectance = DN / 10000.
Parameters
----------
im : torch.Tensor
Input tensor (any shape), expected to contain DN values ~[0, 10000].
stage : {"norm", "denorm"}
- "norm" → divide by 10,000 to map to [0, 1]
- "denorm" → multiply by 10,000 to restore original scale
Returns
-------
torch.Tensor
Scaled tensor.
"""
assert stage in ["norm", "denorm"]
if stage == "norm":
im = im / 10000.0
im = torch.clamp(im, 0, 1)
else: # "denorm"
im = im * 10000.0
im = torch.clamp(im, 0, 10000)
return im
zero_one_signed(im, stage='norm')
¶
Convert values between the [0, 1] and [-1, 1] ranges.
Parameters¶
im : torch.Tensor
Tensor containing values already scaled to [0, 1].
stage : {"norm", "denorm"}
- "norm" → map [0, 1] → [-1, 1] using im * 2 - 1.
- "denorm" → map [-1, 1] → [0, 1] using (im + 1) / 2.
Returns¶
torch.Tensor Range-adjusted tensor.
Source code in opensr_srgan/utils/radiometrics.py
def zero_one_signed(im: torch.Tensor, stage: str = "norm") -> torch.Tensor:
"""Convert values between the ``[0, 1]`` and ``[-1, 1]`` ranges.
Parameters
----------
im : torch.Tensor
Tensor containing values already scaled to ``[0, 1]``.
stage : {"norm", "denorm"}
- ``"norm"`` → map ``[0, 1]`` → ``[-1, 1]`` using ``im * 2 - 1``.
- ``"denorm"`` → map ``[-1, 1]`` → ``[0, 1]`` using ``(im + 1) / 2``.
Returns
-------
torch.Tensor
Range-adjusted tensor.
"""
assert stage in ["norm", "denorm"]
if stage == "norm":
return torch.clamp((im * 2.0) - 1.0, -1.0, 1.0)
return torch.clamp((im + 1.0) / 2.0, 0.0, 1.0)
normalise_10k_signed(im, stage='norm')
¶
Normalize Sentinel-2 DN values to the symmetric [-1, 1] range.
This helper chains :func:normalise_10k ([0, 10000] → [0, 1]) with
:func:zero_one_signed ([0, 1] → [-1, 1]) so models that expect
signed inputs receive the conventional [-1, 1] distribution.
Parameters¶
im : torch.Tensor
Input tensor containing reflectance-like values expressed as
[0, 10000].
stage : {"norm", "denorm"}
Normalisation stage. "denorm" restores the original [0, 10000]
range.
Returns¶
torch.Tensor Tensor in the requested range.
Source code in opensr_srgan/utils/radiometrics.py
def normalise_10k_signed(im: torch.Tensor, stage: str = "norm") -> torch.Tensor:
"""Normalize Sentinel-2 DN values to the symmetric ``[-1, 1]`` range.
This helper chains :func:`normalise_10k` (``[0, 10000]`` → ``[0, 1]``) with
:func:`zero_one_signed` (``[0, 1]`` → ``[-1, 1]``) so models that expect
signed inputs receive the conventional ``[-1, 1]`` distribution.
Parameters
----------
im : torch.Tensor
Input tensor containing reflectance-like values expressed as
``[0, 10000]``.
stage : {"norm", "denorm"}
Normalisation stage. ``"denorm"`` restores the original ``[0, 10000]``
range.
Returns
-------
torch.Tensor
Tensor in the requested range.
"""
assert stage in ["norm", "denorm"]
if stage == "norm":
scaled = normalise_10k(im, stage="norm")
return zero_one_signed(scaled, stage="norm")
scaled = zero_one_signed(im, stage="denorm")
return normalise_10k(scaled, stage="denorm")
sen2_stretch(im)
¶
Apply a simple contrast stretch to Sentinel-2 data.
Multiplies reflectance values by (10/3) ≈ 3.33 to increase dynamic range for visualization or augmentation purposes.
Parameters¶
im : torch.Tensor Sentinel-2 tensor (any shape).
Returns¶
torch.Tensor Contrast-stretched image tensor.
Source code in opensr_srgan/utils/radiometrics.py
def sen2_stretch(im: torch.Tensor) -> torch.Tensor:
"""
Apply a simple contrast stretch to Sentinel-2 data.
Multiplies reflectance values by (10/3) ≈ 3.33 to increase dynamic range
for visualization or augmentation purposes.
Parameters
----------
im : torch.Tensor
Sentinel-2 tensor (any shape).
Returns
-------
torch.Tensor
Contrast-stretched image tensor.
"""
stretched = im * (10 / 3.0)
return torch.clamp(stretched, 0.0, 1.0)
minmax_percentile(tensor, pmin=2, pmax=98)
¶
Perform percentile-based min-max normalization to [0, 1].
Uses quantiles instead of absolute min/max to reduce outlier influence.
Parameters¶
tensor : torch.Tensor Input tensor (any shape). pmin : float Lower percentile (default 2%). pmax : float Upper percentile (default 98%).
Returns¶
torch.Tensor Tensor scaled to [0, 1] based on percentile range.
Source code in opensr_srgan/utils/radiometrics.py
def minmax_percentile(
tensor: torch.Tensor, pmin: float = 2, pmax: float = 98
) -> torch.Tensor:
"""
Perform percentile-based min-max normalization to [0, 1].
Uses quantiles instead of absolute min/max to reduce outlier influence.
Parameters
----------
tensor : torch.Tensor
Input tensor (any shape).
pmin : float
Lower percentile (default 2%).
pmax : float
Upper percentile (default 98%).
Returns
-------
torch.Tensor
Tensor scaled to [0, 1] based on percentile range.
"""
min_val = torch.quantile(tensor, pmin / 100.0)
max_val = torch.quantile(tensor, pmax / 100.0)
tensor = (tensor - min_val) / (max_val - min_val)
return tensor
minmax(img)
¶
Standard min-max normalization to [0, 1] over the entire tensor.
Parameters¶
img : torch.Tensor Input tensor (any shape).
Returns¶
torch.Tensor Min-max normalized tensor.
Source code in opensr_srgan/utils/radiometrics.py
def minmax(img: torch.Tensor) -> torch.Tensor:
"""
Standard min-max normalization to [0, 1] over the entire tensor.
Parameters
----------
img : torch.Tensor
Input tensor (any shape).
Returns
-------
torch.Tensor
Min-max normalized tensor.
"""
min_val = torch.min(img)
max_val = torch.max(img)
normalized_img = (img - min_val) / (max_val - min_val)
return normalized_img
histogram(reference, target)
¶
Perform per-channel histogram matching of target → reference.
Each channel in the target image is adjusted so that its cumulative distribution function (CDF) matches that of the corresponding channel in the reference image. This preserves overall color/radiometric tone relationships, but aligns the pixel intensity distributions more precisely than simple moment matching.
Supports both single images and batched tensors.
Parameters¶
reference : torch.Tensor
Reference image or batch, shape (C, H, W) or (B, C, H, W).
Its histogram will be used as the target distribution.
target : torch.Tensor
Target image or batch to be adjusted, shape (C, H, W) or (B, C, H, W).
Must have the same number of channels as reference.
Returns¶
torch.Tensor Histogram-matched version of the target, with the same shape and dtype as the input. If a single image is given, returns shape (C, H, W).
Source code in opensr_srgan/utils/radiometrics.py
def histogram(reference: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Perform **per-channel histogram matching** of `target` → `reference`.
Each channel in the target image is adjusted so that its cumulative
distribution function (CDF) matches that of the corresponding channel
in the reference image.
This preserves overall color/radiometric tone relationships, but aligns
the pixel intensity distributions more precisely than simple moment matching.
Supports both single images and batched tensors.
Parameters
----------
reference : torch.Tensor
Reference image or batch, shape (C, H, W) or (B, C, H, W).
Its histogram will be used as the target distribution.
target : torch.Tensor
Target image or batch to be adjusted, shape (C, H, W) or (B, C, H, W).
Must have the same number of channels as `reference`.
Returns
-------
torch.Tensor
Histogram-matched version of the target, with the same shape and dtype
as the input. If a single image is given, returns shape (C, H, W).
"""
# Ensure both inputs have correct dimensionality: either (C,H,W) or (B,C,H,W)
assert target.ndim in (3, 4) and reference.ndim in (
3,
4,
), "Expected (C,H,W) or (B,C,H,W) for both reference and target"
# Save device/dtype for conversion back later
device, dtype = target.device, target.dtype
# --- Normalize to batch form (always 4D: B,C,H,W) ---
# If inputs are 3D (single images), temporarily add batch dimension
ref = reference.unsqueeze(0) if reference.ndim == 3 else reference
tgt = target.unsqueeze(0) if target.ndim == 3 else target
# Extract shapes
B_ref, C_ref, H_ref, W_ref = ref.shape
B_tgt, C_tgt, H_tgt, W_tgt = tgt.shape
# Channel sanity check
assert C_ref == C_tgt, f"Channel mismatch: reference={C_ref}, target={C_tgt}"
# --- Resize reference spatially to match target ---
# Uses bilinear interpolation, no corner alignment, safe for float data
if (H_ref, W_ref) != (H_tgt, W_tgt):
ref = F.interpolate(
ref.to(dtype=torch.float32),
size=(H_tgt, W_tgt),
mode="bilinear",
align_corners=False,
)
# Convert to NumPy for histogram matching operations
ref_np = tensor_to_numpy(ref)
tgt_np = tensor_to_numpy(tgt)
out_np = np.empty_like(tgt_np) # preallocate output array
# --- Loop over batches and channels ---
for b in range(B_tgt):
# If reference has only one batch (B_ref=1), broadcast it to all targets
rb = b % B_ref
for c in range(C_tgt):
ref_ch = ref_np[rb, c] # reference channel
tgt_ch = tgt_np[b, c] # target channel
# Mask invalid pixels (NaN or Inf)
mask = np.isfinite(tgt_ch) & np.isfinite(ref_ch)
if mask.any():
# Perform per-channel histogram matching
matched = exposure.match_histograms(tgt_ch[mask], ref_ch[mask])
out = tgt_ch.copy()
out[mask] = matched
out_np[b, c] = out
else:
# If no valid pixels, copy target as-is
out_np[b, c] = tgt_ch
# Convert back to torch tensor on the original device/dtype
out = torch.from_numpy(out_np).to(device=device, dtype=dtype)
# If the original input was 3D, remove the temporary batch dimension
return out[0] if target.ndim == 3 else out
moment(reference, target)
¶
Perform moment matching between two multispectral image tensors.
Each channel in the target image is rescaled to match the mean and
standard deviation (first and second moments) of the corresponding channel
in the reference image.
This operation effectively transfers the global radiometric statistics
(brightness and contrast) from reference to target.
Parameters¶
reference : torch.Tensor Reference image whose per-channel statistics will be matched (e.g. Sentinel-2), shape (C, H, W). target : torch.Tensor Target image to be adjusted (e.g. SPOT-6), shape (C, H, W). reference_amount : float, optional Currently unused. Can later be used to control blending strength between the original target and the moment-matched output (0–1).
Returns¶
torch.Tensor Moment-matched image with shape (C, H, W), where each target channel now has the same mean and standard deviation as the corresponding reference channel.
Source code in opensr_srgan/utils/radiometrics.py
def moment(reference: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Perform **moment matching** between two multispectral image tensors.
Each channel in the `target` image is rescaled to match the mean and
standard deviation (first and second moments) of the corresponding channel
in the `reference` image.
This operation effectively transfers the global radiometric statistics
(brightness and contrast) from `reference` to `target`.
Parameters
----------
reference : torch.Tensor
Reference image whose per-channel statistics will be matched (e.g. Sentinel-2),
shape (C, H, W).
target : torch.Tensor
Target image to be adjusted (e.g. SPOT-6),
shape (C, H, W).
reference_amount : float, optional
Currently unused. Can later be used to control blending strength between
the original target and the moment-matched output (0–1).
Returns
-------
torch.Tensor
Moment-matched image with shape (C, H, W),
where each target channel now has the same mean and standard deviation
as the corresponding reference channel.
"""
device, dtype = target.device, target.dtype
# Convert to NumPy arrays for easier numerical processing
reference_np = tensor_to_numpy(reference)
target_np = tensor_to_numpy(target)
matched_channels = []
# Iterate channel-wise through reference and target
for ref_ch, tgt_ch in zip(reference_np, target_np):
# --- Compute per-channel mean and std ---
ref_mean = np.mean(ref_ch)
tgt_mean = np.mean(tgt_ch)
ref_std = np.std(ref_ch)
tgt_std = np.std(tgt_ch)
# --- Apply moment matching formula ---
# Normalize target → scale by reference std → shift by reference mean
matched_channel = (((tgt_ch - tgt_mean) / tgt_std) * ref_std) + ref_mean
matched_channels.append(matched_channel)
matched_np = np.stack(matched_channels, axis=0)
# Convert back to PyTorch tensor with channel-first format (C, H, W)
matched = torch.from_numpy(matched_np).to(device=device, dtype=dtype)
return matched
Tensor conversions¶
Fallback helpers for converting Torch tensors to NumPy arrays.
tensor_to_numpy(tensor)
¶
Convert a PyTorch tensor to a NumPy array safely across environments.
Provides a robust fallback for minimal or CPU-only PyTorch builds that lack
NumPy bindings (where tensor.numpy() raises RuntimeError: Numpy is not available).
Ensures tensors are detached, moved to CPU, and contiguous before conversion.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tensor
|
Tensor
|
Input tensor to convert. |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
numpy.ndarray: NumPy array representation of the tensor. |
ndarray
|
Falls back to |
Raises:
| Type | Description |
|---|---|
RuntimeError
|
Re-raises conversion errors not related to missing NumPy bindings. |
Notes
- Keeps dtype consistency between PyTorch and NumPy via a lookup table.
- Returns a view when possible, or a copy when conversion via list fallback is used.
Source code in opensr_srgan/utils/tensor_conversions.py
def tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray:
"""Convert a PyTorch tensor to a NumPy array safely across environments.
Provides a robust fallback for minimal or CPU-only PyTorch builds that lack
NumPy bindings (where ``tensor.numpy()`` raises ``RuntimeError: Numpy is not available``).
Ensures tensors are detached, moved to CPU, and contiguous before conversion.
Args:
tensor (torch.Tensor): Input tensor to convert.
Returns:
numpy.ndarray: NumPy array representation of the tensor.
Falls back to ``tensor.tolist()`` conversion if direct bindings are unavailable.
Raises:
RuntimeError: Re-raises conversion errors not related to missing NumPy bindings.
Notes:
- Keeps dtype consistency between PyTorch and NumPy via a lookup table.
- Returns a view when possible, or a copy when conversion via list fallback is used.
"""
tensor_cpu = tensor.detach().cpu()
if not tensor_cpu.is_contiguous():
tensor_cpu = tensor_cpu.contiguous()
try:
return tensor_cpu.numpy()
except RuntimeError as exc:
if "Numpy is not available" not in str(exc):
raise
dtype = _TORCH_TO_NUMPY_DTYPE.get(tensor_cpu.dtype)
return np.asarray(tensor_cpu.tolist(), dtype=dtype)
Model summarisation¶
print_model_summary(self)
¶
Prints a detailed and visually structured summary of the SRGAN configuration. Includes architecture info, resolution scale, training parameters, loss weights, and model sizes.
Source code in opensr_srgan/utils/model_descriptions.py
def print_model_summary(self):
"""
Prints a detailed and visually structured summary of the SRGAN configuration.
Includes architecture info, resolution scale, training parameters, loss weights, and model sizes.
"""
# --- helpers (millions) ---
def count_trainable_params(m):
return sum(p.numel() for p in m.parameters() if p.requires_grad) / 1e6
def count_total_params(m):
return sum(p.numel() for p in m.parameters()) / 1e6
# --- counts ---
g_trainable = count_trainable_params(self.generator)
g_total = count_total_params(self.generator)
d_trainable = count_trainable_params(self.discriminator)
d_total = count_total_params(self.discriminator)
# expose the exact names you want to use elsewhere
g_params = g_trainable # alias for “trainable G params (M)”
d_params = d_trainable # alias for “trainable D params (M)”
total_trainable = g_trainable + d_trainable
total_all = g_total + d_total
# ------------------------------------------------------------------
# Derive human-readable generator description
# ------------------------------------------------------------------
g_type = getattr(self.config.Generator, "model_type", "SRResNet")
g_type_norm = str(g_type).lower().replace("-", "_")
block_type = getattr(self.config.Generator, "block_type", None)
if g_type_norm in {"res", "rcab", "rrdb", "lka"} and block_type is None:
block_type = g_type_norm
g_type_norm = "srresnet"
block_desc_map = {
"standard": "Residual Blocks with BatchNorm",
"res": "Residual Blocks without BatchNorm",
"rcab": "RCAB Blocks with Channel Attention",
"rrdb": "RRDB Dense Residual Blocks",
"lka": "LKA Large-Kernel Attention Blocks",
}
if g_type_norm in {"srresnet", "sr_resnet"}:
block_key = "standard" if block_type is None else str(block_type).lower()
desc = block_desc_map.get(block_key, f"Custom Block Variant: {block_key}")
g_desc = f"SRResNet ({desc})"
elif g_type_norm in {"stochastic_gan", "cgan", "conditional_cgan"}:
g_desc = "Stochastic SRGAN (Latent-modulated Residual Blocks)"
elif g_type_norm == "esrgan":
g_desc = "ESRGAN (RRDB Residual-in-Residual Dense Network)"
else:
g_desc = f"Custom Generator Type: {g_type}"
# ------------------------------------------------------------------
# Resolution info (input → output)
# ------------------------------------------------------------------
scale_factor = getattr(
self.config.Generator, "scaling_factor", getattr(self.generator, "scale", None)
)
if scale_factor is not None:
res_str = f"Super-Resolution Factor: ×{scale_factor}"
else:
res_str = "Super-Resolution Factor: Unknown"
# ------------------------------------------------------------------
# Retrieve loss weights if available
# ------------------------------------------------------------------
loss_cfg = getattr(self.config.Training, "Losses", {})
content_w = getattr(loss_cfg, "content_loss_weight", 1.0)
adv_w = getattr(loss_cfg, "adv_loss_beta", 1.0)
perceptual_w = getattr(loss_cfg, "perceptual_loss_weight", None)
total_w_str = f" • Content: {content_w} | Adversarial: {adv_w}" + (
f" | Perceptual: {perceptual_w}" if perceptual_w is not None else ""
)
print("\n" + "=" * 90)
print("🚀 SRGAN Model Summary")
print("=" * 90)
# ------------------------------------------------------------------
# Generator Info
# ------------------------------------------------------------------
print(f"🧩 Generator")
print(f" • Architecture: {g_desc}")
print(f" • Resolution: {res_str}")
print(f" • Input Channels: {self.config.Model.in_bands}")
feature_channels = getattr(self.config.Generator, "n_channels", None)
if feature_channels is not None:
print(f" • Feature Channels: {feature_channels}")
block_count = getattr(
self.config.Generator,
"n_blocks",
getattr(self.generator, "n_blocks", None),
)
if block_count is not None:
print(f" • Residual Blocks: {block_count}")
small_kernel = getattr(self.config.Generator, "small_kernel_size", None)
large_kernel = getattr(self.config.Generator, "large_kernel_size", None)
if small_kernel is not None or large_kernel is not None:
print(
" • Kernel Sizes: small={small}, large={large}".format(
small=small_kernel if small_kernel is not None else "N/A",
large=large_kernel if large_kernel is not None else "N/A",
)
)
print(f" • Params: {g_params:.2f} M\n")
# ------------------------------------------------------------------
# Discriminator Info
# ------------------------------------------------------------------
d_type = getattr(self.config.Discriminator, "model_type", "standard")
d_blocks = getattr(self.config.Discriminator, "n_blocks", None)
effective_blocks = getattr(
self.discriminator,
"n_blocks",
getattr(self.discriminator, "n_layers", d_blocks),
)
base_channels = getattr(self.discriminator, "base_channels", "N/A")
kernel_size = getattr(self.discriminator, "kernel_size", "N/A")
fc_size = getattr(self.discriminator, "fc_size", None)
if d_type == "patchgan":
d_desc = "PatchGAN"
elif d_type == "esrgan":
d_desc = "ESRGAN"
else:
d_desc = "SRGAN"
print(f"🧠 Discriminator")
print(f" • Architecture: {d_desc}")
if effective_blocks is not None:
print(f" • Blocks/Layers: {effective_blocks}")
print(f" • Base Channels: {base_channels}")
print(f" • Kernel Size: {kernel_size}")
if fc_size is not None:
print(f" • FC Layer Size: {fc_size}")
print(f" • Params: {d_params:.2f} M\n")
# ------------------------------------------------------------------
# Training Setup
# ------------------------------------------------------------------
print(f"⚙️ Training Configuration")
print(f" • Pretrain Generator: {self.pretrain_g_only}")
print(f" • Pretrain Steps: {self.g_pretrain_steps}")
print(f" • Adv. Ramp Steps: {self.adv_loss_ramp_steps}")
print(f" • Label Smoothing: {self.adv_target < 1.0}")
print(f" • Adv. Target Label: {self.adv_target}\n")
# ------------------------------------------------------------------
# EMA Configuration (if present)
# ------------------------------------------------------------------
ema_cfg = getattr(getattr(self.config.Training, "EMA", None), "__dict__", None)
if ema_cfg or hasattr(self.config.Training, "EMA"):
ema = self.config.Training.EMA
enabled = getattr(ema, "enabled", False)
decay = getattr(ema, "decay", None)
update_after = getattr(ema, "update_after_step", None)
use_num_updates = getattr(ema, "use_num_updates", None)
print(f"🔁 Exponential Moving Average (EMA)")
print(f" • Enabled: {enabled}")
if enabled:
print(f" • Decay: {decay}")
print(f" • Update After Step: {update_after}")
print(f" • Use Num Updates: {use_num_updates}")
print() # newline for spacing
# ------------------------------------------------------------------
# Loss Functions
# ------------------------------------------------------------------
print(f"📉 Loss Functions")
print(f" • Content Loss: {type(self.content_loss_criterion).__name__}")
print(f" • Adversarial Loss: {type(self.adversarial_loss_criterion).__name__}")
print(total_w_str + "\n")
# ------------------------------------------------------------------
# Summary
# ------------------------------------------------------------------
print(f"📊 Model Summary")
print(f" • Total Params: {total_all:.2f} M")
print(f" • Total Trainable Params: {total_trainable:.2f} M")
print(
f" • Device: {self.device if hasattr(self, 'device') else 'Not set'}"
)
print("=" * 90 + "\n")
Distributed coordination¶
_is_global_zero()
¶
Return True if this process is the global rank zero (primary) worker.
Used to gate singleton side effects (e.g., logging, file writes, directory
creation) in both single- and multi-process environments. Supports plain CPU,
single-GPU, torch.distributed (torchrun), and SLURM-style launches.
Detection order
- Use
torch.distributedif available and initialized. - Fall back to environment variables
RANKandWORLD_SIZE. - Default to True for non-distributed (single-process) runs.
Returns:
| Name | Type | Description |
|---|---|---|
bool |
bool
|
True if this process is rank zero or if distributed training |
bool
|
is not active (single process). False otherwise. |
Source code in opensr_srgan/utils/gpu_rank.py
def _is_global_zero() -> bool:
"""Return True if this process is the global rank zero (primary) worker.
Used to gate singleton side effects (e.g., logging, file writes, directory
creation) in both single- and multi-process environments. Supports plain CPU,
single-GPU, torch.distributed (``torchrun``), and SLURM-style launches.
Detection order:
1. Use ``torch.distributed`` if available and initialized.
2. Fall back to environment variables ``RANK`` and ``WORLD_SIZE``.
3. Default to True for non-distributed (single-process) runs.
Returns:
bool: True if this process is rank zero or if distributed training
is not active (single process). False otherwise.
"""
# Prefer torch.distributed if available
try:
import torch.distributed as dist
if dist.is_available() and dist.is_initialized():
return dist.get_rank() == 0
except Exception:
pass
# Fallback to env vars commonly set by torchrun/SLURM
rank = int(os.environ.get("RANK", "0"))
world_size = int(os.environ.get("WORLD_SIZE", "1"))
return rank == 0 or world_size == 1