NextStatNextStat

NextStat v0.9.6: Zero-JIT Tax, ESS/grad, and Convergence

Final canonical results with strict backend split: Metal, CUDA V100, and EPYC CPU.

NextStatMAMSLAPSBlackJAXBenchmarksBayesian Inference

2026-02-18 · 18 min read


TL;DR (only final v0.9.6 numbers)

  • LAPS Metal: final matrix is 8/8 ok with Div%=0 across all cases.
  • CUDA V100 parity (3-seed median, canonical): NextStat LAPS keeps zero runtime JIT tax (cold ≈ warm), while BlackJAX cold-start is 11.8–90.6 s in this setup.
  • Time-to-result in real edit cycles: when model structure/shape changes (priors, parameterization, dimensions), JAX/XLA workflows commonly recompile; that compile wall repeats across iterations. NextStat AOT kernels keep iteration latency close to warm-path behavior.
  • ESS/grad (V100 sampling, report-chain normalized): on matched targets, NS LAPS ranges from 2.46× to 45.11× vs BlackJAX in this canonical run.
  • CPU funnel fairness fixed: FunnelNcpModel (NCP) is 6/6 ok across 3 seeds on EPYC for both MAMS and NUTS; centered funnel remains a known pathological control.

0.MAMS and LAPS in one page (for NUTS users)

If you already know NUTS, the key mental model is:

  • NUTS is dynamic HMC with per-transition tree expansion (adaptive path length each step).
  • MAMS (Metropolis Adjusted Microcanonical Sampler) uses microcanonical/isokinetic dynamics with a fixed trajectory length in preconditioned space.
  • LAPS is the massively parallel GPU/Metal sampler path that applies MAMS-style dynamics across thousands of chains with hardware-oriented execution.

What is new relative to standard NUTS implementations

  • Fixed-shape transition kernels are easier to run efficiently on SIMD/GPU backends than recursive tree building.
  • The sampler is built for very large chain parallelism (4096 chains is a normal operating point in this report).
  • In this release, MAMS/LAPS use microcanonical-aware diagnostics gates and explicit NCP-vs-centered disclosure for funnel-like geometries.

Where this design is strongest

  • Time-to-result in iterative workflows (no repeated JIT compile wall in our AOT path).
  • Hierarchical and multi-scale targets where we observe high ESS/grad in the matched-target runs.
  • Local acceleration paths (Metal) and server GPU paths (CUDA) with the same sampler semantics.

Where NUTS or XLA-heavy stacks can still win

  • Simple low-dimensional targets after warm compilation (higher raw warm-path throughput is possible).
  • Some smooth concentrated posteriors on CPU (for example large-n logistic in this report).

This report is therefore not claiming a universal sampler winner; it documents where each execution model is stronger, with explicit fairness caveats.


1.Protocol and fairness rules

Backend split (no mixing)

  • LAPS Metal results are reported separately from LAPS CUDA.
  • CPU (EPYC) results are reported separately from GPU.
  • No cross-backend table mixes Metal and CUDA values.

Multi-run aggregation (anti-cherry-pick)

  • Final CPU/GPU comparison tables use three independent seeds: 42, 123, 777.
  • We report median as the primary number (robust to outliers), with mean ± std shown where useful.
  • Single-seed values are kept only in raw artifacts, not as headline claims.

Time-to-result in iterative modeling

  • Bayesian work is an edit cycle (change prior, add covariate, reparameterize, rerun).
  • In JIT/XLA stacks, graph/shape changes often invalidate compiled executables and trigger recompilation, so cold-start costs recur during exploration.
  • NextStat uses AOT-compiled Rust/CUDA kernels, so wall-clock iteration time stays near warm-path latency even as models evolve.

Funnel parameterization disclosure

For std_normal, eight_schools, and glm_logistic, both engines sample the same target density (identical log-density functions).

For neal_funnel_10d, the parameterizations differ in the V100 parity run:

  • NS LAPS samples the Non-Centered Parameterization (NCP): log p(v, z) = -v²/18 - 0.5 · Σ(z_i²).
  • BlackJAX samples the centered parameterization: log p(v, x) = -v²/18 - 0.5·exp(-v)·Σ(x_i²) - 0.5·(d-1)·v.

