JAX Compile vs Execution
The Benchmark You Actually Need — compile latency matters in short scientific ML loops.
2026-02-08 · 6 min read
Many ML "benchmarks" measure only steady-state throughput. That's the right metric if you run a model for hours. But in scientific pipelines, a lot of work happens in short loops: hyperparameter sweeps, repeated small fits, interactive analysis iterations, short training runs for ablations. In those settings, compile latency can dominate the total cost.
Protocol and artifacts: Public Benchmarks. Validation pack: Validation Report. Suite runbook (repo path): docs/benchmarks/suites/ml.md.
Abstract. The benchmark we actually need is not "examples/sec at steady state". It is a two-regime measurement: time-to-first-result (TTFR) in a fresh process (cold start) and warm throughput once compilation and caches are populated. To make that publishable, we treat each run as a snapshot with raw distributions, a pinned environment, and an explicit cache policy.
1.Two regimes, two metrics
Regime A: cold-start / time-to-first-result
This includes: import time, graph tracing, compilation, first execution. This is the metric that matters for short runs.
Regime B: warm throughput
Steady-state execution when compilation caches are populated, kernels are loaded, and the process is already running. This matters for long runs.
Publishing one number without specifying the regime is not meaningful.
2.Definitions: what we time
For publishable runs we report component timings, not just a single aggregate:
| Metric | What It Measures |
|---|---|
| t_import | Importing the runtime stack (best-effort proxy for "startup cost") |
| t_first_call | First call that triggers tracing + compilation + first execution |
| t_second_call | Second call on the same shapes (warm execution proxy) |
| t_steady_state | Distribution over repeated warm calls (with a declared sync policy) |
For GPU-backed runtimes, the benchmark must explicitly synchronize (or "block until ready") to avoid timing only CPU dispatch.
3.Benchmark protocol (what must be specified)
- ›Whether the process is fresh (new process) or persistent
- ›Cache state (clean cache vs warmed cache)
- ›Dataset sizes and data layout
- ›What is included/excluded from the measurement window
For cold-start benchmarks, the only honest baseline is a fresh process with a declared cache policy.
4.Cache policy: "cold" has multiple meanings
Compile-latency results are extremely sensitive to caching. So instead of pretending there is one "cold start", we publish explicit modes:
- ›Cold process, warm cache — new Python process, but persistent compilation cache allowed
- ›Cold process, cold cache — new Python process, and compilation cache directory is empty
- ›Warm process — same long-lived process (typical for interactive analysis)
If we can't reliably clear a cache (because the runtime stores it outside our control), we treat that as a constraint and publish the limitation.
5.What we will publish
For each snapshot:
- ›Cold-start distributions (not just a single timing)
- ›Warm-throughput distributions
- ›Baseline manifest (versions, hardware, settings)
- ›Cache policy and harness version
Publishing contract: Public Benchmarks. Validation pack artifact: Validation Report.
6.Why this is part of the NextStat benchmark program
NextStat's core value proposition is not "wins a microbenchmark". It's that entire scientific pipelines become faster, more reproducible, and easier to audit. Compile-vs-execution tradeoffs are part of that story when ML is inside the loop.
A.Appendix: seed harness status (today)
The public-benchmarks seed repo includes a runnable ML suite under benchmarks/nextstat-public-benchmarks/suites/ml/.
- ›Measures cold-start TTFR (import + first call) using multiple fresh processes
- ›Measures warm-call throughput as a per-call distribution
- ›Runs with NumPy by default and includes optional
jax_jit_cpu_*cases:warnwithreason="missing_dependency: jax"if JAX is not installed. - ›GPU-intended path:
jax_jit_gpu_*cases run only on CUDA-capable runners. If the JAX backend cannot provide a GPU platform, the suite recordswarnwithreason="gpu_unavailable".
JAX dependency templates live in benchmarks/nextstat-public-benchmarks/env/python/requirements-ml-jax-cpu.txt and benchmarks/nextstat-public-benchmarks/env/python/requirements-ml-jax-cuda12.txt. The suite records best-effort device metadata via the runtime (e.g. JAX device platform, kind, and count). For JAX GPU compilation, the suite also tries to prefer a CUDA toolkit ptxas if one is present on the host.
Reproducible seed run (suite runner writes per-case JSON + suite index):
python benchmarks/nextstat-public-benchmarks/suites/ml/suite.py \
--deterministic \
--out-dir benchmarks/nextstat-public-benchmarks/out/mlPublished JSON artifact contracts (ML suite):
- ›Per-case results:
nextstat.ml_benchmark_result.v1 - ›Suite index:
nextstat.ml_benchmark_suite_result.v1
