Skip to main content

Command Palette

Search for a command to run...

Scaling Quantum State-Vector Simulation to 36 Qubits on Google Cloud TPU

Ashitesh Singh Independent Researcher, Supported by Google TPU Research Cloud (TRC) Program GitHub: https://github.com/AshiteshSingh/Tpu-Accelerated-Quantum-JAX

Updated
4 min read

The Problem

Simulating quantum circuits on a classical computer is exponentially hard. An n-qubit state vector holds 2ⁿ complex amplitudes — at 30 qubits that's 8 GB, at 33 qubits 64 GB, at 36 qubits 549 GB. No single GPU or TPU chip can hold that.

Most frameworks (Qiskit, PennyLane) give up around 29–30 qubits for two reasons:

  1. Memory wall — the full state vector doesn't fit on one device
  2. XLA graph explosion — deep circuits compiled with Python for-loops unroll into millions of XLA graph nodes, crashing the compiler host before a single gate runs

I built a simulator in 100% pure JAX that solves both — scaling to 36 qubits (549 GB) across a 64-chip Google Cloud TPU v6e cluster, with full reverse-mode autodiff and sub-millisecond gate latency.


Solution 1: Distributed State Vector with PositionalSharding

Instead of storing all 2ⁿ amplitudes on one chip, I split the state vector across the entire TPU mesh using JAX's native sharding API:

mesh          = Mesh(mesh_utils.create_device_mesh((16,)), axis_names=("chips",))
state_sharding = NamedSharding(mesh, P("chips", None))

# 33-qubit state: shape [16, 2^29] — each chip holds 2^29 amplitudes
state = jax.jit(init_ground_state, out_shardings=state_sharding)()

Each chip holds exactly its slice of the state vector in its own HBM pool. JAX's Inter-Chip Interconnects (ICI) handle cross-chip gate operations automatically. The host CPU never holds the full array.

Cluster Qubits State Vector Size Per-Chip
Single GPU 29 4.29 GB
TPU v5e-16 33 64 GB 4 GB
TPU v6e-64 36 549.76 GB 8.59 GB

Solution 2: lax.fori_loop for O(1) Compiler Graph Size

A Python for-loop inside @jax.jit statically unrolls every iteration into a separate XLA graph node. At 100 circuit layers × 30 qubits, the HLO graph has millions of nodes — the XLA compiler host OOMs during compilation, before execution even starts.

jax.lax.fori_loop compiles the loop body once as a single While HLO node:

# ✅ O(1) XLA graph — compiles once, runs D times at hardware speed
state = lax.fori_loop(0, D, circuit_body, state)

# ❌ O(D) XLA graph — explodes at 100+ layers
for i in range(D):
    state = apply_gate(state, ...)

Graph size stays constant regardless of circuit depth. No compiler OOM.


The Full Compilation Stack

Python (JAX / jnp)
      ↓  @jax.jit traces once
XLA High-Level Operations (HLO)
      ↓  XLA compiler: fusion, tiling, layout
TPU HBM3 machine instructions
      ↓  ICI for cross-chip ops
Result: 36-qubit state vector across 64 HBM3 pools

The entire simulation — state init, gate application, observables, and gradient accumulation — runs inside one monolithic XLA kernel. No Python re-entry between gates. No CPU↔device copies mid-circuit.


Key Results

  • Gate latency (after JIT warm-up): ~0.01ms
  • Gradient of 100-parameter circuit: ~8ms (one jax.grad backward pass vs ~10s with PennyLane parameter-shift)
  • VQE H₂ ground state: converged to −1.13627 Hartree — within 1.6×10⁻³ Hartree of the FCI reference (quantum chemical accuracy)
  • Grover's at 36 qubits: marked state probability peaks at ~0.998 after 205,887 iterations
  • Barren plateau: gradient variance confirmed to decay as 2⁻ⁿ up to 24 qubits (consistent with McClean et al. 2018)

Try It for Free

The repo includes a Colab notebook that runs on a free Colab TPU v5e-1 — no Google Cloud account needed.

Open In Colab

Six self-contained experiments: GHZ state prep, VQC XOR classifier, VQE H₂, QAOA MaxCut, Monte Carlo noise simulation, and noisy NISQ fidelity decay.


Why Pure JAX — Not PennyLane, Qiskit, or PyTorch?

Framework Why it fails at this scale
PennyLane Python-level gate dispatch; no monolithic XLA kernel; parameter-shift needs 2× evals per gradient
Qiskit-Aer C++/CUDA, single device only; zero TPU support
TensorFlow Quantum Cirq simulation runs on CPU — only classical ops hit TPU
PyTorch + torch_xla Bridge adds per-op latency at every Python→XLA boundary

JAX is XLA. There is no bridge. Python code is the source language for a JIT compiler that targets TPU silicon directly.


References

  1. Bradbury et al. (2018). JAX: composable transformations of Python+NumPy programs.
  2. McClean et al. (2018). Barren plateaus in quantum neural network training landscapes. Nature Communications.
  3. Peruzzo et al. (2014). A variational eigenvalue solver on a photonic quantum processor. Nature Communications.
  4. Farhi et al. (2014). A quantum approximate optimization algorithm. arXiv:1411.7308.
  5. Vidal (2003). Efficient classical simulation of slightly entangled quantum computations. PRL.

Supported by the Google TPU Research Cloud (TRC) program.
Apache 2.0 License · github.com/AshiteshSingh/Tpu-Accelerated-Quantum-JAX

C
CICI Nico1h ago

I initially became concerned due to delays with my withdrawal and the additional requests that came up during the process. At the time, it left me uncertain about whether I would be able to access my funds. The issue was eventually resolved after seeking assistance through Trazevault, although the overall experience was more stressful and drawn out than I had expected. Sharing this in case others have experienced something similar.