StanBlocks.jl
  • Overview
  • Implementations
  • @slic
  • Golf
  • ISBA-2024
  • Crowdsourcing

On this page

  • StanBlocks.jl implementation
    • Variadic functions
    • Function-type arguments and method selection via their type
    • Untyped or abstractly typed arguments
  • Full Julia + StanBlocks.jl code to define the models
  • Generated Stan code
  • How does StanBlocks.jl infer type and dimension?

Reimplementing the Stan models from https://github.com/bob-carpenter/pcr-sensitivity-vs-time

Author

Nikolas Siccha

The full Julia code for this notebook can be accessed via the top right corner (</> Code).

The Julia packages needed to reproduce this document are StanBlocks.jl (for the model generation) and QuartoComponents.jl (for the “pretty” printing). Both packages have to be installed from the latest main branch (as of Oct 14th 2025).

I do agree that only having the option of inferring the type/transformation of a parameter from its sampling distribution can be unnecessarily magical - I have hence opened this feature “reminder” (I can obviously not “request” things from myself).

StanBlocks.jl implementation

The function and model definitions below make use of

  • variadic functions (and argument splatting) - “It is often convenient to be able to write functions taking an arbitrary number of arguments.,
  • function-type arguments and dispatch/method selection via their type - in Julia, “[e]ach function has its own type, which is a subtype of Function.”,
  • “untyped” arguments,
  • “automatic” type and dimension inference (see below) - partly because currently no other way of specifying type and dimension of a parameter is implemented. As stated above, I do agree that this is slightly too magical, and an optional(?) inline type annotation could be clearer.

A few words on each of these points:

Variadic functions

The below lines are two examples of a variadic method definition for the my_bernoulli_lpmfs function,

my_bernoulli_lpmfs(y::int[n], args...) = jbroadcasted(my_bernoulli_lpmfs, y, args...)
my_bernoulli_lpmfs(y::int, args...) = my_bernoulli_lpmf(y, args...)

used in the computation of the pointwise log likelihoods. On the left hand side, args... will simply match all trailing positional arguments after the first one, and on the right hand side these arguments will be forwarded to the built-in jbroadcasted function, which mimics Julia-style broadcasting of its first (function-type) argument over all other arguments.

In the below models, my_bernoulli_lpmfs will be called with the following signatures:

my_bernoulli_lpmfs(y::int[n], f::typeof(logit), theta::vector[n])
my_bernoulli_lpmfs(y::int[n], f::typeof(log), theta::vector[n])
my_bernoulli_lpmfs(y::int, f::typeof(logit), theta::real)
my_bernoulli_lpmfs(y::int, f::typeof(log), theta::real)

all of which will be covered by the variadic function definitions at the beginning of this section.

Function-type arguments and method selection via their type

The below are the simplest possible method definitions which depend on the type of a function-type argument:

upper_alpha(::typeof(logit)) = negative_infinity()
upper_alpha(::typeof(log)) = 0

The defined function, upper_alpha, can be called in one of two ways:

  • Either as upper_alpha(logit), matching the first method definition and thus returning negative infinity, or
  • as upper_alpha(log), matching the second method definition and thus returning zero.

The above function gets used in the regression and regression_mix model to make the upper bound of the alpha parameter depend on the link function link_f, which can be either logit or log.

A slightly more complex example would be the following:

my_bernoulli_lpmf(y, ::typeof(logit), theta) = bernoulli_logit_lpmf(y, theta)
my_bernoulli_lpmf(y, ::typeof(log), theta) = bernoulli_lpmf(y, exp(theta))

which gets used to make the likelihood implementation depend on the link function of the model, allowing us

  • to forward y and theta to bernoulli_logit_lpmf(y, theta), or
  • to forward y and theta to bernoulli_lpmf(y, exp(theta)).

Untyped or abstractly typed arguments

Untyped function arguments are simply arguments for which we don’t specify the type beforehand, allowing it to match any passed in type. Do note that this can lead to “Method Ambiguities” - something that cannot happen in Stan because you always have to specify the concrete types of all function arguments. StanBlocks.jl implements a limited abstract type hierarchy, starting at the top with anything, and e.g. descending towards ordered as anything -> any_vector -> vector -> ordered.

Full Julia + StanBlocks.jl code to define the models

The following reproduces all of the code necessary to implement the 2x5 model matrix (printing excluded):

using StanBlocks

import StanBlocks.stan: logit

@deffun begin 
    "Needed for cross validation"
    my_bernoulli_lpmfs(y::int[n], args...) = jbroadcasted(my_bernoulli_lpmfs, y, args...)
    "Needed for cross validation"
    my_bernoulli_lpmfs(y::int, args...) = my_bernoulli_lpmf(y, args...)
    "Needed to compute the joint likelihood"
    my_bernoulli_lpmf(y, ::typeof(logit), theta) = bernoulli_logit_lpmf(y, theta)
    "Needed for posterior predictions"
    my_bernoulli_rng(::typeof(logit), theta) = bernoulli_logit_rng(theta)
    "Needed to compute the joint likelihood"
    my_bernoulli_lpmf(y, ::typeof(log), theta) = bernoulli_lpmf(y, exp(theta))
    "Needed for posterior predictions"
    my_bernoulli_rng(::typeof(log), theta) = bernoulli_rng(exp(theta))

    "The `Hetero` prior density after constraining - common to both link functions"
    hetero_lpdf(x::vector[n]) = normal_lpdf(x, 0, 3)
    "The `Hetero` prior density after constraining - this function definition gets used to infer the type and shape of the parameter for the logit link"
    hetero_lpdf(x::ordered[n], ::typeof(logit), n) = hetero_lpdf(x)
    "The `Hetero` prior density after constraining - this function definition gets used to infer the type and shape of the parameter for the log link"
    hetero_lpdf(x::positive_ordered[n], ::typeof(log), n) = hetero_lpdf(x)
    "The `RW(1)` prior density after constraining - common to both link functions"
    rw1_lpdf(x::vector[n], sigma) = normal_lpdf(x[1], 0, sigma) + normal_lpdf(x[2:n], x[1:n-1], sigma)
    "The `RW(1)` prior density after constraining - this function definition gets used to infer the type and shape of the parameter for the logit link"
    rw1_lpdf(x::ordered[n], ::typeof(logit), sigma, n) = rw1_lpdf(x, sigma)
    "The `RW(1)` prior density after constraining - this function definition gets used to infer the type and shape of the parameter for the log link"
    rw1_lpdf(x::positive_ordered[n], ::typeof(log), sigma, n) = rw1_lpdf(x, sigma)
    "The `RW(2)` prior density after constraining - common to both link functions"
    rw2_lpdf(x::vector[n], sigma) = rw1_lpdf(x[1:2], sigma) + normal_lpdf(x[3:n], 2x[2:n-1] - x[1:n-2], sigma)
    "The `RW(2)` prior density after constraining - this function definition gets used to infer the type and shape of the parameter for the logit link"
    rw2_lpdf(x::ordered[n], ::typeof(logit), sigma, n) = rw2_lpdf(x, sigma)
    "The `RW(2)` prior density after constraining - this function definition gets used to infer the type and shape of the parameter for the log link"
    rw2_lpdf(x::positive_ordered[n], ::typeof(log), sigma, n) = rw2_lpdf(x, sigma)
    "The upper bound of the alpha parameter (for the logit link function)"
    upper_alpha(::typeof(logit)) = negative_infinity()
    "The upper bound of the alpha parameter (for the log link function)"
    upper_alpha(::typeof(log)) = 0
end
# The main things that StanBlocks.jl has to know is that `y` is an array of `int`s, and `t` is a vector.
mock_data = (;y=[1], t=[1.])

# The base model with the mock data attached - this does not yet have an implementation for theta 
base_model = @slic mock_data begin 
    n = dims(y)[1]
    "The exact implementation of the likelihood depends on the passed link function `link_f`"
    y ~ my_bernoulli(link_f, theta)
