Training with SignificanceLoss
End-to-End Differentiable Pipeline
This guide walks through training a PyTorch neural network where the loss function is the profiled discovery significance Z₀. The gradient flows from the statistical test, through the differentiable histogram, all the way back into your network weights.
Architecture Overview
Neural Network NextStat (Rust/CUDA)
───────────── ────────────────────
Input features HistFactory model
│ (systematics, backgrounds)
▼ │
Classifier(x) → scores │
│ │
▼ │
SoftHistogram(scores) → bins ──────────▶ SignificanceLoss
│ │
│ ◄─── ∂(-Z₀)/∂bins ─┘
▼
loss.backward() → ∂loss/∂weights → optimizer.step()Step 1: Prepare Your Statistical Model
NextStat uses the HistFactory format — the same JSON that pyhf produces. If you already have a pyhf workspace, you can load it directly.
import json
import nextstat
# Load a pyhf-style workspace JSON
with open("workspace.json") as f:
ws = json.load(f)
model = nextstat.from_pyhf(ws)
# model contains: channels, samples, systematics, observed dataStep 2: Create the Loss Function
from nextstat.torch import SignificanceLoss
# SignificanceLoss wraps profiled Z₀ with familiar __call__ semantics.
# By default it returns -Z₀ so SGD minimisation maximises significance.
loss_fn = SignificanceLoss(
model,
signal_sample_name="signal", # which sample the NN controls
device="auto", # "cuda", "metal", or "auto"
negate=True, # -Z₀ for minimisation (default)
eps=1e-12, # numerical stability in sqrt
)
print(f"Signal bins: {loss_fn.n_bins}") # e.g. 10
print(f"Nuisance params: {loss_fn.n_params}") # e.g. 23Step 3: Differentiable Binning
A standard histogram is not differentiable (hard bin edges have zero gradient).SoftHistogram solves this with Gaussian KDE or sigmoid approximations.
from nextstat.torch import SoftHistogram
# Define bin edges matching your statistical model
soft_hist = SoftHistogram(
bin_edges=torch.linspace(0.0, 1.0, 11), # 10 bins over [0, 1]
bandwidth=0.05, # KDE bandwidth (smaller = sharper, noisier)
mode="kde", # "kde" (Gaussian) or "sigmoid" (faster)
)
# Usage:
scores = classifier(batch_features) # [N] continuous outputs
histogram = soft_hist(scores, weights) # [10] differentiable bin counts| Mode | Speed | Gradient Quality |
|---|---|---|
| kde | O(N × B) | Smooth, low variance. Recommended for training. |
| sigmoid | O(N × B) | Sharper bins, noisier gradients. Good for fine-tuning. |
Step 4: Training Loop
import torch
classifier = MyClassifier(input_dim=20, hidden=64).cuda()
optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3)
for epoch in range(50):
for batch_x, batch_w in dataloader:
optimizer.zero_grad()
# Forward: NN → scores → soft histogram → -Z₀
scores = classifier(batch_x.cuda())
histogram = soft_hist(scores, batch_w.cuda())
loss = loss_fn(histogram.double().cuda())
# Backward: gradients flow through NextStat into NN weights
loss.backward()
optimizer.step()
# Monitor: negate to get positive Z₀
with torch.no_grad():
z0 = -loss.item()
print(f"Epoch {epoch}: Z₀ = {z0:.3f}σ")Step 5: Interoperability (JAX, CuPy)
If your data pipeline uses JAX or CuPy, use as_tensor() to convert arrays without copying via the DLPack protocol.
from nextstat.torch import as_tensor
# JAX → PyTorch (zero-copy on GPU via DLPack)
import jax.numpy as jnp
jax_hist = jnp.array([10.0, 20.0, 30.0, 40.0])
torch_hist = as_tensor(jax_hist).double()
# CuPy → PyTorch (zero-copy via __dlpack__)
import cupy as cp
cupy_hist = cp.array([10.0, 20.0, 30.0, 40.0])
torch_hist = as_tensor(cupy_hist).double()
# Works with: PyTorch, JAX, CuPy, NumPy, Apache Arrow, listsAdvanced: Direct Jacobian Access
For external optimisers (SciPy, Optuna) or bin-level analysis, extract the raw gradient ∂q₀/∂signal without going through autograd:
from nextstat.torch import signal_jacobian, signal_jacobian_numpy
# As PyTorch tensor (same device)
grad = signal_jacobian(signal_hist, loss_fn.session)
# As NumPy array (for SciPy / Optuna)
grad_np = signal_jacobian_numpy(signal_hist, loss_fn.session)
# Fast pruning: identify low-impact bins
important = grad.abs() > 0.01
print(f"Important bins: {important.sum()}/{len(important)}")Advanced: Batch Evaluation
from nextstat.torch import batch_profiled_q0_loss
# Evaluate multiple histograms (e.g. ensemble members)
histograms = torch.stack([hist_1, hist_2, hist_3]) # [3, n_bins]
q0_list = batch_profiled_q0_loss(histograms, loss_fn.session)
# q0_list = [Tensor(q0_1), Tensor(q0_2), Tensor(q0_3)]Tips
- dtype — SignificanceLoss expects float64. Always call
.double()before passing tensors. - bandwidth tuning — start with
bandwidth="auto", then decrease for sharper bins once training stabilises. - learning rate — the loss landscape is non-convex (two L-BFGS-B fits per step). Use Adam with lr ≈ 1e-3 to 1e-4.
- warm-up — consider pre-training with cross-entropy for a few epochs before switching to SignificanceLoss.
- Metal (macOS) — Apple Silicon is supported. The signal histogram is uploaded via CPU (no zero-copy), but L-BFGS-B fits run on Metal GPU.
