NUTS.jl

On this page

  • Usage
  • Benchmark
    • Standard normal
    • Non-standard normal

NUTS.jl

Author

Nikolas Siccha

A non-allocating NUTS implementation. Faster than and equivalent to Stan’s default implementation, DynamicHMC.jl’s implementation, and AdvancedHMC.jl’s HMCKernel(Trajectory{MultinomialTS}(Leapfrog(stepsize), StrictGeneralisedNoUTurn())).

For a 100 dimensional standard normal target with unit stepsize and 1k samples, I measure it to be ~5x slower than direct sampling (randn!(...)), ~6x faster than DynamicHMC, ~15x faster than AdvancedHMC and ~25x faster than Stan.jl. For most other posteriors the computational cost will be dominated by the cost of evaluating the log density gradient, so any real world speed-ups should be smaller.

Usage

Exports a single function, nuts!!(state). Use e.g. as

nuts_sample!(samples, rng, posterior; stepsize, position=randn(rng, size(samples, 1)), n_samples=size(samples, 2)) = begin
    state = (;rng, posterior, stepsize, position)
    for i in 1:n_samples
        state = nuts!!(state)
        samples[:, i] .= state.position
    end
    state
end

where posterior has to implement log_density = NUTS.log_density_gradient!(posterior, position, log_density_gradient), i.e. it returns the log density and writes its gradient into log_density_gradient.

Benchmark

Standard normal

Benchmarking and validating implementation using 100 chains sampling from a 100-dimensional standard normal distribution with unit stepsize. See code for benchmark details, either at https://github.com/nsiccha/NUTS.jl/blob/main/docs/index.qmd or on this page vie the menu in the top right corner of the text body.

5×6 DataFrame
Row f time_mean n_leapfrog_mean mean_err_mean var_err_mean q_err_mean
Function Float64 Float64 Float64 Float64 Float64
1 iid_sample! 0.00133044 NaN 0.317203 0.448968 0.00325363
2 nuts_sample! 0.0023915 3.0 0.297439 1.39767 0.0118994
3 dynamichmc_sample! 0.0124168 3.0 0.297527 1.41391 0.0122698
4 advancedhmc_sample! 0.0303231 3.0 0.293058 1.39392 0.0119313
5 stan_sample! 0.054984 NaN 0.297357 1.4106 0.0121138

Non-standard normal

Benchmarking and validating implementation using 100 chains sampling from a 100-dimensional non-standard normal distribution with unit stepsize. See code for benchmark details, either at https://github.com/nsiccha/NUTS.jl/blob/main/docs/index.qmd or on this page vie the menu in the top right corner of the text body.

3×6 DataFrame
Row f time_mean n_leapfrog_mean mean_err_mean var_err_mean q_err_mean
Function Float64 Float64 Float64 Float64 Float64
1 iid_sample! 0.00145097 NaN 0.317203 0.448968 0.00325363
2 nuts_sample! 0.00440357 3.0 0.29371 1.42862 0.0124545
3 dynamichmc_sample! 0.0142346 3.0 0.296059 1.38725 0.0119488
Source Code
{{< include ../README.md >}}

## Benchmark
  
### Standard normal