end
# The submodel for the `Hetero` models
centered_hetero = @slic begin 
    "The type of the `xi` parameter depends on the passed link function `link_f`"
    xi ~ hetero(link_f, n)
    "The negation is needed to ensure that `theta` is in descending order"
    return -xi
end
# The submodel for the `RW(1)` models
centered_rw1 = @slic begin 
    sigma ~ std_normal(;lower=0)
    "The type of the `xi` parameter depends on the passed link function `link_f`"
    xi ~ rw1(link_f, sigma, n)
    "The negation is needed to ensure that `theta` is in descending order"
    return -xi
end
# The submodel for the `RW(2)` models
centered_rw2 = @slic begin 
    sigma ~ normal(0, .5; lower=0)
    "The type of the `xi` parameter depends on the passed link function `link_f`"
    xi ~ rw2(link_f, sigma, n)
    "The negation is needed to ensure that `theta` is in descending order"
    return -xi
end
# The submodel for the `regression` models - reused in the `regression_mix` models
regression = @slic begin 
    "The upper bound of alpha `alpha_upper` depends on the link function `link_f`"
    alpha_upper = upper_alpha(link_f)
    # This could be `alpha ~ normal(0, .5; upper=upper_alpha(link_f))`, once https://github.com/nsiccha/StanBlocks.jl/issues/35 is fixed
    alpha ~ normal(0, .5; upper=alpha_upper)
    beta ~ normal(0, .5; upper=0.)
    return alpha + beta * t
end
# The submodel for the `regression_mix` models - reusing the `regression` submodel
regression_mix = @slic begin 
    c1 ~ regression(;link_f, t) 
    c2 ~ regression(;link_f, t)
    lambda ~ beta(2, 2)
    return lambda * link_f(c1) + (1-lambda) * link_f(c2)
end

bases = (;
    hetero=base_model(quote 
        theta ~ centered_hetero(;link_f, n)
    end),
    rw1=base_model(quote 
        theta ~ centered_rw1(;link_f, n)
    end),
    rw2=base_model(quote 
        theta ~ centered_rw2(;link_f, n)
    end),
    regression=base_model(quote 
        theta ~ regression(;link_f, t)
    end),
    regression_mix=base_model(quote 
        theta ~ regression_mix(;link_f, t)
        y ~ bernoulli(theta)
    end)
)
link_fs = (;logit, log)

# `posteriors` will be a (nested) named tuple - accessing e.g. the `Hetero (log)` model works via `posteriors.hetero.log`
posteriors = map(bases) do base 
    map(link_fs) do link_f 
        base(;link_f)
    end
end

Generated Stan code

The generated Stan code below is accessible via two nested tabsets. The top level (with keys hetero, rw1, rw2, regression, and regression_mix) combines with the link function in the second level (with keys logit and log) to give you access to the corresponding Stan models from the poster.

Warning

Due to this issue, the below Stan codes should not actually compile as they are.

I think removing the offending UDF definitions should make compilation work.

  • hetero
  • rw1
  • rw2
  • regression
  • regression_mix
  • logit
  • log
functions {
// The `Hetero` prior density after constraining - this function definition gets used to infer the type and shape of the parameter for the logit link
real hetero_logit_lpdf(
    vector x,
    int n
) {
    return hetero_lpdf(x);
}
// The `Hetero` prior density after constraining - common to both link functions
real hetero_lpdf(
    vector x
) {
    int n = dims(x)[1];
    return normal_lpdf(x | 0, 3);
}
// Needed to compute the joint likelihood
real my_bernoulli_logit_lpmf(
    array[] int y,
    vector theta
) {
    return bernoulli_logit_lpmf(y | theta);
}
// Needed for cross validation
vector my_bernoulli_logit_lpmfs(
    array[] int y,
    vector args1
) {
    int n = dims(y)[1];
    return jbroadcasted_my_bernoulli_logit_lpmfs(y, args1);
}
vector jbroadcasted_my_bernoulli_logit_lpmfs(
    array[] int x1,
    vector x3
) {
    int n = dims(x1)[1];
    vector[n] rv;
    for(i in 1:n) {
        rv[i] = my_bernoulli_logit_lpmfs(broadcasted_getindex(x1, i), broadcasted_getindex(x3, i));
    }
    return rv;
}
// Needed for cross validation
real my_bernoulli_logit_lpmfs(
    int y,
    real args1
) {
    return my_bernoulli_logit_lpmf(y | args1);
}
// Needed to compute the joint likelihood
real my_bernoulli_logit_lpmf(
    int y,
    real theta
) {
    return bernoulli_logit_lpmf(y | theta);
}
int broadcasted_getindex(array[] int x, int i) {
    int m = dims(x)[1];
    return x[i];
}
func broadcasted_getindex_logit(int i) {
    return logit;
}
real broadcasted_getindex(vector x, int i) {
    int m = dims(x)[1];
    return x[i];
}
// Needed for posterior predictions
array[] int my_bernoulli_logit_rng(
    vector theta
) {
    return bernoulli_logit_rng(theta);
}
}
data {
    int y_n;
    array[y_n] int y;
}
transformed data {
    int n = dims(y)[1];
}
parameters {
    // The type of the `xi` parameter depends on the passed link function `link_f`
    ordered[n] theta_xi;
}
transformed parameters {
    // The negation is needed to ensure that `theta` is in descending order
    vector[n] theta = (-theta_xi);
}
model {
    // The type of the `xi` parameter depends on the passed link function `link_f`
    theta_xi ~ hetero_logit(n);
    // The exact implementation of the likelihood depends on the passed link function `link_f`
    y ~ my_bernoulli_logit(theta);
}
generated quantities {
    // The exact implementation of the likelihood depends on the passed link function `link_f`
    vector[y_n] y_likelihood = my_bernoulli_logit_lpmfs(y, theta);
    // The exact implementation of the likelihood depends on the passed link function `link_f`
    array[n] int y_gen = my_bernoulli_logit_rng(theta);
}
functions {
// The `Hetero` prior density after constraining - this function definition gets used to infer the type and shape of the parameter for the log link
real hetero_log_lpdf(
    vector x,
    int n
) {
    return hetero_lpdf(x);
}
// The `Hetero` prior density after constraining - common to both link functions
real hetero_lpdf(
    vector x
) {
    int n = dims(x)[1];
    return normal_lpdf(x | 0, 3);
}
// Needed to compute the joint likelihood
real my_bernoulli_log_lpmf(
    array[] int y,
    vector theta
) {
    return bernoulli_lpmf(y | exp(theta));
}
// Needed for cross validation
vector my_bernoulli_log_lpmfs(
    array[] int y,
    vector args1
) {
    int n = dims(y)[1];
    return jbroadcasted_my_bernoulli_log_lpmfs(y, args1);
}
vector jbroadcasted_my_bernoulli_log_lpmfs(
    array[] int x1,
    vector x3
) {
    int n = dims(x1)[1];
    vector[n] rv;
    for(i in 1:n) {
        rv[i] = my_bernoulli_log_lpmfs(broadcasted_getindex(x1, i), broadcasted_getindex(x3, i));
    }
    return rv;
}
// Needed for cross validation
real my_bernoulli_log_lpmfs(
    int y,
    real args1
) {
    return my_bernoulli_log_lpmf(y | args1);
}
// Needed to compute the joint likelihood
real my_bernoulli_log_lpmf(
    int y,
    real theta
) {
    return bernoulli_lpmf(y | exp(theta));
}
int broadcasted_getindex(array[] int x, int i) {
    int m = dims(x)[1];
    return x[i];
}
func broadcasted_getindex_log(int i) {
    return log;
}
real broadcasted_getindex(vector x, int i) {
    int m = dims(x)[1];
    return x[i];
}
// Needed for posterior predictions
array[] int my_bernoulli_log_rng(
    vector theta
) {
    return bernoulli_rng(exp(theta));
}
}
data {
    int y_n;
    array[y_n] int y;
}
transformed data {
    int n = dims(y)[1];
}
parameters {
    // The type of the `xi` parameter depends on the passed link function `link_f`
    positive_ordered[n] theta_xi;
}
transformed parameters {
    // The negation is needed to ensure that `theta` is in descending order
    vector[n] theta = (-theta_xi);
}
model {
    // The type of the `xi` parameter depends on the passed link function `link_f`
    theta_xi ~ hetero_log(n);
    // The exact implementation of the likelihood depends on the passed link function `link_f`
    y ~ my_bernoulli_log(theta);
}
generated quantities {
    // The exact implementation of the likelihood depends on the passed link function `link_f`
    vector[y_n] y_likelihood = my_bernoulli_log_lpmfs(y, theta);
    // The exact implementation of the likelihood depends on the passed link function `link_f`
    array[n] int y_gen = my_bernoulli_log_rng(theta);
}
  • logit
  • log
