trainer (module)

Lightning helpers and convenience routines bundled with gradnet.

Utilities to train a gradnet.GradNet with PyTorch Lightning.

This module provides a thin Lightning wrapper and a convenience function (fit()) to optimize a GradNet for a fixed number of updates.

class gradnet.trainer.LossFn(*args, **kwargs)[source]

Bases: Protocol

Protocol for loss functions used with fit().

Implementations must accept a gradnet.GradNet and may accept arbitrary keyword arguments. They should return either a scalar loss tensor, or a tuple (loss, metrics_dict) where metrics_dict maps metric names to floats/ints/tensors.

class gradnet.trainer.GradNetLightning(*, gn, loss_fn, loss_kwargs=None, optim_cls, optim_kwargs, sched_cls=None, sched_kwargs=None, grad_clip_val=0.0, post_step_renorm=True, monitor_key='loss', compile_model=False)[source]

Bases: LightningModule

LightningModule wrapper around a GradNet and a user loss.

This module performs manual optimisation: it evaluates loss_fn to obtain a scalar loss (and optional metrics), applies gradient clipping (if configured), steps the optimizer, optionally renormalises the model parameters, and logs metrics under monitor_key.

Parameters:
  • gn (torch.nn.Module) – Model to optimise. Typically a gradnet.GradNet; any nn.Module is accepted. If the module exposes should_renorm_after_step(), it is queried to decide whether per-step renormalization should run (subject to post_step_renorm=True). If the module exposes renorm_params(), that method is invoked when renormalization is enabled for the current step.

  • loss_fn (LossFn) – Callable evaluated on every optimisation step as loss_fn(gn, **loss_kwargs). Must return either a scalar loss tensor or a (loss, metrics_dict) tuple.

  • loss_kwargs (Mapping[str, Any] | None, optional) – Extra keyword arguments forwarded to loss_fn. When passed through fit(), tensors and arrays are coerced to gn’s device/dtype via gradnet.utils._to_like_struct() before being stored here.

  • optim_cls (type[torch.optim.Optimizer]) – Optimiser class instantiated over gn.parameters().

  • optim_kwargs (dict, optional) – Arguments passed to optim_cls (e.g., {"lr": 1e-2}).

  • sched_cls (type | None, optional) – Optional LR scheduler class applied on top of the optimiser.

  • sched_kwargs (dict | None, optional) – Keyword arguments for sched_cls.

  • grad_clip_val (float, optional) – Gradient-norm clipping threshold. 0.0 disables clipping.

  • post_step_renorm (bool, optional) – Master switch for post-step renormalization. When True, the module’s should_renorm_after_step() policy is used if available; otherwise renormalization runs whenever renorm_params() is available.

  • monitor_key (str, optional) – Metric name under which the primary loss is logged.

  • compile_model (bool, optional) – Attempt to wrap the model with torch.compile() during setup; fall back silently when compilation fails.

setup(stage=None)[source]

Optionally compile the wrapped model.

training_step(batch, batch_idx)[source]

Run one manual-optimization update using loss_fn.

configure_optimizers()[source]

Build the optimizer (and optional scheduler) config for Lightning.

on_save_checkpoint(checkpoint)[source]

Persist GradNet reconstruction metadata when available.

gradnet.trainer.fit(*, gn, loss_fn, loss_kwargs=None, num_updates, optim_cls=<class 'torch.optim.adam.Adam'>, optim_kwargs=None, sched_cls=None, sched_kwargs=None, precision='32-true', accelerator='auto', logger=False, log_dir=None, enable_checkpointing=False, checkpoint_dir=None, checkpoint_every_n=None, save_last=False, callbacks=None, max_time=None, grad_clip_val=0.0, post_step_renorm=True, compile_model=False, seed=None, deterministic=None, verbose=True)[source]

Optimise a gradnet.GradNet for a fixed number of updates.

One trainer epoch corresponds to a single optimiser step, so num_updates equals the number of optimisation steps executed. Each step evaluates loss_fn(gn, **loss_kwargs) and drives manual optimisation via GradNetLightning.

Notes

