# value(x::Vector{Float64}) = x
# function value(x)
# ForwardDiff.value(x)
# end
function mylogpdf(info, theta)
push!(info.all_samples, ForwardDiff.value.(theta))
return -.5 * sum(theta .^ 2)
end
function mylogpdf_and_gradient(info, theta)
push!(info.all_samples, (theta))
return -.5 * sum(theta .^ 2), -theta
end
function run_experiment(;
no_dimensions=2, time_step_size=1., no_samples=1,
termination_threshold=0
)
info = (no_dimensions=no_dimensions, all_samples=[])
initial_theta = randn(no_dimensions)
# no_samples, n_adapts = 2_000, 1_000
metric = DiagEuclideanMetric(no_dimensions)
hamiltonian = Hamiltonian(metric, theta->mylogpdf(info, theta), theta->mylogpdf_and_gradient(info, theta))
# Define a leapfrog solver, with initial step size chosen heuristically
integrator = Leapfrog(time_step_size)
proposal = NUTS{MultinomialTS, EarlyGeneralisedNoUTurn}(integrator, termination_threshold=termination_threshold)
samples, stats = sample(
hamiltonian, proposal, initial_theta, no_samples;
progress=false, verbose=false
)
return merge(info, (samples=samples, stats=stats))
end
# vis_jump(x) = scatter(
# x[1:end-1], x[2:end],
# alpha=.5, title=mean(abs.(x[1:end-1] - x[2:end]))
# )
vis_jump(x; kwargs...) = vis_jump!(plot(), x; kwargs...)
vis_jump!(p,x::AbstractVector; kwargs...) = histogram!(p, abs.(x[1:end-1] - x[2:end]), alpha=.5; kwargs...)
vis_jump!(p,x::AbstractMatrix; kwargs...) = histogram!(p, norm.(eachcol(x[:, 1:end-1] .- x[:, 2:end])), alpha=.5; kwargs...)
# vline!(
# [mean(norm.(eachcol(x[:, 1:end-1] .- x[:, 2:end])))]
# )
function vis_experiment(info)
if length(info.samples) > 1
samples = hcat(info.samples...)
rsamples = randn(size(samples))
bins1 = LinRange(0, 5, 40)
binsd = LinRange(0, 5*sqrt(size(samples, 1)), 40)
return plot(
# vis_jump!(vis_jump(samples[1, :], bins=bins1, label="NUTS"), rsamples[1,:], bins=bins1, label="IND"),
# vis_jump!(vis_jump(samples[1, :].^2, bins=bins1, label="NUTS"), rsamples[1,:].^2, bins=bins1, label="IND"),
vis_jump!(vis_jump(samples, bins=binsd, label="NUTS"), rsamples, bins=binsd, label="IND"),
vis_jump!(vis_jump(samples.^2, bins=binsd, label="NUTS"), rsamples.^2, bins=binsd, label="IND"),
layout=(:,1)
)
return plot(
histogram([samples[1, :], rsamples]),
histogram([samples[1, :].^2, rsamples.^2]),
label="",
)
end
all_samples = hcat(info.all_samples...)
p = scatter([0], [0], label="mode", color=:black)
no_samples = size(all_samples, 2)
idx, start, width = 1, 3, 1
while start <= no_samples
send = min(start+width-1, no_samples)
plot!(
all_samples[1, start:send],
all_samples[2, start:send],
marker=:circle, label="$(start-2):$(send-2)", alpha=.5,
color=[:black, :green][idx % 2 + 1]
)
start = start + width
width *= 2
idx += 1
end
# println("FIRST: ", info.all_samples[1])
# println("LAST: ", info.all_samples[end])
# println("RETURNED: ", info.samples)
scatter!(
scatter!(
p,
# plot(
# all_samples[1,:],
# all_samples[2,:],
# # collect(eachrow(all_samples)),
# marker=:circle, label="", color=:black, alpha=.5
# ),
all_samples[1:1,1],
all_samples[2:2,1],
color=:red,
label="init",
),
info.samples[end][1:1],
info.samples[end][2:2],
color=:blue,
label="sample"
)
end
function vis_experiments(;
time_step_sizes,
termination_threshold=0,
no_samples=1,
no_dimensions=2
)
plot([
vis_experiment(
run_experiment(
time_step_size=time_step_size,
termination_threshold=termination_threshold,
no_samples=no_samples,
no_dimensions=no_dimensions
)
)
for time_step_size in time_step_sizes
]...,
layout=(:, 1), size=(800, length(time_step_sizes) * 400),
legend = :outertopleft
)
end
time_step_sizes=[.15, .2, .25]#2. .^ LinRange(-3, -2, 2)
no_dimensions = 10
no_samples = 10000