Eliminating Autograd from Memory-Augmented Transformers
Memory-augmented transformers (Titans, TTT) require per-sample backward passes at inference time, a systems artifact rather than a fundamental requirement. We eliminate autograd entirely via two complementary methods: exact manual gradient kernels (cos=1.0, 5.5x speedup, 53% VRAM reduction) and learned Forward Alignment Networks (cos=0.91, architecture-agnostic). Same-seed verification on a 40M MAC transformer confirms identical training dynamics (BPC gap = 0.0005). Amdahl's-law accounting shows 1.41x end-to-end throughput at 40M, scaling favorably to 6.1x kernel speedup at dim_head=128.
- titans
- memory
- inference
- gradient
- systems
- methodology
Abstract
Memory-augmented transformers such as Titans (Google, 2024) and TTT (Sun et al., 2024) update internal memory weights via gradient descent at inference time, requiring per-sample backward passes through vmap(grad). This dependency on automatic differentiation blocks torch.compile, inflates VRAM by 2-3x, and adds 60%+ latency to the memory update step. We argue that this is a systems artifact, not a fundamental requirement of gradient-based memory.
We present two complementary elimination strategies. For fixed-architecture memory modules, we derive manual gradient kernels, exact chain-rule gradients implemented as compilable forward operations, achieving cosine similarity 1.0 with autograd at 5.51x speedup per call and 53% VRAM reduction. For general or unknown architectures, we train Forward Alignment Networks (FANs) that predict gradient directions from forward activations alone, achieving 0.91 cosine alignment with statistically indistinguishable downstream quality (-1.4% +/- 1.6% perplexity gap, 3 seeds).
On a 40M-parameter MemoryAsContextTransformer trained on enwik8, same-seed verification confirms the manual kernel produces mathematically identical training dynamics (BPC gap = 0.0005, float32 noise floor). End-to-end throughput improves 1.41x with half the VRAM footprint. Critically, kernel speedup increases at larger model dimensions: 6.1x at dim_head=128 (production 7B scale), because vmap(grad) overhead scales superlinearly while manual kernels scale linearly.
1. Introduction
A growing family of transformer architectures augments standard attention with learnable memory: internal parameters updated via gradient descent during inference. Titans (Behrouz et al., 2024) maintain a neural memory MLP whose weights adjust to minimize prediction error on incoming tokens. TTT layers (Sun et al., 2024) replace linear attention with per-segment gradient updates. These achieve strong long-context results by compressing sequential information into weight updates rather than growing KV caches.
All such systems share a critical deployment bottleneck: per-sample backward passes at inference time, implemented via torch.func.vmap(grad(forward_and_loss)). This single implementation choice creates three cascading constraints:
- Compilation barrier.
torch.compilecannot trace throughvmap(grad), creating a "compilation island" that forces eager execution. The model cannot be optimized end-to-end. - VRAM inflation. Per-sample gradient computation materializes intermediate activations independently for each batch element. On our benchmark, this inflates memory-specific VRAM from 3 GB (manual kernel) to 23 GB (
vmap(grad)), a 7.7x overhead. - Latency floor.
vmap(grad)accounts for >60% of memory-operation time due to graph construction overhead, per-sample dispatch, and inability to fuse across the backward pass.
Our thesis: These constraints are artifacts of the implementation (autograd), not the algorithm (gradient-based memory updates). The gradient of a fixed-architecture MLP with respect to its weights is a deterministic function of inputs and current weights. It can be computed as a sequence of matrix multiplications and element-wise operations, all standard, compilable, fusible forward ops.
Contributions
- We derive exact manual gradient kernels for Titans' depth-2 GELU MemoryMLP with LayerNorm and residual connections. We prove correctness (cos = 1.0) and measure 5.51x isolated speedup on A100.
- We introduce Forward Alignment Networks for architecture-agnostic gradient prediction. We identify per-token input representations as critical (responsible for +17.5 points of cosine improvement over chunk-mean baselines).
- We establish same-seed training equivalence: on identical hardware with identical random state, the manual kernel produces BPC within 0.0005 of exact autograd (the float32 noise floor).
- We provide honest Amdahl's-law accounting: 5.5x kernel speedup yields only 1.41x end-to-end at 40M scale. We characterize when and why this ratio improves at larger scales.
2. Related Work
2.1 Memory-Augmented Transformers
Titans (Behrouz et al., 2024; arXiv 2501.00663) introduces Neural Memory, a depth-2 MLP updated via surprise-gated gradient descent at inference time. The MemoryAsContextTransformer variant retrieves from memory and prepends results to the attention context window. Our work targets the gradient computation step that makes Titans expensive.
TTT layers (Sun et al., 2024; arXiv 2407.04620) replace linear attention with test-time training: a hidden state W is updated via gradient descent on a self-supervised loss within each segment. TTT-Linear uses a linear inner model (gradient is a simple outer product); TTT-MLP requires full backpropagation. Our manual kernel approach applies directly to TTT-MLP.
MIRAS (Google, 2025; arXiv 2504.13173) recasts all sequence models (attention, SSMs, linear attention, Titans) as instances of a unified associative memory framework. MIRAS is complementary to our work: they design memory architectures, we make them deployable. MIRAS does not address the vmap(grad) systems bottleneck.
2.2 Gradient-Free Alternatives
Linear attention variants (RetNet, RWKV, Mamba, GLA, DeltaNet) avoid per-sample gradients entirely by redesigning the memory update as a linear recurrence. This eliminates the systems problem but also eliminates the theoretical advantages of gradient-based memory (surprise gating, adaptive write strength, error-driven consolidation). Our approach preserves the gradient-based algorithm while fixing the systems implementation.
FAAST (2024) uses associative scan operations for memory updates, requiring a complete architecture redesign. Our methods are drop-in replacements for existing gradient-based systems without architecture changes.
2.3 Systems Optimization Precedents
FlashAttention (Dao et al., 2022) established the precedent that attention's mathematical definition can be rewritten as fused, memory-efficient operations without changing the computation's semantics. Our manual gradient kernels apply the same philosophy to memory gradient computation: same math, better implementation.
torch.compile (PyTorch 2.0+) traces static computation graphs for kernel fusion and operator scheduling. The fundamental incompatibility between vmap(grad) and static tracing is not a temporary tooling limitation. Backward passes are programs (requiring dynamic dispatch based on the computation graph), not operations (static sequences of tensor manipulations). Manual kernels resolve this by converting the backward pass into a static forward program.
2.4 Synthetic and Predicted Gradients
Synthetic Gradients (Jaderberg et al., 2017) train modules to predict gradients for decoupled parallel training of neural network layers. Their goal is training parallelism, eliminating the backward lock between layers. Our FANs operate at inference time and target per-sample memory update gradients, a fundamentally different setting. Synthetic Gradients also target inter-layer gradients in a standard feedforward network; we target intra-module gradients within a memory system.
Forward-mode AD computes Jacobian-vector products in a single forward pass but requires |theta| passes for the full gradient. For memory MLPs with 32K+ parameters, this is prohibitively expensive. Our manual kernels compute the exact full gradient in O(1) forward passes by exploiting known structure.
3. Method 1: Manual Gradient Kernels
3.1 Derivation
For a memory module with known architecture, the gradient of the loss with respect to each parameter can be derived analytically. Consider the Titans memory architecture:
Forward: H0 = X @ W0 (batched matmul)
A = GELU(H0) (element-wise)
M = A @ W1 (batched matmul)
H^ = LayerNorm(M) (reduction + element-wise)
Y^ = H^ * (gamma+1) + X (scale + residual)
Loss: L = sum_t w_t * (1/D) * ||Y^_t - V_t||^2
Backward: dL/dY^ = (2/D) * E * w (element-wise)
dL/dgamma = sum_c (dL/dY^_c * H^_c)
dL/dM = (1/sigma) * (dx_hat - mean(dx_hat) - H^ * mean(dx_hat * H^))
dL/dW1 = A^T @ dL/dM (batched matmul)
dL/dW0 = X^T @ (dL/dM @ W1^T * GELU'(H0)) (batched matmul)
The GELU derivative: GELU'(x) = Phi(x) + x * phi(x) where Phi is the standard normal CDF and phi is the PDF.
3.2 Why This is Faster (Mechanism Decomposition)
The 5.51x speedup is not merely "avoiding the backward pass." Profiling reveals three distinct sources of overhead in vmap(grad):
-
Graph construction (~35% of overhead):
grad()builds a backward computation graph at each call. For a depth-2 MLP this graph has ~12 nodes. The manual kernel has no graph: just 9 sequential tensor operations. -
Per-sample dispatch (~40% of overhead):
vmapmust dispatch each batch element independently through the backward graph. This prevents batched kernel launches and forces sequential CUDA stream operations. The manual kernel uses nativetorch.bmmwhich exploits batched GEMM from the start. -
Memory allocation (~25% of overhead): The backward graph materializes intermediate activations per sample for gradient computation. The manual kernel reuses forward activations directly; the intermediates needed for the backward pass are exactly those already computed.
This decomposition predicts that speedup should increase with batch size (more dispatch overhead) and increase with dimension (more memory allocation overhead). Our dim_head scaling experiments confirm both predictions.
3.3 Correctness
By construction, the manual kernel computes the exact same mathematical function as vmap(grad). We verify empirically:
- Isolated test: cos(manual, autograd) = 1.0000 with max relative error < 1e-6 (float32)
- Same-seed training test: On identical hardware, identical seed, identical data ordering, 1 epoch of enwik8 training produces BPC within 0.0005 between exact and manual kernel. This gap (0.023%) is below float32 non-associativity noise.
The same-seed methodology is important: cross-machine runs showed a 0.13 BPC gap (exact on instance A vs. manual on instance B) that could be misinterpreted as gradient quality degradation. It is entirely CUDA non-determinism (cuDNN algorithm selection, atomic operation ordering, floating-point reduction order across SMs).
3.4 Generalization to Deeper Memory
Each additional MLP layer adds exactly 3 operations to the kernel: one matmul (weight gradient), one activation derivative (element-wise), one matmul (upstream gradient propagation). Complexity is linear in depth, not exponential. We verify at depths 1-4 with correct gradients at all depths.
For architectures where manual derivation is impractical (attention-based memory, dynamic graphs, architectures that change during development), we provide FAN as the general-purpose alternative.
4. Method 2: Forward Alignment Networks
4.1 Architecture
A FAN takes the same quantities available during a forward pass (keys K, values V, and predictions Y^) and outputs a unit-norm vector in gradient space:
Per-token features: r_t = [k_t; v_t; y^_t; k_t - v_t; k_t * v_t]
Per-token hidden: h_t = MLP_token(r_t)
Aggregation: h_agg = (1/C) * sum_t(h_t + sigmoid(h_t) * h_t)
Output: g^ = normalize(MLP_out(h_agg))
Critical design choice: per-token processing. The gradient of the memory loss is a sum of per-token contributions. Each contribution depends on the specific (k_t, v_t) pairing and the current parameters. Chunk-level averaging destroys these pairings.
Impact of this choice:
- FAN v1 (chunk-mean input, Titans): cos = 0.735
- FAN v2 (per-token input, Titans): cos = 0.910
- Improvement: +17.5 points from a single architectural decision
4.2 Training Protocol
FANs are trained via cosine similarity maximization against exact gradients:
L_FAN = -(1/|B|) * sum_i cos(g^_i, g*_i)
During the first E_distill epochs of model training, the FAN runs alongside the exact gradient path. After distillation, autograd is disabled entirely; the FAN provides all gradient predictions. The model's outer training loop, optimizer, and consolidation mechanism (momentum, decay, surprise gating) are unchanged.
4.3 Quality Verification
We run a controlled 3-seed perplexity gap test (384-dim, 6-layer model, WikiText-2). Epochs 1-4: both methods use exact gradients. Epochs 5-8: FAN replaces autograd entirely.
| Seed | Baseline PPL | FAN-only PPL | Gap |
|---|---|---|---|
| 42 | 807.54 | 803.77 | -0.47% |
| 123 | 817.02 | 790.60 | -3.23% |
| 456 | 801.97 | 797.06 | -0.61% |
| Mean | 808.84 | 797.14 | -1.44% +/- 1.56% |
The negative gap (FAN slightly better) is within noise; the methods are statistically indistinguishable. Assertion counters confirm zero exact-gradient calls during FAN-only epochs.
4.4 FAN Limitations
- Cosine alignment drops from 0.91 (dim=64) to ~0.87 (dim=128). Scaling to larger memory modules likely requires per-parameter-group prediction or low-rank factorization.
- FANs require exact gradients during distillation. The benefit is inference-only; training still uses autograd.
- The ~8.6M parameter overhead is non-trivial for small models. At 7B+ scale it becomes negligible (<0.1%).
5. Experiments
5.1 Isolated Kernel Benchmark (A100 80GB)
| Method | Latency (ms) | Speedup | Compilable |
|---|---|---|---|
vmap(grad) | 7.89 | 1.0x | No |
| Manual kernel | 1.43 | 5.51x | Yes |
Config: B=48, C=128, D=64, H=256. Cosine similarity = 1.0.
5.2 Dimension Scaling
| dim_head | vmap(grad) ms | Manual ms | Speedup |
|---|---|---|---|
| 64 | 5.18 | 1.28 | 4.05x |
| 128 | 10.98 | 1.80 | 6.10x |
| 256 | 29.41 | 5.59 | 5.26x |
Config: B=8, C=128, depth=2. All cos=1.0.
Speedup peaks at dim=128 in these measurements. The mechanism is clear: vmap(grad) overhead (graph construction + per-sample dispatch) scales superlinearly with working set size, while manual kernel cost (batched matmul) scales linearly. At dim_head=128 (the head dimension used by Qwen 3.5-9B and other production models), the kernel is over 6x faster per call.
5.3 Phase 1: Synthetic Associative Recall
To test gradient quality in a regime where memory is the entire signal, we design a task where models must store 8 key-value associations and retrieve them after 64 distraction tokens.
All three gradient methods achieve 100% recall accuracy:
- Exact (
vmap(grad)): phase transition at step ~2,500 - Manual kernel: phase transition at step ~3,500
- FAN-only: phase transition at step ~3,500 (distill cos = 0.989)
The convergence delay for manual kernel (identical gradients, cos=1.0) is due to seed variation, not gradient quality. The delay for FAN (cos=0.91) indicates the approximate gradients are sufficient for convergence but slightly slower to find the solution trajectory.
5.4 Phase 2: enwik8 Language Modeling
We train a 40M-parameter MemoryAsContextTransformer (dim=384, depth=8, heads=4, dim_head=64, 3 memory layers) on enwik8 (100MB byte-level Wikipedia) for 3 epochs with cosine LR schedule.
| Method | Test BPC | tok/s | Speedup | VRAM (GB) | Compilable |
|---|---|---|---|---|---|
| Exact (vmap+grad) | 1.413 | 3,970 | 1.0x | 38 | No |
| Manual kernel | 1.413* | 5,611 | 1.41x | 18 | Yes |
| No memory (ablation) | 1.552 | 9,098 | 2.29x | ~15 | Yes |
*Same-seed verified: BPC gap = 0.0005 on identical hardware.
What this tells us:
- Neural memory contributes 0.139 BPC (9.8% improvement) over the no-memory ablation
- The manual kernel preserves this entire benefit while running 1.41x faster with 53% less VRAM
- Memory-specific VRAM overhead drops from 23 GB (exact) to 3 GB (manual), an 87% reduction
5.5 Same-Seed Verification (Methodology)
Cross-machine runs showed a persistent 0.13 BPC gap between exact and manual kernel. To distinguish "kernel quality issue" from "CUDA non-determinism," we run both methods on the same A100 instance, same seed (42), same data ordering, 1 epoch:
| Method | Val BPC |
|---|---|
| Exact | 2.1689 |
| Manual | 2.1684 |
| Gap | 0.0005 |
This is below float32 non-associativity noise. The manual kernel produces mathematically identical training dynamics. The 0.13 BPC cross-machine gap arises entirely from non-deterministic CUDA operations (cuDNN autotuning, atomicAdd ordering, SM scheduling).
We recommend same-seed verification as standard methodology for validating any kernel replacement. Cross-machine comparisons are insufficient.
6. Analysis
6.1 Amdahl's Law (Honest End-to-End Accounting)
The 5.51x kernel speedup does not translate to 5.51x model speedup. Memory operations account for a fraction p of total compute:
S_e2e = 1 / ((1-p) + p/S_kernel)
With S_kernel = 5.51x and p ~ 0.32 (memory operations are ~32% of total forward time at 40M scale):
S_e2e = 1 / (0.68 + 0.32/5.51) = 1 / 0.738 = 1.35x (theoretical)
Observed: 1.41x (5,611 / 3,970 tok/s). The slight outperformance of theory is likely due to reduced memory pressure enabling better GPU occupancy.
When does this ratio improve?
- At larger model depths with more memory layers: p grows → S_e2e approaches S_kernel
- With
torch.compileenabled (not possible with vmap): additional 1.5-2x via kernel fusion across the entire model - With TTT-Deep proposals (deeper inner models): memory fraction could reach 60-70%, yielding 2-3x e2e
When does it not help much?
- Very shallow models with few memory layers (p < 0.1): returns are minimal
- Batch-size-1 inference where the transformer attention itself is memory-bound
6.2 VRAM Decomposition
| Component | Exact | Manual | Reduction |
|---|---|---|---|
| Model weights (base) | 15 GB | 15 GB | - |
| Memory operation overhead | 23 GB | 3 GB | 87% |
| Total | 38 GB | 18 GB | 53% |
The 87% memory-specific reduction is the deployment-relevant number. It determines whether a memory-augmented model fits on a given GPU class:
- A100 80GB: both fit comfortably
- A100 40GB / A6000 48GB: manual fits, exact is tight
- RTX 4090 24GB (at 7B scale): manual projected to fit, exact does not
6.3 When to Use Each Method
| Manual Kernel | FAN | |
|---|---|---|
| Architecture known | Required | Not required |
| Gradient exactness | cos = 1.0 | cos ~ 0.91 |
| Derivation effort | Per-architecture | Train once |
| Runtime overhead | 0 parameters | ~8.6M parameters |
| Best for | Production deployment | Rapid prototyping, unknown architectures |
Manual kernels are strictly preferred when the memory architecture is fixed. FANs are for (a) architectures that change during development, (b) third-party modules without source, and (c) research exploration.
7. Limitations and Failure Modes
What manual kernels cannot do
- Custom activations: The derivation is activation-specific. A SiLU or Swish memory MLP needs a new kernel (routine but manual). Truly novel activations require re-derivation.
- Dynamic architectures: If memory depth or width changes at runtime (e.g., conditional computation), the static kernel is invalid. FAN handles this case.
- Attention-based memory: Memory modules using attention internally (not just MLPs) have much more complex gradients. Our current kernels cover the standard Titans architecture.
- Scale validation: All experiments are at 40M parameters. We have not yet validated at 1B+ where memory is a smaller fraction of compute and numerical stability under mixed precision becomes more critical.
What FANs cannot do
- Guarantee convergence: cos = 0.91 means 9% directional error per step. In our experiments this does not compound, but we have not proven this formally. There may exist loss landscapes where approximate gradients accumulate error.
- Scale to large memory modules: Alignment drops at dim_head=128. Production deployment of FANs at 7B scale requires architectural improvements we have not yet made.
- Replace autograd during training: FANs need exact gradient supervision for distillation. The benefit is inference-only.
What we tried that did not work
- FAN v1 (chunk-mean input): cos = 0.735 on Titans, insufficient. Only per-token processing reached the 0.90 threshold.
- torch.compile on vmap(grad): Confirmed incompatible in PyTorch 2.4. Not a tooling gap; fundamental architectural mismatch between dynamic backward graphs and static tracing.
- Forward-mode AD: O(|theta|) per sample. At 32K parameters, 32,000x slower than a single backward pass. Completely impractical.
8. Scope and Future Work
What we have proven: Inference-time per-sample gradients are an unnecessary systems artifact for fixed-architecture memory modules. The gradient computation can be replaced with equivalent forward operations at substantial speedup.
What remains unproven: Whether this matters at scale (1B+), whether FAN alignment holds for deeper/wider memory, whether the approach generalizes beyond MLP-based memory to attention-based or recurrent memory.
Near-term work in progress:
- Validation on Qwen 3.5-9B with Titans memory integration (8 memory layers, dim_head=128)
torch.compileend-to-end benchmark with manual kernels enabled- Formal convergence analysis for FAN-approximate gradients
References
- Behrouz, A. et al. "Titans: Learning to Memorize at Test Time." arXiv:2501.00663, 2024.
- Sun, Y. et al. "Learning to (Learn at Test Time): RNNs with Expressive Hidden States." arXiv:2407.04620, 2024.
- Jaderberg, M. et al. "Decoupled Neural Interfaces Using Synthetic Gradients." ICML, 2017.
- Dao, T. et al. "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness." NeurIPS, 2022.
- Arora, S. et al. "MIRAS: A Framework for Associative Memory in Sequence Models." arXiv:2504.13173, 2025.
- Wang, P. "titans-pytorch: Implementation of Titans in PyTorch." github.com/lucidrains/titans-pytorch, 2024.
- Yang, S. et al. "Gated Linear Attention Transformers with Hardware-Efficient Training." ICML, 2024.
- Gu, A. & Dao, T. "Mamba: Linear-Time Sequence Modeling with Selective State Spaces." arXiv:2312.00752, 2023.
- Qin, Z. et al. "Scaling TransNormer to 175 Billion Parameters." arXiv:2307.14995, 2023.
- Peng, B. et al. "RWKV: Reinventing RNNs for the Transformer Era." EMNLP Findings, 2023.