Download this notebook

Toy recurrent network learning logical gates

Main gradnet concepts demonstrated below

  • Discrete dynamical loss

  • Directed networks

  • Training on toy data

Problem setup

Short circuit-style networks trained with GradNet’s built-in trainer. We allow positive/negative edges and rely on leaky-ReLU dynamics plus a bias node to fit XOR.

Install

install the required dependencies silently

%%capture
!pip install 'gradnet[examples]'

Imports and helpers

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from matplotlib.colors import TwoSlopeNorm

from gradnet import GradNet, fit
from gradnet.utils import plot_adjacency_heatmap, random_seed

random_seed(3)

Truth tables and directed network

Layout: A, B, H, OUT. The initial state loads the two inputs and zeros elsewhere. We unroll state = relu(A @ state) for a few rounds and read the last node as the output. No clamping or masks beyond the zero diagonal from GradNet defaults.

device = torch.device("cpu")
inputs = torch.tensor([[0.,0.],[0.,1.],[1.,0.],[1.,1.]], device=device)

gate_targets = {
    "AND": inputs.new_tensor([0.,0.,0.,1.]),
    "OR":  inputs.new_tensor([0.,1.,1.,1.]),
    "XOR": inputs.new_tensor([0.,1.,1.,0.]),
}

# Layout: A, B, H, OUT
num_nodes = 4
input_idx = (0, 1)
output_idx = num_nodes - 1
num_rounds = 2
budget = 11

Propagation, loss, and training via fit

Unroll state = tanh(A @ state) for num_rounds starting from the input-loaded state; the output is the last node passed through a sigmoid and BCE loss. No clamping beyond the initial state.

def propagate(A, x, rounds=num_rounds):
    # Initialize all nodes to zero, then inject input values at input nodes
    state = torch.zeros(num_nodes, device=A.device, dtype=A.dtype)
    state[list(input_idx)] = x
    # Unroll the recurrence: state = tanh(A @ state)
    for _ in range(rounds):
        state = torch.tanh(A @ state)
    return state[output_idx]  # return the output node's activation


def gate_loss(gn, targets):
    A = gn()  # get current adjacency matrix
    preds = torch.stack([propagate(A, x) for x in inputs])  # run all 4 input combos
    loss = F.mse_loss(preds, targets)
    acc = ((preds > 0.5) == (targets > 0.5)).float().mean()
    return loss, {"acc": acc}


def train_gate(name, targets, num_updates=2000, lr=0.01):
    # Fresh GradNet: directed, free-sign edges, near-zero init, soft budget
    gn = GradNet(
        num_nodes=num_nodes,
        budget=budget,
        directed=True,
        delta_sign="free",
        rand_init_weights=1e-9,
        strict_budget=False,
    )

    fit(
        gn=gn,
        loss_fn=lambda g: gate_loss(g, targets),
        num_updates=num_updates,
        optim_cls=torch.optim.Adam,
        optim_kwargs={"lr": lr},
        logger=False,
        accelerator="cpu",
    )

    with torch.no_grad():
        A = gn()
        preds = torch.stack([propagate(A, x) for x in inputs]).cpu()
    return gn, A.detach().cpu(), preds, targets.detach().cpu()

Train AND, OR, and XOR

gate_results = {}
for name, targets in gate_targets.items():
    print(f"Training {name} gate")
    gn, adj, preds, tgts = train_gate(name, targets)
    gate_results[name] = {"model": gn, "adj": adj, "preds": preds, "targets": tgts}
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Training AND gate
  | Name | Type    | Params | Mode 
-----------------------------------------
0 | gn   | GradNet | 16     | train
-----------------------------------------
16        Trainable params
0         Non-trainable params
16        Total params
0.000     Total estimated model params size (MB)
`Trainer.fit` stopped: `max_epochs=2000` reached.
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Training OR gate
  | Name | Type    | Params | Mode 