functions {
// The `RW(1)` prior density after constraining - this function definition gets used to infer the type and shape of the parameter for the logit link
real rw1_logit_lpdf(
    vector x,
    real sigma,
    int n
) {
    return rw1_lpdf(x | sigma);
}
// The `RW(1)` prior density after constraining - common to both link functions
real rw1_lpdf(
    vector x,
    real sigma
) {
    int n = dims(x)[1];
    return (normal_lpdf(x[1] | 0, sigma) + normal_lpdf(x[2:n] | x[1:(n - 1)], sigma));
}
// Needed to compute the joint likelihood
real my_bernoulli_logit_lpmf(
    array[] int y,
    vector theta
) {
    return bernoulli_logit_lpmf(y | theta);
}
// Needed for cross validation
vector my_bernoulli_logit_lpmfs(
    array[] int y,
    vector args1
) {
    int n = dims(y)[1];
    return jbroadcasted_my_bernoulli_logit_lpmfs(y, args1);
}
vector jbroadcasted_my_bernoulli_logit_lpmfs(
    array[] int x1,
    vector x3
) {
    int n = dims(x1)[1];
    vector[n] rv;
    for(i in 1:n) {
        rv[i] = my_bernoulli_logit_lpmfs(broadcasted_getindex(x1, i), broadcasted_getindex(x3, i));
    }
    return rv;
}
// Needed for cross validation
real my_bernoulli_logit_lpmfs(
    int y,
    real args1
) {
    return my_bernoulli_logit_lpmf(y | args1);
}
// Needed to compute the joint likelihood
real my_bernoulli_logit_lpmf(
    int y,
    real theta
) {
    return bernoulli_logit_lpmf(y | theta);
}
int broadcasted_getindex(array[] int x, int i) {
    int m = dims(x)[1];
    return x[i];
}
func broadcasted_getindex_logit(int i) {
    return logit;
}
real broadcasted_getindex(vector x, int i) {
    int m = dims(x)[1];
    return x[i];
}
// Needed for posterior predictions
array[] int my_bernoulli_logit_rng(
    vector theta
) {
    return bernoulli_logit_rng(theta);
}
}
data {
    int y_n;
    array[y_n] int y;
}
transformed data {
    int n = dims(y)[1];
}
parameters {
    real<lower=0> theta_sigma;
    // The type of the `xi` parameter depends on the passed link function `link_f`
    ordered[n] theta_xi;
}
transformed parameters {
    // The negation is needed to ensure that `theta` is in descending order
    vector[n] theta = (-theta_xi);
}
model {
    theta_sigma ~ std_normal();
    // The type of the `xi` parameter depends on the passed link function `link_f`
    theta_xi ~ rw1_logit(theta_sigma, n);
    // The exact implementation of the likelihood depends on the passed link function `link_f`
    y ~ my_bernoulli_logit(theta);
}
generated quantities {
    // The exact implementation of the likelihood depends on the passed link function `link_f`
    vector[y_n] y_likelihood = my_bernoulli_logit_lpmfs(y, theta);
    // The exact implementation of the likelihood depends on the passed link function `link_f`
    array[n] int y_gen = my_bernoulli_logit_rng(theta);
}
functions {
// The `RW(1)` prior density after constraining - this function definition gets used to infer the type and shape of the parameter for the log link
real rw1_log_lpdf(
    vector x,
    real sigma,
    int n
) {
    return rw1_lpdf(x | sigma);
}
// The `RW(1)` prior density after constraining - common to both link functions
real rw1_lpdf(
    vector x,
    real sigma
) {
    int n = dims(x)[1];
    return (normal_lpdf(x[1] | 0, sigma) + normal_lpdf(x[2:n] | x[1:(n - 1)], sigma));
}
// Needed to compute the joint likelihood
real my_bernoulli_log_lpmf(
    array[] int y,
    vector theta
) {
    return bernoulli_lpmf(y | exp(theta));
}
// Needed for cross validation
vector my_bernoulli_log_lpmfs(
    array[] int y,
    vector args1
) {
    int n = dims(y)[1];
    return jbroadcasted_my_bernoulli_log_lpmfs(y, args1);
}
vector jbroadcasted_my_bernoulli_log_lpmfs(
    array[] int x1,
    vector x3
) {
    int n = dims(x1)[1];
    vector[n] rv;
    for(i in 1:n) {
        rv[i] = my_bernoulli_log_lpmfs(broadcasted_getindex(x1, i), broadcasted_getindex(x3, i));
    }
    return rv;
}
// Needed for cross validation
real my_bernoulli_log_lpmfs(
    int y,
    real args1
) {
    return my_bernoulli_log_lpmf(y | args1);
}
// Needed to compute the joint likelihood
real my_bernoulli_log_lpmf(
    int y,
    real theta
) {
    return bernoulli_lpmf(y | exp(theta));
}
int broadcasted_getindex(array[] int x, int i) {
    int m = dims(x)[1];
    return x[i];
}
func broadcasted_getindex_log(int i) {
    return log;
}
real broadcasted_getindex(vector x, int i) {
    int m = dims(x)[1];
    return x[i];
}
// Needed for posterior predictions
array[] int my_bernoulli_log_rng(
    vector theta
) {
    return bernoulli_rng(exp(theta));
}
}
data {
    int y_n;
    array[y_n] int y;
}
transformed data {
    int n = dims(y)[1];
}
parameters {
    real<lower=0> theta_sigma;
    // The type of the `xi` parameter depends on the passed link function `link_f`
    positive_ordered[n] theta_xi;
}
transformed parameters {
    // The negation is needed to ensure that `theta` is in descending order
    vector[n] theta = (-theta_xi);
}
model {
    theta_sigma ~ std_normal();
    // The type of the `xi` parameter depends on the passed link function `link_f`
    theta_xi ~ rw1_log(theta_sigma, n);
    // The exact implementation of the likelihood depends on the passed link function `link_f`
    y ~ my_bernoulli_log(theta);
}
generated quantities {
    // The exact implementation of the likelihood depends on the passed link function `link_f`
    vector[y_n] y_likelihood = my_bernoulli_log_lpmfs(y, theta);
    // The exact implementation of the likelihood depends on the passed link function `link_f`
    array[n] int y_gen = my_bernoulli_log_rng(theta);
}
  • logit
  • log
