Source code for gradnet.trainer

"""Utilities to train a :class:`gradnet.GradNet` with PyTorch Lightning.

This module provides a thin Lightning wrapper and a convenience function
(:func:`fit`) to optimize a ``GradNet`` for a fixed number of
updates.
"""
from __future__ import annotations
from typing import Dict, Optional, Tuple, Union, Mapping, Any, Protocol
import logging
import os
import warnings
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

os.environ.setdefault("LIGHTING_USE_RICH", "0")

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers.logger import Logger as LightningLoggerBase
from pytorch_lightning.callbacks import Callback

try:  # prefer notebook progress bar when the stack supports it
    from tqdm import TqdmWarning  # type: ignore[attr-defined]
except (ImportError, AttributeError):
    TqdmWarning = Warning  # fallback when tqdm lacks TqdmWarning

try:  # PL >= 1.6-ish
    from pytorch_lightning.utilities.warnings import PossibleUserWarning
except Exception:  # Fallback for older PL where it's just a UserWarning
    PossibleUserWarning = UserWarning

for warning_pattern in (
    r"The 'train_dataloader' does not have many workers.*",
    r"GPU available but not used.*",
):
    warnings.filterwarnings("ignore", message=warning_pattern, category=PossibleUserWarning)

with warnings.catch_warnings():
    warnings.simplefilter("ignore", category=TqdmWarning)
    try:
        from tqdm.auto import tqdm
    except Exception:
        from tqdm import tqdm  # noqa: F401  # CLI fallback without warnings

from .utils import _to_like_struct
from .gradnet import GradNet


