gradnet.ode.integrate_ode

gradnet.ode.integrate_ode(gn, f, x0, tt, *, f_kwargs=None, method='dopri5', rtol=0.0001, atol=0.0001, solver_options=None, adjoint=False, adjoint_options=None, adjoint_params=None, event_fn=None, track_gradients=True)[source]

Integrate an ODE dx/dt = f(t, x, A, **f_kwargs) using torchdiffeq.

This is a convenience wrapper around torchdiffeq.odeint with optional adjoint sensitivities and event-based termination. The vector field is called as f(t, x, A, **f_kwargs) where A is the network adjacency matrix represented as a (potentially sparse) torch.Tensor.

Parameters:
  • gn (Callable[[], torch.Tensor] | torch.Tensor) – Tensor A or a zero-arg callable returning A. If an nn.Module is provided, its parameters are included in the default adjoint parameter set.

  • f (Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]) – Vector field returning dx/dt with the same shape as x.

  • x0 (torch.Tensor | float | int) – Initial state (scalars are promoted).

  • tt (torch.Tensor) – 1D time grid (monotone; may decrease for reverse-time event searches).

  • f_kwargs (Mapping[str, Any] | None, optional) – Keyword arguments passed to f (and event_fn if provided). Tensors/NumPy arrays are moved to the adjacency device without forcing them to A.dtype.

  • method (str, optional) – Integrator, e.g., adaptive stepsize "dopri5" (default), or fixed-step "rk4", see more options in torchdiffeq documentation).

  • rtol (float, optional) – Relative tolerance.

  • atol (float, optional) – Absolute tolerance.

  • solver_options (dict | None, optional) – Additional solver options.

  • adjoint (bool, optional) – If True, use the adjoint method.

  • adjoint_options (dict | None, optional) – Options for adjoint solve (e.g., {"norm": "seminorm"}).

  • adjoint_params (Sequence[torch.Tensor] | None, optional) – Explicit list of parameters for adjoint gradients. Defaults to parameters discovered in the wrapped vector field (including modules in f_kwargs).

  • event_fn (Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor] | None, optional) – Optional scalar function g(t, x, A, **f_kwargs); the integration stops on zero-crossing.

  • track_gradients (bool, optional) – Enable autograd during the solve.

Returns:

(tt_out, x_out) where x_out has shape (len(tt_out), *x0.shape).

If an event is used, (tt_out, x_out, t_event, x_event) tt_out and x_out are truncated at the detected event time. t_event and x_event are the differentiable event time and state.

Return type:

tuple

Raises:

TypeError – If f_kwargs is not a mapping (and not None).

Examples

Basic integration without events:

import torch
from gradnet.ode import integrate_ode

A = torch.tensor([[0., 1.], [-1., 0.]])

def vf(t, x, A):
    return A @ x

x0 = torch.tensor([1., 0.])
tt = torch.linspace(0, 1, steps=11)
t_out, x_out = integrate_ode(A, vf, x0, tt)

Event-driven integration until x[0] crosses zero:

def event(t, x, A):
    return x[0]  # stop when it crosses 0

tt_partial, x_partial, t_event, x_event = integrate_ode(
    A, vf, x0, tt, event_fn=event
)

(*) See also the torchdiffeq documentation for supported methods and options.