-----------------------------------------
0 | gn   | GradNet | 16     | train
-----------------------------------------
16        Trainable params
0         Non-trainable params
16        Total params
0.000     Total estimated model params size (MB)
`Trainer.fit` stopped: `max_epochs=2000` reached.
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Training XOR gate
  | Name | Type    | Params | Mode 
-----------------------------------------
0 | gn   | GradNet | 16     | train
-----------------------------------------
16        Trainable params
0         Non-trainable params
16        Total params
0.000     Total estimated model params size (MB)
`Trainer.fit` stopped: `max_epochs=2000` reached.

Check truth tables after training

for name, res in gate_results.items():
    print(f"{name} gate outputs")
    for combo, pred, target in zip(inputs.cpu().tolist(), res["preds"], res["targets"]):
        bits = [int(v) for v in combo]
        print(f"{bits} -> pred={pred.item():.3f} (binary {int(pred > 0.5)}), target={int(target.item())}")

        
AND gate outputs
[0, 0] -> pred=0.000 (binary 0), target=0
[0, 1] -> pred=0.013 (binary 0), target=0
[1, 0] -> pred=0.013 (binary 0), target=0
[1, 1] -> pred=0.900 (binary 1), target=1
OR gate outputs
[0, 0] -> pred=0.000 (binary 0), target=0
[0, 1] -> pred=0.994 (binary 1), target=1
[1, 0] -> pred=0.996 (binary 1), target=1
[1, 1] -> pred=1.000 (binary 1), target=1
XOR gate outputs
[0, 0] -> pred=0.000 (binary 0), target=0
[0, 1] -> pred=0.813 (binary 1), target=1
[1, 0] -> pred=0.813 (binary 1), target=1
[1, 1] -> pred=0.096 (binary 0), target=0

Plot the adjacencies

abs_max = max(float(res["adj"].abs().max()) for res in gate_results.values())
abs_max = max(abs_max, 1e-6)
norm = TwoSlopeNorm(vmin=-abs_max, vcenter=0.0, vmax=abs_max)
imshow_kwargs = {"cmap": "RdBu", "norm": norm}

# Node labels for ticks: A, B, hidden(s) as H/H1..., and O for output
hidden_counter = 0
node_labels = []
for n in range(num_nodes):
    if n == input_idx[0]:
        node_labels.append("A")
    elif len(input_idx) > 1 and n == input_idx[1]:
        node_labels.append("B")
    elif n == output_idx:
        node_labels.append("O")
    else:
        label = "H" if hidden_counter == 0 else f"H{hidden_counter}"
        node_labels.append(label)
        hidden_counter += 1

fig, axs = plt.subplots(1, 3, figsize=(6, 2.5), dpi=200, constrained_layout=True)
for ax, (name, res) in zip(axs, gate_results.items()):
    plot_adjacency_heatmap(res["adj"], ax=ax, title=f"{name} gate", imshow_kwargs=imshow_kwargs, add_colorbar=False)
    ax.set_xlabel(None)
    ax.set_ylabel(None)
    ax.set_xticks(range(num_nodes), node_labels)
    ax.set_yticks(range(num_nodes), node_labels)

cbar = fig.colorbar(axs[-1].images[0], ax=axs, location="right", shrink=0.9)
cbar.set_label("$A_{ij}$")
plt.show()
../_images/bf0bf832ccdf791aac8bc9c51c010bf6f756ec4dffbd6db6d824512a94ac0b97.png

Plot as networks

# Plot each learned gate as a small directed graph
import math
import matplotlib as mpl

# Re-compute a shared scale so widths/colors are comparable across gates
edge_abs_max = max(float(res['adj'].abs().max()) for res in gate_results.values())
edge_abs_max = max(edge_abs_max, 1e-8)
edge_norm = mpl.colors.TwoSlopeNorm(vmin=-edge_abs_max, vcenter=0.0, vmax=edge_abs_max)
edge_cmap = mpl.cm.get_cmap('RdBu')  # red = negative, blue = positive

dx = 0.4
dy = 1.6