functions {
// The `RW(2)` prior density after constraining - this function definition gets used to infer the type and shape of the parameter for the logit link
real rw2_logit_lpdf(
    vector x,
    real sigma,
    int n
) {
    return rw2_lpdf(x | sigma);
}
// The `RW(2)` prior density after constraining - common to both link functions
real rw2_lpdf(
    vector x,
    real sigma
) {
    int n = dims(x)[1];
    return (rw1_lpdf(x[1:2] | sigma) + normal_lpdf(x[3:n] | ((2 * x[2:(n - 1)]) - x[1:(n - 2)]), sigma));
}
// The `RW(1)` prior density after constraining - common to both link functions
real rw1_lpdf(
    vector x,
    real sigma
) {
    int n = dims(x)[1];
    return (normal_lpdf(x[1] | 0, sigma) + normal_lpdf(x[2:n] | x[1:(n - 1)], sigma));
}
// Needed to compute the joint likelihood
real my_bernoulli_logit_lpmf(
    array[] int y,
    vector theta
) {
    return bernoulli_logit_lpmf(y | theta);
}
// Needed for cross validation
vector my_bernoulli_logit_lpmfs(
    array[] int y,
    vector args1
) {
    int n = dims(y)[1];
    return jbroadcasted_my_bernoulli_logit_lpmfs(y, args1);
}
vector jbroadcasted_my_bernoulli_logit_lpmfs(
    array[] int x1,
    vector x3
) {
    int n = dims(x1)[1];
    vector[n] rv;
    for(i in 1:n) {
        rv[i] = my_bernoulli_logit_lpmfs(broadcasted_getindex(x1, i), broadcasted_getindex(x3, i));
    }
    return rv;
}
// Needed for cross validation
real my_bernoulli_logit_lpmfs(
    int y,
    real args1
) {
    return my_bernoulli_logit_lpmf(y | args1);
}
// Needed to compute the joint likelihood
real my_bernoulli_logit_lpmf(
    int y,
    real theta
) {
    return bernoulli_logit_lpmf(y | theta);
}
int broadcasted_getindex(array[] int x, int i) {
    int m = dims(x)[1];
    return x[i];
}
func broadcasted_getindex_logit(int i) {
    return logit;
}
real broadcasted_getindex(vector x, int i) {
    int m = dims(x)[1];
    return x[i];
}
// Needed for posterior predictions
array[] int my_bernoulli_logit_rng(
    vector theta
) {
    return bernoulli_logit_rng(theta);
}
}
data {
    int y_n;
    array[y_n] int y;
}
transformed data {
    int n = dims(y)[1];
}
parameters {
    real<lower=0> theta_sigma;
    // The type of the `xi` parameter depends on the passed link function `link_f`
    ordered[n] theta_xi;
}
transformed parameters {
    // The negation is needed to ensure that `theta` is in descending order
    vector[n] theta = (-theta_xi);
}
model {
    theta_sigma ~ normal(0, 0.5);
    // The type of the `xi` parameter depends on the passed link function `link_f`
    theta_xi ~ rw2_logit(theta_sigma, n);
    // The exact implementation of the likelihood depends on the passed link function `link_f`
    y ~ my_bernoulli_logit(theta);
}
generated quantities {
    // The exact implementation of the likelihood depends on the passed link function `link_f`
    vector[y_n] y_likelihood = my_bernoulli_logit_lpmfs(y, theta);
    // The exact implementation of the likelihood depends on the passed link function `link_f`
    array[n] int y_gen = my_bernoulli_logit_rng(theta);
}
functions {
// The `RW(2)` prior density after constraining - this function definition gets used to infer the type and shape of the parameter for the log link
real rw2_log_lpdf(
    vector x,
    real sigma,
    int n
) {
    return rw2_lpdf(x | sigma);
}
// The `RW(2)` prior density after constraining - common to both link functions
real rw2_lpdf(
    vector x,
    real sigma
) {
    int n = dims(x)[1];
    return (rw1_lpdf(x[1:2] | sigma) + normal_lpdf(x[3:n] | ((2 * x[2:(n - 1)]) - x[1:(n - 2)]), sigma));
}
// The `RW(1)` prior density after constraining - common to both link functions
real rw1_lpdf(
    vector x,
    real sigma
) {
    int n = dims(x)[1];
    return (normal_lpdf(x[1] | 0, sigma) + normal_lpdf(x[2:n] | x[1:(n - 1)], sigma));
}
// Needed to compute the joint likelihood
real my_bernoulli_log_lpmf(
    array[] int y,
    vector theta
) {
    return bernoulli_lpmf(y | exp(theta));
}
// Needed for cross validation
vector my_bernoulli_log_lpmfs(
    array[] int y,
    vector args1
) {
    int n = dims(y)[1];
    return jbroadcasted_my_bernoulli_log_lpmfs(y, args1);
}
vector jbroadcasted_my_bernoulli_log_lpmfs(
    array[] int x1,
    vector x3
) {
    int n = dims(x1)[1];
    vector[n] rv;
    for(i in 1:n) {
        rv[i] = my_bernoulli_log_lpmfs(broadcasted_getindex(x1, i), broadcasted_getindex(x3, i));
    }
    return rv;
}
// Needed for cross validation
real my_bernoulli_log_lpmfs(
    int y,
    real args1
) {
    return my_bernoulli_log_lpmf(y | args1);
}
// Needed to compute the joint likelihood
real my_bernoulli_log_lpmf(
    int y,
    real theta
) {
    return bernoulli_lpmf(y | exp(theta));
}
int broadcasted_getindex(array[] int x, int i) {
    int m = dims(x)[1];
    return x[i];
}
func broadcasted_getindex_log(int i) {
    return log;
}
real broadcasted_getindex(vector x, int i) {
    int m = dims(x)[1];
    return x[i];
}
// Needed for posterior predictions
array[] int my_bernoulli_log_rng(
    vector theta
) {
    return bernoulli_rng(exp(theta));
}
}
data {
    int y_n;
    array[y_n] int y;
}
transformed data {
    int n = dims(y)[1];
}
parameters {
    real<lower=0> theta_sigma;
    // The type of the `xi` parameter depends on the passed link function `link_f`
    positive_ordered[n] theta_xi;
}
transformed parameters {
    // The negation is needed to ensure that `theta` is in descending order
    vector[n] theta = (-theta_xi);
}
model {
    theta_sigma ~ normal(0, 0.5);
    // The type of the `xi` parameter depends on the passed link function `link_f`
    theta_xi ~ rw2_log(theta_sigma, n);
    // The exact implementation of the likelihood depends on the passed link function `link_f`
    y ~ my_bernoulli_log(theta);
}
generated quantities {
    // The exact implementation of the likelihood depends on the passed link function `link_f`
    vector[y_n] y_likelihood = my_bernoulli_log_lpmfs(y, theta);
    // The exact implementation of the likelihood depends on the passed link function `link_f`
    array[n] int y_gen = my_bernoulli_log_rng(theta);
}
  • logit
  • log
