Early NUTS

Author

Nikolas Siccha

Published

November 15, 2022

cd("/home/niko/github/earlynuts/quarto")
using Pkg
Pkg.activate(".")
Pkg.add("Revise")
Pkg.add("Plots")
Pkg.add("AdvancedHMC")
Pkg.add("ForwardDiff")
Pkg.add("LinearAlgebra")
Pkg.add("Statistics")
using Revise
using Plots
using AdvancedHMC
using ForwardDiff
using LinearAlgebra
using Statistics
#
  Activating project at `~/github/earlynuts/quarto`
    Updating registry at `~/.julia/registries/General.toml`
   Resolving package versions...
  No Changes to `~/github/earlynuts/quarto/Project.toml`
  No Changes to `~/github/earlynuts/quarto/Manifest.toml`
   Resolving package versions...
  No Changes to `~/github/earlynuts/quarto/Project.toml`
  No Changes to `~/github/earlynuts/quarto/Manifest.toml`
   Resolving package versions...
  No Changes to `~/github/earlynuts/quarto/Project.toml`
  No Changes to `~/github/earlynuts/quarto/Manifest.toml`
   Resolving package versions...
  No Changes to `~/github/earlynuts/quarto/Project.toml`
  No Changes to `~/github/earlynuts/quarto/Manifest.toml`
   Resolving package versions...
  No Changes to `~/github/earlynuts/quarto/Project.toml`
  No Changes to `~/github/earlynuts/quarto/Manifest.toml`
   Resolving package versions...
  No Changes to `~/github/earlynuts/quarto/Project.toml`
  No Changes to `~/github/earlynuts/quarto/Manifest.toml`
Base.@kwdef struct EarlyGeneralisedNoUTurn{F<:AbstractFloat} <: AdvancedHMC.DynamicTerminationCriterion
    max_depth::Int=10
    Δ_max::F=1000.0
    termination_threshold::F=0.
end
normalized_dot(x, y) = dot(x,y)/(norm(x)*norm(y))

function early_generalised_uturn_criterion(threshold, rho, p_sharp_minus, p_sharp_plus)
    # println(normalized_dot(rho, p_sharp_minus), "|", normalized_dot(rho, p_sharp_plus), "<=", threshold)
    return (
        normalized_dot(rho, p_sharp_minus) <= threshold
    ) || (
        normalized_dot(rho, p_sharp_plus) <= threshold
    )
end

AdvancedHMC.TurnStatistic(
    ::EarlyGeneralisedNoUTurn, z::AdvancedHMC.PhasePoint
) = 
    AdvancedHMC.TurnStatistic(z.r)

function AdvancedHMC.isterminated(egnut::EarlyGeneralisedNoUTurn, h::AdvancedHMC.Hamiltonian, t::AdvancedHMC.BinaryTree)
    rho = t.ts.rho
    s = early_generalised_uturn_criterion(egnut.termination_threshold, rho, AdvancedHMC.∂H∂r(h, t.zleft.r), AdvancedHMC.∂H∂r(h, t.zright.r))
    return AdvancedHMC.Termination(s, false)
end
function AdvancedHMC.isterminated(tc::TC, h::AdvancedHMC.Hamiltonian, t::AdvancedHMC.BinaryTree, _tleft, _tright) where {TC<:EarlyGeneralisedNoUTurn}
    return AdvancedHMC.isterminated(tc, h, t)
end
# value(x::Vector{Float64}) = x
# function value(x) 
#     ForwardDiff.value(x)
# end
function mylogpdf(info, theta)
    push!(info.all_samples, ForwardDiff.value.(theta))
    return -.5 * sum(theta .^ 2)
end
function mylogpdf_and_gradient(info, theta)
    push!(info.all_samples, (theta))
    return -.5 * sum(theta .^ 2), -theta
end
function run_experiment(;
    no_dimensions=2, time_step_size=1., no_samples=1,
    termination_threshold=0
)
    info = (no_dimensions=no_dimensions, all_samples=[])
    initial_theta = randn(no_dimensions)
    # no_samples, n_adapts = 2_000, 1_000
    metric = DiagEuclideanMetric(no_dimensions)
    hamiltonian = Hamiltonian(metric, theta->mylogpdf(info, theta), theta->mylogpdf_and_gradient(info, theta))

    # Define a leapfrog solver, with initial step size chosen heuristically
    integrator = Leapfrog(time_step_size)
    proposal = NUTS{MultinomialTS, EarlyGeneralisedNoUTurn}(integrator, termination_threshold=termination_threshold)

    samples, stats = sample(
        hamiltonian, proposal, initial_theta, no_samples;
        progress=false, verbose=false
    )
    return merge(info, (samples=samples, stats=stats))
end

