gradnet.trainer.fit_gradnet
- gradnet.trainer.fit(*, gn, loss_fn, loss_kwargs=None, num_updates, optim_cls=<class 'torch.optim.adam.Adam'>, optim_kwargs=None, sched_cls=None, sched_kwargs=None, precision='32-true', accelerator='auto', logger=False, log_dir=None, enable_checkpointing=False, checkpoint_dir=None, checkpoint_every_n=None, save_last=False, callbacks=None, max_time=None, grad_clip_val=0.0, post_step_renorm=True, compile_model=False, seed=None, deterministic=None, verbose=True)[source]
Optimise a
gradnet.GradNetfor a fixed number of updates.One trainer epoch corresponds to a single optimiser step, so
num_updatesequals the number of optimisation steps executed. Each step evaluatesloss_fn(gn, **loss_kwargs)and drives manual optimisation viaGradNetLightning.Notes
When called from a Jupyter notebook, PyTorch Lightning writes progress and hardware summaries to
stderr. Jupyter renders those lines in red by default, but the colour does not indicate an error.- Parameters:
gn (GradNet) – Network to optimise.
loss_fn (LossFn) – Callable invoked as
loss_fn(gn, **loss_kwargs)and returning either a scalar loss tensor or(loss, metrics_dict).loss_kwargs (Mapping[str, Any] | None, optional) – Extra keyword arguments forwarded to
loss_fn. When provided, tensors and arrays are coerced togn’s device/dtype viagradnet.utils._to_like_struct().num_updates (int) – Number of optimisation steps to run.
optim_cls (type[torch.optim.Optimizer], optional) – Optimiser class constructed as
optim_cls(gn.parameters(), **optim_kwargs).optim_kwargs (dict | None, optional) – Keyword arguments for the optimiser. Defaults to
{"lr": 1e-2}whenNone.sched_cls (type | None, optional) – Optional learning-rate scheduler applied to the optimiser.
sched_kwargs (dict | None, optional) – Keyword arguments for
sched_cls.precision (str | int, optional) – Forwarded to
pl.Trainer(precision=...)(e.g.,"32-true",16).accelerator (str, optional) – Passed to
pl.Trainer(accelerator=...)("auto","cpu","gpu", etc.).logger (LightningLoggerBase | bool | None, optional) – Logger configuration forwarded to
pl.Trainer. UseTruefor the default logger (falls back toCSVLoggerwhen TensorBoard is unavailable),Falseto disable logging, or supply a Lightning logger instance.log_dir (str | None, optional) – When
loggerisTrue(request default logging), use this directory for logs. Treats the value as the final directory path (e.g.,"./lightning_logs/gradnet"). WhenNone, defaults to"./lightning_logs/gradnet". Ignored when a customloggerinstance is provided or logging is disabled.enable_checkpointing (bool, optional) – Enable the default
ModelCheckpointcallback. WhenTrueit monitorslossinminmode to keep the best checkpoint, can store the last checkpoint whensave_last=True, and optionally adds periodic checkpoints viacheckpoint_every_n.checkpoint_dir (str | None, optional) – Directory used by
ModelCheckpointwhen checkpointing is enabled.checkpoint_every_n (int | None, optional) – Save an additional checkpoint every
checkpoint_every_nepochs (updates). Provide an integer greater than or equal to 1 to enable periodic saves, orNoneto disable them.save_last (bool, optional) – Whether to always save the final checkpoint in addition to the best.
callbacks (list[pl.Callback] | None, optional) – Additional Lightning callbacks to register.
max_time (str | None, optional) – Training time limit forwarded to
pl.Trainer(max_time=...).grad_clip_val (float, optional) – Gradient-norm clipping threshold applied before optimiser steps.
post_step_renorm (bool, optional) – Master switch for post-step renormalization. When enabled, if
gnexposesshould_renorm_after_step(), that policy decides whethergn.renorm_params()runs after each optimizer step.compile_model (bool, optional) – Attempt to wrap
gnwithtorch.compile()during setup.seed (int | None, optional) – When provided, seeds PyTorch Lightning via
pl.seed_everything.deterministic (bool | str | None, optional) – If not
None, passed totorch.use_deterministic_algorithms.verbose (bool, optional) – Show progress via
tqdm.auto.tqdm.
- Returns:
The configured trainer and the best checkpoint path (
Nonewhen checkpointing is disabled).- Return type:
- Raises:
TypeError – If
loss_kwargsis neitherNonenor a mapping.
Examples
>>> from pytorch_lightning.loggers import TensorBoardLogger >>> logger = TensorBoardLogger(save_dir="logs", name="demo") >>> trainer, best_ckpt = fit( ... gn=model, ... loss_fn=loss, ... num_updates=100, ... logger=logger, ... )
See also
PyTorch Optimizers (
torch.optim), PyTorch LR Schedulers (torch.optim.lr_scheduler), and PyTorch Lightning’s Trainer and Callbacks documentation for accepted values ofprecisionandacceleratorand for available callback types.