These are not the same optimization problem. The centered funnel has position-dependent curvature that is fundamentally harder for fixed-metric samplers. The neal_funnel rows in section 3 and the Appendix therefore reflect both algorithmic and parameterization differences and should not be interpreted as a like-for-like throughput comparison. They are retained to show convergence behavior (NS converges, BlackJAX does not) but excluded from headline ESS/grad claims.

  • CPU now has explicit FunnelNcpModel for fair NCP comparisons (section 6).
  • Centered FunnelModel remains a separate hard-geometry control.

Algorithmic changes in v0.9.6

  • MAMS uses eps_jitter=0.1 by default (±10% uniform step-size noise per transition), breaking fixed-L periodicity and improving tail ESS on periodic targets like std_normal.
  • Default trajectory length: L = √d in preconditioned space (Robnik et al. 2023).

BlackJAX configuration (V100 parity run)

To preempt concerns about competitor misconfiguration, the full BlackJAX setup:

  • Sampler: blackjax.adjusted_mclmc with isokinetic_mclachlan integrator.
  • Warmup: built-in blackjax.adjusted_mclmc_find_L_and_step_size (500 iterations, single-chain warmup, target_accept=0.9, diagonal_preconditioning=True).
  • Trajectory length: tuned by BlackJAX warmup (L, step_size), then n_steps = round(L / step_size).
  • Mass matrix: sampling uses tuned inverse_mass_matrix from BlackJAX warmup.
  • Multi-chain: 4096 chains, jax.vmap(run_chain), block_until_ready() + device_get() for fair host-side timing.
  • Cold/warm: cold = first vmap call (includes XLA compilation); warm = second call with cached JIT.
  • Init: chains are initialized around the warmed single-chain state (warmed_state.position + N(0, 0.5)).
  • Seed: 42 (cold), 1042 (warm).
  • Seeds: 42, 123, 777 (for each seed, warm run uses seed + 1000 key path).
  • Source: benchmarks/gpu_triple_bench.py, functions _blackjax_builtin_warmup() and bench_blackjax().

V100 parity run config (NS LAPS, 3 seeds)

  • n_chains=4096, n_warmup=500, n_samples=1000, report_chains=256, seeds=42/123/777.
  • Section 3/4 report median across 3 seeds.
  • R̂ computed from 256 report chains (512 half-chains), giving materially tighter diagnostics than the earlier 64-chain reporting.

2.Canonical LAPS Metal results (final)

Hardware: Apple M5, 10 GPU cores, 24 GB unified memory.

ModelChainsw+sWall (s)ESS/sDiv%Status
std_normal_10d256100+1000.141.1753,6800.0ok
std_normal_10d_4096ch4096200+5000.091.03812,5850.0ok
eight_schools4096500+20000.251.007124,7050.0ok
neal_funnel_10d4096500+20000.311.00622,7910.0ok
neal_funnel_riemannian4096500+20000.271.01014,1420.0ok
glm_logistic_n200_p64096500+20002.151.0054,6470.0ok
glm_logistic_n1000_p204096500+200034.321.0102480.0ok
glm_logistic_n5000_p204096500+200059.061.0151100.0ok

Note: the 256-chain std_normal_10d row (R̂ 1.175) demonstrates the minimum viable chain count; the 4096-chain row is the canonical benchmark configuration.

In practice, this shows that local Apple Silicon can run datacenter-style massively parallel inference workloads with strong convergence diagnostics, without CUDA setup or JIT compile latency.

Quality gate policy used for this matrix:

  • MAMS/LAPS: QualityGates::microcanonical() (EBFMI is warning-only).
  • NUTS: strict default gate preserved (EBFMI fail < 0.20).

3.CUDA V100 parity run (LAPS vs BlackJAX, 3-seed median)

Hardware: Tesla V100-PCIE-16GB.

ModelEngineCold (s)Warm (s)min ESSESS/s (warm)
std_normal_10dNS LAPS GPU1.5540.240159,753680,7851.0062
std_normal_10dBlackJAX GPU14.0640.2251,7717,8471.1010
eight_schoolsNS LAPS GPU1.4250.24175,682314,4761.0065
eight_schoolsBlackJAX GPU11.7690.34628,02075,2551.0080
neal_funnel_10dNS LAPS GPU1.4040.25954,768211,5811.0083
neal_funnel_10dBlackJAX GPU15.5170.4127061,7591.2732
glm_logisticNS LAPS GPU23.7919.25477,8528,4151.0086
glm_logisticBlackJAX GPU90.61577.76519,5832261.0122

