gradnet.trainer.fit

gradnet.trainer.fit(*, gn, loss_fn, loss_kwargs=None, num_updates, optim_cls=<class 'torch.optim.adam.Adam'>, optim_kwargs=None, sched_cls=None, sched_kwargs=None, precision='32-true', accelerator='auto', logger=False, log_dir=None, enable_checkpointing=False, checkpoint_dir=None, checkpoint_every_n=None, save_last=False, callbacks=None, max_time=None, grad_clip_val=0.0, post_step_renorm=True, compile_model=False, seed=None, deterministic=None, verbose=True)[source]

Optimise a gradnet.GradNet for a fixed number of updates.

One trainer epoch corresponds to a single optimiser step, so num_updates equals the number of optimisation steps executed. Each step evaluates loss_fn(gn, **loss_kwargs) and drives manual optimisation via GradNetLightning.

Notes

When called from a Jupyter notebook, PyTorch Lightning writes progress and hardware summaries to stderr. Jupyter renders those lines in red by default, but the colour does not indicate an error.

Parameters:
  • gn (GradNet) – Network to optimise.

  • loss_fn (LossFn) – Callable invoked as loss_fn(gn, **loss_kwargs) and returning either a scalar loss tensor or (loss, metrics_dict).

  • loss_kwargs (Mapping[str, Any] | None, optional) – Extra keyword arguments forwarded to loss_fn. When provided, tensors and arrays are coerced to gn’s device/dtype via gradnet.utils._to_like_struct().

  • num_updates (int) – Number of optimisation steps to run.

  • optim_cls (type[torch.optim.Optimizer], optional) – Optimiser class constructed as optim_cls(gn.parameters(), **optim_kwargs).

  • optim_kwargs (dict | None, optional) – Keyword arguments for the optimiser. Defaults to {"lr": 1e-2} when None.

  • sched_cls (type | None, optional) – Optional learning-rate scheduler applied to the optimiser.

  • sched_kwargs (dict | None, optional) – Keyword arguments for sched_cls.

  • precision (str | int, optional) – Forwarded to pl.Trainer(precision=...) (e.g., "32-true", 16).

  • accelerator (str, optional) – Passed to pl.Trainer(accelerator=...) ("auto", "cpu", "gpu", etc.).

  • logger (LightningLoggerBase | bool | None, optional) – Logger configuration forwarded to pl.Trainer. Use True for the default logger (falls back to CSVLogger when TensorBoard is unavailable), False to disable logging, or supply a Lightning logger instance.

  • log_dir (str | None, optional) – When logger is True (request default logging), use this directory for logs. Treats the value as the final directory path (e.g., "./lightning_logs/gradnet"). When None, defaults to "./lightning_logs/gradnet". Ignored when a custom logger instance is provided or logging is disabled.

  • enable_checkpointing (bool, optional) – Enable the default ModelCheckpoint callback. When True it monitors loss in min mode to keep the best checkpoint, can store the last checkpoint when save_last=True, and optionally adds periodic checkpoints via checkpoint_every_n.

  • checkpoint_dir (str | None, optional) – Directory used by ModelCheckpoint when checkpointing is enabled.

  • checkpoint_every_n (int | None, optional) – Save an additional checkpoint every checkpoint_every_n epochs (updates). Provide an integer greater than or equal to 1 to enable periodic saves, or None to disable them.

  • save_last (bool, optional) – Whether to always save the final checkpoint in addition to the best.

  • callbacks (list[pl.Callback] | None, optional) – Additional Lightning callbacks to register.

  • max_time (str | None, optional) – Training time limit forwarded to pl.Trainer(max_time=...).

  • grad_clip_val (float, optional) – Gradient-norm clipping threshold applied before optimiser steps.

  • post_step_renorm (bool, optional) – Master switch for post-step renormalization. When enabled, if gn exposes should_renorm_after_step(), that policy decides whether gn.renorm_params() runs after each optimizer step.

  • compile_model (bool, optional) – Attempt to wrap gn with torch.compile() during setup.

  • seed (int | None, optional) – When provided, seeds PyTorch Lightning via pl.seed_everything.

  • deterministic (bool | str | None, optional) – If not None, passed to torch.use_deterministic_algorithms.

  • verbose (bool, optional) – Show progress via tqdm.auto.tqdm.

Returns:

The configured trainer and the best checkpoint path (None when checkpointing is disabled).

Return type:

tuple[pl.Trainer, str | None]

Raises:

TypeError – If loss_kwargs is neither None nor a mapping.

Examples

>>> from pytorch_lightning.loggers import TensorBoardLogger
>>> logger = TensorBoardLogger(save_dir="logs", name="demo")
>>> trainer, best_ckpt = fit(
...     gn=model,
...     loss_fn=loss,
...     num_updates=100,
...     logger=logger,
... )

See also

PyTorch Optimizers (torch.optim), PyTorch LR Schedulers (torch.optim.lr_scheduler), and PyTorch Lightning’s Trainer and Callbacks documentation for accepted values of precision and accelerator and for available callback types.