Scaling Quantum State-Vector Simulation to 36 Qubits on Google Cloud TPU
Ashitesh Singh Independent Researcher, Supported by Google TPU Research Cloud (TRC) Program GitHub: https://github.com/AshiteshSingh/Tpu-Accelerated-Quantum-JAX
The Problem
Simulating quantum circuits on a classical computer is exponentially hard. An n-qubit state vector holds 2ⁿ complex amplitudes — at 30 qubits that's 8 GB, at 33 qubits 64 GB, at 36 qubits 549 GB. No single GPU or TPU chip can hold that.
Most frameworks (Qiskit, PennyLane) give up around 29–30 qubits for two reasons:
- Memory wall — the full state vector doesn't fit on one device
- XLA graph explosion — deep circuits compiled with Python
for-loops unroll into millions of XLA graph nodes, crashing the compiler host before a single gate runs
I built a simulator in 100% pure JAX that solves both — scaling to 36 qubits (549 GB) across a 64-chip Google Cloud TPU v6e cluster, with full reverse-mode autodiff and sub-millisecond gate latency.
Solution 1: Distributed State Vector with PositionalSharding
Instead of storing all 2ⁿ amplitudes on one chip, I split the state vector across the entire TPU mesh using JAX's native sharding API:
mesh = Mesh(mesh_utils.create_device_mesh((16,)), axis_names=("chips",))
state_sharding = NamedSharding(mesh, P("chips", None))
# 33-qubit state: shape [16, 2^29] — each chip holds 2^29 amplitudes
state = jax.jit(init_ground_state, out_shardings=state_sharding)()
Each chip holds exactly its slice of the state vector in its own HBM pool. JAX's Inter-Chip Interconnects (ICI) handle cross-chip gate operations automatically. The host CPU never holds the full array.
| Cluster | Qubits | State Vector Size | Per-Chip |
|---|---|---|---|
| Single GPU | 29 | 4.29 GB | — |
| TPU v5e-16 | 33 | 64 GB | 4 GB |
| TPU v6e-64 | 36 | 549.76 GB | 8.59 GB |
Solution 2: lax.fori_loop for O(1) Compiler Graph Size
A Python for-loop inside @jax.jit statically unrolls every iteration into a separate XLA graph node. At 100 circuit layers × 30 qubits, the HLO graph has millions of nodes — the XLA compiler host OOMs during compilation, before execution even starts.
jax.lax.fori_loop compiles the loop body once as a single While HLO node:
# ✅ O(1) XLA graph — compiles once, runs D times at hardware speed
state = lax.fori_loop(0, D, circuit_body, state)
# ❌ O(D) XLA graph — explodes at 100+ layers
for i in range(D):
state = apply_gate(state, ...)
Graph size stays constant regardless of circuit depth. No compiler OOM.
The Full Compilation Stack
Python (JAX / jnp)
↓ @jax.jit traces once
XLA High-Level Operations (HLO)
↓ XLA compiler: fusion, tiling, layout
TPU HBM3 machine instructions
↓ ICI for cross-chip ops
Result: 36-qubit state vector across 64 HBM3 pools
The entire simulation — state init, gate application, observables, and gradient accumulation — runs inside one monolithic XLA kernel. No Python re-entry between gates. No CPU↔device copies mid-circuit.
Key Results
- Gate latency (after JIT warm-up): ~0.01ms
- Gradient of 100-parameter circuit: ~8ms (one
jax.gradbackward pass vs ~10s with PennyLane parameter-shift) - VQE H₂ ground state: converged to −1.13627 Hartree — within 1.6×10⁻³ Hartree of the FCI reference (quantum chemical accuracy)
- Grover's at 36 qubits: marked state probability peaks at ~0.998 after 205,887 iterations
- Barren plateau: gradient variance confirmed to decay as 2⁻ⁿ up to 24 qubits (consistent with McClean et al. 2018)
Try It for Free
The repo includes a Colab notebook that runs on a free Colab TPU v5e-1 — no Google Cloud account needed.
Six self-contained experiments: GHZ state prep, VQC XOR classifier, VQE H₂, QAOA MaxCut, Monte Carlo noise simulation, and noisy NISQ fidelity decay.
Why Pure JAX — Not PennyLane, Qiskit, or PyTorch?
| Framework | Why it fails at this scale |
|---|---|
| PennyLane | Python-level gate dispatch; no monolithic XLA kernel; parameter-shift needs 2× evals per gradient |
| Qiskit-Aer | C++/CUDA, single device only; zero TPU support |
| TensorFlow Quantum | Cirq simulation runs on CPU — only classical ops hit TPU |
PyTorch + torch_xla |
Bridge adds per-op latency at every Python→XLA boundary |
JAX is XLA. There is no bridge. Python code is the source language for a JIT compiler that targets TPU silicon directly.
References
- Bradbury et al. (2018). JAX: composable transformations of Python+NumPy programs.
- McClean et al. (2018). Barren plateaus in quantum neural network training landscapes. Nature Communications.
- Peruzzo et al. (2014). A variational eigenvalue solver on a photonic quantum processor. Nature Communications.
- Farhi et al. (2014). A quantum approximate optimization algorithm. arXiv:1411.7308.
- Vidal (2003). Efficient classical simulation of slightly entangled quantum computations. PRL.
Supported by the Google TPU Research Cloud (TRC) program.
Apache 2.0 License · github.com/AshiteshSingh/Tpu-Accelerated-Quantum-JAX