When called from a Jupyter notebook, PyTorch Lightning writes progress and hardware summaries to stderr. Jupyter renders those lines in red by default, but the colour does not indicate an error.

Parameters:
  • gn (GradNet) – Network to optimise.

  • loss_fn (LossFn) – Callable invoked as loss_fn(gn, **loss_kwargs) and returning either a scalar loss tensor or (loss, metrics_dict).

  • loss_kwargs (Mapping[str, Any] | None, optional) – Extra keyword arguments forwarded to loss_fn. When provided, tensors and arrays are coerced to gn’s device/dtype via gradnet.utils._to_like_struct().

  • num_updates (int) – Number of optimisation steps to run.

  • optim_cls (type[torch.optim.Optimizer], optional) – Optimiser class constructed as optim_cls(gn.parameters(), **optim_kwargs).

  • optim_kwargs (dict | None, optional) – Keyword arguments for the optimiser. Defaults to {"lr": 1e-2} when None.

  • sched_cls (type | None, optional) – Optional learning-rate scheduler applied to the optimiser.

  • sched_kwargs (dict | None, optional) – Keyword arguments for sched_cls.

  • precision (str | int, optional) – Forwarded to pl.Trainer(precision=...) (e.g., "32-true", 16).

  • accelerator (str, optional) – Passed to pl.Trainer(accelerator=...) ("auto", "cpu", "gpu", etc.).

  • logger (LightningLoggerBase | bool | None, optional) – Logger configuration forwarded to pl.Trainer. Use True for the default logger (falls back to CSVLogger when TensorBoard is unavailable), False to disable logging, or supply a Lightning logger instance.

  • log_dir (str | None, optional) – When logger is True (request default logging), use this directory for logs. Treats the value as the final directory path (e.g., "./lightning_logs/gradnet"). When None, defaults to "./lightning_logs/gradnet". Ignored when a custom logger instance is provided or logging is disabled.

  • enable_checkpointing (bool, optional) – Enable the default ModelCheckpoint callback. When True it monitors loss in min mode to keep the best checkpoint, can store the last checkpoint when save_last=True, and optionally adds periodic checkpoints via checkpoint_every_n.

  • checkpoint_dir (str | None, optional) – Directory used by ModelCheckpoint when checkpointing is enabled.

  • checkpoint_every_n (int | None, optional) – Save an additional checkpoint every checkpoint_every_n epochs (updates). Provide an integer greater than or equal to 1 to enable periodic saves, or None to disable them.

  • save_last (bool, optional) – Whether to always save the final checkpoint in addition to the best.

  • callbacks (list[pl.Callback] | None, optional) – Additional Lightning callbacks to register.

  • max_time (str | None, optional) – Training time limit forwarded to pl.Trainer(max_time=...).

  • grad_clip_val (float, optional) – Gradient-norm clipping threshold applied before optimiser steps.

  • post_step_renorm (bool, optional) – Master switch for post-step renormalization. When enabled, if gn exposes should_renorm_after_step(), that policy decides whether gn.renorm_params() runs after each optimizer step.

  • compile_model (bool, optional) – Attempt to wrap gn with torch.compile() during setup.

  • seed (int | None, optional) – When provided, seeds PyTorch Lightning via pl.seed_everything.

  • deterministic (bool | str | None, optional) – If not None, passed to torch.use_deterministic_algorithms.

  • verbose (bool, optional) – Show progress via tqdm.auto.tqdm.

Returns:

The configured trainer and the best checkpoint path (None when checkpointing is disabled).

Return type:

tuple[pl.Trainer, str | None]

Raises:

TypeError – If loss_kwargs is neither None nor a mapping.

Examples

>>> from pytorch_lightning.loggers import TensorBoardLogger
>>> logger = TensorBoardLogger(save_dir="logs", name="demo")
>>> trainer, best_ckpt = fit(
...     gn=model,
...     loss_fn=loss,
...     num_updates=100,
...     logger=logger,
... )

See also

PyTorch Optimizers (torch.optim), PyTorch LR Schedulers (torch.optim.lr_scheduler), and PyTorch Lightning’s Trainer and Callbacks documentation for accepted values of precision and accelerator and for available callback types.