Neural Surrogate Distillation
Train Nanosecond Surrogates from NextStat's Likelihood
NextStat can serve as a high-fidelity oracle for training neural network surrogates of the HistFactory likelihood surface. The surrogate runs in nanoseconds instead of milliseconds, enabling real-time MCMC, global EFT fits, or interactive dashboards.
The Workflow
Sample parameter space (Sobol / LHS) → Evaluate NLL + gradient via NextStat GPU → Train a small MLP → Deploy the surrogate for real-time inference. NextStat provides the ground truth; the surrogate provides the speed.
Quick Start
import nextstat
from nextstat.distill import generate_dataset, train_mlp_surrogate, predict_nll
model = nextstat.from_pyhf(workspace_json)
# 1. Generate 100k (params, NLL, gradient) tuples
ds = generate_dataset(model, n_samples=100_000, method="sobol")
print(f"{ds.n_samples} points, {ds.n_params} params")
print(f"NLL range: [{ds.nll.min():.1f}, {ds.nll.max():.1f}]")
# 2. Train a surrogate MLP (built-in convenience)
surrogate = train_mlp_surrogate(ds, epochs=100, device="cuda")
# 3. Predict NLL at new points (nanoseconds per eval)
import numpy as np
test_params = np.array(model.parameter_init())
pred_nll = predict_nll(surrogate, test_params)
print(f"Surrogate NLL: {pred_nll:.2f}")Sampling Methods
| Method | Coverage | Best For |
|---|---|---|
| sobol | Quasi-random, low discrepancy | Default. Best coverage with fewest points. |
| lhs | Stratified per dimension | Good coverage, no power-of-2 requirement. |
| uniform | Pure random | Baseline comparison. |
| gaussian | Concentrated near MLE | Fine-tuning near the minimum. Focused surrogates. |
Custom Training Loop
For production use, convert the dataset to PyTorch and write your own training:
from nextstat.distill import generate_dataset, to_torch_dataset
import torch
import torch.nn.functional as F
ds = generate_dataset(model, n_samples=500_000, method="sobol")
train_ds = to_torch_dataset(ds)
loader = torch.utils.data.DataLoader(train_ds, batch_size=4096, shuffle=True)
surrogate = torch.nn.Sequential(
torch.nn.Linear(ds.n_params, 256), torch.nn.SiLU(),
torch.nn.Linear(256, 256), torch.nn.SiLU(),
torch.nn.Linear(256, 1),
).cuda()
optimizer = torch.optim.Adam(surrogate.parameters(), lr=1e-3)
for epoch in range(100):
for params_batch, nll_batch, grad_batch in loader:
params_batch = params_batch.cuda()
nll_batch = nll_batch.cuda()
grad_batch = grad_batch.cuda()
pred = surrogate(params_batch).squeeze()
loss = F.mse_loss(pred, nll_batch)
# Optional: gradient-informed training (Sobolev loss)
params_batch.requires_grad_(True)
pred_g = surrogate(params_batch).squeeze()
pred_grad = torch.autograd.grad(pred_g.sum(), params_batch, create_graph=True)[0]
loss += 0.1 * F.mse_loss(pred_grad, grad_batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()Export Formats
| Function | Format | Use Case |
|---|---|---|
| to_torch_dataset(ds) | TensorDataset | PyTorch DataLoader for training |
| to_numpy(ds) | dict of ndarray | SciPy, sklearn, JAX |
| to_npz(ds, path) | .npz (compressed) | Persistent storage, reproducibility |
| to_parquet(ds, path) | .parquet (zstd) | Polars, DuckDB, Spark pipelines |
| from_npz(path) | .npz → Dataset | Reload a previously saved dataset |
Validation
Always validate the surrogate against NextStat's exact computation:
import numpy as np
from nextstat.distill import predict_nll
# Compare surrogate vs exact at random points
test_params = np.random.default_rng(99).uniform(
ds.parameter_bounds[:, 0], ds.parameter_bounds[:, 1],
size=(1000, ds.n_params)
)
pred = predict_nll(surrogate, test_params)
exact = np.array([model.nll(p.tolist()) for p in test_params])
rmse = np.sqrt(np.mean((pred - exact) ** 2))
max_err = np.max(np.abs(pred - exact))
print(f"RMSE: {rmse:.4f}, Max error: {max_err:.4f}")When to Use Surrogates
- MCMC with many parameters — the surrogate replaces expensive NLL calls in the inner loop of HMC/NUTS.
- Interactive dashboards — real-time likelihood contours as the user drags sliders.
- Global EFT fits — scanning 100+ Wilson coefficients where exact fits are prohibitive.
- NOT recommended for final results — always validate with NextStat's exact fit. The surrogate is for exploration, the exact computation is for publication.
