Trainer details¶
This page describes the training control flow used by OpenSR-SRGAN on PyTorch Lightning 2+.
Bootstrap sequence¶
- Validate Lightning version.
SRGAN_model.setup_lightning()enforces Lightning>= 2.0. - Enable manual optimization. The model sets
automatic_optimization = False. - Bind training-step helper.
training_step_PL2is attached as the active training-step implementation. - Build trainer kwargs.
build_lightning_kwargs()normalises accelerator/devices and preparesfit_kwargs(includingckpt_pathwhen resuming).
Training-step anatomy¶
training_step_PL2(self, batch, batch_idx) always performs manual optimizer control.
opt_d, opt_g = self.optimizers()
pretrain_phase = self._pretrain_check()
self.log("training/pretrain_phase", float(pretrain_phase), sync_dist=True)
Pretraining branch¶
When pretrain_phase is active:
- Discriminator metrics are logged as zeros (no discriminator optimizer step).
- Generator content loss is computed and logged.
- The generator optimizer performs
zero_grad -> manual_backward -> step. - EMA updates after the generator step when enabled and active.
Adversarial branch¶
When pretraining is finished:
- Discriminator update
- Compute
D(hr)andD(sr.detach()). - Apply BCE or Wasserstein objective (+ optional R1 penalty).
- Log discriminator losses/probabilities.
- Run
manual_backwardandopt_d.step(). - Generator update
- Compute content loss + metrics.
- Compute adversarial generator objective from
D(sr). - Apply ramped adversarial weight (
training/adv_loss_weight). - Log
generator/content_loss,generator/adversarial_loss,generator/total_loss. - Run
manual_backwardandopt_g.step(). - EMA updates after the generator step when enabled and active.
Resume behavior¶
Model.continue_training is passed through build_lightning_kwargs() and forwarded as:
This restores optimizer/scheduler state, EMA state, and global step before continuing.
Runtime checks summary¶
| Check | Source | Purpose |
|---|---|---|
Lightning >= 2.0 |
SRGAN_model.setup_lightning() |
Reject unsupported runtime versions. |
| Manual optimization enabled | setup_lightning() |
Ensure GAN optimizer alternation is explicit. |
| Pretraining active? | _pretrain_check() |
Gate between content-only and adversarial training. |
| Adversarial weight | _adv_loss_weight() |
Log and apply the current GAN loss multiplier. |
| EMA active? | self.global_step >= self._ema_update_after_step |
Delay EMA updates until configured step. |
Workflow map¶
See opensr_srgan/model/training_workflow.txt for the full text branch map aligned to the current implementation.