Differentiable HistFactory Layer for PyTorch
NextStat provides a differentiable NLL layer that integrates directly into PyTorch training loops. This enables end-to-end optimization of neural network classifiers on physics significance with full systematic uncertainty handling.
The intended pipeline: NN scores → differentiable histogram → profiled likelihood → q₀ → Z₀. Concretely: SoftHistogram + SignificanceLoss in nextstat.torch.
What Is Differentiable (and What Is Not)
- The layer is differentiable w.r.t. the signal histogram
s— a 1D tensor of bin yields produced by your model +SoftHistogram. - For fixed-parameter
nll_loss: nuisance parameters are constants, backward returns∂NLL/∂s. - For profiled objectives (
q₀,qμ,Z₀): the envelope theorem gives exact gradients at fitted optima without backpropagating through optimizer iterations. - Not yet differentiable: derived quantities like CLs bands or
μ_up(α)— those require implicit differentiation through root-finding (see Phase 2 below).
Architecture
PyTorch Training Loop (GPU)
|
+-- Neural Network --> signal_histogram (torch.Tensor, CUDA)
|
+-- NextStatNLL.forward(signal_histogram)
| +-- tensor.data_ptr() --> raw CUDA device pointer (u64)
| +-- Rust (PyO3): passes pointer to CUDA kernel
| +-- Kernel: reads signal bins from PyTorch memory (ZERO-COPY)
| +-- Kernel: writes dNLL/d(signal) into PyTorch grad tensor (ZERO-COPY)
| +-- Return: NLL scalar
|
+-- NextStatNLL.backward(grad_output)
+-- grad_signal = cached from forward (already on CUDA)
+-- return grad_output * grad_signal --> flows back to NNKey insight (CUDA, Phase 1 NLL): No device↔host copies of the signal histogram or its gradient. The kernel reads signal bins directly from PyTorch GPU memory and writes gradients back. Small transfers remain for nuisance parameters and the returned scalar NLL.
For profiled objectives (q₀/qμ/Z₀/Zμ), signal is still read zero-copy, but the final ∂/∂signal is returned as a small host Vec<f64> and materialized as a CUDA tensor (O(100–1000) floats).
Quick Start — End-to-end Z₀ Training
import torch
import nextstat
from nextstat.torch import SoftHistogram, SignificanceLoss
model = nextstat.from_pyhf(workspace_json)
# One-time GPU session init (CUDA preferred, Metal fallback)
loss_fn = SignificanceLoss(model, signal_sample_name="signal", device="auto")
soft_hist = SoftHistogram(
bin_edges=torch.linspace(0.0, 1.0, 11),
bandwidth="auto", mode="kde",
)
optimizer = torch.optim.Adam(nn.parameters(), lr=1e-3)
for batch in dataloader:
optimizer.zero_grad()
scores = nn(batch) # [N]
signal_hist = soft_hist(scores).double() # [B]
if torch.cuda.is_available():
signal_hist = signal_hist.cuda() # CUDA path expects float64
loss = loss_fn(signal_hist) # returns -Z0
loss.backward()
optimizer.step()Quick Start — Fixed-parameter NLL
import torch
import nextstat
from nextstat.torch import create_session, nll_loss
model = nextstat.from_pyhf(workspace_json)
session = create_session(model, signal_sample_name="signal")
optimizer = torch.optim.Adam(nn.parameters(), lr=1e-3)
for batch in dataloader:
optimizer.zero_grad()
signal_hist = nn(batch).double().cuda() # CUDA float64
loss = nll_loss(signal_hist, session)
loss.backward() # gradient flows to NN
optimizer.step()API Reference
| Function | Description |
|---|---|
| create_session(model, signal_sample_name) | Create GPU session. Returns DifferentiableSession. |
| nll_loss(signal, session, params?) | NLL at fixed nuisance parameters (fast, zero-copy) |
| create_profiled_session(model, sample) | Create profiled GPU session for q₀/qμ |
| profiled_q0_loss(signal, session) | Discovery test statistic q₀ |
| profiled_z0_loss(signal, session) | √q₀ with numerical stability |
| profiled_qmu_loss(signal, session, mu_test) | Upper-limit test statistic qμ |
| SignificanceLoss(model, ...) | High-level: wraps profiled_z0_loss, returns −Z₀ |
| SoftHistogram(bin_edges, bandwidth, mode) | Differentiable histogram (KDE / sigmoid) |
| NextStatNLL | Low-level torch.autograd.Function |
DifferentiableSession (Native API)
Available as nextstat._core.DifferentiableSession when built with CUDA:
| Method | Description |
|---|---|
| .nll_grad_signal(params, signal_ptr, grad_ptr) | Raw kernel call: reads signal from signal_ptr, writes gradient to grad_ptr |
| .signal_n_bins() | Total number of signal bins across all channels |
| .n_params() | Number of model parameters (nuisance + POI) |
| .parameter_init() | Default parameter values (list of float) |
Profiled Significance (q₀ / qμ) on GPU
GPU-accelerated profiled test statistics for discovery and upper limits. These require two profile fits per forward pass but directly optimize physics metrics:
from nextstat.torch import (
create_profiled_session,
profiled_q0_loss, # discovery test statistic q₀
profiled_z0_loss, # √q₀ with numerical stability
profiled_qmu_loss, # upper-limit test statistic qμ
)
session = create_profiled_session(model, "signal")
signal = nn(batch).double().cuda().requires_grad_(True)
q0 = profiled_q0_loss(signal, session)
z0 = profiled_z0_loss(signal, session)
qmu = profiled_qmu_loss(signal, session, mu_test=5.0)Gradient Formulas
Envelope theorem (profiled objectives):
∂q₀/∂s = 2 · ( ∂NLL/∂s |_{θ=θ̂₀} − ∂NLL/∂s |_{θ=θ̂} )Phase 1 — fixed-parameter NLL gradient per signal bin:
expected_i = (signal_i + delta_i) * factor_i + sum_{other samples}
dNLL/d(signal_i) = (1 - obs_i / expected_i) * factor_iwhere factor_i is the product of all multiplicative modifiers (NormFactor, NormSys, ShapeSys, etc.) on the signal sample at bin i. One-sided discovery: if μ̂ < 0 or q₀ clamps to zero, the returned q₀ and its gradient are zero.
Practical Notes (correctness)
- Zeroed gradient buffer (CUDA) — the kernel accumulates via
atomicAdd, so the gradient output tensor must be initialized to zeros. The Python wrapper doestorch.zeros_like(signal). - CUDA stream synchronization — PyTorch and NextStat may use different CUDA streams. The wrappers call
torch.cuda.synchronize()before and after the native call. - Multi-channel signal layout — if the signal sample appears in multiple channels, the external signal buffer is a concatenation:
[ch0_bins..., ch1_bins..., ...]. Usesession.signal_n_bins()for the total count.
# Multi-channel: concatenate per-channel histograms in model order
signal = torch.cat([signal_sr, signal_vr], dim=0).double().cuda()
loss = loss_fn(signal)Metal (Apple Silicon)
create_profiled_session(..., device="auto") prefers CUDA and falls back to Metal (requires --features metal).
- GPU computation in f32 (Apple GPU precision); inputs/outputs converted at API boundary
- Signal uploaded from CPU (no raw pointer interop with MPS tensors)
- L-BFGS-B tolerance relaxed to 1e-3 (vs 1e-5 on CUDA)
Validation and Evidence
- CUDA zero-copy NLL + signal gradient (∂NLL/∂s):
tests/python/test_torch_layer.py - Profiled q₀/qμ envelope gradients:
tests/python/test_differentiable_profiled_q0.py - Finite-difference error: max FD error 2.07e⁻⁹ on benchmark fixtures
Architecture Decisions
Why a separate CUDA kernel? The existing batch_nll_grad.cu is optimized for batch toy fitting (1 block = 1 toy). The differentiable kernel has different requirements: single model evaluation, external signal pointer, signal gradient output. Keeping them separate avoids branching in the hot path.
Why zero-copy? Traditional approach: PyTorch GPU → CPU → Rust → GPU → CPU → PyTorch GPU (4 PCIe transfers). Zero-copy: the kernel reads signal bins and writes gradients directly in GPU memory. The only H→D transfer is the small nuisance parameter vector (~250 doubles = 2KB).
Phase 2 (Future): Implicit Differentiation
For derived profiled metrics (interpolated upper limits μ_up(α), full CLs bands), you need implicit differentiation through the solver/root-finding layer. The simple envelope-theorem gradient suffices for q₀ and qμ but not for all downstream quantities.
dq/ds = dq/ds|_{θ fixed} - (d²NLL/ds dθ)ᵀ (d²NLL/dθ²)⁻¹ dq/dθ|_{s fixed}This requires the cross-Hessian d²NLL/ds/dθ, which can be computed via finite differences of the GPU gradient w.r.t. signal bins.
Competitive Landscape
| Project | Status | Limitations |
|---|---|---|
| pyhf #882 | Open, not implemented | Pure Python, SLSQP not differentiable |
| neos (arXiv:2203.05570) | PoC 2022 | Slow, no GPU batch, simple models |
| gradhep/relaxed | Utility library | Not a pipeline, only soft histogram ops |
| arXiv:2508.17802 | Scikit-HEP+JAX, 2025 | JAX-only, JIT tracer leaks |
| NextStat | Production | First PyTorch-native, CUDA zero-copy |
Requirements
- PyTorch (optional dependency, imported lazily)
- NVIDIA GPU with CUDA for zero-copy path; Apple Silicon (Metal) as fallback
- CPU tensors supported but require device→host copy