Reading this table

  • Zero JIT tax: NS LAPS cold remains close to warm (AOT-compiled Rust/CUDA). BlackJAX cold-start is materially higher in this setup (11.8–90.6 s).
  • Warm-start throughput (canonical run): NS LAPS is higher on all matched targets in this setup.
  • neal_funnel is not a like-for-like comparison (see section 1: NS samples NCP, BlackJAX samples centered). In these 3 seeds, BlackJAX centered-funnel R̂ ranges 1.260–1.275 and remains weaker than NS NCP, which is expected from parameterization difficulty, not a sampler defect.

4.ESS/grad on V100 (sampling phase, matched targets only, 3-seed median)

ModelNS LAPS ESS/gradBlackJAX ESS/gradRatio (NS/BJ)
std_normal_10d0.3120170.00691745.11×
eight_schools0.0985440.0401042.46×
glm_logistic0.1013700.00263838.43×

neal_funnel is excluded from this table because the two engines sample different parameterizations (see section 1).

A major contributor to the change vs earlier drafts is denominator normalization: both engines now compute ESS/grad on the same report_chains budget.

The practical interpretation for this canonical run is:

  • NS LAPS achieves higher ESS/grad across all matched targets reported here.
  • glm_logistic remains the most expensive target for both engines in absolute wall time.

5.LAPS quality verification on V100 (report_chains=256)

Separate run with tighter diagnostics (report_chains=256 → 512 half-chains → SE(R̂) ≈ 0.015).

ModelR̂ maxESS_tail minE-BFMIStatus
StdNormal 10d1.017518,9471.035ok
NealFunnel NCP 10d1.012648,2020.970ok
GLM n=5000 p=201.014949,6600.863ok
GLM n=200 p=61.004455,4230.449ok
NealFunnel centered 10d1.29142570.000fail (expected control)

This confirms that LAPS convergence is solid when measured with sufficient diagnostic chains. The parity-run R̂ values (section 3, report_chains=256) are directly comparable to the quality run.


6.CPU EPYC (MAMS vs NUTS) and funnel parity fix

Hardware: AMD EPYC 7502P, 32 cores / 64 threads, 128 GB RAM (Hetzner dedicated).

EPYC multi-seed summary (42/123/777, 3-run aggregate)

Config: n_chains=4, n_warmup=1000, n_samples=1000, eps_jitter=0.1.

ModelMAMS ESS/s (median)MAMS (mean ± std)NUTS ESS/s (median)NUTS (mean ± std)Ratio
std_normal_d2129,592137,761 ± 75,444200,841200,329 ± 13,4600.645
std_normal_d10100,420103,641 ± 4,69285,15995,604 ± 15,8151.179
std_normal_d5013,00713,150 ± 86728,30526,113 ± 3,6380.460
eight_schools98,20193,408 ± 8,22748,57746,018 ± 5,7812.022
logreg_n1000_p10714711 ± 103,8963,914 ± 280.183
logreg_n5000_p203736 ± 4186190 ± 110.200

Observed pattern in this real-run matrix:

Casedimn_dataRatio MAMS/NUTSLeader
std_normal_d220.645NUTS
eight_schools1082.022MAMS
std_normal_d10101.179MAMS
std_normal_d50500.460NUTS
logreg_n1000_p101010000.183NUTS
logreg_n5000_p202050000.200NUTS

Why large-n logistic favors NUTS in this CPU protocol

  • Gradient cost scales with O(n·p) per leapfrog step; with n=5000, p=20, each extra step is expensive.
  • NUTS can terminate trajectories early via U-turn, while MAMS uses fixed trajectory length in preconditioned space.
  • As n grows, posterior geometry becomes closer to well-conditioned Gaussian; this is a strong regime for NUTS with adaptive path length.

