API Reference
Sampling
WarmupHMC.adaptive_warmup_mcmc Function
adaptive_warmup_mcmc(rng, lpdf; kwargs...)
adaptive_warmup_mcmc(rngs::AbstractArray, lpdf_or_lpdfs; parallel=true, kwargs...)Run windowed adaptive NUTS warm-up + sampling against the LogDensityProblems-compatible lpdf, returning a NamedTuple of posterior positions/gradients plus diagnostics. Multi-chain dispatch broadcasts over rngs and (optionally) per-chain log densities.
The warm-up procedure is windowed and inspired by Stan's and nutpie's warm-up procedures, but differs in several important ways:
Initializes via Pathfinder (LBFGS-based variational approximation).
Warm-up windows target a number of GRADIENT EVALUATIONS rather than MCMC transitions. Default 1000, doubled after every window.
Uses POSITIONS AND GRADIENTS (like nutpie), plus the INTERMEDIATE POSITIONS AND GRADIENTS visited during NUTS tree traversal (selected pseudo-randomly, only if the Hamiltonian error is small enough). Up to
recording_targetintermediate states are kept.Learns three candidate linear transformations in parallel at the end of every warm-up window:
Pathfinder's initial transformation + an updated diagonal scaling,
A standard diagonal "mass matrix",
A novel, adaptive sequence of Householder reflections followed by diagonal scaling.
Selection minimises
loss(p', g') = sum(abs2(log(std(p') * std(g'))))on the transformed intermediate positions/gradients — zero for an uncorrelated Normal target.Adapts step size for only the first
stepsize_adaptation_limittransitions per window (default 50), then freezes the step size and treats subsequent transitions as posterior samples.Stops warm-up adaptively: if the marginal-scale condition number drops below
variance_cond_target(default2.0), no new window starts.
If nonlinear_adapt=true (the default) and lpdf wraps a ReparametrizedProblem, the active IndexedReparametrization is optimised at the end of every warm-up window, and posterior samples are transformed back to the original parametrization before returning.
init kwarg
init controls per-chain initialization:
missing(default) — randomUniform(-2, +2)start, then Pathfinder.a
Real— randomUniform(-init, +init)start, then Pathfinder.a
Distribution— sample from it, then Pathfinder.an
AbstractVector— use as the unconstrained starting position, then Pathfinder.a
PathfinderResult— take the first draw, skip running Pathfinder.a
NamedTuple— interpret as a pre-built initialization (position,position_and_gradient,scale,squared_scale); skips Pathfinder entirely.
For the multi-chain method, pass either a scalar to broadcast or a Vector of length length(rngs) for per-chain initial values. There is no separate initial_params kwarg.
Selected keyword arguments
n_draws=1000— number of posterior draws to collect.n_evaluations=1000— gradient-evaluation budget for the first window; doubled each subsequent window.recording_target=1000— maximum number of intermediate positions/gradients to keep.stepsize_adaptation_limit=50— per-window cap on step-size adaptation transitions.target_acceptance_rate=0.8,max_tree_depth=10— standard NUTS knobs.nonlinear_adapt=true— whether to activate the reparametrization hooks (no-op whenlpdfcarries no reparametrization).variance_cond_target=2.0— restart threshold on the marginal-scale condition number.progress=nothing,description="MCMC",monitor_ess— progress and diagnostic reporting via Treebars.parallel=true(multi-chain only) — run chains onThreads.@threads.
Returns
For the single-chain method, a NamedTuple with fields including initial_position, halo_position, halo_gradient, posterior_position, posterior_gradient, ess, scale_options, active_transformation, stepsize, total_evaluation_counter, n_divergent_samples, position_and_gradient, scale_changes. For the multi-chain method, a Vector of such NamedTuples.
Reparametrizations
WarmupHMC.ReparametrizedProblem Type
ReparametrizedProblem(reparametrizer, problem, ad_backend=nothing)Wrap a LogDensityProblems-compatible problem with a nonlinear reparametrization. The reparametrizer (typically an IndexedReparametrization) transforms the parameter vector before evaluating the log density, adding the log-Jacobian correction.
Gradients are computed by differentiating only through the reparametrization transform (using ad_backend, e.g. AutoMooncake() from DifferentiationInterface.jl), while reusing the inner problem's native gradient. This allows use with FFI-based backends like BridgeStan.
Example
using WarmupHMC, DifferentiationInterface, Mooncake
ir = IndexedReparametrization([
i => Reparametrization(PartiallyCentered(1.0), PartiallyCentered(1.0),
x -> x[loc_idx], x -> x[scale_idx])
for i in param_indices
])
rp = ReparametrizedProblem(ir, my_problem, AutoMooncake())
result = adaptive_warmup_mcmc(rng, rp)WarmupHMC.IndexedReparametrization Type
IndexedReparametrization(pairs)Maps dimension indices to Reparametrization objects. This is the main container passed to ReparametrizedProblem.
pairs is a vector of idx => Reparametrization(...) entries. Dimensions not listed are passed through unchanged.
Example
# Eight schools: reparametrize dims 1:8 with shared location (dim 9) and scale (dim 10)
ir = IndexedReparametrization(
1:8 .=> Ref(Reparametrization(
PartiallyCentered(1.0), PartiallyCentered(1.0),
x -> x[9], x -> x[10]
))
)WarmupHMC.PartiallyCentered Type
PartiallyCentered(c)Centering parameter for hierarchical reparametrization, where c ∈ [0, 1]. c = 0 is fully non-centered, c = 1 is fully centered.
For a parameter x with location loc and log-scale log_scale, the transform from PartiallyCentered(source) to PartiallyCentered(target) is:
y = target * loc + (x - source * loc) * exp(log_scale * (target - source))with log-Jacobian log_scale * (target - source).
During warmup, the optimizer tries multiple candidate centering values and picks the one minimizing a correlation-based loss.
sourceWarmupHMC.Reparametrization Type
Reparametrization(target, source, args...)Maps between two PartiallyCentered parametrizations. The args are either constant values or functions x -> x[i] that extract the location and log-scale from the full parameter vector.
Example
# Reparametrize dimensions 1:8, with location at x[9] and log-scale at x[10]
Reparametrization(PartiallyCentered(1.0), PartiallyCentered(1.0), x -> x[9], x -> x[10])During adaptation, the source centering is updated to minimize a loss function while target stays fixed.