NextStatNextStat

Neural Surrogate Distillation

Train Nanosecond Surrogates from NextStat's Likelihood

NextStat can serve as a high-fidelity oracle for training neural network surrogates of the HistFactory likelihood surface. The surrogate runs in nanoseconds instead of milliseconds, enabling real-time MCMC, global EFT fits, or interactive dashboards.

The Workflow

Sample parameter space (Sobol / LHS) → Evaluate NLL + gradient via NextStat GPU → Train a small MLP → Deploy the surrogate for real-time inference. NextStat provides the ground truth; the surrogate provides the speed.

Quick Start

import nextstat
from nextstat.distill import generate_dataset, train_mlp_surrogate, predict_nll

model = nextstat.from_pyhf(workspace_json)

# 1. Generate 100k (params, NLL, gradient) tuples
ds = generate_dataset(model, n_samples=100_000, method="sobol")
print(f"{ds.n_samples} points, {ds.n_params} params")
print(f"NLL range: [{ds.nll.min():.1f}, {ds.nll.max():.1f}]")

# 2. Train a surrogate MLP (built-in convenience)
surrogate = train_mlp_surrogate(ds, epochs=100, device="cuda")

# 3. Predict NLL at new points (nanoseconds per eval)
import numpy as np
test_params = np.array(model.parameter_init())
pred_nll = predict_nll(surrogate, test_params)
print(f"Surrogate NLL: {pred_nll:.2f}")

Sampling Methods

MethodCoverageBest For
sobolQuasi-random, low discrepancyDefault. Best coverage with fewest points.
lhsStratified per dimensionGood coverage, no power-of-2 requirement.
uniformPure randomBaseline comparison.
gaussianConcentrated near MLEFine-tuning near the minimum. Focused surrogates.

Custom Training Loop

For production use, convert the dataset to PyTorch and write your own training:

from nextstat.distill import generate_dataset, to_torch_dataset
import torch
import torch.nn.functional as F

ds = generate_dataset(model, n_samples=500_000, method="sobol")
train_ds = to_torch_dataset(ds)
loader = torch.utils.data.DataLoader(train_ds, batch_size=4096, shuffle=True)

surrogate = torch.nn.Sequential(
    torch.nn.Linear(ds.n_params, 256), torch.nn.SiLU(),
    torch.nn.Linear(256, 256), torch.nn.SiLU(),
    torch.nn.Linear(256, 1),
).cuda()

optimizer = torch.optim.Adam(surrogate.parameters(), lr=1e-3)

for epoch in range(100):
    for params_batch, nll_batch, grad_batch in loader:
        params_batch = params_batch.cuda()
        nll_batch = nll_batch.cuda()
        grad_batch = grad_batch.cuda()

        pred = surrogate(params_batch).squeeze()
        loss = F.mse_loss(pred, nll_batch)

        # Optional: gradient-informed training (Sobolev loss)
        params_batch.requires_grad_(True)
        pred_g = surrogate(params_batch).squeeze()
        pred_grad = torch.autograd.grad(pred_g.sum(), params_batch, create_graph=True)[0]
        loss += 0.1 * F.mse_loss(pred_grad, grad_batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Export Formats

FunctionFormatUse Case
to_torch_dataset(ds)TensorDatasetPyTorch DataLoader for training
to_numpy(ds)dict of ndarraySciPy, sklearn, JAX
to_npz(ds, path).npz (compressed)Persistent storage, reproducibility
to_parquet(ds, path).parquet (zstd)Polars, DuckDB, Spark pipelines
from_npz(path).npz → DatasetReload a previously saved dataset

Validation

Always validate the surrogate against NextStat's exact computation:

import numpy as np
from nextstat.distill import predict_nll

# Compare surrogate vs exact at random points
test_params = np.random.default_rng(99).uniform(
    ds.parameter_bounds[:, 0], ds.parameter_bounds[:, 1],
    size=(1000, ds.n_params)
)

pred = predict_nll(surrogate, test_params)
exact = np.array([model.nll(p.tolist()) for p in test_params])

rmse = np.sqrt(np.mean((pred - exact) ** 2))
max_err = np.max(np.abs(pred - exact))
print(f"RMSE: {rmse:.4f}, Max error: {max_err:.4f}")

When to Use Surrogates

  • MCMC with many parameters — the surrogate replaces expensive NLL calls in the inner loop of HMC/NUTS.
  • Interactive dashboards — real-time likelihood contours as the user drags sliders.
  • Global EFT fits — scanning 100+ Wilson coefficients where exact fits are prohibitive.
  • NOT recommended for final results — always validate with NextStat's exact fit. The surrogate is for exploration, the exact computation is for publication.