Source code for gradnet.ode

"""ODE integration utilities with optional adjoint and event support.

This module provides a thin wrapper around :mod:`torchdiffeq` to integrate
ordinary differential equations whose dynamics may depend on a static
adjacency matrix (e.g., produced by a :class:`gradnet.GradNet`). It offers:

- A single entry point :func:`integrate_ode` for forward solves, with optional
  adjoint sensitivity via :func:`torchdiffeq.odeint_adjoint`.
- Event-based termination via :func:`torchdiffeq.odeint_event` to stop an
  integration when a user-defined scalar function crosses zero.
- Careful device/dtype alignment for initial conditions, time grids, and
  keyword arguments.

The public API mirrors the style used in :mod:`gradnet.trainer` so the
docstrings render well when building documentation with Sphinx.
"""

from __future__ import annotations
from typing import Callable, Any, Optional, Union, Mapping, Sequence
import numpy as np
import torch
import torch.nn as nn
from torchdiffeq import odeint, odeint_adjoint, odeint_event


class _VectorField(nn.Module):
    """Internal wrapper that exposes vector-field parameters to adjoint solves."""

    def __init__(
        self,
        f: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
        A: torch.Tensor,
        kwargs: Mapping[str, Any] | None,
        gn_module: Optional[nn.Module] = None,
        params_modules: Optional[dict[str, nn.Module]] = None,
    ):
        super().__init__()
        self._f = f
        self.A = A
        if isinstance(gn_module, nn.Module):
            self.gn = gn_module  # register so adjoint sees gn.parameters()
        if params_modules:
            for k, m in params_modules.items():
                self.add_module(f"param_mod_{k}", m)
        self._kwargs = {} if kwargs is None else kwargs

    def forward(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        """Evaluate the user vector field as ``f(t, x, A, **kwargs)``."""
        return self._f(t, x, self.A, **self._kwargs)


def _real_dtype(dtype: torch.dtype) -> torch.dtype:
    """Return the corresponding real dtype for ``dtype``."""
    return torch.zeros((), dtype=dtype).real.dtype


def _promoted_state_dtype(
    A: torch.Tensor, x0: Union[torch.Tensor, float, int]
) -> torch.dtype:
    """Choose a solve dtype that preserves complex-valued states."""
    x0_dtype = x0.dtype if isinstance(x0, torch.Tensor) else torch.as_tensor(x0).dtype
    return torch.promote_types(A.dtype, x0_dtype)


def _to_device_struct(obj: Any, device: torch.device) -> Any:
    """Recursively move tensors/NumPy arrays in ``obj`` to ``device``."""
    if isinstance(obj, torch.Tensor):
        return obj.to(device=device)
    if isinstance(obj, np.ndarray):  # also catches np.matrix
        return torch.as_tensor(obj).to(device=device)
    if isinstance(obj, np.generic):  # NumPy scalar (e.g., np.float32(3.0))
        return torch.as_tensor(obj, device=device)
    if isinstance(obj, Mapping):
        return obj.__class__({k: _to_device_struct(v, device) for k, v in obj.items()})
    if isinstance(obj, tuple) and hasattr(obj, "_fields"):  # namedtuple
        return obj.__class__(*[_to_device_struct(v, device) for v in obj])
    if isinstance(obj, (list, tuple)):
        typ = obj.__class__
        return typ(_to_device_struct(v, device) for v in obj)
    return obj  # nn.Module or anything else stays as-is


[docs] def integrate_ode( gn: Union[Callable[[], torch.Tensor], torch.Tensor], f: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], x0: Union[torch.Tensor, float, int], tt: torch.Tensor, *, f_kwargs: Mapping[str, Any] | None = None, # kwargs for f / event_fn method: str = "dopri5", rtol: float = 1e-4, atol: float = 1e-4, solver_options: Optional[dict] = None, adjoint: bool = False, adjoint_options: Optional[dict] = None, # e.g., {'norm': 'seminorm'} adjoint_params: Optional[Sequence[torch.Tensor]] = None, # optional override event_fn: Optional[ Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor] ] = None, track_gradients: bool = True, ): """Integrate an ODE ``dx/dt = f(t, x, A, **f_kwargs)`` using torchdiffeq. This is a convenience wrapper around ``torchdiffeq.odeint`` with optional adjoint sensitivities and event-based termination. The vector field is called as ``f(t, x, A, **f_kwargs)`` where ``A`` is the network adjacency matrix represented as a (potentially sparse) torch.Tensor. Args: gn (Callable[[], torch.Tensor] | torch.Tensor): Tensor ``A`` or a zero-arg callable returning ``A``. If an ``nn.Module`` is provided, its parameters are included in the default adjoint parameter set. f (Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]): Vector field returning ``dx/dt`` with the same shape as ``x``. x0 (torch.Tensor | float | int): Initial state (scalars are promoted). tt (torch.Tensor): 1D time grid (monotone; may decrease for reverse-time event searches). f_kwargs (Mapping[str, Any] | None, optional): Keyword arguments passed to ``f`` (and ``event_fn`` if provided). Tensors/NumPy arrays are moved to the adjacency device without forcing them to ``A.dtype``. method (str, optional): Integrator, e.g., adaptive stepsize ``"dopri5"`` (default), or fixed-step ``"rk4"``, see more options in `torchdiffeq documentation <https://github.com/rtqichen/torchdiffeq>`_). rtol (float, optional): Relative tolerance. atol (float, optional): Absolute tolerance. solver_options (dict | None, optional): Additional solver options. adjoint (bool, optional): If ``True``, use the adjoint method. adjoint_options (dict | None, optional): Options for adjoint solve (e.g., ``{"norm": "seminorm"}``). adjoint_params (Sequence[torch.Tensor] | None, optional): Explicit list of parameters for adjoint gradients. Defaults to parameters discovered in the wrapped vector field (including modules in ``f_kwargs``). event_fn (Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor] | None, optional): Optional scalar function ``g(t, x, A, **f_kwargs)``; the integration stops on zero-crossing. track_gradients (bool, optional): Enable autograd during the solve. Returns: tuple: ``(tt_out, x_out)`` where ``x_out`` has shape ``(len(tt_out), *x0.shape)``. If an event is used, ``(tt_out, x_out, t_event, x_event)`` ``tt_out`` and ``x_out`` are truncated at the detected event time. ``t_event`` and ``x_event`` are the differentiable event time and state. Raises: TypeError: If ``f_kwargs`` is not a mapping (and not ``None``). Examples: Basic integration without events:: import torch from gradnet.ode import integrate_ode A = torch.tensor([[0., 1.], [-1., 0.]]) def vf(t, x, A): return A @ x x0 = torch.tensor([1., 0.]) tt = torch.linspace(0, 1, steps=11) t_out, x_out = integrate_ode(A, vf, x0, tt) Event-driven integration until ``x[0]`` crosses zero:: def event(t, x, A): return x[0] # stop when it crosses 0 tt_partial, x_partial, t_event, x_event = integrate_ode( A, vf, x0, tt, event_fn=event ) (*) See also the `torchdiffeq documentation <https://github.com/rtqichen/torchdiffeq>`_ for supported methods and options. """ # Build adjacency once and keep a module handle for adjoint parameter discovery. if callable(gn): gn_module = gn if isinstance(gn, nn.Module) else None A = gn() else: gn_module = None A = gn if not isinstance(A, torch.Tensor): A = torch.as_tensor(A) # Ensure dense adjacency for downstream vector fields that use dense-style indexing if hasattr(A, "layout") and A.layout != torch.strided: A = A.to_dense() # Align the solve state to a promoted dtype so complex inputs stay complex. state_dtype = _promoted_state_dtype(A, x0) time_dtype = _real_dtype(state_dtype) x0 = torch.as_tensor(x0, device=A.device, dtype=state_dtype) tt = torch.as_tensor(tt, device=A.device, dtype=time_dtype) # params must be kwargs if provided if f_kwargs is None: f_kwargs = {} elif isinstance(f_kwargs, Mapping): f_kwargs = _to_device_struct(f_kwargs, A.device) else: raise TypeError("`f_kwargs` must be a Mapping of keyword arguments (or None).") # Collect any nn.Modules inside params for adjoint to see them automatically params_modules = {k: v for k, v in f_kwargs.items() if isinstance(v, nn.Module)} # Wrap the vector field so adjoint sees relevant module parameters. vf = _VectorField( f=f, A=A, kwargs=f_kwargs, gn_module=gn_module, params_modules=params_modules if params_modules else None, ).to(A.device) # Select solver interface and shared ODE kwargs. ode_interface = odeint_adjoint if adjoint else odeint solver_options = {} if solver_options is None else solver_options base_kwargs = dict(rtol=rtol, atol=atol, method=method, options=solver_options) if adjoint: if adjoint_options is not None: base_kwargs["adjoint_options"] = adjoint_options if adjoint_params is not None: base_kwargs["adjoint_params"] = tuple(adjoint_params) # else: default is tuple(vf.parameters()), which now includes gn / any nn.Modules in params # Temporarily switch gradient tracking mode for this solve. with torch.set_grad_enabled(track_gradients): # Event-aware path if event_fn is not None: def _efn(t: torch.Tensor, x: torch.Tensor) -> torch.Tensor: return event_fn(t, x, A, **f_kwargs) t0 = tt[0] t1 = tt[-1] decreasing = (t1 - t0) < 0 # Cap the event so integration always terminates by the requested end. def _efn_capped(t: torch.Tensor, x: torch.Tensor) -> torch.Tensor: g = _efn(t, x) t_cap = (t - t1) if decreasing else (t1 - t) t_cap = t_cap.to(dtype=g.dtype, device=g.device) return torch.minimum(g, t_cap) _ret = odeint_event( vf, x0, t0, event_fn=_efn_capped, odeint_interface=ode_interface, **base_kwargs, ) # torchdiffeq versions may return (t, x) or (t, x, index). t_event, x_event = _ret[0], _ret[1] # Ensure the stop time does not extend beyond the integration grid. t_stop = ( torch.maximum(t_event, t1) if decreasing else torch.minimum(t_event, t1) ) # Snap tiny endpoint roundoff to the exact requested endpoint. if t_stop.is_floating_point(): eps = torch.finfo(t_stop.dtype).eps scale = torch.maximum(torch.abs(t1), torch.ones_like(t1)) near_end = torch.abs(t_stop - t1) <= (128.0 * eps) * scale t_stop = torch.where(near_end, t1, t_stop) # Build output grid up to the event, preserving time direction. # If times decrease (reverse-time), include points >= t_event and append t_event. if decreasing: mask = tt >= t_stop else: mask = tt <= t_stop tt_partial = tt[mask] if tt_partial.numel() == 0 or not torch.equal(tt_partial[-1], t_stop): tt_partial = torch.cat([tt_partial, t_stop.unsqueeze(0)], dim=0) x_partial = ode_interface(vf, x0, tt_partial, **base_kwargs) # Make sure returned (t_event, x_event) correspond to the capped output. return tt_partial, x_partial, tt_partial[-1], x_partial[-1] # Standard solve y = ode_interface(vf, x0, tt, **base_kwargs) return tt, y