node_labels = {0: 'A', 1: 'B', output_idx: 'O'}
if num_nodes >= 3:
    node_labels.setdefault(2, 'H')


def layout_positions(num_nodes, input_idx, output_idx):
    # Inputs stacked on the left (x=0)
    inputs = list(input_idx)
    y_inputs = [(((len(inputs) - 1) / 2) - i)*dy for i in range(len(inputs))]
    positions = {idx: (0.0, y) for idx, y in zip(inputs, y_inputs)}

    # Middle layer for all other non-output nodes
    mid_nodes = [i for i in range(num_nodes) if i not in inputs and i != output_idx]
    if mid_nodes:
        y_mid = [(((len(mid_nodes) - 1) / 2) - i)*dy for i in range(len(mid_nodes))]
        for idx, y in zip(mid_nodes, y_mid):
            positions[idx] = (dx, y)

    # Output on the right (x=2) centered relative to current nodes
    if positions:
        min_y = min(y for _, y in positions.values())
        max_y = max(y for _, y in positions.values())
        out_y = 0.5 * (min_y + max_y)
    else:
        out_y = 0.0
    positions[output_idx] = (2.5*dx, out_y)
    return positions


def draw_gate_graph(name, adj, ax):
    adj_np = adj.detach().cpu().numpy()
    pos = layout_positions(adj_np.shape[0], input_idx, output_idx)
    threshold = 1e-2
    curve_rad = 0.20  # clockwise bend for bidirectional pairs

    # Draw directed edges j -> i with styling from weight
    for i in range(adj_np.shape[0]):
        for j in range(adj_np.shape[1]):
            w = adj_np[i, j]
            if abs(w) < threshold:
                continue
            (x0, y0), (x1, y1) = pos[j], pos[i]
            width = 6.0 * (abs(w) / edge_abs_max)
            color = edge_cmap(edge_norm(w))

            has_reverse = i != j and abs(adj_np[j, i]) >= threshold
            rad = curve_rad if has_reverse else 0.0

            ax.annotate(
                '',
                xy=(x1, y1), xytext=(x0, y0),
                arrowprops=dict(
                    arrowstyle='-|>',
                    color=color,
                    lw=width,
                    alpha=1,
                    shrinkA=9, shrinkB=9,
                    connectionstyle=f'arc3,rad={rad}',
                ),
                zorder=1,
            )

    # Draw nodes on top of edges
    for idx, (x, y) in pos.items():
        kind = 'input' if idx in input_idx else ('output' if idx == output_idx else 'hidden')
        facecolor = {'input': '#86c5da', 'hidden': '#cccccc', 'output': '#f4b400'}[kind]
        ax.scatter(x, y, s=520, color=facecolor, edgecolor='k', zorder=2)
        label = node_labels.get(idx, f'n{idx}')
        ax.text(x, y, label, ha='center', va='center', weight='bold', fontsize=13)

    ys = [y for _, y in pos.values()]
    y_margin = 0.6 if len(ys) > 1 else 0.4
    ax.set_title(f"{name} gate")
    ax.set_xlim(-0.5, 2.5*dx + 0.5)
    ax.set_ylim(min(ys) - y_margin, max(ys) + y_margin)
    ax.axis('off')


fig, axes = plt.subplots(1, len(gate_results), figsize=(4 * len(gate_results), 3), dpi=200)
if len(gate_results) == 1:
    axes = [axes]

for ax, (name, res) in zip(axes, gate_results.items()):
    draw_gate_graph(name, res['adj'], ax)

plt.show()
/var/folders/72/79vqt54j447byqmvb80g_n3w0000gn/T/ipykernel_96387/3576067285.py:9: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed two minor releases later. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap(obj)`` instead.
  edge_cmap = mpl.cm.get_cmap('RdBu')  # red = negative, blue = positive
../_images/c49659de6fe02d1336aa5676697474468f80de7cd3fb23c0f4f8a4798119b0e0.png