data {
int K;
array[K] int n_obs;
array[K] int y;
}
parameters {
real<lower=0> lambda;
}
model {
25, 1);
lambda ~ gamma(for (i in 1:K)
target += n_obs[i] * poisson_lpmf(y[i] | lambda);
}
This will be a trivial point, but if feasible it can be more computationally efficient to work with an aggregated version of the log likelihood.
Let \(y_1, y_2, \ldots, y_N\) be a collection of i.i.d. observations with a log-likelihood function
\[ \ell(\theta | y) = \sum_{i=1}^N \log f(y_i | \theta) \]
for some arbitrary probability function \(f(\cdot)\). Obviously, we can group identical observations together where \(y_i = y_j\) for \(i \neq j\). Then,
\[ \ell(\theta | y) = \sum_{j=1}^K N_j \log f(y_j | \theta) \]
where \(N_j\) denotes the number of times \(y_j\) appears in the dataset for \(K\) unique values.
In the case where \(y_i \sim \text{Poisson}\left(\lambda \right)\), this can easily coded in Stan as follows:
Benchmarking
For a simulated dataset of 2,000 observations drawn from a poisson distribution with \(\lambda = 25\), there is a decrease in runtime when using the aggregated version of the model.
library(cmdstanr)
library(dplyr)
library(ggplot2)
library(ggthemes)
library(parallel)
options(mc.cores = parallel::detectCores() - 1)
<- cmdstan_model("./poisson.stan")
mod1 <- cmdstan_model("./poisson_aggregated.stan")
mod2
<- function(expr) {
bench <- Sys.time()
start
expr<- Sys.time()
end - start
end
}
<- function(mod, data) {
stan_sample $sample(data = data, chains = 1, show_messages = F, show_exceptions = F)
mod
}
###
# Simulate data
<- 2e3
N <- 25
lambda <- rpois(N, lambda)
y
# Model 1 data
<- list(N = N, y = y)
data
# Model 2 data - aggregated
<- data.frame(y = y) |> group_by(y) |> summarise(n = n())
df <- list(K = nrow(df), n_obs = df$n, y = df$y)
agg_data
###
# Run each model for `iter` iterations
<- 100
iter
# I believe `mclapply` uses a process pool, so discard the first
# mc.cores iterations which will have an associated startup cost.
<- getOption("mc.cores")
burnin
<- mclapply(1:iter, \(i) bench(stan_sample(mod1, data)))
l1 <- mclapply(1:iter, \(i) bench(stan_sample(mod2, agg_data)))
l2
<- data.frame(time = c(unlist(l1[(burnin+1):iter]),
est unlist(l2[(burnin+1):iter])),
model = c(rep("Unaggregated", iter - burnin),
rep("Aggregated", iter - burnin)))
ggplot(est, aes(x = model, y = time, color = model)) +
geom_jitter(width = 0.1, alpha = 0.5) +
xlab("Model") +
ylab("Runtime (sec)") +
theme_minimal(base_size = 12) +
theme(axis.title.x = element_text(face = "bold"),
axis.title.y = element_text(face = "bold"),
legend.position = "none") +
scale_colour_solarized()
Miscellaneous
Stan code for the unaggregated model.
data {
int N;
array[N] int y;
}
parameters {
real<lower=0> lambda;
}
model {
25, 1);
lambda ~ gamma(target += poisson_lpmf(y | lambda);
}
Computing environment:
sessionInfo()
R version 4.4.0 (2024-04-24)
Platform: x86_64-pc-linux-gnu
Running under: Arch Linux
Matrix products: default
BLAS/LAPACK: /usr/lib/libopenblas.so.0.3; LAPACK version 3.12.0
locale:
[1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C
[3] LC_TIME=en_US.UTF-8 LC_COLLATE=en_US.UTF-8
[5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8
[7] LC_PAPER=en_US.UTF-8 LC_NAME=C
[9] LC_ADDRESS=C LC_TELEPHONE=C
[11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C
time zone: Europe/Oslo
tzcode source: system (glibc)
attached base packages:
[1] parallel stats graphics grDevices utils datasets methods
[8] base
other attached packages:
[1] ggthemes_5.1.0 ggplot2_3.5.1 dplyr_1.1.4 cmdstanr_0.8.0
loaded via a namespace (and not attached):
[1] gtable_0.3.5 jsonlite_1.8.8 compiler_4.4.0
[4] tidyselect_1.2.1 stringr_1.5.1 scales_1.3.0
[7] yaml_2.3.8 fastmap_1.2.0 R6_2.5.1
[10] labeling_0.4.3 generics_0.1.3 distributional_0.4.0
[13] knitr_1.46 backports_1.4.1 checkmate_2.3.1
[16] tibble_3.2.1 munsell_0.5.1 pillar_1.9.0
[19] posterior_1.5.0 rlang_1.1.3 utf8_1.2.4
[22] stringi_1.8.4 xfun_0.44 cli_3.6.2
[25] withr_3.0.0 magrittr_2.0.3 ps_1.7.6
[28] digest_0.6.35 grid_4.4.0 processx_3.8.4
[31] lifecycle_1.0.4 vctrs_0.6.5 evaluate_0.23
[34] glue_1.7.0 tensorA_0.36.2.1 farver_2.1.2
[37] codetools_0.2-20 abind_1.4-5 colorspace_2.1-0
[40] fansi_1.0.6 purrr_1.0.2 rmarkdown_2.27
[43] tools_4.4.0 pkgconfig_2.0.3 htmltools_0.5.8.1
sprintf("cmdstan version: %s", cmdstan_version()) |> print()
[1] "cmdstan version: 2.34.1"