Em Algorithm: Part 2 - MNIST used finite mixture models to predict handwritten digits from the MNIST dataset. The downside with finite mixture models is that we have to pre-specify the number of clusters. Absent strong prior information, deciding the optimal number can be challenging.
An alternative approach would be to dynamically estimate the number of clusters as part of the statistical model. For this we will be looking at the infinite extension of mixture models, which will allow the model to create and remove cluster groups as part of the estimation procedure. However, unlike the previous finite mixture models we won’t be using the expectation maximization algorithm to fit our infinite mixture models, and instead turn to Gibbs sampling.
Brief Recap
For \(i = 1, 2, \ldots, N\) exchangeable observations and K clusters, let \(Y_i\) be conditionally distributed given a latent cluster label, \(z_i \in \{1, 2, \ldots, K \}\), according to some probability distribution \(F\) with parameters \(\theta_{z_i}\).
\[
Y_i | z_i = k \sim F(\theta_k)
\]
The marginalized distribution for \(Y_i\) is then a mixture of distributions with weights, \(\pi_1, \ldots, \pi_k\), where \(P(z_i = k) = \pi_k\) and \(\sum_{k=1}^K \pi_k = 1\).
In a Bayesian setup without additional prior information, we typically place a symmetric Dirichlet prior with concentration parameter \(\alpha > 0\) for the cluster probabilities.
We complete the generative model with priors on the cluster specific distribution parameters, \(\theta_k \sim G_0\).
Infinite Mixtures
Infinite mixture models arise from removing the upper bounds on the number of latent labels, or clusters. This might seem like it would yield a model that is computationally intractable, but we can derive an expression for the model that is amenable to Gibbs sampling.
Essentially, our estimation procedure will dynamically allocate observations to any of the infinite number of clusters, of which only a finite subset will actually be instantiated. For this we will need expressions for the probability of assignment to an already instantiated cluster and the probability of creating a new cluster.
Start by expressing the joint conditional probability of the latent labels using the cluster counts, \(n_k = \sum_{i=1}^N \mathbb{I}\{z_i = k\}\) for \(k = 1, 2, \ldots, K\).
We don’t have to solve this integral. Instead, recognize that \(\prod_{k=1}^K \pi_k^{n_k + \frac{\alpha}{K} - 1}\) partially matches the PDF of a Dirichlet distribution. Since a valid continuous probability distribution integrates to one, our integral is the inverse of the normalizing constant, \(\frac{1}{B(n_k + \frac{\alpha}{K})}\), where \(B(\cdot)\) is the Beta function. Then,
The numerator was found previously in Equation 2 while the denominator can be found similarly by removing the \(i\)th observation from the cluster counts, denoted \(n_{-i,k}\) for \(k = 1, 2, \ldots, K\).
Substituting into Equation 3 and taking advantage of the property \(\Gamma(x + 1) = x \Gamma(x)\) for the Gamma function, we are left with the following.
Note, from Equation 5 the probability of being allocated to one of the instantiated clusters is proportional to the cluster count. Meanwhile, creating a new cluster is governed by the concentration parameter \(\alpha > 0\). As \(\alpha \to 0\) then \(P(z_i \in \boldsymbol{z_{-i}} | \boldsymbol{z_{-i}}) = 1\) almost surely, and if \(\alpha \to \infty\), then each observation will receive its own cluster.
The Turing.jldocumentation on infinite mixture models has a cool animated plot showing the dynamic process of assigning and/or creating clusters for new observations that we can shameless copy.
Inference for infinite mixtures can be approached in several different ways.1 We’re going to use a collapsed Gibbs sampler with conjugate priors, which albeit slow and computationally intensive, is a simple and straightforward algorithm.
A Gibbs sampler creates a Markov chain that generates samples for a joint distribution by iteratively sampling from the conditional distributions.
In the context of an infinite mixture model, if \(G_0\) is a conjugate base distribution, we can integrate out the cluster parameters \(\theta_k\) to estimate the posterior \(p(\boldsymbol{z} | \boldsymbol{Y})\) by iteratively sampling from \(p(z_i | \boldsymbol{z_{-i}})\).
We can formalize a single iteration of the collapsed Gibbs sampler as follows:
For \(i = 1, 2, \ldots, N\) observations
Remove observation \(i\) from its currently assigned cluster.
Assign a cluster to observation \(i\) by drawing from \(z_i | \boldsymbol{z_{-i}}, \boldsymbol{Y}\) with probabilities:
where \(H_{-i}(\theta)\) is the posterior \(p(\theta | \boldsymbol{Y_{-i}}) \propto p(\boldsymbol{Y_{-i}} | \theta) G_0(\theta)\).
MNIST Revisited
Let’s return to MNIST. In Em Algorithm: Part 2 - MNIST, we fit separate Gaussian mixture models for each digit. For comparability we’ll follow the same procedure.
For \(i = 1, 2, \ldots, N\) images, let \(Y_i \in \mathbb{R}^{D^\star}\) be a column vector of reduced dimensionality from the original dataset through a principal component analysis.
usingMKL, MLDatasets, LinearAlgebra# We will eventually parallelize at the model levelBLAS.set_num_threads(1)ENV["DATADEPS_ALWAYS_ACCEPT"] =truepixels, labels =MNIST(split=:train)[:]model_input =reshape(pixels, (28*28, 60_000))usingMultivariateStats# Match the dimensionality of EM - Part 2pca =fit(PCA, model_input, maxoutdim=50)Y =transform(pca, model_input)
We partition the data according to the digit labels s.t. \(\boldsymbol{Y}^{(m)} = \{ Y_i : L_i = m \}\) for \(m = 1, 2, \ldots, M\) models where \(L_i\) is the original MNIST label for the \(i^{th}\) image. The mixture model for a specific digit is then specified as follows.
\[
Y_i^{(m)} | z_i^{(m)} = k \sim N(\mu_{mk}, \Sigma_{mk})
\]
This time we do not bound the total number of clusters and set a normal-inverse-wishart prior on the mean vector and the covariance matrix of the likelihood.
To pass around our priors we create a simple struct. Note, we will eventually be using some low-level BLAS/LAPACK functions, which do not accept mixing Float32 and Float64 values, so we will have to ensure consistent types throughout our code.
In the math equations that follow I will drop the \(m\) super/subscripts to reduce notational clutter and leave implicit that we are dealing with a single model for a single digit subset of the original data.
Conjugacy
In a multivariate normal model, \(y \sim N\left(\mu, \Sigma\right)\), with a conjugate normal-inverse-wishart prior, the posterior follows the same form as the prior with the following parameter updates:
where \(n\) is the number of observations and \(Q = \sum_i^n y_i y_i'\).
Moreover, the posterior predictive distribution for a new observation, \(\tilde{y}\), is a multivariate t-distribution.2
\[
\tilde{y} | y \sim \text{Multi-t}_{\nu_n - D + 1}\left( \mu_n, \Lambda_n \frac{\kappa_n + 1}{\kappa_n (\nu_n - D + 1)} \right)
\]
In our infinite mixture model, this is the closed form solution to the integral \(\int F(y_i, \theta) dH_{-i}(\theta_j)\) where \(\theta_j = (\mu_j, \Sigma_j)\), which is the posterior predictive distribution for cluster \(j\) for a new observation.
The prior predictive distribution in the infinite mixture model, \(\int F(y_i,
\theta) dG_0(\theta)\), also follows a multivariate t-distribution under the prior distribution, \(G_0\).
\[
y_i \sim \text{Multi-t}_{\nu_0 - D + 1}\left( \mu_0, \Lambda_0 \frac{\kappa_0 + 1}{\kappa_0 (\nu_0 - D + 1)} \right)
\]
In terms of implementation, there is no multivariate t-distribution in the documentation for Distributions.jl; however, it is actually implemented and exported in the source as MvTDist. We will code our own implementation though in hopes of trying to eke out a bit more speed in our sampler.
Thankfully, the multivariate t-distribution is fairly straightforward. If a random \(D\)-vector, \(y\), is distributed as \(y \sim \text{Multi-t}_{\nu}(\mu, \Sigma)\), then the probability density function is the following:
For computational efficiency we will work with the cholesky decomposition of the covariance matrix, \(\Sigma = L L'\) where \(L\) is a lower triangular matrix. This gives us the determinant, \(\left|\Sigma\right| = \left( \prod_{i=1}^D L_{ii} \right)^2\), and the inverse, \(\Sigma^{-1} = \left( L^{-1}\right) ' L^{-1}\).
Our constructor for the multivariate t-distribution then just calculates the cholesky decomposition of \(\Sigma\) and pre-computes the constants in the log probability density function.
struct Multi_t{T<:Real} constant::T # Constant term for the logpdf H::T # Type consistent 0.5 D::Int # Dimensionality of data ν::T # Degrees of freedom μ::Vector{T} # Mean vector L::Matrix{T} # Lower triangular from Cholesky decomposition of covariance matrix _rdx::Vector{T} # Pre-allocated buffer for logpdfendusingSpecialFunctions, IntelVectorMathfunctionMulti_t(ν::Real, μ::Vector{R}, Σ::Matrix{T}) where {R<:Real, T<:AbstractFloat} S =promote_type(typeof(ν), R, T) μ =convert(Vector{S}, μ) Σ =convert(Matrix{S}, Σ)# Cholesky decomposition on the lower triangular LAPACK.potrf!('L', Σ)# Log determinant of covariance matrix _rdx =diag(Σ) IVM.log!(_rdx) D =length(μ) H =S(0.5) a = H * ν b = H * D constant =loggamma(a + b) -loggamma(a) - b *log(π* ν) -sum(_rdx)Multi_t{S}(constant, H, D, S(ν), μ, Σ, _rdx)endMulti_t(ν::Real, μ::Vector{R}, Σ::Matrix{T}) where {R<:Real, T<:Integer} =Multi_t(ν, μ, float(Σ))
The remaining log transformed non-constant term from the multi-t probability density function can be simplified as follows:
functionmulti_t_logpdf(td::Multi_t, y::AbstractVector{<:Real}) @. td._rdx = y - td.μ;# Solve L * b = (y - μ) for b BLAS.trsv!('L', 'N', 'N', td.L, td._rdx); z =sum(abs2, td._rdx) td.constant - td.H * (td.ν + td.D) *log1p(z / td.ν)endmulti_t_logpdf(td::Multi_t, Y::AbstractMatrix) = [multi_t_logpdf(td, y) for y ineachcol(Y)]
Clusters
Before we can fully implement our Gibbs sampler, we need some bookkeeping code. We need to track the sufficient statistics for each instantiated cluster that allow us to re-create the posterior parameters.
Specifically, for an instantiated cluster we save the number of assigned observations, \(n_k\); the sum of the observations, \(S_k\); and, the outer products of the assigned observations, \(Q_k\).
mutable struct Cluster{T<:Real} N::Int # Num. of observations for a cluster D::Int # Dimensionality of observations S::Vector{T} # Sum of observations Q::Matrix{T} # Sum of outer products of observations# Scratch buffers for posterior calculations _rdx::Vector{T} _rdy::Vector{T} _rdm::Matrix{T} priors::Priors{T} # Priors --- note, BLAS cannot mix Float64 & Float32 td::Multi_t # Predictive distributionfunctionCluster(Y::AbstractVecOrMat{T}, priors::Priors{T}; predictive=true) where {T<:Real}# Sufficient statistics D =size(Y, 1) N =size(Y, 2) S =vec(sum(Y, dims=2)) Q = Y * Y' _rdx =Vector{T}(undef, D) _rdy =Vector{T}(undef, D) _rdm =Matrix{T}(undef, D, D) obj =new{T}(N, D, S, Q, _rdx, _rdy, _rdm, priors)if predictive ===true obj.td =posterior_predictive(obj)endreturn objendend
The posterior parameters are updated using the sufficient statistics according to Equation 7.
functionposterior(cluster::Cluster) κ = cluster.priors.κ_0 + cluster.N# μ = (S + κ_0 * μ_0) / κ --- posterior predictive mean @. cluster._rdx = (cluster.S + cluster.priors.κ_0 * cluster.priors.μ_0) / κ;# (Ȳ - μ_0) --- required for posterior predictive covariance matrix @. cluster._rdy = cluster.S / cluster.N - cluster.priors.μ_0# Scale matrix# Λ = Λ_0 + Q - (1 / N * S - μ_0) * (1 / N * S - μ_0)' +# κ_0 * N / κ * ( S / N - μ_0) * (S / N - μ_0)'fused_update!(cluster._rdm, cluster.priors.Λ_0, cluster.Q, cluster.S, cluster._rdy, 1/ cluster.N, cluster.priors.κ_0 * cluster.N / κ)return (cluster.priors.ν_0 + cluster.N, κ, cluster._rdx, cluster._rdm)endfunctionfused_update!(rdm, Λ_0, Q, S, rdy, α, β) D =size(rdm, 1)for j in1:D@simdfor i in j:D@inbounds rdm[i,j] = Λ_0[i,j] + Q[i,j] - α * S[i] * S[j] + β * rdy[i] * rdy[j]endendend
From this we can write a function to form the posterior predictive distribution for a new observation.
When we add or remove an observation from a cluster we update our sufficient statistics.3 For example, when adding an observation, \(Y_i\), to cluster \(k\) then
functionadd!(cluster::Cluster, y::AbstractVector{T}) where {T<:Real} cluster.N +=1axpy!(1, y, cluster.S)# Q += y y' BLAS.syr!('L', T(1.0), y, cluster.Q) cluster.td =posterior_predictive(cluster)end
Removing an observation is similarly as straightforward.
functionremove!(cluster::Cluster, y::AbstractVector{T}) where {T<:Real} cluster.N -=1axpy!(-1, y, cluster.S)# Q -= y y' BLAS.syr!('L', T(-1.0), y, cluster.Q) cluster.td =posterior_predictive(cluster)end
Gibbs Sampler
We can now implement our sampler. Again, for bookkeeping purposes we will track state using a struct and instantiate the sampler using a constructor.
mutable struct Gibbs{S<:Real, T<:Real} α::S Y::AbstractMatrix{T} D::Int N::Int priors::Priors{T} H::Multi_t __active::Dict{Int32, Cluster{T}} __lp_new_cluster::Vector{T} iterations::Int32 Z::Matrix{Int32} map_clusters::VectorfunctionGibbs(Y::AbstractMatrix{S}, priors::Priors{T}; α=1, init_clusters=10,) where {S<:Real, T<:Real} D, N =size(Y) R =promote_type(S, T) model_data =convert(Matrix{R}, Y) model_priors =convert(Priors{R}, priors)# Randomly initialize cluster assignments inits =rand(1:init_clusters, N) clusters =Dict(k =>Cluster(model_data[:, inits .== k], model_priors)for k inunique(inits)) Z =reshape(inits, length(inits), 1)# Pre-calculate unnormalized log-probability for creating a new cluster Σ_0 = model_priors.Λ_0 * ((model_priors.κ_0 +1) / (model_priors.κ_0 * (model_priors.ν_0 - D +1))) td =Multi_t(model_priors.ν_0 - D +1, model_priors.μ_0, Matrix(Σ_0)) lp_new_cluster =log(α) .+multi_t_logpdf(td, model_data)new{eltype(α), T}(α, model_data, D, N, model_priors, td, clusters, lp_new_cluster, 1, Z)endend
The actual sampling logic is encoded in a functor that iterates for a fixed number of iterations sequentially sampling from \(z_i | \boldsymbol{z_{-i}}, \boldsymbol{Y} \; \forall i \in \{1, 2, \ldots, N \}\) for each pass. The result is the posterior cluster labels for our model stored as a matrix in the \(Z\) struct field.
usingLogExpFunctions: softmax!function(g::Gibbs)(niter; refresh=div(niter, 10)) M = g.iterations ==1 ? niter -1: niterifsize(g.Z, 2) < M + g.iterations z_samples =zeros(Int32, g.N, M) g.Z =hcat(g.Z, z_samples)end# Pre-allocate vector to hold proportional log-probabilities for each cluster clp =Vector{eltype(g.__lp_new_cluster)}(undef, 2* (length(g.__active) +1)) vY =eachcol(g.Y) max_id =keys(g.__active) |> maximum# Iteration bounds lower, upper = (g.iterations +1), (M + g.iterations)@info"Sampling from model..."@inboundsfor m in lower:upper m % refresh ==0&&@info"Iteration: $m / $upper"for i in1:g.Nif g.__active[g.Z[i, m -1]].N >1remove!(g.__active[g.Z[i, m -1]], vY[i])elsedelete!(g.__active, g.Z[i, m -1])end cluster_keys =keys(g.__active) |> collect num_clusters =length(cluster_keys)length(clp) < num_clusters +1&&resize!(clp, 2*length(clp))for idx ineachindex(cluster_keys) cluster = g.__active[cluster_keys[idx]] clp[idx] =log(cluster.N) +multi_t_logpdf(cluster.td, vY[i])end clp[num_clusters +1] = g.__lp_new_cluster[i]# Transform and normalize to probability simplexsoftmax!(view(clp, 1:num_clusters +1))# Z_i ~ Categorical(clp_{1:K+1})ifrand() > clp[num_clusters +1] g.Z[i, m] =wsample(cluster_keys, view(clp, 1:num_clusters))add!(g.__active[g.Z[i, m]], vY[i])else g.Z[i, m] = (max_id +=1) g.__active[g.Z[i, m]] =Cluster(vY[i], g.priors)endendend g.iterations += Mend
Posterior Prediction
Our ultimate goal remains to predict the correct digit given a PCA transformed vector of gray-scale pixel values. We assign the cluster label for a new observation, \(\tilde{y}\), by finding the model with the “best fit.”
Typically this would involve the posterior predictive distribution.
However, evaluating this over even a small subset of the posterior draws is fairly computationally expensive. Since we are only interested in a point estimate prediction under zero-one loss we’ll form the log-likelihood using the maximum a posteriori (MAP) estimate for the latent labels \(z\).
usingStatsFuns: logmvgammafunctionintegrated_log_lik(Y, priors) cluster =Cluster(Y, priors; predictive=false) ν, κ, _, Λ =posterior(cluster) Λ .=Symmetric(Λ, :L) a =0.5* cluster.D b =0.5* priors.ν_0 c =0.5* ν a *log(priors.κ_0 / κ) + b *logdet(priors.Λ_0) - c *logdet(Λ) +logmvgamma(cluster.D, c) -logmvgamma(cluster.D, b) - (cluster.N * a) *log(π)end
Meanwhile, the first term in Equation 8 is the prior for the latent cluster labels, \(z\), which can be found by expanding the joint probability, \(p(z_1, \ldots, z_N)\), and using Equation 5 and Equation 6,
where \(K^\star\) is the instantiated number of clusters after observing \(N\) observations.
Combining the two terms, we implement the joint log likelihood, which is proportional to \(p(z | \boldsymbol{Y})\).
functionjoint_log_lik(draw, Y, α, priors) N =size(Y, 2) z =unique(draw) K =length(z)# p(y | z) lpx =sum(integrated_log_lik(Y[:, draw .== k], priors) for k in z)# p(z) lpz = K *log(α) +sum(logfactorial(sum(draw .== k) -1) for k in z) -sum(log(α + j) for j in0:(N -1)) lpx + lpzend
The MAP estimate for \(z\) is then the posterior draw which maximizes this joint log-likelihood.
draws(g::Gibbs; burnin=0, thin=1) = g.Z[:, (burnin+1):thin:end]functionoptim!(g::Gibbs; kwargs...)@info"Finding most likely cluster assignments..." Z =draws(g; kwargs...) column =argmax(joint_log_lik(draw, g.Y, g.α, g.priors) for draw ineachcol(Z)) map_labels = Z[:, column] g.map_clusters = [Cluster(g.Y[:, map_labels .== k], g.priors) for k inunique(map_labels)]end
Generating Predictions
To actually generate a prediction, we evaluate the posterior predictive density for a single model using the MAP estimate for \(z\).
functionpredict(Y, models; kwargs...) [argmax(new_log_lik(sampler, y; kwargs...) for sampler in models) -1 for y ineachcol(Y)]end
Running the Models
That’s enough math. Let’s actually run the models.
Start by instantiating a Gibbs sampler for each subset of the training dataset corresponding to the different MNIST digits with fairly uninformative priors.
models =map(0:9) do k data = Y[:, labels .== k] D =size(data, 1) μ_0 =zeros(D) κ_0 =1 ν_0 = D +1 Λ_0 =I(D) priors =Priors{Float64}(κ_0, ν_0, μ_0, Matrix(Λ_0))Gibbs(data, priors, α =1, init_clusters =4)end
We run each sampler in a parallelized fashion for a fixed number of iterations. Warning, this will take a substantial amount of time.
usingBase.Threadsniter =3_000burnin =1_500thin =3Threads.@threads:greedy for sampler in modelssampler(niter)end
Before we generate the MAP estimates for \(z\) for each model, let’s check the behaviour of our samplers. Figure 1 shows the traceplots for the cluster counts after discarding the first half of the posterior draws.
plts =map(models) do sampler n = [length(unique(z)) for z ineachcol(draws(sampler; burnin, thin))]plot(1:length(n), n, legend =:none)endplot(plts..., layout = (5, 2))
Figure 1: MCMC traceplots of the cluster count per model.
We can also plot the acceptance ratios, i.e. the proportion of observations assigned the same cluster label between sampling iterations, as seen in Figure 2.
functionacceptance(Z) N, M =size(Z) [sum(Z[:, i] .== Z[:, i+1]) / N for i in1:(M -1)]endratios = [draws(sampler; burnin, thin) |> acceptance for sampler in models]plts = [plot(1:length(p), p, legend =:none) for p in ratios]plot(plts..., layout = (5, 2))
Figure 2: MCMC traceplots for acceptance ratios per model.
Finally, let’s check our prediction accuracy for the training dataset.
We can also explore how the models learned different variants of the same digit. Looking at the first model that corresponds to the ‘0’ digit, Figure 3 plots the expectation for each cluster.
imgs = [reconstruct(pca, cluster.td.μ) for cluster in models[1].map_clusters]plts = [plot(Gray.(reshape(1.- i, 28, 28)')) for i in imgs]plot(plts..., axis =false, ticks =false)
Figure 3: Predicted images for each variant of the digit ‘0’.
Turns out that ‘0’ is pretty easy to model. This is reflected in the per digit accuracy scores.
for i in0:9 aux =check(Z_test[test_labels .== i], fill(i, sum(test_labels .== i)))println("Digit '$(i)' accuracy: $aux")end
Compared to the ensemble of finite Gaussian mixture models we ended up improving our accuracy on the test dataset by approximately one percentage point at the cost of a significantly increased runtime — on my local machine the infinite mixture models, although memory efficient, take roughly 5x-6x longer to fit. The slow speed, even when generating posterior predictions, undermines the practicality of these models, at least in the form used here.
That said, there is always more that can be done. Similar to the finite mixture models, we did not explore different pre-processing strategies or discuss hyperparameter tuning — in the case of the latter, a common extension is to place a Gamma prior on the concentration parameter \(\alpha\). Furthermore, a production grade solution would actually have to address the question of convergence for the Gibbs samplers.4 And, we would need to design a more robust implementation — checking the dimensions of inputs, asserting that the prior scale matrix is positive-semidefinite, etcetc.5
Maybe, some day, I’ll code up one of the more practical inference algorithms for infinite mixture models.
Addendum
An alternative, and perhaps more canonical, exposition of infinite mixture models from the perspective of Bayesian non-parametrics goes via Dirichlet processes.
A Dirichlet process is a distribution over all possible probability measures defined on some measure space. For our purposes, we denote a random probability measure distributed according to a Dirichlet process as \(G \sim \text{DP}(\alpha G_0)\) defined on measurable space \((\Omega, \mathcal{A})\) with parameters \(\alpha > 0\) and \(G_0\), a baseline probability measure also defined on \((\Omega, \mathcal{A})\). By definition, a realization \(G\) is discrete with probability one and satisfies the property
for any finite partition \(A_1, \ldots, A_k\) of \(\Omega\) such that \(\mathbb{E}[G(A)] = G_0(A)\) and \(\text{Var}\left(G(A)\right) = \frac{G_0(A)(1 - G_0(A))}{1 + \alpha}\) for all \(A \in \mathcal{A}\).6
In the context of a mixture model, we let the mixture parameters be distributed according to a Dirichlet process measure. For an exchangeable sequence, \(Y_1, Y_2, \ldots, Y_N\), this leads to the following generative model.
\[
\begin{align*}
Y_i | \theta_i & \sim F(\theta_i) \\
\theta_i | G & \sim G \\
G & \sim \text{DP}\left(G_0, \alpha\right)
\end{align*}
\]
Marginalizing over the measure \(G\) results in the following conditional distribution,
Define \(z_i = k \Leftrightarrow \theta_i = \theta_k\), and we have exactly our conditional probabilities given by Equation 5 and Equation 6, commonly referred to as the Chinese Restaurant Process.
References
Blackwell, David, and James B. MacQueen. 1973. “Ferguson Distributions via Polya Urn Schemes.”The Annals of Statistics 1 (2). https://doi.org/10.1214/aos/1176342372.
Blei, David M., and Michael I. Jordan. 2006. “Variational Inference for Dirichlet Process Mixtures.”Bayesian Analysis 1 (1). https://doi.org/10.1214/06-BA104.
Murphy, Kevin P. 2013. Machine Learning: A Probabilistic Perspective. 4. print. (fixed many typos). Adaptive Computation and Machine Learning Series. Cambridge, Mass.: MIT Press.
Neal, Radford M. 2000. “Markov Chain Sampling Methods for Dirichlet Process Mixture Models.”Journal of Computational and Graphical Statistics 9 (2): 249–65.
Raftery, Adrian E., and Steven M. Lewis. 1992. “[Practical Markov Chain Monte Carlo]: Comment: One Long Run with Diagnostics: Implementation Strategies for Markov Chain Monte Carlo.”Statistical Science 7 (4). https://doi.org/10.1214/ss/1177011143.
Raykov, Yordan P., Alexis Boukouvalas, and Max A. Little. 2016. “Simple Approximate MAP Inference for Dirichlet Processes Mixtures.”Electronic Journal of Statistics 10 (2). https://doi.org/10.1214/16-EJS1196.
Neal (2000) is a canonical reference for Gibbs sampling in the context of conjugate and non-conjugate priors for infinite mixture models. Meanwhile, Blei and Jordan (2006) provides a variational inference approach that is used by scikit-learn and Raykov, Boukouvalas, and Little (2016) offer an approximate MAP estimator.↩︎
For the derivation, refer to your favorite source on conjugate priors. Or, check out section 3.6 “Multivariate normal with unknown mean and variance” in BDA3.↩︎
The code for adding/removing observations from a cluster should set off some alarm bells as it is not numerically stable and will lead to accumulated floating point error. A more robust approach would re-calculate the sufficient statistics from scratch, rather than partially updating. I’ll partially mitigate the issue by using Float64 variables, and leave the rest as an exercise to the reader.↩︎
Previous work discussing infinite mixture models typically use the Raftery and Lewis diagnostic to assess convergence (Raftery and Lewis 1992).↩︎
We should ideally be checking the return values from our BLAS/LAPACK calls. For example, accumulated floating point error when adding/removing observations can lead to the cholesky decomposition failing when creating the posterior predictive distribution for a cluster.↩︎
There are several proofs showing that Dirichlet process distributed random probability measures are discrete almost surely. My favorite is Blackwell and MacQueen (1973) using the Polya urn scheme.↩︎