functions {
// The upper bound of the alpha parameter (for the logit link function)
real upper_alpha_logit(
    
) {
    return negative_infinity();
}
// Needed to compute the joint likelihood
real my_bernoulli_logit_lpmf(
    array[] int y,
    vector theta
) {
    return bernoulli_logit_lpmf(y | theta);
}
// Needed for cross validation
vector my_bernoulli_logit_lpmfs(
    array[] int y,
    vector args1
) {
    int n = dims(y)[1];
    return jbroadcasted_my_bernoulli_logit_lpmfs(y, args1);
}
vector jbroadcasted_my_bernoulli_logit_lpmfs(
    array[] int x1,
    vector x3
) {
    int n = dims(x1)[1];
    vector[n] rv;
    for(i in 1:n) {
        rv[i] = my_bernoulli_logit_lpmfs(broadcasted_getindex(x1, i), broadcasted_getindex(x3, i));
    }
    return rv;
}
// Needed for cross validation
real my_bernoulli_logit_lpmfs(
    int y,
    real args1
) {
    return my_bernoulli_logit_lpmf(y | args1);
}
// Needed to compute the joint likelihood
real my_bernoulli_logit_lpmf(
    int y,
    real theta
) {
    return bernoulli_logit_lpmf(y | theta);
}
int broadcasted_getindex(array[] int x, int i) {
    int m = dims(x)[1];
    return x[i];
}
func broadcasted_getindex_logit(int i) {
    return logit;
}
real broadcasted_getindex(vector x, int i) {
    int m = dims(x)[1];
    return x[i];
}
// Needed for posterior predictions
array[] int my_bernoulli_logit_rng(
    vector theta
) {
    return bernoulli_logit_rng(theta);
}
}
data {
    int y_n;
    array[y_n] int y;
    int t_n;
    vector[t_n] t;
}
transformed data {
    int n = dims(y)[1];
    // The upper bound of alpha `alpha_upper` depends on the link function `link_f`
    real theta_alpha_upper = upper_alpha_logit();
}
parameters {
    real<upper=theta_alpha_upper> theta_alpha;
    real<upper=0.0> theta_beta;
}
transformed parameters {
    vector[t_n] theta = (theta_alpha + (theta_beta * t));
}
model {
    theta_alpha ~ normal(0, 0.5);
    theta_beta ~ normal(0, 0.5);
    // The exact implementation of the likelihood depends on the passed link function `link_f`
    y ~ my_bernoulli_logit(theta);
}
generated quantities {
    // The exact implementation of the likelihood depends on the passed link function `link_f`
    vector[y_n] y_likelihood = my_bernoulli_logit_lpmfs(y, theta);
    // The exact implementation of the likelihood depends on the passed link function `link_f`
    array[t_n] int y_gen = my_bernoulli_logit_rng(theta);
}
functions {
// The upper bound of the alpha parameter (for the log link function)
int upper_alpha_log(
    
) {
    return 0;
}
// Needed to compute the joint likelihood
real my_bernoulli_log_lpmf(
    array[] int y,
    vector theta
) {
    return bernoulli_lpmf(y | exp(theta));
}
// Needed for cross validation
vector my_bernoulli_log_lpmfs(
    array[] int y,
    vector args1
) {
    int n = dims(y)[1];
    return jbroadcasted_my_bernoulli_log_lpmfs(y, args1);
}
vector jbroadcasted_my_bernoulli_log_lpmfs(
    array[] int x1,
    vector x3
) {
    int n = dims(x1)[1];
    vector[n] rv;
    for(i in 1:n) {
        rv[i] = my_bernoulli_log_lpmfs(broadcasted_getindex(x1, i), broadcasted_getindex(x3, i));
    }
    return rv;
}
// Needed for cross validation
real my_bernoulli_log_lpmfs(
    int y,
    real args1
) {
    return my_bernoulli_log_lpmf(y | args1);
}
// Needed to compute the joint likelihood
real my_bernoulli_log_lpmf(
    int y,
    real theta
) {
    return bernoulli_lpmf(y | exp(theta));
}
int broadcasted_getindex(array[] int x, int i) {
    int m = dims(x)[1];
    return x[i];
}
func broadcasted_getindex_log(int i) {
    return log;
}
real broadcasted_getindex(vector x, int i) {
    int m = dims(x)[1];
    return x[i];
}
// Needed for posterior predictions
array[] int my_bernoulli_log_rng(
    vector theta
) {
    return bernoulli_rng(exp(theta));
}
}
data {
    int y_n;
    array[y_n] int y;
    int t_n;
    vector[t_n] t;
}
transformed data {
    int n = dims(y)[1];
    // The upper bound of alpha `alpha_upper` depends on the link function `link_f`
    int theta_alpha_upper = upper_alpha_log();
}
parameters {
    real<upper=theta_alpha_upper> theta_alpha;
    real<upper=0.0> theta_beta;
}
transformed parameters {
    vector[t_n] theta = (theta_alpha + (theta_beta * t));
}
model {
    theta_alpha ~ normal(0, 0.5);
    theta_beta ~ normal(0, 0.5);
    // The exact implementation of the likelihood depends on the passed link function `link_f`
    y ~ my_bernoulli_log(theta);
}
generated quantities {
    // The exact implementation of the likelihood depends on the passed link function `link_f`
    vector[y_n] y_likelihood = my_bernoulli_log_lpmfs(y, theta);
    // The exact implementation of the likelihood depends on the passed link function `link_f`
    array[t_n] int y_gen = my_bernoulli_log_rng(theta);
}
  • logit
  • log
functions {
// The upper bound of the alpha parameter (for the logit link function)
real upper_alpha_logit(
    
) {
    return negative_infinity();
}
vector bernoulli_lpmfs(
    array[] int obs,
    vector args1
) {
    int n = dims(obs)[1];
    return jbroadcasted_bernoulli_lpmfs(obs, args1);
}
vector jbroadcasted_bernoulli_lpmfs(
    array[] int x1,
    vector x2
) {
    int n = dims(x1)[1];
    vector[n] rv;
    for(i in 1:n) {
        rv[i] = bernoulli_lpmfs(broadcasted_getindex(x1, i), broadcasted_getindex(x2, i));
    }
    return rv;
}
real bernoulli_lpmfs(int args1, real args2) {
    return bernoulli_lpmf(args1 | args2);
}
int broadcasted_getindex(array[] int x, int i) {
    int m = dims(x)[1];
    return x[i];
}
real broadcasted_getindex(vector x, int i) {
    int m = dims(x)[1];
    return x[i];
}
// Needed to compute the joint likelihood
real my_bernoulli_logit_lpmf(
    array[] int y,
    vector theta
) {
    return bernoulli_logit_lpmf(y | theta);
}
// Needed for cross validation
vector my_bernoulli_logit_lpmfs(
    array[] int y,
    vector args1
) {
    int n = dims(y)[1];
    return jbroadcasted_my_bernoulli_logit_lpmfs(y, args1);
}
vector jbroadcasted_my_bernoulli_logit_lpmfs(
    array[] int x1,
    vector x3
) {
    int n = dims(x1)[1];
    vector[n] rv;
    for(i in 1:n) {
        rv[i] = my_bernoulli_logit_lpmfs(broadcasted_getindex(x1, i), broadcasted_getindex(x3, i));
    }
    return rv;
}
// Needed for cross validation
real my_bernoulli_logit_lpmfs(
    int y,
    real args1
) {
    return my_bernoulli_logit_lpmf(y | args1);
}
// Needed to compute the joint likelihood
real my_bernoulli_logit_lpmf(
    int y,
    real theta
) {
    return bernoulli_logit_lpmf(y | theta);
}
func broadcasted_getindex_logit(int i) {
    return logit;
}
// Needed for posterior predictions
array[] int my_bernoulli_logit_rng(
    vector theta
) {
    return bernoulli_logit_rng(theta);
}
}
data {
    int t_n;
    vector[t_n] t;
    int y_n;
    array[y_n] int y;
}
transformed data {
    // The upper bound of alpha `alpha_upper` depends on the link function `link_f`
    real theta_c1_alpha_upper = upper_alpha_logit();
    // The upper bound of alpha `alpha_upper` depends on the link function `link_f`
    real theta_c2_alpha_upper = upper_alpha_logit();
    int n = dims(y)[1];
}
parameters {
    real<upper=theta_c1_alpha_upper> theta_c1_alpha;
    real<upper=0.0> theta_c1_beta;
    real<upper=theta_c2_alpha_upper> theta_c2_alpha;
    real<upper=0.0> theta_c2_beta;
    real<lower=0, upper=1> theta_lambda;
}
transformed parameters {
    vector[t_n] theta_c1 = (theta_c1_alpha + (theta_c1_beta * t));
    vector[t_n] theta_c2 = (theta_c2_alpha + (theta_c2_beta * t));
    vector[t_n] theta = ((theta_lambda * logit(theta_c1)) + ((1 - theta_lambda) * logit(theta_c2)));
}
model {
    theta_c1_alpha ~ normal(0, 0.5);
    theta_c1_beta ~ normal(0, 0.5);
    theta_c2_alpha ~ normal(0, 0.5);
    theta_c2_beta ~ normal(0, 0.5);
    theta_lambda ~ beta(2, 2);
    y ~ bernoulli(theta);
    // The exact implementation of the likelihood depends on the passed link function `link_f`
    y ~ my_bernoulli_logit(theta);
}
generated quantities {
    vector[y_n] y_likelihood = bernoulli_lpmfs(y, theta);
    array[t_n] int y_gen = bernoulli_rng(theta);
    // The exact implementation of the likelihood depends on the passed link function `link_f`
    vector[y_n] y_likelihood = my_bernoulli_logit_lpmfs(y, theta);
    // The exact implementation of the likelihood depends on the passed link function `link_f`
    array[t_n] int y_gen = my_bernoulli_logit_rng(theta);
}
functions {
// The upper bound of the alpha parameter (for the log link function)
int upper_alpha_log(
    
) {
    return 0;
}
vector bernoulli_lpmfs(
    array[] int obs,
    vector args1
) {
    int n = dims(obs)[1];
    return jbroadcasted_bernoulli_lpmfs(obs, args1);
}
vector jbroadcasted_bernoulli_lpmfs(
    array[] int x1,
    vector x2
) {
    int n = dims(x1)[1];
    vector[n] rv;
    for(i in 1:n) {
        rv[i] = bernoulli_lpmfs(broadcasted_getindex(x1, i), broadcasted_getindex(x2, i));
    }
    return rv;
}
real bernoulli_lpmfs(int args1, real args2) {
    return bernoulli_lpmf(args1 | args2);
}
int broadcasted_getindex(array[] int x, int i) {
    int m = dims(x)[1];
    return x[i];
}
real broadcasted_getindex(vector x, int i) {
    int m = dims(x)[1];
    return x[i];
}
// Needed to compute the joint likelihood
real my_bernoulli_log_lpmf(
    array[] int y,
    vector theta
) {
    return bernoulli_lpmf(y | exp(theta));
}
// Needed for cross validation
vector my_bernoulli_log_lpmfs(
    array[] int y,
    vector args1
) {
    int n = dims(y)[1];
    return jbroadcasted_my_bernoulli_log_lpmfs(y, args1);
}
vector jbroadcasted_my_bernoulli_log_lpmfs(
    array[] int x1,
    vector x3
) {
    int n = dims(x1)[1];
    vector[n] rv;
    for(i in 1:n) {
        rv[i] = my_bernoulli_log_lpmfs(broadcasted_getindex(x1, i), broadcasted_getindex(x3, i));
    }
    return rv;
}
// Needed for cross validation
real my_bernoulli_log_lpmfs(
    int y,
    real args1
) {
    return my_bernoulli_log_lpmf(y | args1);
}
// Needed to compute the joint likelihood
real my_bernoulli_log_lpmf(
    int y,
    real theta
) {
    return bernoulli_lpmf(y | exp(theta));
}
func broadcasted_getindex_log(int i) {
    return log;
}
// Needed for posterior predictions
array[] int my_bernoulli_log_rng(
    vector theta
) {
    return bernoulli_rng(exp(theta));
}
}
data {
    int t_n;
    vector[t_n] t;
    int y_n;
    array[y_n] int y;
}
transformed data {
    // The upper bound of alpha `alpha_upper` depends on the link function `link_f`
    int theta_c1_alpha_upper = upper_alpha_log();
    // The upper bound of alpha `alpha_upper` depends on the link function `link_f`
    int theta_c2_alpha_upper = upper_alpha_log();
    int n = dims(y)[1];
}
parameters {
    real<upper=theta_c1_alpha_upper> theta_c1_alpha;
    real<upper=0.0> theta_c1_beta;
    real<upper=theta_c2_alpha_upper> theta_c2_alpha;
    real<upper=0.0> theta_c2_beta;
    real<lower=0, upper=1> theta_lambda;
}
transformed parameters {
    vector[t_n] theta_c1 = (theta_c1_alpha + (theta_c1_beta * t));
    vector[t_n] theta_c2 = (theta_c2_alpha + (theta_c2_beta * t));
    vector[t_n] theta = ((theta_lambda * log(theta_c1)) + ((1 - theta_lambda) * log(theta_c2)));
}
model {
    theta_c1_alpha ~ normal(0, 0.5);
    theta_c1_beta ~ normal(0, 0.5);
    theta_c2_alpha ~ normal(0, 0.5);
    theta_c2_beta ~ normal(0, 0.5);
    theta_lambda ~ beta(2, 2);
    y ~ bernoulli(theta);
    // The exact implementation of the likelihood depends on the passed link function `link_f`
    y ~ my_bernoulli_log(theta);
}
generated quantities {
    vector[y_n] y_likelihood = bernoulli_lpmfs(y, theta);
    array[t_n] int y_gen = bernoulli_rng(theta);
    // The exact implementation of the likelihood depends on the passed link function `link_f`
    vector[y_n] y_likelihood = my_bernoulli_log_lpmfs(y, theta);
    // The exact implementation of the likelihood depends on the passed link function `link_f`
    array[t_n] int y_gen = my_bernoulli_log_rng(theta);
}