Practical recommendation

  • Prefer MAMS for hierarchical / multi-scale geometries.
  • Prefer NUTS for large-n GLM-like posteriors on CPU.
  • A reasonable product direction is an explicit method="auto" heuristic (e.g. GLM with large n → NUTS; hierarchical/funnel-like targets → MAMS), while keeping manual override.

Funnel parameterization control (EPYC, 3 seeds)

Config: n_chains=4, n_warmup=1000, n_samples=1000.

MAMS:

ModelSeedESS_tailEBFMIStatus
Centered (FunnelModel)421.0785221n/aok
Centered (FunnelModel)1231.035331n/afail
Centered (FunnelModel)7771.0781244n/aok
NCP (FunnelNcpModel)421.00671,914n/aok
NCP (FunnelNcpModel)1231.01001,897n/aok
NCP (FunnelNcpModel)7771.00481,924n/aok

NUTS:

ModelSeedESS_tailEBFMIStatus
Centered (FunnelModel)422.384414n/afail
Centered (FunnelModel)1231.363672n/afail
Centered (FunnelModel)7771.948017n/afail
NCP (FunnelNcpModel)421.00262,516n/aok
NCP (FunnelNcpModel)1231.00271,604n/aok
NCP (FunnelNcpModel)7771.00242,385n/aok

Interpretation:

  • NCP is 6/6 ok across all seeds for both MAMS and NUTS. ESS_tail ranges 1,604–2,516 (NUTS) and 1,897–1,924 (MAMS).
  • Centered is 3/3 fail for NUTS and 1/3 fail for MAMS.
  • The previous CPU funnel mismatch was methodological (centered vs NCP), not a "CPU is weak" issue.
  • FunnelNcpModel is the recommended benchmark parameterization for CPU/GPU parity.
  • Centered FunnelModel is kept as a known pathological control; this is a limitation demonstration, not a product regression.
  • In these EPYC funnel-control artifacts, EBFMI is not exported (n/a in the tables), so pass/fail here is based on R̂/ESS quality gates.

7.Reproducibility and environment metadata

  • V100 benchmark JSON includes top-level environment snapshot (python, jax, cuda, gpu, package versions).
  • EPYC suite stores hardware/config/seed metadata and per-case metrics; full package-level environment snapshot is currently only present in the V100 parity JSON.

Artifacts (all in docs/blog/artifacts/v096-zero-jit-tax/):

  • V100 3-seed matrix (canonical): v100-multi-seed-matrix-canonical.json
  • V100 chart data (canonical): v100-parity-chart-data-canonical.csv, v100-essgrad-ratio-canonical.csv
  • V100 raw 3-seed parity run (canonical): v100_v096_builtinwarmup_3seed_20260218T224654Z/seed_42/gpu_triple_bench.json, seed_123/..., seed_777/...
  • V100 funnel addendum raw runs: v100_ns_funnel_3seed_20260218T231337Z/*, v100_bj_funnel_builtin3seed_20260218T231204Z/*
  • V100 quality run: v100-quality-report256-5models.json
  • V100 + EPYC refresh note: 2026-02-17-v096-refresh-v100-epyc.md
  • EPYC multi-seed matrix: epyc-multi-seed-matrix.json
  • EPYC suite output: epyc-mams-suite.json
  • EPYC funnel-control: epyc-funnel-control-3seed.json

A.Appendix: V100 neal_funnel (different parameterizations)

Retained for transparency. These rows compare NS LAPS (NCP) against BlackJAX (centered) — not a like-for-like comparison.

MetricNS LAPS (NCP)BlackJAX (centered)
Cold (s)1.40415.517
Warm (s)0.2590.412
min_ESS54,768706
ESS/s (warm)211,5811,759
1.00831.2732
ESS/grad0.0713120.000710

BlackJAX's non-convergence on the centered funnel is expected (see section 6: even NUTS fails 3/3 on centered funnel with 4 chains and standard budget). This comparison primarily demonstrates that NS's default NCP dispatch produces converged results where the centered parameterization does not.


References

  • Robnik, Cohn-Gordon, Seljak. Metropolis Adjusted Microcanonical Hamiltonian Monte Carlo (MAMS). arXiv:2503.01707
  • BlackJAX. Composable Bayesian inference in JAX. arXiv:2402.10797

Related reading