"""Utilities to train a :class:`gradnet.GradNet` with PyTorch Lightning.
This module provides a thin Lightning wrapper and a convenience function
(:func:`fit`) to optimize a ``GradNet`` for a fixed number of
updates.
"""
from __future__ import annotations
from typing import Dict, Optional, Tuple, Union, Mapping, Any, Protocol
import logging
import os
import warnings
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
os.environ.setdefault("LIGHTING_USE_RICH", "0")
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers.logger import Logger as LightningLoggerBase
from pytorch_lightning.callbacks import Callback
try: # prefer notebook progress bar when the stack supports it
from tqdm import TqdmWarning # type: ignore[attr-defined]
except (ImportError, AttributeError):
TqdmWarning = Warning # fallback when tqdm lacks TqdmWarning
try: # PL >= 1.6-ish
from pytorch_lightning.utilities.warnings import PossibleUserWarning
except Exception: # Fallback for older PL where it's just a UserWarning
PossibleUserWarning = UserWarning
for warning_pattern in (
r"The 'train_dataloader' does not have many workers.*",
r"GPU available but not used.*",
):
warnings.filterwarnings("ignore", message=warning_pattern, category=PossibleUserWarning)
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=TqdmWarning)
try:
from tqdm.auto import tqdm
except Exception:
from tqdm import tqdm # noqa: F401 # CLI fallback without warnings
from .utils import _to_like_struct
from .gradnet import GradNet
[docs]
class LossFn(Protocol):
"""Protocol for loss functions used with :func:`fit`.
Implementations must accept a :class:`gradnet.GradNet` and may accept
arbitrary keyword arguments. They should return either a scalar loss
tensor, or a tuple ``(loss, metrics_dict)`` where ``metrics_dict`` maps
metric names to floats/ints/tensors.
"""
def __call__(
self,
model: GradNet,
**loss_kwargs: Any,
) -> Union[
torch.Tensor,
Tuple[torch.Tensor, Dict[str, Union[float, int, torch.Tensor]]],
]: ...
class _OneItem(Dataset):
"""A trivial dataset that always yields a single empty batch.
Used to drive the Lightning training loop with one update per epoch
without relying on external data.
"""
def __len__(self): return 1
def __getitem__(self, idx): return {}
[docs]
class GradNetLightning(pl.LightningModule):
"""LightningModule wrapper around a ``GradNet`` and a user loss.
This module performs manual optimisation: it evaluates ``loss_fn`` to obtain
a scalar loss (and optional metrics), applies gradient clipping (if
configured), steps the optimizer, optionally renormalises the model
parameters, and logs metrics under ``monitor_key``.
Parameters
----------
gn : torch.nn.Module
Model to optimise. Typically a :class:`gradnet.GradNet`; any ``nn.Module``
is accepted. If the module exposes ``should_renorm_after_step()``,
it is queried to decide whether per-step renormalization should run
(subject to ``post_step_renorm=True``). If the module exposes
``renorm_params()``, that method is invoked when renormalization is
enabled for the current step.
loss_fn : LossFn
Callable evaluated on every optimisation step as
``loss_fn(gn, **loss_kwargs)``. Must return either a scalar loss tensor or
a ``(loss, metrics_dict)`` tuple.
loss_kwargs : Mapping[str, Any] | None, optional
Extra keyword arguments forwarded to ``loss_fn``. When passed through
:func:`fit`, tensors and arrays are coerced to ``gn``'s device/dtype
via :func:`gradnet.utils._to_like_struct` before being stored here.
optim_cls : type[torch.optim.Optimizer]
Optimiser class instantiated over ``gn.parameters()``.
optim_kwargs : dict, optional
Arguments passed to ``optim_cls`` (e.g., ``{"lr": 1e-2}``).
sched_cls : type | None, optional
Optional LR scheduler class applied on top of the optimiser.
sched_kwargs : dict | None, optional
Keyword arguments for ``sched_cls``.
grad_clip_val : float, optional
Gradient-norm clipping threshold. ``0.0`` disables clipping.
post_step_renorm : bool, optional
Master switch for post-step renormalization. When ``True``, the module's
``should_renorm_after_step()`` policy is used if available; otherwise
renormalization runs whenever ``renorm_params()`` is available.
monitor_key : str, optional
Metric name under which the primary loss is logged.
compile_model : bool, optional
Attempt to wrap the model with :func:`torch.compile` during ``setup``;
fall back silently when compilation fails.
"""
def __init__(
self,
*,
gn: nn.Module,
loss_fn: LossFn,
loss_kwargs: Mapping[str, Any] | None = None, # kwargs for the loss function
optim_cls: type[torch.optim.Optimizer],
optim_kwargs: dict,
sched_cls: Optional[type] = None,
sched_kwargs: Optional[dict] = None,
grad_clip_val: float = 0.0,
post_step_renorm: bool = True,
monitor_key: str = "loss",
compile_model: bool = False,
):
super().__init__()
self._gradnet_config = gn.export_config() if isinstance(gn, GradNet) else None
self.gn = gn
self.loss_fn = loss_fn
self.loss_kwargs = {} if loss_kwargs is None else loss_kwargs
self.optim_cls = optim_cls
self.optim_kwargs = optim_kwargs
self.sched_cls = sched_cls
self.sched_kwargs = sched_kwargs or {}
self.grad_clip_val = float(grad_clip_val)
self.post_step_renorm = bool(post_step_renorm)
self.monitor_key = monitor_key
self.compile_model = bool(compile_model)
self.automatic_optimization = False # manual optimization
[docs]
def setup(self, stage: Optional[str] = None):
"""Optionally compile the wrapped model."""
if self.compile_model:
try:
self.gn = torch.compile(self.gn) # type: ignore[attr-defined]
except Exception as e:
pl.utilities.rank_zero.rank_zero_warn(f"torch.compile failed; continuing uncompiled. Error: {e}")
[docs]
def training_step(self, batch, batch_idx):
"""Run one manual-optimization update using ``loss_fn``."""
# compute loss (+ optional metrics)
out = self.loss_fn(self.gn, **self.loss_kwargs)
loss, metrics = (out, {}) if isinstance(out, torch.Tensor) else out
opt = self.optimizers()
self.manual_backward(loss)
if self.grad_clip_val > 0:
self.clip_gradients(opt, gradient_clip_val=self.grad_clip_val, gradient_clip_algorithm="norm")
opt.step()
opt.zero_grad(set_to_none=True)
# optional: renormalize after each update according to model policy
if self.post_step_renorm and hasattr(self.gn, "renorm_params"):
should_renorm = True
policy = getattr(self.gn, "should_renorm_after_step", None)
if callable(policy):
should_renorm = bool(policy())
if should_renorm:
self.gn.renorm_params()
self.log(self.monitor_key, loss, prog_bar=True, on_epoch=True, on_step=False, sync_dist=True, batch_size=1)
for k, v in metrics.items():
v = v if isinstance(v, torch.Tensor) else torch.tensor(float(v), device=loss.device)
self.log(k, v, prog_bar=False, on_epoch=True, on_step=False, sync_dist=True, batch_size=1)
return loss.detach()
[docs]
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
"""Persist GradNet reconstruction metadata when available."""
if self._gradnet_config is not None:
checkpoint["gradnet_config"] = self._gradnet_config
class _EpochTQDM(Callback):
"""Minimal epoch-wise TQDM progress bar callback.
Shows total updates, and displays numeric metrics
collected in ``trainer.callback_metrics``.
"""
def on_fit_start(self, trainer, *_):
self.bar = tqdm(total=trainer.max_epochs, desc="Updates", dynamic_ncols=True)
def on_train_epoch_end(self, trainer, *_):
self.bar.set_postfix({k: v.item() if hasattr(v, "item") else v
for k, v in trainer.callback_metrics.items()
if isinstance(v, (int, float)) or hasattr(v, "item")})
self.bar.update(1)
def on_fit_end(self, *_):
self.bar.close()
def _make_checkpoint(**kwargs) -> ModelCheckpoint:
"""Create a ModelCheckpoint compatible across Lightning versions."""
try:
return ModelCheckpoint(save_on_train_epoch_end=True, **kwargs)
except TypeError: # older PL versions may not support this kwarg
return ModelCheckpoint(**kwargs)
def _resolve_logger(
logger: LightningLoggerBase | bool | None,
*,
verbose: bool,
log_dir: Optional[str] = None,
) -> LightningLoggerBase | bool:
"""Return a Lightning logger instance or ``False``.
When ``logger`` is ``True`` this attempts to build a ``TensorBoardLogger``
and falls back to ``CSVLogger`` if TensorBoard is unavailable.
"""
if isinstance(logger, LightningLoggerBase):
return logger
if not logger: # covers False and None
return False
# determine save_dir and name based on provided `log_dir` or defaults
if log_dir is None:
save_dir, name = "lightning_logs", "gradnet"
else:
# treat `log_dir` as the final directory path (e.g., "./lightning_logs/gradnet")
# and split into (parent, basename) for Lightning loggers
norm = os.path.normpath(log_dir)
save_dir, name = os.path.dirname(norm) or ".", os.path.basename(norm) or "gradnet"
try:
from pytorch_lightning.loggers import TensorBoardLogger
return TensorBoardLogger(save_dir=save_dir, name=name)
except Exception as exc: # pragma: no cover - depends on optional dependency
warnings.warn(
"TensorBoard logger unavailable; using CSVLogger instead. "
"Install the `tensorboard` package to re-enable TensorBoard logging. "
f"Original error: {exc}",
RuntimeWarning,
)
from pytorch_lightning.loggers import CSVLogger
return CSVLogger(save_dir=save_dir, name=name)
[docs]
def fit(
*,
gn: GradNet,
loss_fn: LossFn,
loss_kwargs: Mapping[str, Any] | None = None, # kwargs for the loss function
num_updates: int,
optim_cls: type[torch.optim.Optimizer] = torch.optim.Adam,
optim_kwargs: Optional[dict] = None,
sched_cls: Optional[type] = None,
sched_kwargs: Optional[dict] = None,
# runtime
precision: Union[str, int] = "32-true",
accelerator: str = "auto",
# logging/ckpt
logger: LightningLoggerBase | bool | None = False,
log_dir: Optional[str] = None,
enable_checkpointing: bool = False,
checkpoint_dir: Optional[str] = None,
checkpoint_every_n: Optional[int] = None,
save_last: bool = False,
callbacks: Optional[list[pl.Callback]] = None,
max_time: Optional[str] = None,
# extras
grad_clip_val: float = 0.0,
post_step_renorm: bool = True,
compile_model: bool = False,
seed: Optional[int] = None,
deterministic: Optional[Union[bool, str]] = None,
verbose: bool = True,
):
"""Optimise a :class:`gradnet.GradNet` for a fixed number of updates.
One trainer epoch corresponds to a single optimiser step, so
``num_updates`` equals the number of optimisation steps executed. Each step
evaluates ``loss_fn(gn, **loss_kwargs)`` and drives manual optimisation via
:class:`GradNetLightning`.
Notes
-----
When called from a Jupyter notebook, PyTorch Lightning writes progress and
hardware summaries to ``stderr``. Jupyter renders those lines in red by
default, but the colour does not indicate an error.
Parameters
----------
gn : GradNet
Network to optimise.
loss_fn : LossFn
Callable invoked as ``loss_fn(gn, **loss_kwargs)`` and returning either a
scalar loss tensor or ``(loss, metrics_dict)``.
loss_kwargs : Mapping[str, Any] | None, optional
Extra keyword arguments forwarded to ``loss_fn``. When provided, tensors
and arrays are coerced to ``gn``'s device/dtype via
:func:`gradnet.utils._to_like_struct`.
num_updates : int
Number of optimisation steps to run.
optim_cls : type[torch.optim.Optimizer], optional
Optimiser class constructed as ``optim_cls(gn.parameters(), **optim_kwargs)``.
optim_kwargs : dict | None, optional
Keyword arguments for the optimiser. Defaults to ``{"lr": 1e-2}`` when
``None``.
sched_cls : type | None, optional
Optional learning-rate scheduler applied to the optimiser.
sched_kwargs : dict | None, optional
Keyword arguments for ``sched_cls``.
precision : str | int, optional
Forwarded to ``pl.Trainer(precision=...)`` (e.g., ``"32-true"``, ``16``).
accelerator : str, optional
Passed to ``pl.Trainer(accelerator=...)`` (``"auto"``, ``"cpu"``, ``"gpu"``, etc.).
logger : LightningLoggerBase | bool | None, optional
Logger configuration forwarded to ``pl.Trainer``. Use ``True`` for the
default logger (falls back to ``CSVLogger`` when TensorBoard is
unavailable), ``False`` to disable logging, or supply a Lightning logger
instance.
log_dir : str | None, optional
When ``logger`` is ``True`` (request default logging), use this directory
for logs. Treats the value as the final directory path (e.g.,
``"./lightning_logs/gradnet"``). When ``None``, defaults to
``"./lightning_logs/gradnet"``. Ignored when a custom ``logger`` instance
is provided or logging is disabled.
enable_checkpointing : bool, optional
Enable the default ``ModelCheckpoint`` callback. When ``True`` it
monitors ``loss`` in ``min`` mode to keep the best checkpoint, can store
the last checkpoint when ``save_last=True``, and optionally adds periodic
checkpoints via ``checkpoint_every_n``.
checkpoint_dir : str | None, optional
Directory used by ``ModelCheckpoint`` when checkpointing is enabled.
checkpoint_every_n : int | None, optional
Save an additional checkpoint every ``checkpoint_every_n`` epochs
(updates). Provide an integer greater than or equal to 1 to enable
periodic saves, or ``None`` to disable them.
save_last : bool, optional
Whether to always save the final checkpoint in addition to the best.
callbacks : list[pl.Callback] | None, optional
Additional Lightning callbacks to register.
max_time : str | None, optional
Training time limit forwarded to ``pl.Trainer(max_time=...)``.
grad_clip_val : float, optional
Gradient-norm clipping threshold applied before optimiser steps.
post_step_renorm : bool, optional
Master switch for post-step renormalization. When enabled, if ``gn``
exposes ``should_renorm_after_step()``, that policy decides whether
``gn.renorm_params()`` runs after each optimizer step.
compile_model : bool, optional
Attempt to wrap ``gn`` with :func:`torch.compile` during setup.
seed : int | None, optional
When provided, seeds PyTorch Lightning via ``pl.seed_everything``.
deterministic : bool | str | None, optional
If not ``None``, passed to ``torch.use_deterministic_algorithms``.
verbose : bool, optional
Show progress via :class:`tqdm.auto.tqdm`.
Returns
-------
tuple[pl.Trainer, str | None]
The configured trainer and the best checkpoint path (``None`` when
checkpointing is disabled).
Raises
------
TypeError
If ``loss_kwargs`` is neither ``None`` nor a mapping.
Examples
--------
>>> from pytorch_lightning.loggers import TensorBoardLogger
>>> logger = TensorBoardLogger(save_dir="logs", name="demo")
>>> trainer, best_ckpt = fit(
... gn=model,
... loss_fn=loss,
... num_updates=100,
... logger=logger,
... )
.. seealso::
PyTorch Optimizers (``torch.optim``), PyTorch LR Schedulers
(``torch.optim.lr_scheduler``), and PyTorch Lightning's Trainer and
Callbacks documentation for accepted values of ``precision`` and
``accelerator`` and for available callback types.
"""
if seed is not None:
pl.seed_everything(seed, workers=True)
if deterministic is not None:
torch.use_deterministic_algorithms(bool(deterministic))
# params must be kwargs if provided
if loss_kwargs is not None and not isinstance(loss_kwargs, Mapping):
raise TypeError("`loss_kwargs` must be a Mapping of keyword arguments (or None).")
loss_kwargs = _to_like_struct(loss_kwargs, gn) if isinstance(loss_kwargs, Mapping) else {}
module = GradNetLightning(
gn=gn,
loss_fn=loss_fn,
loss_kwargs=loss_kwargs,
optim_cls=optim_cls,
optim_kwargs=optim_kwargs or {"lr": 1e-2},
sched_cls=sched_cls,
sched_kwargs=sched_kwargs,
grad_clip_val=grad_clip_val,
post_step_renorm=post_step_renorm,
monitor_key="loss",
compile_model=compile_model,
)
cb = list(callbacks or [])
ckpt = None
if enable_checkpointing:
if checkpoint_every_n is not None:
if not isinstance(checkpoint_every_n, int) or checkpoint_every_n < 1:
raise ValueError("`checkpoint_every_n` must be an integer >= 1 or None.")
ckpt = _make_checkpoint(
dirpath=checkpoint_dir,
filename="gn-{epoch:05d}",
monitor="loss",
mode="min",
save_top_k=1,
save_last=save_last,
auto_insert_metric_name=False,
)
cb.append(ckpt)
if checkpoint_every_n is not None:
# keep regularly spaced checkpoints alongside the metric-based best
periodic = _make_checkpoint(
dirpath=checkpoint_dir,
filename="gn-periodic-{epoch:05d}",
monitor=None,
save_top_k=-1,
save_last=False,
every_n_epochs=checkpoint_every_n,
auto_insert_metric_name=False,
)
cb.append(periodic)
# progress bar only when verbose
if verbose:
cb.append(_EpochTQDM())
# Silence PL info logs if verbose is False
prev_levels: dict[str, int] = {}
if not verbose:
for name in ("pytorch_lightning", "lightning"):
lg = logging.getLogger(name)
prev_levels[name] = lg.level
lg.setLevel(logging.ERROR)
trainer_logger = _resolve_logger(logger, verbose=verbose, log_dir=log_dir)
trainer = pl.Trainer(
max_epochs=int(num_updates),
accelerator=accelerator,
precision=precision,
logger=trainer_logger,
enable_checkpointing=enable_checkpointing,
callbacks=cb,
log_every_n_steps=1,
max_time=max_time,
enable_progress_bar=False,
enable_model_summary=bool(verbose),
)
loader = DataLoader(_OneItem(), batch_size=1, shuffle=False, num_workers=0)
trainer.fit(module, train_dataloaders=loader)
# Restore previous PL logger levels if we changed them
if not verbose:
for name, lvl in prev_levels.items():
logging.getLogger(name).setLevel(lvl)
return trainer, (ckpt.best_model_path if (enable_checkpointing and ckpt is not None) else None)