# vis_jump(x) = scatter(
#     x[1:end-1], x[2:end], 
#     alpha=.5, title=mean(abs.(x[1:end-1] - x[2:end]))
# )
vis_jump(x; kwargs...) = vis_jump!(plot(), x; kwargs...)
vis_jump!(p,x::AbstractVector; kwargs...) = histogram!(p, abs.(x[1:end-1] - x[2:end]), alpha=.5; kwargs...)
vis_jump!(p,x::AbstractMatrix; kwargs...) = histogram!(p, norm.(eachcol(x[:, 1:end-1] .- x[:, 2:end])), alpha=.5; kwargs...)
# vline!(
#     [mean(norm.(eachcol(x[:, 1:end-1] .- x[:, 2:end])))]
# )

function vis_experiment(info)
    if length(info.samples) > 1
        samples = hcat(info.samples...)
        rsamples = randn(size(samples))
        bins1 = LinRange(0, 5, 40)
        binsd = LinRange(0, 5*sqrt(size(samples, 1)), 40)
        return plot(
            # vis_jump!(vis_jump(samples[1, :], bins=bins1, label="NUTS"), rsamples[1,:], bins=bins1, label="IND"),
            # vis_jump!(vis_jump(samples[1, :].^2, bins=bins1, label="NUTS"), rsamples[1,:].^2, bins=bins1, label="IND"),
            vis_jump!(vis_jump(samples, bins=binsd, label="NUTS"), rsamples, bins=binsd, label="IND"),
            vis_jump!(vis_jump(samples.^2, bins=binsd, label="NUTS"), rsamples.^2, bins=binsd, label="IND"),
            layout=(:,1)
        )
        return plot(
            histogram([samples[1, :], rsamples]),
            histogram([samples[1, :].^2, rsamples.^2]),
            label="",
        )
    end
    all_samples = hcat(info.all_samples...)
    p = scatter([0], [0], label="mode", color=:black)
    no_samples = size(all_samples, 2)
    idx, start, width = 1, 3, 1
    while start <= no_samples
        send = min(start+width-1, no_samples)
        plot!(
            all_samples[1, start:send],
            all_samples[2, start:send],
            marker=:circle, label="$(start-2):$(send-2)", alpha=.5,
            color=[:black, :green][idx % 2 + 1]
        )
        start = start + width
        width *= 2
        idx += 1
    end
    # println("FIRST: ", info.all_samples[1])
    # println("LAST: ", info.all_samples[end])
    # println("RETURNED: ", info.samples)

    scatter!(
    scatter!(
        p,
        # plot(
        #     all_samples[1,:],
        #     all_samples[2,:],
        #     # collect(eachrow(all_samples)), 
        #     marker=:circle, label="", color=:black, alpha=.5
        # ),
        all_samples[1:1,1],
        all_samples[2:2,1],
        color=:red,
        label="init",
    ),
        info.samples[end][1:1],
        info.samples[end][2:2], 
        color=:blue,
        label="sample"
    )

end
function vis_experiments(;
    time_step_sizes, 
    termination_threshold=0,
    no_samples=1,
    no_dimensions=2
)
    plot([
        vis_experiment(
            run_experiment(
                time_step_size=time_step_size,
                termination_threshold=termination_threshold,
                no_samples=no_samples,
                no_dimensions=no_dimensions
            )
        )
        for time_step_size in time_step_sizes
    ]..., 
    layout=(:, 1), size=(800, length(time_step_sizes) * 400),
    legend = :outertopleft
    )
end

time_step_sizes=[.15, .2, .25]#2. .^ LinRange(-3, -2, 2)
no_dimensions = 10
no_samples = 10000
10000
vis_experiments(
    time_step_sizes=time_step_sizes, 
    termination_threshold=cos(pi/2),
    no_dimensions=no_dimensions
)

vis_experiments(
    time_step_sizes=time_step_sizes, 
    termination_threshold=cos(pi/3),
    no_dimensions=no_dimensions
)

vis_experiments(
    time_step_sizes=time_step_sizes, 
    termination_threshold=cos(pi/4)
)

vis_experiments(
    time_step_sizes=time_step_sizes, 
    termination_threshold=cos(pi/2),
    no_samples=no_samples,
    no_dimensions=no_dimensions
)

vis_experiments(
    time_step_sizes=time_step_sizes, 
    termination_threshold=cos(pi/2-pi/16),
    no_samples=no_samples,
    no_dimensions=no_dimensions
)

vis_experiments(
    time_step_sizes=time_step_sizes,
    termination_threshold=cos(pi/2-pi/8),
    no_samples=no_samples,
    no_dimensions=no_dimensions
)

vis_experiments(
    time_step_sizes=time_step_sizes,
    termination_threshold=cos(pi/2-pi/4),
    no_samples=no_samples,
    no_dimensions=no_dimensions
)