Source code for gradnet.utils

import torch
import random
import numpy as np
import time
import warnings
import tempfile
from functools import wraps
from pathlib import Path
import torch.linalg as LA
from typing import Mapping, Optional, Union
import csv


[docs] def random_seed(seed): """Set random seed for reproducibility. Works with torch, numpy, and random.""" torch.use_deterministic_algorithms(True) torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed)
[docs] def laplacian(A): """Compute the graph Laplacian from the adjacency matrix.""" D = torch.diag(A.sum(dim=1)) return D - A
[docs] def prune_edges( del_adj: torch.Tensor, *, threshold: Optional[float] = None, target_edge_number: Optional[int] = None, renorm: bool = True, ) -> torch.Tensor: """Prune edges in an adjacency-like tensor. Use either a numeric ``threshold`` (keeps entries with ``abs(x) >= threshold``) or specify ``target_edge_number`` to automatically determine a threshold that yields exactly that many unpruned entries. Exactly one of these must be provided. If ``renorm`` is True, rescales the pruned tensor to match the original L1 norm. If all entries are pruned, returns an all-zero tensor. """ # Validate mutually exclusive arguments if (threshold is None) == (target_edge_number is None): raise ValueError("Provide exactly one of 'threshold' or 'target_edge_number'.") # Determine threshold via target count if requested if threshold is None: k = int(target_edge_number) # type: ignore[arg-type] abs_vals = torch.abs(del_adj).reshape(-1) # Consider only strictly positive magnitudes so zeros are always pruned pos_vals = abs_vals[abs_vals > 0] M = int(pos_vals.numel()) if k < 0 or k > M: raise ValueError( f"target_edge_number must be between 0 and {M} for the given tensor; got {k}." ) if M == 0: # Nothing to keep regardless of k; everything is zero already pruned = torch.zeros_like(del_adj) return pruned # Handle boundary cases explicitly if k == 0: # Threshold above the maximum magnitude prunes everything vmax = pos_vals.max().item() # Use machine epsilon for floating tensors; fall back to small constant otherwise try: eps = float(torch.finfo(pos_vals.dtype).eps) except TypeError: eps = 1e-12 threshold = vmax + eps elif k == M: # Keep everything with positive magnitude threshold = pos_vals.min().item() else: # Find a threshold in (v_{k+1}, v_k] to keep exactly k entries. # Sort magnitudes descending and examine neighbors around k. sorted_vals = torch.sort(pos_vals, descending=True).values v_k = sorted_vals[k - 1].item() v_next = sorted_vals[k].item() if v_k == v_next: # Exact k is impossible due to ties at the boundary raise ValueError( "Cannot achieve the requested target_edge_number exactly due to " "duplicate magnitudes at the pruning boundary." ) threshold = (v_k + v_next) / 2.0 # Apply pruning using the resolved threshold norm = torch.abs(del_adj).sum() pruned = torch.where( torch.abs(del_adj) < float(threshold), torch.zeros_like(del_adj), del_adj ) if not renorm: return pruned pruned_norm = torch.abs(pruned).sum() if pruned_norm < 1e-12: # all edges pruned return torch.zeros_like(del_adj) return pruned * (norm / pruned_norm)
[docs] def to_networkx(gn, pruning_threshold: float = 1e-8): """Export the current adjacency to a NetworkX graph. Edges with absolute weight below ``pruning_threshold`` are dropped. Supports both dense and sparse internal representations. :param pruning_threshold: Minimum absolute weight to keep an edge. :type pruning_threshold: float :return: A ``networkx.DiGraph`` if ``directed`` else a ``Graph``. :rtype: networkx.Graph | networkx.DiGraph """ try: import networkx as nx except ImportError as exc: # pragma: no cover - exercised when the extra is absent raise ImportError( "to_networkx requires the optional 'networkx' extra; install it with" " `pip install gradnet[networkx]`." ) from exc directed = gn.directed net = nx.DiGraph() if directed else nx.Graph() net.add_nodes_from(range(gn.num_nodes)) A = gn() if isinstance(A, torch.Tensor) and A.layout != torch.strided: A = A.coalesce() idx = A.indices().t().tolist() vals = A.values().detach().cpu().tolist() if not directed: seen = set() for (i, j), w in zip(idx, vals): if i == j: continue a, b = (i, j) if i < j else (j, i) if (a, b) in seen: continue seen.add((a, b)) if abs(w) > pruning_threshold: net.add_edge(a, b, weight=float(w)) else: for (i, j), w in zip(idx, vals): if abs(w) > pruning_threshold: net.add_edge(int(i), int(j), weight=float(w)) else: adj = A.detach().cpu() m = ( gn.mask.to_dense() if isinstance(gn.mask, torch.Tensor) and gn.mask.layout != torch.strided else gn.mask ) for i in range(gn.num_nodes): j_range = range(gn.num_nodes) if directed else range(i + 1, gn.num_nodes) for j in j_range: w = float(adj[i, j]) if abs(w) > pruning_threshold and (m[i, j] != 0): net.add_edge(i, j, weight=w) return net
[docs] def plot_adjacency_heatmap( gn, *, ax=None, title: str = None, xlabel: str = "$j$", ylabel: str = "$i$", cbar_label: str = "$A_{ij}$", add_colorbar: bool = True, cbar_kwargs: Optional[dict] = None, imshow_kwargs: Optional[dict] = None, plt_show: bool = False, ): """Plot an adjacency matrix as a heatmap. - If ``ax`` is ``None``, creates a new figure and axes. - The colorbar attaches to ``ax.figure`` unless ``add_colorbar=False``. - Accepts a GradNet-like object (callable with no args), a Torch tensor, or any array-like representing an adjacency. """ import matplotlib.pyplot as plt # Resolve input to a NumPy array adjacency if isinstance(gn, torch.Tensor): data = gn.detach().cpu().numpy() elif callable(gn): # GradNet or similar returning adjacency via __call__ A = gn() data = ( A.detach().cpu().numpy() if isinstance(A, torch.Tensor) else np.asarray(A) ) else: data = np.asarray(gn) imshow_kwargs = {} if imshow_kwargs is None else dict(imshow_kwargs) if ax is None: fig, ax = plt.subplots() else: fig = ax.figure im = ax.imshow(data, **imshow_kwargs) if add_colorbar: cb_kwargs = {} if cbar_kwargs is None else dict(cbar_kwargs) cb_kwargs.setdefault("label", cbar_label) fig.colorbar(im, ax=ax, **cb_kwargs) ax.set(title=title, xlabel=xlabel, ylabel=ylabel) if plt_show: plt.show() return im
[docs] def plot_graph( gn, *, ax=None, pruning_threshold: float = 1e-8, layout: str = "spring", node_size: float = 15.0, edgecolors: str = "black", edge_width_scaling: float = 1.0, draw_kwargs: Optional[dict] = None, add_colorbar: bool = False, colorbar_label: str = None, plt_show: bool = False, ): # TODO! layout="networkx" option is weird for positions """Draw the NetworkX representation of ``gn``. - If ``ax`` is ``None``, creates a new figure and axes. - Uses ``to_networkx`` and derives edge widths from weights. - ``layout`` can be a ``networkx.draw_*`` name or a callable. - If `add_colorbar=True`, adds a colorbar when `node_color` is array-like. """ import matplotlib.pyplot as plt from matplotlib.cm import ScalarMappable import numpy as np try: import networkx as nx except ImportError as exc: raise ImportError( "plot_graph requires the optional 'networkx' extra; install it with" " `pip install gradnet[networkx]`." ) from exc if ax is None: _, ax = plt.subplots() net = to_networkx(gn, pruning_threshold=pruning_threshold) edge_weights = list(nx.get_edge_attributes(net, "weight").values()) edge_weights = [w * edge_width_scaling for w in edge_weights] if not edge_weights: edge_weights = None draw_kwargs = {} if draw_kwargs is None else dict(draw_kwargs) draw_kwargs.setdefault("nodelist", sorted(net.nodes())) draw_kwargs.setdefault("node_size", node_size) draw_kwargs.setdefault("width", edge_weights) draw_kwargs.setdefault("edgecolors", edgecolors) draw_fn = getattr(nx, f"draw_{layout}") if isinstance(layout, str) else layout if not callable(draw_fn): raise ValueError(f"layout '{layout}' is not callable") # Draw the network draw_fn(net, ax=ax, **draw_kwargs) # Optionally add a colorbar if add_colorbar and "node_color" in draw_kwargs: node_color = draw_kwargs["node_color"] if hasattr(node_color, "__len__") and not isinstance(node_color, str): cmap = draw_kwargs.get("cmap", plt.cm.viridis) sm = ScalarMappable(cmap=cmap) sm.set_array(np.asarray(node_color)) ax.figure.colorbar(sm, ax=ax, label=colorbar_label) if plt_show: plt.show() return
[docs] def load_scalars(log_dir: Union[str, Path]): """Return shared steps and a dict of scalar series from Lightning logs. The ``log_dir`` can be either a specific version directory (e.g., ``lightning_logs/gradnet/version_3``) or the parent folder that contains multiple ``version_*`` subdirectories (e.g., ``lightning_logs/gradnet``). This function prefers CSV logs when present and falls back to TensorBoard event files if available. Returns ``(steps, series)`` where ``steps`` is a single list of integers (epoch/step) shared by all metrics, and ``series`` is a mapping ``{name: values}`` with values aligned to ``steps``. Missing values are filled with ``nan``. Usage: >>> steps, series = load_scalars('lightning_logs/gradnet') >>> loss = series['loss'] :param log_dir: Path to a logger directory or its parent. :return: tuple[list[int], dict[str, list[float]]] """ root = Path(log_dir) def _is_version_dir(p: Path) -> bool: if not p.is_dir(): return False if (p / "metrics.csv").exists(): return True for f in p.iterdir(): if f.is_file() and f.name.startswith("events.out.tfevents"): return True return False def _find_version_dir(base: Path) -> Path: if _is_version_dir(base): return base candidates = [ d for d in base.iterdir() if d.is_dir() and d.name.startswith("version") ] if not candidates: return base # best effort; may still contain event files directly def _ver_num(p: Path) -> int: name = p.name try: return int(name.split("_")[-1]) except Exception: return -1 candidates.sort(key=lambda p: (_ver_num(p), p.stat().st_mtime)) return candidates[-1] version_dir = _find_version_dir(root) # 1) Try CSV (CSVLogger) csv_path = version_dir / "metrics.csv" if csv_path.exists(): # aggregate rows per x (epoch preferred, else step/global_step, else row index) per_x: dict[int, dict[str, float]] = {} metric_names: set[str] = set() next_row_index = 0 with csv_path.open(newline="") as f: reader = csv.DictReader(f) fieldnames = reader.fieldnames or [] long_format = "metric" in fieldnames and "value" in fieldnames for row in reader: # compute x for this row epoch = row.get("epoch") step = row.get("step") or row.get("global_step") x: Optional[int] = None for cand in (epoch, step): if cand is not None and str(cand) != "": try: x = int(float(cand)) break except Exception: pass if x is None: x = next_row_index next_row_index += 1 if long_format: name = row.get("metric") val = row.get("value") if ( name is None or val is None or val == "" or str(val).lower() == "nan" ): continue try: v = float(val) except Exception: continue metric_names.add(name) per_x.setdefault(x, {})[name] = v else: # wide format: one row per x, multiple metric columns for name, val in row.items(): if name in { "epoch", "step", "global_step", "time", "created_at", }: continue if val is None or val == "" or str(val).lower() == "nan": continue try: v = float(val) except Exception: continue metric_names.add(name) per_x.setdefault(x, {})[name] = v if not per_x: return [], {} steps = sorted(per_x.keys()) # Build aligned series with NaNs for missing values series: dict[str, list[float]] = { name: [float("nan")] * len(steps) for name in sorted(metric_names) } index_of = {s: i for i, s in enumerate(steps)} for s, vals in per_x.items(): i = index_of[s] for name, v in vals.items(): series[name][i] = float(v) return steps, series # 2) Try TensorBoard events (TensorBoardLogger) try: from tensorboard.backend.event_processing.event_accumulator import EventAccumulator # type: ignore except Exception: # tensorboard not available and no CSV found raise RuntimeError( f"No metrics.csv found in '{version_dir}' and 'tensorboard' is not installed. " "Install it with `pip install tensorboard`, or enable CSVLogger." ) # If the provided directory is not a version dir, EventAccumulator can still # discover event files within it. ea = EventAccumulator(str(version_dir), size_guidance={"scalars": 0}) ea.Reload() scalar_tags = list(ea.Tags().get("scalars", [])) if not scalar_tags: return [], {} # unify steps across all tags step_set: set[int] = set() per_tag_events: dict[str, list] = {} for tag in scalar_tags: ev = ea.Scalars(tag) per_tag_events[tag] = ev for e in ev: step_set.add(int(e.step)) steps = sorted(step_set) idx = {s: i for i, s in enumerate(steps)} series: dict[str, list[float]] = { tag: [float("nan")] * len(steps) for tag in scalar_tags } for tag, ev in per_tag_events.items(): for e in ev: ii = idx[int(e.step)] series[tag][ii] = float(e.value) return steps, series
[docs] def animate_adjacency( checkpoints: Union[str, Path], *, output_path: Optional[Union[str, Path]] = None, fps: int = 30, dpi: int = 100, figsize: Optional[tuple[float, float]] = None, title_template: Optional[str] = "Checkpoint {index}: {name}", imshow_kwargs: Optional[Mapping] = None, ): """Animate adjacency heatmaps for GradNet checkpoints named ``gn-periodic-*.ckpt``.""" try: import matplotlib.pyplot as plt from matplotlib import animation as mpl_animation except Exception as exc: warnings.warn(f"Matplotlib is required for animations: {exc}") return None from .gradnet import GradNet fps = max(1, int(fps)) root = Path(checkpoints) if root.is_dir(): ckpts = sorted(p for p in root.glob("gn-periodic-*.ckpt") if p.is_file()) else: matches = ( root.is_file() and root.name.startswith("gn-periodic-") and root.suffix == ".ckpt" ) ckpts = [root] if matches else [] if not ckpts: warnings.warn("No checkpoints matching 'gn-periodic-*.ckpt' were found.") return None adjacencies = [] for path in ckpts: model = GradNet.from_checkpoint(str(path), map_location="cpu") adjacencies.append(model.to_numpy()) show_kwargs = {} if imshow_kwargs is None else dict(imshow_kwargs) show_kwargs.setdefault("vmin", min(float(adj.min()) for adj in adjacencies)) show_kwargs.setdefault("vmax", max(float(adj.max()) for adj in adjacencies)) fig, ax = plt.subplots(figsize=figsize) im = plot_adjacency_heatmap( adjacencies[0], ax=ax, title=( title_template.format(index=0, name=ckpts[0].name) if title_template else None ), imshow_kwargs=show_kwargs, ) def _update(index: int): im.set_data(adjacencies[index]) if title_template: ax.set_title(title_template.format(index=index, name=ckpts[index].name)) return [im] ani = mpl_animation.FuncAnimation( fig, _update, frames=len(adjacencies), interval=1000.0 / fps ) saved_path: Optional[Path] = None temporary_path: Optional[Path] = None ffmpeg_error: Optional[Exception] = None try: from matplotlib.animation import FFMpegWriter except Exception as exc: ffmpeg_error = exc else: ffmpeg_error = None target_path: Optional[Path] if output_path: target_path = Path(output_path) target_path.parent.mkdir(parents=True, exist_ok=True) else: try: tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) tmp.close() temporary_path = Path(tmp.name) target_path = temporary_path except Exception as exc: warnings.warn( f"Unable to allocate temporary file for MP4 animation: {exc}" ) target_path = None if target_path is not None: try: ani.save(str(target_path), writer=FFMpegWriter(fps=fps), dpi=dpi) saved_path = target_path except Exception as exc: warnings.warn(f"Failed to save animation to {target_path}: {exc}") if temporary_path and temporary_path.exists(): temporary_path.unlink(missing_ok=True) temporary_path = None saved_path = None if saved_path is None and output_path and ffmpeg_error is not None: warnings.warn( f"FFMpeg writer unavailable; failed to create MP4 animation: {ffmpeg_error}" ) displayed = False try: from IPython.display import HTML, Video, display if saved_path and saved_path.exists(): # Embed MP4 first to minimize notebook size when possible. display(Video(str(saved_path), embed=True)) displayed = True else: try: display(HTML(ani.to_html5_video())) displayed = True except Exception: display(HTML(ani.to_jshtml())) displayed = True except Exception: pass if not displayed: try: plt.show() displayed = True except Exception: pass if not displayed: warnings.warn( "Unable to display the animation; consider running inside a notebook environment." ) plt.close(fig) return saved_path or ani
[docs] def positions_to_distance_matrix(positions: torch.Tensor, norm: float = 2.0): """Compute the pairwise distance matrix from node positions using a given norm.""" diff = positions.unsqueeze(1) - positions.unsqueeze(0) return LA.vector_norm(diff, ord=norm, dim=-1)
[docs] def regularization_loss(del_adj: torch.Tensor) -> torch.Tensor: """ Regularization loss for sparsifying the delta adjacency. Computes sum(log(abs(del_adj) + 1)) / N, where N is the last dimension size. """ # f = lambda x: torch.sigmoid(x) f = lambda x: torch.log(x + 1) return torch.sum(f(torch.abs(del_adj))) / del_adj.shape[-1]
################################################################################# private utils def _to_like_struct(obj, like: torch.Tensor): """Recursively move/cast tensors (and NumPy) inside obj to like.device/dtype; leave others as-is.""" if isinstance(obj, torch.Tensor): return obj.to(device=like.device, dtype=like.dtype) if isinstance(obj, np.ndarray): # also catches np.matrix # as_tensor shares memory on CPU; then we move/cast to match `like` t = torch.as_tensor(obj) # stays on CPU first return t.to(device=like.device, dtype=like.dtype) if isinstance(obj, np.generic): # NumPy scalar (e.g., np.float32(3.0)) return torch.tensor(obj, device=like.device, dtype=like.dtype) if isinstance(obj, Mapping): return obj.__class__({k: _to_like_struct(v, like) for k, v in obj.items()}) if isinstance(obj, tuple) and hasattr(obj, "_fields"): # namedtuple return obj.__class__(*[_to_like_struct(v, like) for v in obj]) if isinstance(obj, (list, tuple)): typ = obj.__class__ return typ(_to_like_struct(v, like) for v in obj) return obj # nn.Module or anything else stays as-is def _shortest_path(A: torch.Tensor, pair="full"): """Compute shortest path distances with SciPy and preserve Torch grads. - Accepts an adjacency tensor ``A`` (dense or sparse PyTorch). - Edge costs equal the provided weights. Zeros off-diagonal denote absence of edges. - ``pair`` may be ``"full"`` for all-pairs distances or a tuple ``(i, j)`` for a single-source, single-target distance. - Uses SciPy's Dijkstra to get predecessors and reconstructs distances by summing Torch weights along chosen paths so gradients flow. - For sparse Torch tensors, converts to SciPy CSR; otherwise uses a dense NumPy array. Dense/sparse behavior is preserved. Returns: - If ``pair == 'full'``: ``torch.Tensor`` of shape ``(N, N)`` with grads. - If ``pair`` is ``(i, j)``: a scalar ``torch.Tensor`` distance. """ try: from scipy.sparse import csr_matrix from scipy.sparse.csgraph import shortest_path as sp_shortest_path except Exception as e: # pragma: no cover - environment dependent raise RuntimeError("scipy is required for shortest_path computation") from e if not isinstance(A, torch.Tensor): raise TypeError("A must be a torch.Tensor (dense or sparse)") if A.ndim != 2 or A.shape[0] != A.shape[1]: raise ValueError("Adjacency must be a square 2D matrix") N = A.shape[0] # Infer directionality by symmetry (tolerant). Adense_for_sym = A.to_dense() if A.layout != torch.strided else A directed = not torch.allclose(Adense_for_sym, Adense_for_sym.T) # Build SciPy graph (costs) from Torch if A.layout != torch.strided: Ac = A.coalesce() ii = Ac.indices()[0].detach().cpu().numpy() jj = Ac.indices()[1].detach().cpu().numpy() vv_np = Ac.values().detach().cpu().numpy() C = csr_matrix((vv_np, (ii, jj)), shape=(N, N)) else: # zeros off-diagonal represent no edge for csgraph; keep diagonal zeros C = A.detach().cpu().numpy() # Helper to reconstruct Torch-summed cost along predecessor path def _reconstruct_cost_from_predecessors( src: int, dst: int, pred_row: np.ndarray ) -> torch.Tensor: if src == dst: # return a dense scalar to avoid sparse/dense copy issues downstream return torch.zeros((), dtype=A.dtype, device=A.device) k = int(dst) if k < 0 or k >= N: return torch.tensor(float("inf"), dtype=A.dtype, device=A.device) total = torch.zeros((), dtype=A.dtype, device=A.device) Adense = A.to_dense() if (A.layout != torch.strided) else A while k != src: pk = int(pred_row[k]) if pk == -9999 or pk < 0: return torch.tensor( float("inf"), dtype=A.dtype, device=A.device ) # unreachable total = total + Adense[pk, k] k = pk return total if pair == "full": # Compute all-pairs predecessors once dist_np, pred_np = sp_shortest_path( C, directed=directed, return_predecessors=True, unweighted=False ) # Reconstruct distances via Torch sums along chosen paths out = torch.empty((N, N), device=A.device, dtype=A.dtype) Adense = A.to_dense() if (A.layout != torch.strided) else A def cost_entry(u, v): return Adense[u, v] for i in range(N): pred_row = pred_np[i] for j in range(N): if i == j: out[i, j] = torch.zeros((), device=A.device, dtype=A.dtype) continue # unreachable? if not np.isfinite(dist_np[i, j]): out[i, j] = torch.tensor( float("inf"), device=A.device, dtype=A.dtype ) continue # backtrack using predecessors and sum torch costs along the path k = j total = torch.zeros((), device=A.device, dtype=A.dtype) while k != i: pk = int(pred_row[k]) if pk == -9999 or pk < 0: total = torch.tensor( float("inf"), device=A.device, dtype=A.dtype ) break total = total + cost_entry(pk, k) k = pk out[i, j] = total return out # pair = (i, j): single-source, single-target if not (isinstance(pair, (tuple, list)) and len(pair) == 2): raise ValueError("pair must be 'full' or a tuple (i, j)") i, j = int(pair[0]), int(pair[1]) dist_np, pred_np = sp_shortest_path( C, directed=directed, indices=i, return_predecessors=True, unweighted=False ) # If unreachable if not np.isfinite(dist_np[j]): return torch.tensor(float("inf"), dtype=A.dtype, device=A.device) return _reconstruct_cost_from_predecessors(i, j, pred_np) def _timeit(func): @wraps(func) def wrapper(*args, **kwargs): start = time.time() result = func(*args, **kwargs) end = time.time() print(f"{func.__name__} took {end - start:.4f} seconds") return result return wrapper