Row | f | time_mean | n_leapfrog_mean | mean_err_mean | var_err_mean | q_err_mean |
---|---|---|---|---|---|---|
Function | Float64 | Float64 | Float64 | Float64 | Float64 | |
1 | iid_sample! | 0.00133044 | NaN | 0.317203 | 0.448968 | 0.00325363 |
2 | nuts_sample! | 0.0023915 | 3.0 | 0.297439 | 1.39767 | 0.0118994 |
3 | dynamichmc_sample! | 0.0124168 | 3.0 | 0.297527 | 1.41391 | 0.0122698 |
4 | advancedhmc_sample! | 0.0303231 | 3.0 | 0.293058 | 1.39392 | 0.0119313 |
5 | stan_sample! | 0.054984 | NaN | 0.297357 | 1.4106 | 0.0121138 |
NUTS.jl
A non-allocating NUTS implementation. Faster than and equivalent to Stan’s default implementation, DynamicHMC.jl’s implementation, and AdvancedHMC.jl’s HMCKernel(Trajectory{MultinomialTS}(Leapfrog(stepsize), StrictGeneralisedNoUTurn()))
.
For a 100 dimensional standard normal target with unit stepsize and 1k samples, I measure it to be ~5x slower than direct sampling (randn!(...)
), ~6x faster than DynamicHMC, ~15x faster than AdvancedHMC and ~25x faster than Stan.jl. For most other posteriors the computational cost will be dominated by the cost of evaluating the log density gradient, so any real world speed-ups should be smaller.
Usage
Exports a single function, nuts!!(state)
. Use e.g. as
nuts_sample!(samples, rng, posterior; stepsize, position=randn(rng, size(samples, 1)), n_samples=size(samples, 2)) = begin
state = (;rng, posterior, stepsize, position)
for i in 1:n_samples
state = nuts!!(state)
samples[:, i] .= state.position
end
state
end
where posterior
has to implement log_density = NUTS.log_density_gradient!(posterior, position, log_density_gradient)
, i.e. it returns the log density and writes its gradient into log_density_gradient
.
Benchmark
Standard normal
Benchmarking and validating implementation using 100 chains sampling from a 100-dimensional standard normal distribution with unit stepsize. See code for benchmark details, either at https://github.com/nsiccha/NUTS.jl/blob/main/docs/index.qmd or on this page vie the menu in the top right corner of the text body.
Non-standard normal
Benchmarking and validating implementation using 100 chains sampling from a 100-dimensional non-standard normal distribution with unit stepsize. See code for benchmark details, either at https://github.com/nsiccha/NUTS.jl/blob/main/docs/index.qmd or on this page vie the menu in the top right corner of the text body.
Row | f | time_mean | n_leapfrog_mean | mean_err_mean | var_err_mean | q_err_mean |
---|---|---|---|---|---|---|
Function | Float64 | Float64 | Float64 | Float64 | Float64 | |
1 | iid_sample! | 0.00145097 | NaN | 0.317203 | 0.448968 | 0.00325363 |
2 | nuts_sample! | 0.00440357 | 3.0 | 0.29371 | 1.42862 | 0.0124545 |
3 | dynamichmc_sample! | 0.0142346 | 3.0 | 0.296059 | 1.38725 | 0.0119488 |