NextStatNextStat

Training with SignificanceLoss

End-to-End Differentiable Pipeline

This guide walks through training a PyTorch neural network where the loss function is the profiled discovery significance Z₀. The gradient flows from the statistical test, through the differentiable histogram, all the way back into your network weights.

Architecture Overview

Neural Network                    NextStat (Rust/CUDA)
─────────────                    ────────────────────
  Input features                   HistFactory model
       │                           (systematics, backgrounds)
       ▼                                  │
  Classifier(x) → scores                 │
       │                                  │
       ▼                                  │
  SoftHistogram(scores) → bins ──────────▶ SignificanceLoss
       │                                  │
       │               ◄─── ∂(-Z₀)/∂bins ─┘
       ▼
  loss.backward()  →  ∂loss/∂weights  →  optimizer.step()

Step 1: Prepare Your Statistical Model

NextStat uses the HistFactory format — the same JSON that pyhf produces. If you already have a pyhf workspace, you can load it directly.

import json
import nextstat

# Load a pyhf-style workspace JSON
with open("workspace.json") as f:
    ws = json.load(f)

model = nextstat.from_pyhf(ws)
# model contains: channels, samples, systematics, observed data

Step 2: Create the Loss Function

from nextstat.torch import SignificanceLoss

# SignificanceLoss wraps profiled Z₀ with familiar __call__ semantics.
# By default it returns -Z₀ so SGD minimisation maximises significance.
loss_fn = SignificanceLoss(
    model,
    signal_sample_name="signal",  # which sample the NN controls
    device="auto",                # "cuda", "metal", or "auto"
    negate=True,                  # -Z₀ for minimisation (default)
    eps=1e-12,                    # numerical stability in sqrt
)

print(f"Signal bins: {loss_fn.n_bins}")    # e.g. 10
print(f"Nuisance params: {loss_fn.n_params}")  # e.g. 23

Step 3: Differentiable Binning

A standard histogram is not differentiable (hard bin edges have zero gradient).SoftHistogram solves this with Gaussian KDE or sigmoid approximations.

from nextstat.torch import SoftHistogram

# Define bin edges matching your statistical model
soft_hist = SoftHistogram(
    bin_edges=torch.linspace(0.0, 1.0, 11),  # 10 bins over [0, 1]
    bandwidth=0.05,     # KDE bandwidth (smaller = sharper, noisier)
    mode="kde",         # "kde" (Gaussian) or "sigmoid" (faster)
)

# Usage:
scores = classifier(batch_features)     # [N] continuous outputs
histogram = soft_hist(scores, weights)  # [10] differentiable bin counts
ModeSpeedGradient Quality
kdeO(N × B)Smooth, low variance. Recommended for training.
sigmoidO(N × B)Sharper bins, noisier gradients. Good for fine-tuning.

Step 4: Training Loop

import torch

classifier = MyClassifier(input_dim=20, hidden=64).cuda()
optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3)

for epoch in range(50):
    for batch_x, batch_w in dataloader:
        optimizer.zero_grad()

        # Forward: NN → scores → soft histogram → -Z₀
        scores = classifier(batch_x.cuda())
        histogram = soft_hist(scores, batch_w.cuda())
        loss = loss_fn(histogram.double().cuda())

        # Backward: gradients flow through NextStat into NN weights
        loss.backward()
        optimizer.step()

    # Monitor: negate to get positive Z₀
    with torch.no_grad():
        z0 = -loss.item()
        print(f"Epoch {epoch}: Z₀ = {z0:.3f}σ")

Step 5: Interoperability (JAX, CuPy)

If your data pipeline uses JAX or CuPy, use as_tensor() to convert arrays without copying via the DLPack protocol.

from nextstat.torch import as_tensor

# JAX → PyTorch (zero-copy on GPU via DLPack)
import jax.numpy as jnp
jax_hist = jnp.array([10.0, 20.0, 30.0, 40.0])
torch_hist = as_tensor(jax_hist).double()

# CuPy → PyTorch (zero-copy via __dlpack__)
import cupy as cp
cupy_hist = cp.array([10.0, 20.0, 30.0, 40.0])
torch_hist = as_tensor(cupy_hist).double()

# Works with: PyTorch, JAX, CuPy, NumPy, Apache Arrow, lists

Advanced: Direct Jacobian Access

For external optimisers (SciPy, Optuna) or bin-level analysis, extract the raw gradient ∂q₀/∂signal without going through autograd:

from nextstat.torch import signal_jacobian, signal_jacobian_numpy

# As PyTorch tensor (same device)
grad = signal_jacobian(signal_hist, loss_fn.session)

# As NumPy array (for SciPy / Optuna)
grad_np = signal_jacobian_numpy(signal_hist, loss_fn.session)

# Fast pruning: identify low-impact bins
important = grad.abs() > 0.01
print(f"Important bins: {important.sum()}/{len(important)}")

Advanced: Batch Evaluation

from nextstat.torch import batch_profiled_q0_loss

# Evaluate multiple histograms (e.g. ensemble members)
histograms = torch.stack([hist_1, hist_2, hist_3])  # [3, n_bins]
q0_list = batch_profiled_q0_loss(histograms, loss_fn.session)
# q0_list = [Tensor(q0_1), Tensor(q0_2), Tensor(q0_3)]

Tips

  • dtype — SignificanceLoss expects float64. Always call .double() before passing tensors.
  • bandwidth tuning — start with bandwidth="auto", then decrease for sharper bins once training stabilises.
  • learning rate — the loss landscape is non-convex (two L-BFGS-B fits per step). Use Adam with lr ≈ 1e-3 to 1e-4.
  • warm-up — consider pre-training with cross-entropy for a few epochs before switching to SignificanceLoss.
  • Metal (macOS) — Apple Silicon is supported. The signal histogram is uploaded via CPU (no zero-copy), but L-BFGS-B fits run on Metal GPU.