Skip to main content

Command Palette

Search for a command to run...

Simulating Quantum Computers at Scale: JAX, GPUs, and Cloud TPUs

Updated
5 min read

Note: The detailed research paper and full experimental results for this project have already been published. This post provides a high-level summary of the core system design, key benchmarks, and engineering insights.

Classical simulation of quantum systems is the bedrock of modern quantum algorithm research. However, simulating variational quantum algorithms (VQAs) has historically run into a massive wall: gradient computation scaling.

In this post, I’ll introduce the JAX Quantum Research Suite—a hardware-accelerated, fully differentiable quantum simulator designed to run seamlessly on consumer GPUs and scale up to Google Cloud TPU clusters.

We’ll explore how it overcomes traditional limitations, achieves up to a 48.7× speedup in gradient calculations, and scales simulations up to 36 qubits on TPU v6e and 1,024 qubits using Matrix Product States (MPS).

This research and scaling development was supported by the Google TPU Research Cloud (TRC) program.


The Core Challenge: Differentiating Quantum Circuits

For variational algorithms like VQE (Variational Quantum Eigensolver) or QAOA (Quantum Approximate Optimization Algorithm), we need to optimize parameterized quantum circuits.

In standard frameworks, computing gradients is done using the Parameter-Shift Rule (PSR). For a circuit with $P$ parameters, PSR requires $2P$ full circuit evaluations. When you scale up to hundreds of parameters, this linear $O(P)$ overhead becomes a major bottleneck.

By using JAX, our simulator implements reverse-mode automatic differentiation (backpropagation). This allows us to compute all $P$ gradients in a single backward pass—an $O(1)$ scaling cost relative to the parameter count.


🚀 Key Highlights & Benchmarks

1. 48.7× Faster Gradients (Reverse-Mode AD vs. Parameter-Shift)

On a 15-qubit Hardware-Efficient Ansatz circuit with 120 parameters, we benchmarked gradient calculation times:

  • Parameter-Shift Rule (PSR): 1,826 ms
  • JAX Reverse-Mode Autodiff (jax.grad): 37.5 ms
  • Performance Leap: A 48.7× speedup, which grows even larger as the number of parameters increases.
  • Vs. PennyLane JAX Backend: Our custom statevector simulator ran ~4× faster (2ms vs 8ms for 50 params on GPU) due to monolithic XLA kernel fusion.

2. Grover's Algorithm at 36 Qubits (549 GB Statevector)

Full statevector simulation scales exponentially in memory. A 36-qubit simulation requires 549.76 GB of memory to store the amplitudes. Using a 64-chip Google Cloud TPU v6e mesh, we successfully ran Grover's algorithm at 36 qubits, distributing the massive statevector across TPU HBM3.

3. Shor's Algorithm at 33 Qubits on TPU v5e-16

We demonstrated a distributed QFT (Quantum Fourier Transform) pipeline for Shor's order-finding at a 33-qubit scale (64 GB statevector) on a 16-chip Cloud TPU v5e mesh, using network-optimized collective communication (ppermute) to reduce inter-chip transfer spikes from 8 GB to 128 MB.

4. 1,024-Qubit MPS VQE & Numerical Stability Breakthroughs

To simulate beyond statevector limits (36+ qubits), we built a differentiable Matrix Product State (MPS) simulator in pure JAX. When differentiating through Singular Value Decomposition (SVD), JAX frequently crashes due to Wirtinger gradient singularities. We introduced three key stabilization techniques:

  1. SVD Epsilon Floor (1e-7) to prevent division-by-zero.
  2. Site-level Normalization to stop amplitude drift.
  3. SVD Jitter & Momentum SGD to low-pass filter and damp "V-bounce" oscillations caused by bond-dimension entanglement ceilings.

🛠️ Three Engineering Tricks for Cloud TPU Scaling

Tracing deep quantum circuits into XLA bytecode can crash compile hosts or run out of memory. We solved these issues with three specific JAX primitives:

A. Multi-Device PositionalSharding

We partition the statevector's leading dimension across physical TPU chips:

sharding = PositionalSharding(jax.devices()).reshape(NUM_DEV, 1)
state = jax.device_put(state, sharding)

XLA handles local gate execution automatically, routing cross-shard gates via TPU Inter-Chip Interconnects (ICI).

B. lax.fori_loop for O(1) Compiler Graph Size

Unrolled Python loops create millions of HLO nodes, causing compile-time memory leaks. We wrap our circuit execution using JAX's loop control:

state_new = jax.lax.fori_loop(0, depth, body_fn, state_init)

This forces XLA to compile the gate block once, keeping the compiler graph size constant ($O(1)$) regardless of circuit depth.

C. jax.checkpoint for O(1) Backpropagation Memory

Reverse-mode AD normally stores all intermediate states in memory. For deep circuits, this causes HBM OOM errors. We apply gradient checkpointing to recompute intermediate states on-the-fly during the backward pass:

@jax.checkpoint
def circuit_layer(state, params):
    return apply_gates(state, params)

🎯 Summary of Experiments

  • GHZ State Prep: Learned to prepare a 3-qubit GHZ state with \(\mathcal{F} > 0.9999\) in 200 epochs.
  • XOR VQC Classifier: Batched classification using jax.vmap achieving 97%+ accuracy.
  • VQE (H₂ Molecule): Solved ground state energy to within 0.1 mHa of the FCI reference (chemical accuracy).
  • QAOA MaxCut: Achieved a 99.7% approximation ratio at depth \(p=5\) on a 6-node weighted graph.
  • Monte Carlo Noise Trajectories: Simulated noisy NISQ channels (amplitude/phase damping) showing exact statistical convergence matching analytical curves.
  • Barren Plateaus: Confirmed McClean's exponential gradient variance decay (\(\mathcal{O}(2^{-n})\)) up to 24 qubits on TPU.

🙏 Acknowledgements

This research was made possible through the Google TPU Research Cloud (TRC) program, which provided the TPU v5e and v6e compute resources necessary to scale our simulations to 33 and 36 qubits.