Benchmarking and validating implementation using 100 chains sampling from a 100-dimensional standard normal distribution with unit stepsize. See code for benchmark details, either at [https://github.com/nsiccha/NUTS.jl/blob/main/docs/index.qmd](https://github.com/nsiccha/NUTS.jl/blob/main/docs/index.qmd) or on this page vie the menu in the top right corner of the text body.

```{julia}
using DynamicHMC, NUTS, Random, Distributions, LinearAlgebra, LogExpFunctions, Chairmarks, LogDensityProblems, AdvancedHMC, Plots, DataFrames, StatsBase
ENV["CMDSTAN"] = "/home/niko/.cmdstan/cmdstan-2.34.0"  
using Stan
    
sm = (@isdefined sm) ? sm : Stan.SampleModel("Normal", """
data{int n;}
parameters{vector[n] x;}
model{x ~ std_normal();}
""")
struct MatrixFactorization{M1,M2} <: AbstractMatrix{Float64}
    m1::M1
    m2::M2
end
Base.size(M::MatrixFactorization) = (size(M.m1, 1), size(M.m2, 2))
Base.size(M::MatrixFactorization, i) = size(M)[i]
LinearAlgebra.lmul!(M::MatrixFactorization, x::AbstractVector) = begin 
    lmul!(M.m1, lmul!(M.m2, x))
end
LinearAlgebra.mul!(y::AbstractVector, M::MatrixFactorization, x::AbstractVector) = begin 
    lmul!(M.m1, mul!(y, M.m2, x))
end
LinearAlgebra.ldiv!(M::MatrixFactorization, x::AbstractVector) = begin 
    ldiv!(M.m2, ldiv!(M.m1, x))
end
LinearAlgebra.ldiv!(M::MatrixFactorization, x::AbstractMatrix) = begin
    for xi in eachcol(x)
        ldiv!(M, xi)
    end
    x
end
Base.adjoint(M::MatrixFactorization) = MatrixFactorization(M.m2', M.m1')
Base.inv(M::MatrixFactorization) = MatrixFactorization(inv(M.m2), inv(M.m1))
struct Reflection{V}
    v::V
end
Base.size(M::Reflection) = (length(M.v), length(M.v))
Base.size(M::Reflection, i) = length(M.v)
LinearAlgebra.lmul!(M::Reflection, x::AbstractVector) = x .-= 2 .* M.v .* dot(M.v, x)
LinearAlgebra.mul!(y::AbstractVector, M::Reflection, x::AbstractVector) = begin 
    y .= x .- 2 .* M.v .* dot(M.v, x)
end
LinearAlgebra.ldiv!(M::Reflection, x::AbstractVector) = lmul!(M, x)
Base.adjoint(M::Reflection) = M
Base.inv(M::Reflection) = M

NUTS.square(M::UniformScaling) = M * M
NUTS.square(M::MatrixFactorization) = MatrixFactorization(M, M')
struct PreparedNormal{P,S,C}
    location::P
    scale::S
    cache::C
end
PreparedNormal(location, scale) = PreparedNormal(location, scale, zero(location))
StandardNormal{P,C} = PreparedNormal{P,UniformScaling{Bool},C}
whiten!(samples, d::PreparedNormal) = begin
    samples .-= d.location
    ldiv!(d.scale, samples)
end
NUTS.log_density_gradient!(d::PreparedNormal, x::AbstractVector, g::AbstractVector) = begin 
    @. g = d.location - x
    ldiv!(d.scale, g)
    rv = -.5 * sum(abs2, g)
    ldiv!(d.scale', g)
    rv
end
NUTS.log_density_gradient!(d::StandardNormal, x::AbstractVector, g::AbstractVector) = begin 
    @. g = -x
    -.5 * dot(x,x)
end

Random.rand!(rng::AbstractRNG, d::PreparedNormal, x::AbstractVector) = begin 
    lmul!(d.scale, randn!(rng, x))
    @. x += d.location
end
StatsBase.cov(d::PreparedNormal) = if isa(d.scale, UniformScaling)
    Diagonal(fill(d.scale.λ ^ 2, length(d.location)))
else
    NUTS.square(d.scale)
end
scale(d::PreparedNormal) = if isa(d.scale, UniformScaling)
    Diagonal(fill(d.scale.λ, length(d.location)))
else
    d.scale
end
LogDensityProblems.capabilities(::PreparedNormal) = LogDensityProblems.LogDensityOrder{1}()
LogDensityProblems.dimension(d::PreparedNormal) = length(d.location)
LogDensityProblems.logdensity_and_gradient(d::PreparedNormal, x::AbstractVector) = begin 
    g = zero(x)
    NUTS.log_density_gradient!(d, x, g), g
end
q_err(x::AbstractVector) = mean((sort(x) .- quantile.(Normal(), range(0, 1, length(x)+2)[2:end-1])).^2)
q_err(x::AbstractMatrix) = mean(q_err, eachrow(x))
errs(x::AbstractMatrix) = (;mean_err=norm(mean(x; dims=2)), var_err=norm(1 .- var(x; dims=2)), q_err=q_err(x))
iid_sample!(samples, rng, posterior; kwargs...) = begin 
    time = @elapsed for i in 1:n_samples
        rand!(rng, posterior, view(samples, :, i))
    end
    (;time, n_leapfrog=NaN, errs(whiten!(samples, posterior))...)
end
nuts_sample!(samples, rng, posterior; stepsize, position=randn(rng, size(samples, 1)), n_samples=size(samples, 2)) = begin
    state = (;rng, posterior, stepsize, position, posterior.scale)
    n_leapfrog = 0
    time = @elapsed for i in 1:n_samples
        state = nuts!!(state)
        samples[:, i] .= state.position
        n_leapfrog += state.n_leapfrog
    end
    (;time, n_leapfrog=n_leapfrog / n_samples, errs(whiten!(samples, posterior))...)
end
dynamichmc_sample!(samples, rng, posterior; stepsize, position=randn(rng, size(samples, 1)), n_samples=size(samples, 2)) = begin
    algorithm = DynamicHMC.NUTS()
    H = DynamicHMC.Hamiltonian(DynamicHMC.GaussianKineticEnergy(cov(posterior), inv(scale(posterior)')), posterior)
    Q = DynamicHMC.evaluate_ℓ(posterior, position; strict=true)
    n_leapfrog = 0
    time = @elapsed for i in 1:n_samples
        Q, stats = DynamicHMC.sample_tree(rng, algorithm, H, Q, stepsize)
        samples[:, i] .= Q.q
        n_leapfrog += stats.steps
    end
    (;time, n_leapfrog=n_leapfrog / n_samples, errs(whiten!(samples, posterior))...)
end
advancedhmc_sample!(samples, rng, posterior; stepsize, position=randn(rng, size(samples, 1)), n_samples=size(samples, 2)) = begin
    h = AdvancedHMC.Hamiltonian(UnitEuclideanMetric(d), posterior)
    kernel = HMCKernel(Trajectory{MultinomialTS}(Leapfrog(stepsize), StrictGeneralisedNoUTurn()))
    z = AdvancedHMC.phasepoint(rng, position, h) 
    n_leapfrog = 0
    time = @elapsed for i in 1:n_samples
        (;stat, z) = AdvancedHMC.transition(rng, h, kernel, z)
        samples[:, i] .= z.θ
        n_leapfrog += stat.n_steps
    end
    (;time, n_leapfrog=n_leapfrog / n_samples, errs(whiten!(samples, posterior))...)
end
stan_sample!(samples, rng, posterior; stepsize, position=randn(rng, size(samples, 1)), n_samples=size(samples, 2)) = begin 
    data = Dict("n"=>length(position))
    time = @elapsed Stan.stan_sample(sm; data, num_chains=1, num_samples=n_samples, num_warmups=0, engaged=false, stepsize, init=Dict("x"=>randn(d)));
    df = Stan.read_samples(sm, :dataframe)
    (;time, n_leapfrog=NaN, errs(Matrix(df)')...)
end
d = 100
stepsize = 1.
n_samples = 1_000
samples = zeros((d, n_samples))
seeds = 1:100;
```
```{julia}
posterior = PreparedNormal(zeros(d), I)
fs = (iid_sample!, nuts_sample!, dynamichmc_sample!, advancedhmc_sample!, stan_sample!)
df = DataFrame([
    merge((;f, seed), f(samples, Xoshiro(seed), posterior; stepsize))
    for f in fs
    for seed in seeds
])
combine(groupby(df, :f), :time=>mean, :n_leapfrog=>mean, :mean_err=>mean, :var_err=>mean, :q_err=>mean)
```
```{julia}
plot(
    plot(xscale=:log10, [
        sort(df[df.f .== f, :time]) for f in fs
    ], seeds, label=permutedims(collect(fs)), title="Runtime"), 
    plot([
        sort(df[df.f .== f, :mean_err]) for f in fs
    ], seeds, label=permutedims(collect(fs)), title="Mean err"),
    plot([
        sort(df[df.f .== f, :var_err]) for f in fs
    ], seeds, label=permutedims(collect(fs)), title="Var err"), 
    plot([
        sort(df[df.f .== f, :q_err]) for f in fs
    ], seeds, label=permutedims(collect(fs)), title="Q err"),
    layout=(:,1), size=(800, 1200)
)
```

  
### Non-standard normal

Benchmarking and validating implementation using 100 chains sampling from a 100-dimensional non-standard normal distribution with unit stepsize. See code for benchmark details, either at [https://github.com/nsiccha/NUTS.jl/blob/main/docs/index.qmd](https://github.com/nsiccha/NUTS.jl/blob/main/docs/index.qmd) or on this page vie the menu in the top right corner of the text body.
```{julia}
posterior = PreparedNormal(randn(d), MatrixFactorization(Reflection(normalize!(randn(d))), Diagonal(exp.(randn(d)))))
fs = (iid_sample!, nuts_sample!, dynamichmc_sample!)#, advancedhmc_sample!)
df = DataFrame([
    merge((;f, seed), f(samples, Xoshiro(seed), posterior; stepsize))
    for f in fs
    for seed in seeds
])
combine(groupby(df, :f), :time=>mean, :n_leapfrog=>mean, :mean_err=>mean, :var_err=>mean, :q_err=>mean)
```
```{julia}
plot(
    plot(xscale=:log10, [
        sort(df[df.f .== f, :time]) for f in fs
    ], seeds, label=permutedims(collect(fs)), title="Runtime"), 
    plot([
        sort(df[df.f .== f, :mean_err]) for f in fs
    ], seeds, label=permutedims(collect(fs)), title="Mean err"),
    plot([
        sort(df[df.f .== f, :var_err]) for f in fs
    ], seeds, label=permutedims(collect(fs)), title="Var err"), 
    plot([
        sort(df[df.f .== f, :q_err]) for f in fs
    ], seeds, label=permutedims(collect(fs)), title="Q err"),
    layout=(:,1), size=(800, 1200)
)
```