[docs] class LossFn(Protocol): """Protocol for loss functions used with :func:`fit`. Implementations must accept a :class:`gradnet.GradNet` and may accept arbitrary keyword arguments. They should return either a scalar loss tensor, or a tuple ``(loss, metrics_dict)`` where ``metrics_dict`` maps metric names to floats/ints/tensors. """ def __call__( self, model: GradNet, **loss_kwargs: Any, ) -> Union[ torch.Tensor, Tuple[torch.Tensor, Dict[str, Union[float, int, torch.Tensor]]], ]: ...
class _OneItem(Dataset): """A trivial dataset that always yields a single empty batch. Used to drive the Lightning training loop with one update per epoch without relying on external data. """ def __len__(self): return 1 def __getitem__(self, idx): return {}
[docs] class GradNetLightning(pl.LightningModule): """LightningModule wrapper around a ``GradNet`` and a user loss. This module performs manual optimisation: it evaluates ``loss_fn`` to obtain a scalar loss (and optional metrics), applies gradient clipping (if configured), steps the optimizer, optionally renormalises the model parameters, and logs metrics under ``monitor_key``. Parameters ---------- gn : torch.nn.Module Model to optimise. Typically a :class:`gradnet.GradNet`; any ``nn.Module`` is accepted. If the module exposes ``should_renorm_after_step()``, it is queried to decide whether per-step renormalization should run (subject to ``post_step_renorm=True``). If the module exposes ``renorm_params()``, that method is invoked when renormalization is enabled for the current step. loss_fn : LossFn Callable evaluated on every optimisation step as ``loss_fn(gn, **loss_kwargs)``. Must return either a scalar loss tensor or a ``(loss, metrics_dict)`` tuple. loss_kwargs : Mapping[str, Any] | None, optional Extra keyword arguments forwarded to ``loss_fn``. When passed through :func:`fit`, tensors and arrays are coerced to ``gn``'s device/dtype via :func:`gradnet.utils._to_like_struct` before being stored here. optim_cls : type[torch.optim.Optimizer] Optimiser class instantiated over ``gn.parameters()``. optim_kwargs : dict, optional Arguments passed to ``optim_cls`` (e.g., ``{"lr": 1e-2}``). sched_cls : type | None, optional Optional LR scheduler class applied on top of the optimiser. sched_kwargs : dict | None, optional Keyword arguments for ``sched_cls``. grad_clip_val : float, optional Gradient-norm clipping threshold. ``0.0`` disables clipping. post_step_renorm : bool, optional Master switch for post-step renormalization. When ``True``, the module's ``should_renorm_after_step()`` policy is used if available; otherwise renormalization runs whenever ``renorm_params()`` is available. monitor_key : str, optional Metric name under which the primary loss is logged. compile_model : bool, optional Attempt to wrap the model with :func:`torch.compile` during ``setup``; fall back silently when compilation fails. """ def __init__( self, *, gn: nn.Module, loss_fn: LossFn, loss_kwargs: Mapping[str, Any] | None = None, # kwargs for the loss function optim_cls: type[torch.optim.Optimizer], optim_kwargs: dict, sched_cls: Optional[type] = None, sched_kwargs: Optional[dict] = None, grad_clip_val: float = 0.0, post_step_renorm: bool = True, monitor_key: str = "loss", compile_model: bool = False, ): super().__init__() self._gradnet_config = gn.export_config() if isinstance(gn, GradNet) else None self.gn = gn self.loss_fn = loss_fn self.loss_kwargs = {} if loss_kwargs is None else loss_kwargs self.optim_cls = optim_cls self.optim_kwargs = optim_kwargs self.sched_cls = sched_cls self.sched_kwargs = sched_kwargs or {} self.grad_clip_val = float(grad_clip_val) self.post_step_renorm = bool(post_step_renorm) self.monitor_key = monitor_key self.compile_model = bool(compile_model) self.automatic_optimization = False # manual optimization
[docs] def setup(self, stage: Optional[str] = None): """Optionally compile the wrapped model.""" if self.compile_model: try: self.gn = torch.compile(self.gn) # type: ignore[attr-defined] except Exception as e: pl.utilities.rank_zero.rank_zero_warn(f"torch.compile failed; continuing uncompiled. Error: {e}")
[docs] def training_step(self, batch, batch_idx): """Run one manual-optimization update using ``loss_fn``.""" # compute loss (+ optional metrics) out = self.loss_fn(self.gn, **self.loss_kwargs) loss, metrics = (out, {}) if isinstance(out, torch.Tensor) else out opt = self.optimizers() self.manual_backward(loss) if self.grad_clip_val > 0: self.clip_gradients(opt, gradient_clip_val=self.grad_clip_val, gradient_clip_algorithm="norm") opt.step() opt.zero_grad(set_to_none=True) # optional: renormalize after each update according to model policy if self.post_step_renorm and hasattr(self.gn, "renorm_params"): should_renorm = True policy = getattr(self.gn, "should_renorm_after_step", None) if callable(policy): should_renorm = bool(policy()) if should_renorm: self.gn.renorm_params() self.log(self.monitor_key, loss, prog_bar=True, on_epoch=True, on_step=False, sync_dist=True, batch_size=1) for k, v in metrics.items(): v = v if isinstance(v, torch.Tensor) else torch.tensor(float(v), device=loss.device) self.log(k, v, prog_bar=False, on_epoch=True, on_step=False, sync_dist=True, batch_size=1) return loss.detach()
[docs] def configure_optimizers(self): """Build the optimizer (and optional scheduler) config for Lightning.""" opt = self.optim_cls(self.gn.parameters(), **self.optim_kwargs) if self.sched_cls is None: return opt sched = self.sched_cls(opt, **self.sched_kwargs) return { "optimizer": opt, "lr_scheduler": {"scheduler": sched, "interval": "epoch", "frequency": 1, "name": "lr"}, }
[docs] def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None: """Persist GradNet reconstruction metadata when available.""" if self._gradnet_config is not None: checkpoint["gradnet_config"] = self._gradnet_config
class _EpochTQDM(Callback): """Minimal epoch-wise TQDM progress bar callback. Shows total updates, and displays numeric metrics collected in ``trainer.callback_metrics``. """ def on_fit_start(self, trainer, *_): self.bar = tqdm(total=trainer.max_epochs, desc="Updates", dynamic_ncols=True) def on_train_epoch_end(self, trainer, *_): self.bar.set_postfix({k: v.item() if hasattr(v, "item") else v for k, v in trainer.callback_metrics.items() if isinstance(v, (int, float)) or hasattr(v, "item")}) self.bar.update(1) def on_fit_end(self, *_): self.bar.close() def _make_checkpoint(**kwargs) -> ModelCheckpoint: """Create a ModelCheckpoint compatible across Lightning versions.""" try: return ModelCheckpoint(save_on_train_epoch_end=True, **kwargs) except TypeError: # older PL versions may not support this kwarg return ModelCheckpoint(**kwargs) def _resolve_logger( logger: LightningLoggerBase | bool | None, *, verbose: bool, log_dir: Optional[str] = None, ) -> LightningLoggerBase | bool: """Return a Lightning logger instance or ``False``. When ``logger`` is ``True`` this attempts to build a ``TensorBoardLogger`` and falls back to ``CSVLogger`` if TensorBoard is unavailable. """ if isinstance(logger, LightningLoggerBase): return logger if not logger: # covers False and None return False # determine save_dir and name based on provided `log_dir` or defaults if log_dir is None: save_dir, name = "lightning_logs", "gradnet" else: # treat `log_dir` as the final directory path (e.g., "./lightning_logs/gradnet") # and split into (parent, basename) for Lightning loggers norm = os.path.normpath(log_dir) save_dir, name = os.path.dirname(norm) or ".", os.path.basename(norm) or "gradnet" try: from pytorch_lightning.loggers import TensorBoardLogger return TensorBoardLogger(save_dir=save_dir, name=name) except Exception as exc: # pragma: no cover - depends on optional dependency warnings.warn( "TensorBoard logger unavailable; using CSVLogger instead. " "Install the `tensorboard` package to re-enable TensorBoard logging. " f"Original error: {exc}", RuntimeWarning, ) from pytorch_lightning.loggers import CSVLogger return CSVLogger(save_dir=save_dir, name=name)
[docs] def fit( *, gn: GradNet, loss_fn: LossFn, loss_kwargs: Mapping[str, Any] | None = None, # kwargs for the loss function num_updates: int, optim_cls: type[torch.optim.Optimizer] = torch.optim.Adam, optim_kwargs: Optional[dict] = None, sched_cls: Optional[type] = None, sched_kwargs: Optional[dict] = None, # runtime precision: Union[str, int] = "32-true", accelerator: str = "auto", # logging/ckpt logger: LightningLoggerBase | bool | None = False, log_dir: Optional[str] = None, enable_checkpointing: bool = False, checkpoint_dir: Optional[str] = None, checkpoint_every_n: Optional[int] = None, save_last: bool = False, callbacks: Optional[list[pl.Callback]] = None, max_time: Optional[str] = None, # extras grad_clip_val: float = 0.0, post_step_renorm: bool = True, compile_model: bool = False, seed: Optional[int] = None, deterministic: Optional[Union[bool, str]] = None, verbose: bool = True, ): """Optimise a :class:`gradnet.GradNet` for a fixed number of updates. One trainer epoch corresponds to a single optimiser step, so ``num_updates`` equals the number of optimisation steps executed. Each step evaluates ``loss_fn(gn, **loss_kwargs)`` and drives manual optimisation via :class:`GradNetLightning`. Notes ----- When called from a Jupyter notebook, PyTorch Lightning writes progress and hardware summaries to ``stderr``. Jupyter renders those lines in red by default, but the colour does not indicate an error. Parameters ---------- gn : GradNet Network to optimise. loss_fn : LossFn Callable invoked as ``loss_fn(gn, **loss_kwargs)`` and returning either a scalar loss tensor or ``(loss, metrics_dict)``. loss_kwargs : Mapping[str, Any] | None, optional Extra keyword arguments forwarded to ``loss_fn``. When provided, tensors and arrays are coerced to ``gn``'s device/dtype via :func:`gradnet.utils._to_like_struct`. num_updates : int Number of optimisation steps to run. optim_cls : type[torch.optim.Optimizer], optional Optimiser class constructed as ``optim_cls(gn.parameters(), **optim_kwargs)``. optim_kwargs : dict | None, optional Keyword arguments for the optimiser. Defaults to ``{"lr": 1e-2}`` when ``None``. sched_cls : type | None, optional Optional learning-rate scheduler applied to the optimiser. sched_kwargs : dict | None, optional Keyword arguments for ``sched_cls``. precision : str | int, optional Forwarded to ``pl.Trainer(precision=...)`` (e.g., ``"32-true"``, ``16``). accelerator : str, optional Passed to ``pl.Trainer(accelerator=...)`` (``"auto"``, ``"cpu"``, ``"gpu"``, etc.). logger : LightningLoggerBase | bool | None, optional Logger configuration forwarded to ``pl.Trainer``. Use ``True`` for the default logger (falls back to ``CSVLogger`` when TensorBoard is unavailable), ``False`` to disable logging, or supply a Lightning logger instance. log_dir : str | None, optional When ``logger`` is ``True`` (request default logging), use this directory for logs. Treats the value as the final directory path (e.g., ``"./lightning_logs/gradnet"``). When ``None``, defaults to ``"./lightning_logs/gradnet"``. Ignored when a custom ``logger`` instance is provided or logging is disabled. enable_checkpointing : bool, optional Enable the default ``ModelCheckpoint`` callback. When ``True`` it monitors ``loss`` in ``min`` mode to keep the best checkpoint, can store the last checkpoint when ``save_last=True``, and optionally adds periodic checkpoints via ``checkpoint_every_n``. checkpoint_dir : str | None, optional Directory used by ``ModelCheckpoint`` when checkpointing is enabled. checkpoint_every_n : int | None, optional Save an additional checkpoint every ``checkpoint_every_n`` epochs (updates). Provide an integer greater than or equal to 1 to enable periodic saves, or ``None`` to disable them. save_last : bool, optional Whether to always save the final checkpoint in addition to the best. callbacks : list[pl.Callback] | None, optional Additional Lightning callbacks to register. max_time : str | None, optional Training time limit forwarded to ``pl.Trainer(max_time=...)``. grad_clip_val : float, optional Gradient-norm clipping threshold applied before optimiser steps. post_step_renorm : bool, optional Master switch for post-step renormalization. When enabled, if ``gn`` exposes ``should_renorm_after_step()``, that policy decides whether ``gn.renorm_params()`` runs after each optimizer step. compile_model : bool, optional Attempt to wrap ``gn`` with :func:`torch.compile` during setup. seed : int | None, optional When provided, seeds PyTorch Lightning via ``pl.seed_everything``. deterministic : bool | str | None, optional If not ``None``, passed to ``torch.use_deterministic_algorithms``. verbose : bool, optional Show progress via :class:`tqdm.auto.tqdm`. Returns ------- tuple[pl.Trainer, str | None] The configured trainer and the best checkpoint path (``None`` when checkpointing is disabled). Raises ------ TypeError If ``loss_kwargs`` is neither ``None`` nor a mapping. Examples -------- >>> from pytorch_lightning.loggers import TensorBoardLogger >>> logger = TensorBoardLogger(save_dir="logs", name="demo") >>> trainer, best_ckpt = fit( ... gn=model, ... loss_fn=loss, ... num_updates=100, ... logger=logger, ... ) .. seealso:: PyTorch Optimizers (``torch.optim``), PyTorch LR Schedulers (``torch.optim.lr_scheduler``), and PyTorch Lightning's Trainer and Callbacks documentation for accepted values of ``precision`` and ``accelerator`` and for available callback types. """ if seed is not None: pl.seed_everything(seed, workers=True) if deterministic is not None: torch.use_deterministic_algorithms(bool(deterministic)) # params must be kwargs if provided if loss_kwargs is not None and not isinstance(loss_kwargs, Mapping): raise TypeError("`loss_kwargs` must be a Mapping of keyword arguments (or None).") loss_kwargs = _to_like_struct(loss_kwargs, gn) if isinstance(loss_kwargs, Mapping) else {} module = GradNetLightning( gn=gn, loss_fn=loss_fn, loss_kwargs=loss_kwargs, optim_cls=optim_cls, optim_kwargs=optim_kwargs or {"lr": 1e-2}, sched_cls=sched_cls, sched_kwargs=sched_kwargs, grad_clip_val=grad_clip_val, post_step_renorm=post_step_renorm, monitor_key="loss", compile_model=compile_model, ) cb = list(callbacks or []) ckpt = None if enable_checkpointing: if checkpoint_every_n is not None: if not isinstance(checkpoint_every_n, int) or checkpoint_every_n < 1: raise ValueError("`checkpoint_every_n` must be an integer >= 1 or None.") ckpt = _make_checkpoint( dirpath=checkpoint_dir, filename="gn-{epoch:05d}", monitor="loss", mode="min", save_top_k=1, save_last=save_last, auto_insert_metric_name=False, ) cb.append(ckpt) if checkpoint_every_n is not None: # keep regularly spaced checkpoints alongside the metric-based best periodic = _make_checkpoint( dirpath=checkpoint_dir, filename="gn-periodic-{epoch:05d}", monitor=None, save_top_k=-1, save_last=False, every_n_epochs=checkpoint_every_n, auto_insert_metric_name=False, ) cb.append(periodic) # progress bar only when verbose if verbose: cb.append(_EpochTQDM()) # Silence PL info logs if verbose is False prev_levels: dict[str, int] = {} if not verbose: for name in ("pytorch_lightning", "lightning"): lg = logging.getLogger(name) prev_levels[name] = lg.level lg.setLevel(logging.ERROR) trainer_logger = _resolve_logger(logger, verbose=verbose, log_dir=log_dir) trainer = pl.Trainer( max_epochs=int(num_updates), accelerator=accelerator, precision=precision, logger=trainer_logger, enable_checkpointing=enable_checkpointing, callbacks=cb, log_every_n_steps=1, max_time=max_time, enable_progress_bar=False, enable_model_summary=bool(verbose), ) loader = DataLoader(_OneItem(), batch_size=1, shuffle=False, num_workers=0) trainer.fit(module, train_dataloaders=loader) # Restore previous PL logger levels if we changed them if not verbose: for name, lvl in prev_levels.items(): logging.getLogger(name).setLevel(lvl) return trainer, (ckpt.best_model_path if (enable_checkpointing and ckpt is not None) else None)