How does StanBlocks.jl infer type and dimension?

The way StanBlocks.jl currently operates to figure out the type and dimension of a model parameter defined by

theta ~ distribution(foo, bar)

is via looking for a function definition

distribution_lpdf(theta::theta_type[theta_dim1, ...], foo::foo_type[...], bar::bar_type[...]) = ...

matching the types of foo and bar. theta_type then has to be one of Stan’s built-in types, and theta_dim, ... can be expressions depending on foo, foo_type[...], bar, and bar_type[...].

For example, the built-in definition

multi_normal_lpdf(obs::vector[n], loc::vector[n], cov)

tells StanBlocks.jl that any parameter theta initialized via

theta ~ multi_normal(loc, cov)

will be a vector of the same shape as loc.

Once I add this feature, an alternative syntax to specify types and dimensions could be

theta::vector[n] ~ distribution(foo, bar)

which would communicate intent more clearly, both to the programmer and the compiler.

Source Code
---
title: "Reimplementing the Stan models from https://github.com/bob-carpenter/pcr-sensitivity-vs-time"
---

The full Julia code for this notebook can be accessed via the top right corner (`</> Code`).

The Julia packages needed to reproduce this document are [`StanBlocks.jl`](https://github.com/nsiccha/StanBlocks.jl) (for the model generation) and [`QuartoComponents.jl`](https://github.com/nsiccha/QuartoComponents.jl) (for the "pretty" printing). 
Both packages have to be installed from the latest `main` branch (as of Oct 14th 2025).

I do agree that only having the option of inferring the type/transformation of a parameter from its sampling distribution can be unnecessarily magical - I have hence opened [this feature "reminder"](https://github.com/nsiccha/StanBlocks.jl/issues/31) (I can obviously not "request" things from myself).

## StanBlocks.jl implementation

The function and model definitions below make use of

* variadic functions (and argument splatting) - ["It is often convenient to be able to write functions taking an arbitrary number of arguments.](https://docs.julialang.org/en/v1/manual/functions/#Varargs-Functions),
* function-type arguments and dispatch/method selection via their type - [in Julia, "[e]ach function has its own type, which is a subtype of Function."](https://docs.julialang.org/en/v1/manual/types/#Types-of-functions),
* "untyped" arguments,
* "automatic" type and dimension inference ([see below](#how-does-stanblocks.jl-infer-type-and-dimension)) - partly because currently no other way of specifying type and dimension of a parameter is implemented. As stated above, I do agree that this is slightly too magical, and an optional(?) inline type annotation could be clearer.

A few words on each of these points:

### Variadic functions

The below lines are two examples of a variadic method definition for the `my_bernoulli_lpmfs` function,

```julia
my_bernoulli_lpmfs(y::int[n], args...) = jbroadcasted(my_bernoulli_lpmfs, y, args...)
my_bernoulli_lpmfs(y::int, args...) = my_bernoulli_lpmf(y, args...)
```
used in the computation of the pointwise log likelihoods. 
On the left hand side, `args...` will simply match all trailing positional arguments after the first one, 
and on the right hand side these arguments will be forwarded to the built-in `jbroadcasted` function,
which mimics Julia-style broadcasting of its first (function-type) argument over all other arguments.

In the below models, `my_bernoulli_lpmfs` will be called with the following signatures:

```julia
my_bernoulli_lpmfs(y::int[n], f::typeof(logit), theta::vector[n])
my_bernoulli_lpmfs(y::int[n], f::typeof(log), theta::vector[n])
my_bernoulli_lpmfs(y::int, f::typeof(logit), theta::real)
my_bernoulli_lpmfs(y::int, f::typeof(log), theta::real)
```
all of which will be covered by the variadic function definitions at the beginning of this section.

### Function-type arguments and method selection via their type

The below are the simplest possible method definitions which depend on the type of a function-type argument:

```julia
upper_alpha(::typeof(logit)) = negative_infinity()
upper_alpha(::typeof(log)) = 0
```

The defined function, `upper_alpha`, can be called in one of two ways: 

* Either as `upper_alpha(logit)`, matching the first method definition and thus returning negative infinity, or
* as `upper_alpha(log)`, matching the second method definition and thus returning zero. 

The above function gets used in the `regression` and `regression_mix` model to make the upper bound of the `alpha` parameter
depend on the link function `link_f`, which can be either `logit` or `log`.

**A slightly more complex example** would be the following:
```julia
my_bernoulli_lpmf(y, ::typeof(logit), theta) = bernoulli_logit_lpmf(y, theta)
my_bernoulli_lpmf(y, ::typeof(log), theta) = bernoulli_lpmf(y, exp(theta))
```
which gets used to make the likelihood implementation depend on the link function of the model, allowing us 

* to forward `y` and `theta` to `bernoulli_logit_lpmf(y, theta)`, or
* to forward `y` and `theta` to `bernoulli_lpmf(y, exp(theta))`.

### Untyped or abstractly typed arguments

Untyped function arguments are simply arguments for which we don't specify the type beforehand, allowing it to match any passed in type.
Do note that this can lead to ["Method Ambiguities"](https://docs.julialang.org/en/v1/manual/methods/#man-ambiguities) - something that cannot happen in Stan because you always have to specify the **concrete** types of all function arguments.
[StanBlocks.jl implements a limited abstract type hierarchy](https://github.com/nsiccha/StanBlocks.jl/blob/ab181eba40b5d2b7bbcf30cc283470df95b7d4bd/src/slic_stan/functions.jl#L1-L21), starting at the top with `anything`, and e.g. descending towards `ordered` as `anything -> any_vector -> vector -> ordered`.

## Full Julia + StanBlocks.jl code to define the models

The following reproduces all of the code necessary to implement the 2x5 model matrix (printing excluded):

```julia
using StanBlocks

import StanBlocks.stan: logit

@deffun begin 
    "Needed for cross validation"
    my_bernoulli_lpmfs(y::int[n], args...) = jbroadcasted(my_bernoulli_lpmfs, y, args...)
    "Needed for cross validation"
    my_bernoulli_lpmfs(y::int, args...) = my_bernoulli_lpmf(y, args...)
    "Needed to compute the joint likelihood"
    my_bernoulli_lpmf(y, ::typeof(logit), theta) = bernoulli_logit_lpmf(y, theta)
    "Needed for posterior predictions"
    my_bernoulli_rng(::typeof(logit), theta) = bernoulli_logit_rng(theta)
    "Needed to compute the joint likelihood"
    my_bernoulli_lpmf(y, ::typeof(log), theta) = bernoulli_lpmf(y, exp(theta))
    "Needed for posterior predictions"
    my_bernoulli_rng(::typeof(log), theta) = bernoulli_rng(exp(theta))

    "The `Hetero` prior density after constraining - common to both link functions"
    hetero_lpdf(x::vector[n]) = normal_lpdf(x, 0, 3)
    "The `Hetero` prior density after constraining - this function definition gets used to infer the type and shape of the parameter for the logit link"
    hetero_lpdf(x::ordered[n], ::typeof(logit), n) = hetero_lpdf(x)
    "The `Hetero` prior density after constraining - this function definition gets used to infer the type and shape of the parameter for the log link"
    hetero_lpdf(x::positive_ordered[n], ::typeof(log), n) = hetero_lpdf(x)
    "The `RW(1)` prior density after constraining - common to both link functions"
    rw1_lpdf(x::vector[n], sigma) = normal_lpdf(x[1], 0, sigma) + normal_lpdf(x[2:n], x[1:n-1], sigma)
    "The `RW(1)` prior density after constraining - this function definition gets used to infer the type and shape of the parameter for the logit link"
    rw1_lpdf(x::ordered[n], ::typeof(logit), sigma, n) = rw1_lpdf(x, sigma)
    "The `RW(1)` prior density after constraining - this function definition gets used to infer the type and shape of the parameter for the log link"
    rw1_lpdf(x::positive_ordered[n], ::typeof(log), sigma, n) = rw1_lpdf(x, sigma)
    "The `RW(2)` prior density after constraining - common to both link functions"
    rw2_lpdf(x::vector[n], sigma) = rw1_lpdf(x[1:2], sigma) + normal_lpdf(x[3:n], 2x[2:n-1] - x[1:n-2], sigma)
    "The `RW(2)` prior density after constraining - this function definition gets used to infer the type and shape of the parameter for the logit link"
    rw2_lpdf(x::ordered[n], ::typeof(logit), sigma, n) = rw2_lpdf(x, sigma)
    "The `RW(2)` prior density after constraining - this function definition gets used to infer the type and shape of the parameter for the log link"
    rw2_lpdf(x::positive_ordered[n], ::typeof(log), sigma, n) = rw2_lpdf(x, sigma)
    "The upper bound of the alpha parameter (for the logit link function)"
    upper_alpha(::typeof(logit)) = negative_infinity()
    "The upper bound of the alpha parameter (for the log link function)"
    upper_alpha(::typeof(log)) = 0
end
# The main things that StanBlocks.jl has to know is that `y` is an array of `int`s, and `t` is a vector.
mock_data = (;y=[1], t=[1.])

# The base model with the mock data attached - this does not yet have an implementation for theta 
base_model = @slic mock_data begin 
    n = dims(y)[1]
    "The exact implementation of the likelihood depends on the passed link function `link_f`"
    y ~ my_bernoulli(link_f, theta)
end
# The submodel for the `Hetero` models
centered_hetero = @slic begin 
    "The type of the `xi` parameter depends on the passed link function `link_f`"
    xi ~ hetero(link_f, n)
    "The negation is needed to ensure that `theta` is in descending order"
    return -xi
end
# The submodel for the `RW(1)` models
centered_rw1 = @slic begin 
    sigma ~ std_normal(;lower=0)
    "The type of the `xi` parameter depends on the passed link function `link_f`"
    xi ~ rw1(link_f, sigma, n)
    "The negation is needed to ensure that `theta` is in descending order"
    return -xi
end
# The submodel for the `RW(2)` models
centered_rw2 = @slic begin 
    sigma ~ normal(0, .5; lower=0)
    "The type of the `xi` parameter depends on the passed link function `link_f`"
    xi ~ rw2(link_f, sigma, n)
    "The negation is needed to ensure that `theta` is in descending order"
    return -xi
end
# The submodel for the `regression` models - reused in the `regression_mix` models
regression = @slic begin 
    "The upper bound of alpha `alpha_upper` depends on the link function `link_f`"
    alpha_upper = upper_alpha(link_f)
    # This could be `alpha ~ normal(0, .5; upper=upper_alpha(link_f))`, once https://github.com/nsiccha/StanBlocks.jl/issues/35 is fixed
    alpha ~ normal(0, .5; upper=alpha_upper)
    beta ~ normal(0, .5; upper=0.)
    return alpha + beta * t
end
# The submodel for the `regression_mix` models - reusing the `regression` submodel
regression_mix = @slic begin 
    c1 ~ regression(;link_f, t) 
    c2 ~ regression(;link_f, t)
    lambda ~ beta(2, 2)
    return lambda * link_f(c1) + (1-lambda) * link_f(c2)
end

bases = (;
    hetero=base_model(quote 
        theta ~ centered_hetero(;link_f, n)
    end),
    rw1=base_model(quote 
        theta ~ centered_rw1(;link_f, n)
    end),
    rw2=base_model(quote 
        theta ~ centered_rw2(;link_f, n)
    end),
    regression=base_model(quote 
        theta ~ regression(;link_f, t)
    end),
    regression_mix=base_model(quote 
        theta ~ regression_mix(;link_f, t)
        y ~ bernoulli(theta)
    end)
)
link_fs = (;logit, log)

# `posteriors` will be a (nested) named tuple - accessing e.g. the `Hetero (log)` model works via `posteriors.hetero.log`
posteriors = map(bases) do base 
    map(link_fs) do link_f 
        base(;link_f)
    end
end
```

## Generated Stan code

The generated Stan code below is accessible via two nested tabsets. 
The top level (with keys `hetero`, `rw1`, `rw2`, `regression`, and `regression_mix`) 
combines with the link function in the second level (with keys `logit` and `log`)
to give you access to the corresponding Stan models from the poster.

::: {.callout-warning}

Due to [this issue](https://github.com/nsiccha/StanBlocks.jl/issues/38), the below Stan codes should not actually compile as they are.

I think removing the offending UDF definitions should make compilation work.

:::

```{julia}
using StanBlocks, QuartoComponents

import StanBlocks.stan: logit

@deffun begin 
    "Needed for cross validation"
    my_bernoulli_lpmfs(y::int[n], args...) = jbroadcasted(my_bernoulli_lpmfs, y, args...)
    "Needed for cross validation"
    my_bernoulli_lpmfs(y::int, args...) = my_bernoulli_lpmf(y, args...)
    "Needed to compute the joint likelihood"
    my_bernoulli_lpmf(y, ::typeof(logit), theta) = bernoulli_logit_lpmf(y, theta)
    "Needed for posterior predictions"
    my_bernoulli_rng(::typeof(logit), theta) = bernoulli_logit_rng(theta)
    "Needed to compute the joint likelihood"
    my_bernoulli_lpmf(y, ::typeof(log), theta) = bernoulli_lpmf(y, exp(theta))
    "Needed for posterior predictions"
    my_bernoulli_rng(::typeof(log), theta) = bernoulli_rng(exp(theta))

    "The `Hetero` prior density after constraining - common to both link functions"
    hetero_lpdf(x::vector[n]) = normal_lpdf(x, 0, 3)
    "The `Hetero` prior density after constraining - this function definition gets used to infer the type and shape of the parameter for the logit link"
    hetero_lpdf(x::ordered[n], ::typeof(logit), n) = hetero_lpdf(x)
    "The `Hetero` prior density after constraining - this function definition gets used to infer the type and shape of the parameter for the log link"
    hetero_lpdf(x::positive_ordered[n], ::typeof(log), n) = hetero_lpdf(x)
    "The `RW(1)` prior density after constraining - common to both link functions"
    rw1_lpdf(x::vector[n], sigma) = normal_lpdf(x[1], 0, sigma) + normal_lpdf(x[2:n], x[1:n-1], sigma)
    "The `RW(1)` prior density after constraining - this function definition gets used to infer the type and shape of the parameter for the logit link"
    rw1_lpdf(x::ordered[n], ::typeof(logit), sigma, n) = rw1_lpdf(x, sigma)
    "The `RW(1)` prior density after constraining - this function definition gets used to infer the type and shape of the parameter for the log link"
    rw1_lpdf(x::positive_ordered[n], ::typeof(log), sigma, n) = rw1_lpdf(x, sigma)
    "The `RW(2)` prior density after constraining - common to both link functions"
    rw2_lpdf(x::vector[n], sigma) = rw1_lpdf(x[1:2], sigma) + normal_lpdf(x[3:n], 2x[2:n-1] - x[1:n-2], sigma)
    "The `RW(2)` prior density after constraining - this function definition gets used to infer the type and shape of the parameter for the logit link"
    rw2_lpdf(x::ordered[n], ::typeof(logit), sigma, n) = rw2_lpdf(x, sigma)
    "The `RW(2)` prior density after constraining - this function definition gets used to infer the type and shape of the parameter for the log link"
    rw2_lpdf(x::positive_ordered[n], ::typeof(log), sigma, n) = rw2_lpdf(x, sigma)
    "The upper bound of the alpha parameter (for the logit link function)"
    upper_alpha(::typeof(logit)) = negative_infinity()
    "The upper bound of the alpha parameter (for the log link function)"
    upper_alpha(::typeof(log)) = 0
end
# The main things that StanBlocks.jl has to know is that `y` is an array of `int`s, and `t` is a vector.
mock_data = (;y=[1], t=[1.])

# The base model with the mock data attached - this does not yet have an implementation for theta 
base_model = @slic mock_data begin 
    n = dims(y)[1]
    "The exact implementation of the likelihood depends on the passed link function `link_f`"
    y ~ my_bernoulli(link_f, theta)
end
# The submodel for the `Hetero` models
centered_hetero = @slic begin 
    "The type of the `xi` parameter depends on the passed link function `link_f`"
    xi ~ hetero(link_f, n)
    "The negation is needed to ensure that `theta` is in descending order"
    return -xi
end
# The submodel for the `RW(1)` models
centered_rw1 = @slic begin 
    sigma ~ std_normal(;lower=0)
    "The type of the `xi` parameter depends on the passed link function `link_f`"
    xi ~ rw1(link_f, sigma, n)
    "The negation is needed to ensure that `theta` is in descending order"
    return -xi
end
# The submodel for the `RW(2)` models
centered_rw2 = @slic begin 
    sigma ~ normal(0, .5; lower=0)
    "The type of the `xi` parameter depends on the passed link function `link_f`"
    xi ~ rw2(link_f, sigma, n)
    "The negation is needed to ensure that `theta` is in descending order"
    return -xi
end
# The submodel for the `regression` models - reused in the `regression_mix` models
regression = @slic begin 
    "The upper bound of alpha `alpha_upper` depends on the link function `link_f`"
    alpha_upper = upper_alpha(link_f)
    # This could be `alpha ~ normal(0, .5; upper=upper_alpha(link_f))`, once https://github.com/nsiccha/StanBlocks.jl/issues/35 is fixed
    alpha ~ normal(0, .5; upper=alpha_upper)
    beta ~ normal(0, .5; upper=0.)
    return alpha + beta * t
end
# The submodel for the `regression_mix` models - reusing the `regression` submodel
regression_mix = @slic begin 
    c1 ~ regression(;link_f, t) 
    c2 ~ regression(;link_f, t)
    lambda ~ beta(2, 2)
    return lambda * link_f(c1) + (1-lambda) * link_f(c2)
end

bases = (;
    hetero=base_model(quote 
        theta ~ centered_hetero(;link_f, n)
    end),
    rw1=base_model(quote 
        theta ~ centered_rw1(;link_f, n)
    end),
    rw2=base_model(quote 
        theta ~ centered_rw2(;link_f, n)
    end),
    regression=base_model(quote 
        theta ~ regression(;link_f, t)
    end),
    regression_mix=base_model(quote 
        theta ~ regression_mix(;link_f, t)
        y ~ bernoulli(theta)
    end)
)
link_fs = (;logit, log)  

# `posteriors` will be a (nested) named tuple - accessing e.g. the `Hetero (log)` model works via `posteriors.hetero.log`
posteriors = map(bases) do base 
    map(link_fs) do link_f 
        base(;link_f)
    end
end

tabsets = map(posteriors) do subposteriors 
    map(subposteriors) do posterior 
        QuartoComponents.Code("stan", stan_code(posterior))
    end |> QuartoComponents.Tabset
end |> QuartoComponents.Tabset


```

## How does StanBlocks.jl infer type and dimension?

The way StanBlocks.jl currently operates to figure out the type and dimension of a model parameter defined by 
```julia
theta ~ distribution(foo, bar)
``` 
is via looking for a function definition 
```julia
distribution_lpdf(theta::theta_type[theta_dim1, ...], foo::foo_type[...], bar::bar_type[...]) = ...
```
matching the types of `foo` and `bar`. `theta_type` then has to be one of Stan's built-in types, and `theta_dim, ...` can be expressions depending on `foo`, `foo_type[...]`, `bar`, and `bar_type[...]`.

For example, [the built-in definition](https://github.com/nsiccha/StanBlocks.jl/blob/e1145f528e423f6b46c89488bf25a18a51d0718a/src/slic_stan/builtin.jl#L204) 
```julia
multi_normal_lpdf(obs::vector[n], loc::vector[n], cov)
``` 
tells StanBlocks.jl that any parameter `theta` initialized via 
```julia
theta ~ multi_normal(loc, cov)
```
will be a vector of the same shape as `loc`.

Once I add [this feature](https://github.com/nsiccha/StanBlocks.jl/issues/31), an alternative syntax to specify types and dimensions could be 

```julia
theta::vector[n] ~ distribution(foo, bar)
```
which would communicate intent more clearly, both to the programmer and the compiler.