gradnet.ode.integrate_ode
- gradnet.ode.integrate_ode(gn, f, x0, tt, *, f_kwargs=None, method='dopri5', rtol=0.0001, atol=0.0001, solver_options=None, adjoint=False, adjoint_options=None, adjoint_params=None, event_fn=None, track_gradients=True)[source]
Integrate an ODE
dx/dt = f(t, x, A, **f_kwargs)using torchdiffeq.This is a convenience wrapper around
torchdiffeq.odeintwith optional adjoint sensitivities and event-based termination. The vector field is called asf(t, x, A, **f_kwargs)whereAis the network adjacency matrix represented as a (potentially sparse) torch.Tensor.- Parameters:
gn (Callable[[], torch.Tensor] | torch.Tensor) – Tensor
Aor a zero-arg callable returningA. If annn.Moduleis provided, its parameters are included in the default adjoint parameter set.f (Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]) – Vector field returning
dx/dtwith the same shape asx.x0 (torch.Tensor | float | int) – Initial state (scalars are promoted).
tt (torch.Tensor) – 1D time grid (monotone; may decrease for reverse-time event searches).
f_kwargs (Mapping[str, Any] | None, optional) – Keyword arguments passed to
f(andevent_fnif provided). Tensors/NumPy arrays are moved to the adjacency device without forcing them toA.dtype.method (str, optional) – Integrator, e.g., adaptive stepsize
"dopri5"(default), or fixed-step"rk4", see more options in torchdiffeq documentation).rtol (float, optional) – Relative tolerance.
atol (float, optional) – Absolute tolerance.
solver_options (dict | None, optional) – Additional solver options.
adjoint (bool, optional) – If
True, use the adjoint method.adjoint_options (dict | None, optional) – Options for adjoint solve (e.g.,
{"norm": "seminorm"}).adjoint_params (Sequence[torch.Tensor] | None, optional) – Explicit list of parameters for adjoint gradients. Defaults to parameters discovered in the wrapped vector field (including modules in
f_kwargs).event_fn (Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor] | None, optional) – Optional scalar function
g(t, x, A, **f_kwargs); the integration stops on zero-crossing.track_gradients (bool, optional) – Enable autograd during the solve.
- Returns:
(tt_out, x_out)wherex_outhas shape(len(tt_out), *x0.shape).If an event is used,
(tt_out, x_out, t_event, x_event)tt_outandx_outare truncated at the detected event time.t_eventandx_eventare the differentiable event time and state.
- Return type:
- Raises:
TypeError – If
f_kwargsis not a mapping (and notNone).
Examples
Basic integration without events:
import torch from gradnet.ode import integrate_ode A = torch.tensor([[0., 1.], [-1., 0.]]) def vf(t, x, A): return A @ x x0 = torch.tensor([1., 0.]) tt = torch.linspace(0, 1, steps=11) t_out, x_out = integrate_ode(A, vf, x0, tt)
Event-driven integration until
x[0]crosses zero:def event(t, x, A): return x[0] # stop when it crosses 0 tt_partial, x_partial, t_event, x_event = integrate_ode( A, vf, x0, tt, event_fn=event )
(*) See also the torchdiffeq documentation for supported methods and options.