trainer (module)
Lightning helpers and convenience routines bundled with gradnet.
Utilities to train a gradnet.GradNet with PyTorch Lightning.
This module provides a thin Lightning wrapper and a convenience function
(fit()) to optimize a GradNet for a fixed number of
updates.
- class gradnet.trainer.LossFn(*args, **kwargs)[source]
Bases:
ProtocolProtocol for loss functions used with
fit().Implementations must accept a
gradnet.GradNetand may accept arbitrary keyword arguments. They should return either a scalar loss tensor, or a tuple(loss, metrics_dict)wheremetrics_dictmaps metric names to floats/ints/tensors.
- class gradnet.trainer.GradNetLightning(*, gn, loss_fn, loss_kwargs=None, optim_cls, optim_kwargs, sched_cls=None, sched_kwargs=None, grad_clip_val=0.0, post_step_renorm=True, monitor_key='loss', compile_model=False)[source]
Bases:
LightningModuleLightningModule wrapper around a
GradNetand a user loss.This module performs manual optimisation: it evaluates
loss_fnto obtain a scalar loss (and optional metrics), applies gradient clipping (if configured), steps the optimizer, optionally renormalises the model parameters, and logs metrics undermonitor_key.- Parameters:
gn (torch.nn.Module) – Model to optimise. Typically a
gradnet.GradNet; anynn.Moduleis accepted. If the module exposesshould_renorm_after_step(), it is queried to decide whether per-step renormalization should run (subject topost_step_renorm=True). If the module exposesrenorm_params(), that method is invoked when renormalization is enabled for the current step.loss_fn (LossFn) – Callable evaluated on every optimisation step as
loss_fn(gn, **loss_kwargs). Must return either a scalar loss tensor or a(loss, metrics_dict)tuple.loss_kwargs (Mapping[str, Any] | None, optional) – Extra keyword arguments forwarded to
loss_fn. When passed throughfit(), tensors and arrays are coerced togn’s device/dtype viagradnet.utils._to_like_struct()before being stored here.optim_cls (type[torch.optim.Optimizer]) – Optimiser class instantiated over
gn.parameters().optim_kwargs (dict, optional) – Arguments passed to
optim_cls(e.g.,{"lr": 1e-2}).sched_cls (type | None, optional) – Optional LR scheduler class applied on top of the optimiser.
sched_kwargs (dict | None, optional) – Keyword arguments for
sched_cls.grad_clip_val (float, optional) – Gradient-norm clipping threshold.
0.0disables clipping.post_step_renorm (bool, optional) – Master switch for post-step renormalization. When
True, the module’sshould_renorm_after_step()policy is used if available; otherwise renormalization runs wheneverrenorm_params()is available.monitor_key (str, optional) – Metric name under which the primary loss is logged.
compile_model (bool, optional) – Attempt to wrap the model with
torch.compile()duringsetup; fall back silently when compilation fails.
- gradnet.trainer.fit(*, gn, loss_fn, loss_kwargs=None, num_updates, optim_cls=<class 'torch.optim.adam.Adam'>, optim_kwargs=None, sched_cls=None, sched_kwargs=None, precision='32-true', accelerator='auto', logger=False, log_dir=None, enable_checkpointing=False, checkpoint_dir=None, checkpoint_every_n=None, save_last=False, callbacks=None, max_time=None, grad_clip_val=0.0, post_step_renorm=True, compile_model=False, seed=None, deterministic=None, verbose=True)[source]
Optimise a
gradnet.GradNetfor a fixed number of updates.One trainer epoch corresponds to a single optimiser step, so
num_updatesequals the number of optimisation steps executed. Each step evaluatesloss_fn(gn, **loss_kwargs)and drives manual optimisation viaGradNetLightning.Notes
When called from a Jupyter notebook, PyTorch Lightning writes progress and hardware summaries to
stderr. Jupyter renders those lines in red by default, but the colour does not indicate an error.- Parameters:
gn (GradNet) – Network to optimise.
loss_fn (LossFn) – Callable invoked as
loss_fn(gn, **loss_kwargs)and returning either a scalar loss tensor or(loss, metrics_dict).loss_kwargs (Mapping[str, Any] | None, optional) – Extra keyword arguments forwarded to
loss_fn. When provided, tensors and arrays are coerced togn’s device/dtype viagradnet.utils._to_like_struct().num_updates (int) – Number of optimisation steps to run.
optim_cls (type[torch.optim.Optimizer], optional) – Optimiser class constructed as
optim_cls(gn.parameters(), **optim_kwargs).optim_kwargs (dict | None, optional) – Keyword arguments for the optimiser. Defaults to
{"lr": 1e-2}whenNone.sched_cls (type | None, optional) – Optional learning-rate scheduler applied to the optimiser.
sched_kwargs (dict | None, optional) – Keyword arguments for
sched_cls.precision (str | int, optional) – Forwarded to
pl.Trainer(precision=...)(e.g.,"32-true",16).accelerator (str, optional) – Passed to
pl.Trainer(accelerator=...)("auto","cpu","gpu", etc.).logger (LightningLoggerBase | bool | None, optional) – Logger configuration forwarded to
pl.Trainer. UseTruefor the default logger (falls back toCSVLoggerwhen TensorBoard is unavailable),Falseto disable logging, or supply a Lightning logger instance.log_dir (str | None, optional) – When
loggerisTrue(request default logging), use this directory for logs. Treats the value as the final directory path (e.g.,"./lightning_logs/gradnet"). WhenNone, defaults to"./lightning_logs/gradnet". Ignored when a customloggerinstance is provided or logging is disabled.enable_checkpointing (bool, optional) – Enable the default
ModelCheckpointcallback. WhenTrueit monitorslossinminmode to keep the best checkpoint, can store the last checkpoint whensave_last=True, and optionally adds periodic checkpoints viacheckpoint_every_n.checkpoint_dir (str | None, optional) – Directory used by
ModelCheckpointwhen checkpointing is enabled.checkpoint_every_n (int | None, optional) – Save an additional checkpoint every
checkpoint_every_nepochs (updates). Provide an integer greater than or equal to 1 to enable periodic saves, orNoneto disable them.save_last (bool, optional) – Whether to always save the final checkpoint in addition to the best.
callbacks (list[pl.Callback] | None, optional) – Additional Lightning callbacks to register.
max_time (str | None, optional) – Training time limit forwarded to
pl.Trainer(max_time=...).grad_clip_val (float, optional) – Gradient-norm clipping threshold applied before optimiser steps.
post_step_renorm (bool, optional) – Master switch for post-step renormalization. When enabled, if
gnexposesshould_renorm_after_step(), that policy decides whethergn.renorm_params()runs after each optimizer step.compile_model (bool, optional) – Attempt to wrap
gnwithtorch.compile()during setup.seed (int | None, optional) – When provided, seeds PyTorch Lightning via
pl.seed_everything.deterministic (bool | str | None, optional) – If not
None, passed totorch.use_deterministic_algorithms.verbose (bool, optional) – Show progress via
tqdm.auto.tqdm.
- Returns:
The configured trainer and the best checkpoint path (
Nonewhen checkpointing is disabled).- Return type:
- Raises:
TypeError – If
loss_kwargsis neitherNonenor a mapping.
Examples
>>> from pytorch_lightning.loggers import TensorBoardLogger >>> logger = TensorBoardLogger(save_dir="logs", name="demo") >>> trainer, best_ckpt = fit( ... gn=model, ... loss_fn=loss, ... num_updates=100, ... logger=logger, ... )
See also
PyTorch Optimizers (
torch.optim), PyTorch LR Schedulers (torch.optim.lr_scheduler), and PyTorch Lightning’s Trainer and Callbacks documentation for accepted values ofprecisionandacceleratorand for available callback types.