MNIST Revisited - Infinite Mixtures

2025-02-26

Em Algorithm: Part 2 - MNIST used finite mixture models to predict handwritten digits from the MNIST dataset. The downside with finite mixture models is that we have to pre-specify the number of clusters. Absent strong prior information, deciding the optimal number can be challenging.

An alternative approach would be to dynamically estimate the number of clusters as part of the statistical model. For this we will be looking at the infinite extension of mixture models, which will allow the model to create and remove cluster groups as part of the estimation procedure. However, unlike the previous finite mixture models we won’t be using the expectation maximization algorithm to fit our infinite mixture models, and instead turn to Gibbs sampling.

Brief Recap

For \(i = 1, 2, \ldots, N\) exchangeable observations and K clusters, let \(Y_i\) be conditionally distributed given a latent cluster label, \(z_i \in \{1, 2, \ldots, K \}\), according to some probability distribution \(F\) with parameters \(\theta_{z_i}\).

\[ Y_i | z_i = k \sim F(\theta_k) \]

The marginalized distribution for \(Y_i\) is then a mixture of distributions with weights, \(\pi_1, \ldots, \pi_k\), where \(P(z_i = k) = \pi_k\) and \(\sum_{k=1}^K \pi_k = 1\).

\[ p(Y_i | \boldsymbol{\pi}, \boldsymbol{\theta}) = \sum_{k=1}^K \pi_k F(\theta_k) \]

In a Bayesian setup without additional prior information, we typically place a symmetric Dirichlet prior with concentration parameter \(\alpha > 0\) for the cluster probabilities.

\[ \pi_1, \ldots, \pi_K \sim \text{Dirichlet}\left(\frac{\alpha}{K}, \ldots, \frac{\alpha}{K}\right) \]

We complete the generative model with priors on the cluster specific distribution parameters, \(\theta_k \sim G_0\).

Infinite Mixtures

Infinite mixture models arise from removing the upper bounds on the number of latent labels, or clusters. This might seem like it would yield a model that is computationally intractable, but we can derive an expression for the model that is amenable to Gibbs sampling.

Essentially, our estimation procedure will dynamically allocate observations to any of the infinite number of clusters, of which only a finite subset will actually be instantiated. For this we will need expressions for the probability of assignment to an already instantiated cluster and the probability of creating a new cluster.

Start by expressing the joint conditional probability of the latent labels using the cluster counts, \(n_k = \sum_{i=1}^N \mathbb{I}\{z_i = k\}\) for \(k = 1, 2, \ldots, K\).

\[ p(z_1, \ldots, z_N | \pi_1, \ldots, \pi_K) = \prod_{k=1}^K \pi_k^{n_k} \tag{1}\]

Marginalize over the mixture probabilities in Equation 1.

\[ \begin{align*} p(z_1, \ldots, z_N) & = \int_{\boldsymbol{\pi}} p(z_1, \ldots, z_N | \boldsymbol{\pi}) p(\pi_1, \ldots, \pi_K) d\boldsymbol{\pi} \\ & = \int_{\boldsymbol{\pi}} \prod_{k=1}^K \pi_k^{n_k} \cfrac{\Gamma(\alpha)}{\Gamma(\frac{\alpha}{K})^K} \prod_{k=1}^K \pi_k^{\frac{\alpha}{K} - 1} d\boldsymbol{\pi} \\ & = \cfrac{\Gamma(\alpha)}{\Gamma(\frac{\alpha}{K})^K} \int_{\boldsymbol{\pi}} \prod_{k=1}^K \pi_k^{n_k + \frac{\alpha}{K} - 1} d\boldsymbol{\pi} \end{align*} \]

We don’t have to solve this integral. Instead, recognize that \(\prod_{k=1}^K \pi_k^{n_k + \frac{\alpha}{K} - 1}\) partially matches the PDF of a Dirichlet distribution. Since a valid continuous probability distribution integrates to one, our integral is the inverse of the normalizing constant, \(\frac{1}{B(n_k + \frac{\alpha}{K})}\), where \(B(\cdot)\) is the Beta function. Then,

\[ p(z_1, \ldots, z_N) = \cfrac{\Gamma(\alpha)}{\Gamma(\alpha + N)} \prod_{k=1}^K \cfrac{\Gamma(n_k + \frac{\alpha}{K})}{\Gamma(\frac{\alpha}{K})} \tag{2}\]

Next, we find the marginal probability that observation \(i\) is assigned cluster \(j\) given the rest of the latent labels, \(\boldsymbol{z_{-i}}\).

\[ p(z_i = j | \boldsymbol{z_{-i}}) = \cfrac{p(z_1, \ldots, z_N)}{p(\boldsymbol{z_{-i}})} \tag{3}\]

The numerator was found previously in Equation 2 while the denominator can be found similarly by removing the \(i\)th observation from the cluster counts, denoted \(n_{-i,k}\) for \(k = 1, 2, \ldots, K\).

\[ p(\boldsymbol{z_{-i}}) = \cfrac{\Gamma(\alpha)}{\Gamma(\alpha + N - 1)} \prod_{k=1}^K \cfrac{\Gamma(n_{-i,k} + \frac{\alpha}{K})}{\Gamma(\frac{\alpha}{K})} \]

Substituting into Equation 3 and taking advantage of the property \(\Gamma(x + 1) = x \Gamma(x)\) for the Gamma function, we are left with the following.

\[ \begin{align*} p(z_i = j | \boldsymbol{z_{-i}}) & = \cfrac{\Gamma(\alpha + N - 1)}{\Gamma(\alpha + N)} \cfrac{\Gamma(n_{-i,j} + 1 + \frac{\alpha}{K})}{\Gamma(n_{-i,j} + \frac{\alpha}{K})} \\ & = \cfrac{n_{-i,j} + \frac{\alpha}{K}}{\alpha + N - 1} \end{align*} \tag{4}\]

Letting \(K \to \infty\) in Equation 4 we find the probability of an observation being assigned to cluster \(j\) when \(n_{-i,j} > 0\).

\[ p(z_i = j | \boldsymbol{z_{-i}}) = \frac{n_{-i,j}}{\alpha + N - 1} \tag{5}\]

We can also find the probability that \(z_i\) is allocated a new cluster, i.e. one of the infinite clusters where \(n_{-i,j} = 0\).

\[ \begin{align*} p(z_i \notin \boldsymbol{z_{-i}} | \boldsymbol{z_{-i}}) & = 1 - p(z_i \in \boldsymbol{z_{-i}} | \boldsymbol{z_{-i}}) \\ & = 1 - \sum_j \frac{n_{-i,j}}{\alpha + N - 1} \\ & = 1 - \frac{N - 1}{\alpha + N - 1} \\ & = \frac{\alpha}{\alpha + N - 1} \end{align*} \tag{6}\]

Note, from Equation 5 the probability of being allocated to one of the instantiated clusters is proportional to the cluster count. Meanwhile, creating a new cluster is governed by the concentration parameter \(\alpha > 0\). As \(\alpha \to 0\) then \(P(z_i \in \boldsymbol{z_{-i}} | \boldsymbol{z_{-i}}) = 1\) almost surely, and if \(\alpha \to \infty\), then each observation will receive its own cluster.

The Turing.jl documentation on infinite mixture models has a cool animated plot showing the dynamic process of assigning and/or creating clusters for new observations that we can shameless copy.

Code
using Logging, StatsBase, Plots

disable_logging(Logging.Info)

α = 10
N = 250

z = zeros(Int, N)
lens = Vector{Int}()

anim = @animate for i in 1:N
    i % 25 == 0 && @info "Iteration: $i"
    if rand() < α /+ i - 1)
        z[i] = length(lens) + 1
        push!(lens, 1)
    else
        z[i] = wsample(eachindex(lens), lens)
        lens[z[i]] += 1
    end

    scatter(1:i, z[1:i], markersize=3, markeralpha=0.5, legend=false,
            xlabel = "Observation", ylabel = "Cluster")
end

webm(anim)

Collapsed Gibbs Sampler

Inference for infinite mixtures can be approached in several different ways.1 We’re going to use a collapsed Gibbs sampler with conjugate priors, which albeit slow and computationally intensive, is a simple and straightforward algorithm.

A Gibbs sampler creates a Markov chain that generates samples for a joint distribution by iteratively sampling from the conditional distributions.

In the context of an infinite mixture model, if \(G_0\) is a conjugate base distribution, we can integrate out the cluster parameters \(\theta_k\) to estimate the posterior \(p(\boldsymbol{z} | \boldsymbol{Y})\) by iteratively sampling from \(p(z_i | \boldsymbol{z_{-i}})\).

We can formalize a single iteration of the collapsed Gibbs sampler as follows:

  • For \(i = 1, 2, \ldots, N\) observations

    1. Remove observation \(i\) from its currently assigned cluster.

    2. Assign a cluster to observation \(i\) by drawing from \(z_i | \boldsymbol{z_{-i}}, \boldsymbol{Y}\) with probabilities:

      \[ p(z_i = j | \boldsymbol{z_{-i}}, \boldsymbol{Y}) \propto \begin{cases} n_{-i,j} \int F(Y_i, \theta) \; dH_{-i}(\theta) & j \in \boldsymbol{z_{-i}} \\ \alpha \int F(Y_i, \theta) \; dG_0(\theta) & j \notin \boldsymbol{z_{-i}} \; \text{(new cluster)} \end{cases} \]

      where \(H_{-i}(\theta)\) is the posterior \(p(\theta | \boldsymbol{Y_{-i}}) \propto p(\boldsymbol{Y_{-i}} | \theta) G_0(\theta)\).

MNIST Revisited

Let’s return to MNIST. In Em Algorithm: Part 2 - MNIST, we fit separate Gaussian mixture models for each digit. For comparability we’ll follow the same procedure.

For \(i = 1, 2, \ldots, N\) images, let \(Y_i \in \mathbb{R}^{D^\star}\) be a column vector of reduced dimensionality from the original dataset through a principal component analysis.

using MKL, MLDatasets, LinearAlgebra

# We will eventually parallelize at the model level
BLAS.set_num_threads(1)

ENV["DATADEPS_ALWAYS_ACCEPT"] = true

pixels, labels = MNIST(split=:train)[:]
model_input = reshape(pixels, (28*28, 60_000))

using MultivariateStats

# Match the dimensionality of EM - Part 2
pca = fit(PCA, model_input, maxoutdim=50)
Y = transform(pca, model_input)

We partition the data according to the digit labels s.t. \(\boldsymbol{Y}^{(m)} = \{ Y_i : L_i = m \}\) for \(m = 1, 2, \ldots, M\) models where \(L_i\) is the original MNIST label for the \(i^{th}\) image. The mixture model for a specific digit is then specified as follows.

\[ Y_i^{(m)} | z_i^{(m)} = k \sim N(\mu_{mk}, \Sigma_{mk}) \]

This time we do not bound the total number of clusters and set a normal-inverse-wishart prior on the mean vector and the covariance matrix of the likelihood.

\[ \begin{align*} \mu_{mk} | \Sigma_{mk} \sim N(\mu_0, \frac{1}{\kappa_0} \Sigma_{mk}) \\ \Sigma_{mk} \sim \mathcal{W}^{-1}_{\nu_0}(\Lambda_0^{-1}) \end{align*} \]

To pass around our priors we create a simple struct. Note, we will eventually be using some low-level BLAS/LAPACK functions, which do not accept mixing Float32 and Float64 values, so we will have to ensure consistent types throughout our code.

import Base: convert

struct Priors{T<:Real}
    κ_0::T
    ν_0::T
    μ_0::Vector{T}
    Λ_0::Matrix{T}
end

convert(::Type{Priors{T}}, p::Priors) where {T<:Real} = Priors{T}(p.κ_0, p.ν_0, p.μ_0, p.Λ_0)

In the math equations that follow I will drop the \(m\) super/subscripts to reduce notational clutter and leave implicit that we are dealing with a single model for a single digit subset of the original data.

Conjugacy

In a multivariate normal model, \(y \sim N\left(\mu, \Sigma\right)\), with a conjugate normal-inverse-wishart prior, the posterior follows the same form as the prior with the following parameter updates:

\[ \begin{align*} \mu_n & = \frac{\kappa_0}{\kappa_0 + n} \mu_0 + \frac{n}{\kappa_0 + n} \bar{y} \\ \kappa_n & = \kappa_0 + n \\ \nu_n & = \nu_0 + n \\ \Lambda_n & = \Lambda_0 + Q - n \bar{y} \bar{y}' + \frac{\kappa_0 n}{\kappa_0 + n} (\bar{y} - \mu_0 ) (\bar{y} - \mu_0)' \end{align*} \tag{7}\]

where \(n\) is the number of observations and \(Q = \sum_i^n y_i y_i'\).

Moreover, the posterior predictive distribution for a new observation, \(\tilde{y}\), is a multivariate t-distribution.2

\[ \tilde{y} | y \sim \text{Multi-t}_{\nu_n - D + 1}\left( \mu_n, \Lambda_n \frac{\kappa_n + 1}{\kappa_n (\nu_n - D + 1)} \right) \]

In our infinite mixture model, this is the closed form solution to the integral \(\int F(y_i, \theta) dH_{-i}(\theta_j)\) where \(\theta_j = (\mu_j, \Sigma_j)\), which is the posterior predictive distribution for cluster \(j\) for a new observation.

The prior predictive distribution in the infinite mixture model, \(\int F(y_i, \theta) dG_0(\theta)\), also follows a multivariate t-distribution under the prior distribution, \(G_0\).

\[ y_i \sim \text{Multi-t}_{\nu_0 - D + 1}\left( \mu_0, \Lambda_0 \frac{\kappa_0 + 1}{\kappa_0 (\nu_0 - D + 1)} \right) \]

In terms of implementation, there is no multivariate t-distribution in the documentation for Distributions.jl; however, it is actually implemented and exported in the source as MvTDist. We will code our own implementation though in hopes of trying to eke out a bit more speed in our sampler.

Thankfully, the multivariate t-distribution is fairly straightforward. If a random \(D\)-vector, \(y\), is distributed as \(y \sim \text{Multi-t}_{\nu}(\mu, \Sigma)\), then the probability density function is the following:

\[ p(y) = \cfrac{\Gamma\left(\frac{\nu + D}{2}\right)}{\Gamma\left(\frac{\nu}{2}\right) \left(\nu \pi\right)^{D / 2} |\Sigma|^{\frac{1}{2}}} \left[1 + \frac{1}{\nu}(y - \mu)' \Sigma^{-1} (y - \mu)\right]^{\frac{-(\nu + D)}{2}} \]

For computational efficiency we will work with the cholesky decomposition of the covariance matrix, \(\Sigma = L L'\) where \(L\) is a lower triangular matrix. This gives us the determinant, \(\left|\Sigma\right| = \left( \prod_{i=1}^D L_{ii} \right)^2\), and the inverse, \(\Sigma^{-1} = \left( L^{-1}\right) ' L^{-1}\).

Our constructor for the multivariate t-distribution then just calculates the cholesky decomposition of \(\Sigma\) and pre-computes the constants in the log probability density function.

struct Multi_t{T<:Real}
    constant::T     # Constant term for the logpdf
    H::T            # Type consistent 0.5
    D::Int          # Dimensionality of data
    ν::T            # Degrees of freedom
    μ::Vector{T}    # Mean vector
    L::Matrix{T}    # Lower triangular from Cholesky decomposition of covariance matrix

    _rdx::Vector{T} # Pre-allocated buffer for logpdf
end

using SpecialFunctions, IntelVectorMath
function Multi_t::Real, μ::Vector{R}, Σ::Matrix{T}) where {R<:Real, T<:AbstractFloat}
    S = promote_type(typeof(ν), R, T)
    μ = convert(Vector{S}, μ)
    Σ = convert(Matrix{S}, Σ)

    # Cholesky decomposition on the lower triangular
    LAPACK.potrf!('L', Σ)

    # Log determinant of covariance matrix
    _rdx = diag(Σ)
    IVM.log!(_rdx)

    D = length(μ)
    H = S(0.5)

    a = H * ν
    b = H * D
    constant = loggamma(a + b) - loggamma(a) - b * log(π * ν) -  sum(_rdx)

    Multi_t{S}(constant, H, D, S(ν), μ, Σ, _rdx)
end

Multi_t::Real, μ::Vector{R}, Σ::Matrix{T}) where {R<:Real, T<:Integer} = Multi_t(ν, μ, float(Σ))

The remaining log transformed non-constant term from the multi-t probability density function can be simplified as follows:

\[ \frac{-(\nu + D)}{2} \log \left[ 1 + \frac{1}{\nu}(y - \mu)' \Sigma^{-1} (y - \mu) \right] = \frac{-(\nu + D)}{2} \log \left[ 1 + \frac{1}{\nu} b' b \right] \]

where \(b = L^{-1}(y - \mu)\).

function multi_t_logpdf(td::Multi_t, y::AbstractVector{<:Real})
    @. td._rdx = y - td.μ;

    # Solve L * b = (y - μ) for b
    BLAS.trsv!('L', 'N', 'N', td.L, td._rdx);
    z = sum(abs2, td._rdx)

    td.constant - td.H * (td.ν + td.D) * log1p(z / td.ν)
end

multi_t_logpdf(td::Multi_t, Y::AbstractMatrix) = [multi_t_logpdf(td, y) for y in eachcol(Y)]

Clusters

Before we can fully implement our Gibbs sampler, we need some bookkeeping code. We need to track the sufficient statistics for each instantiated cluster that allow us to re-create the posterior parameters.

Specifically, for an instantiated cluster we save the number of assigned observations, \(n_k\); the sum of the observations, \(S_k\); and, the outer products of the assigned observations, \(Q_k\).

mutable struct Cluster{T<:Real}
  N::Int             # Num. of observations for a cluster
  D::Int             # Dimensionality of observations
  S::Vector{T}       # Sum of observations
  Q::Matrix{T}       # Sum of outer products of observations

  # Scratch buffers for posterior calculations
  _rdx::Vector{T}
  _rdy::Vector{T}
  _rdm::Matrix{T}

  priors::Priors{T}  # Priors --- note, BLAS cannot mix Float64 & Float32
  td::Multi_t        # Predictive distribution

  function Cluster(Y::AbstractVecOrMat{T}, priors::Priors{T}; predictive=true) where {T<:Real}
      # Sufficient statistics
      D = size(Y, 1)
      N = size(Y, 2)
      S = vec(sum(Y, dims=2))
      Q = Y * Y'

      _rdx = Vector{T}(undef, D)
      _rdy = Vector{T}(undef, D)
      _rdm = Matrix{T}(undef, D, D)

      obj = new{T}(N, D, S, Q, _rdx, _rdy, _rdm, priors)
      if predictive === true
          obj.td = posterior_predictive(obj)
      end

      return obj
    end
end

The posterior parameters are updated using the sufficient statistics according to Equation 7.

function posterior(cluster::Cluster)
    κ = cluster.priors.κ_0 + cluster.N

    # μ = (S + κ_0 * μ_0) / κ --- posterior predictive mean
    @. cluster._rdx = (cluster.S + cluster.priors.κ_0 * cluster.priors.μ_0) / κ;

    # (Ȳ - μ_0) --- required for posterior predictive covariance matrix
    @. cluster._rdy = cluster.S / cluster.N - cluster.priors.μ_0

    # Scale matrix
    #   Λ = Λ_0 + Q - (1 / N * S - μ_0) * (1 / N * S - μ_0)' +
    #       κ_0 * N / κ * ( S / N - μ_0) * (S / N - μ_0)'
    fused_update!(cluster._rdm, cluster.priors.Λ_0, cluster.Q, cluster.S,
                  cluster._rdy, 1 / cluster.N, cluster.priors.κ_0 * cluster.N / κ)

    return (cluster.priors.ν_0 + cluster.N, κ, cluster._rdx, cluster._rdm)
end

function fused_update!(rdm, Λ_0, Q, S, rdy, α, β)
    D = size(rdm, 1)
    for j in 1:D
        @simd for i in j:D
            @inbounds rdm[i,j] = Λ_0[i,j] + Q[i,j] - α * S[i] * S[j] + β * rdy[i] * rdy[j]
        end
    end
end

From this we can write a function to form the posterior predictive distribution for a new observation.

function posterior_predictive(cluster::Cluster)
    ν, κ, μ, Λ = posterior(cluster)

    df = ν - cluster.D + 1
    BLAS.scal!((κ + 1) /* df), Λ)
    Multi_t(df, μ, Λ)
end

Add/Removing Observations

When we add or remove an observation from a cluster we update our sufficient statistics.3 For example, when adding an observation, \(Y_i\), to cluster \(k\) then

\[ \begin{align*} n_k^\star &= n_k + 1 \\ S_k^\star & = S_k + y_i \\ Q_k^\star & = Q_k + y_i y_i' \end{align*} \]

function add!(cluster::Cluster, y::AbstractVector{T}) where {T<:Real}
    cluster.N += 1
    axpy!(1, y, cluster.S)

    # Q += y y'
    BLAS.syr!('L', T(1.0), y, cluster.Q)
    cluster.td = posterior_predictive(cluster)
end

Removing an observation is similarly as straightforward.

function remove!(cluster::Cluster, y::AbstractVector{T}) where {T<:Real}
    cluster.N -= 1
    axpy!(-1, y, cluster.S)

    # Q -= y y'
    BLAS.syr!('L', T(-1.0), y, cluster.Q)
    cluster.td = posterior_predictive(cluster)
end

Gibbs Sampler

We can now implement our sampler. Again, for bookkeeping purposes we will track state using a struct and instantiate the sampler using a constructor.

mutable struct Gibbs{S<:Real, T<:Real}
    α::S
    Y::AbstractMatrix{T}
    D::Int
    N::Int

    priors::Priors{T}
    H::Multi_t

    __active::Dict{Int32, Cluster{T}}
    __lp_new_cluster::Vector{T}

    iterations::Int32
    Z::Matrix{Int32}

    map_clusters::Vector

    function Gibbs(Y::AbstractMatrix{S},
                   priors::Priors{T};
                   α=1,
                   init_clusters=10,) where {S<:Real, T<:Real}
        D, N = size(Y)

        R = promote_type(S, T)
        model_data = convert(Matrix{R}, Y)
        model_priors = convert(Priors{R}, priors)

        # Randomly initialize cluster assignments
        inits = rand(1:init_clusters, N)
        clusters = Dict(k => Cluster(model_data[:, inits .== k], model_priors)
                                     for k in unique(inits))

        Z = reshape(inits, length(inits), 1)

        # Pre-calculate unnormalized log-probability for creating a new cluster
        Σ_0 = model_priors.Λ_0 * ((model_priors.κ_0 + 1) / (model_priors.κ_0 * (model_priors.ν_0 - D + 1)))
        td = Multi_t(model_priors.ν_0 - D + 1, model_priors.μ_0, Matrix(Σ_0))

        lp_new_cluster = log(α) .+ multi_t_logpdf(td, model_data)

        new{eltype(α), T}(α, model_data, D, N, model_priors, td, clusters,
                          lp_new_cluster, 1, Z)
    end
end

The actual sampling logic is encoded in a functor that iterates for a fixed number of iterations sequentially sampling from \(z_i | \boldsymbol{z_{-i}}, \boldsymbol{Y} \; \forall i \in \{1, 2, \ldots, N \}\) for each pass. The result is the posterior cluster labels for our model stored as a matrix in the \(Z\) struct field.

using LogExpFunctions: softmax!
function(g::Gibbs)(niter; refresh=div(niter, 10))
    M = g.iterations == 1 ? niter - 1 : niter

    if size(g.Z, 2) < M + g.iterations
        z_samples = zeros(Int32, g.N, M)
        g.Z = hcat(g.Z, z_samples)
    end

    # Pre-allocate vector to hold proportional log-probabilities for each cluster
    clp = Vector{eltype(g.__lp_new_cluster)}(undef, 2 * (length(g.__active) + 1))

    vY = eachcol(g.Y)
    max_id = keys(g.__active) |> maximum

    # Iteration bounds
    lower, upper = (g.iterations + 1), (M + g.iterations)

    @info "Sampling from model..."
    @inbounds for m in lower:upper
        m % refresh == 0 && @info "Iteration: $m / $upper"

        for i in 1:g.N
            if g.__active[g.Z[i, m - 1]].N > 1
                remove!(g.__active[g.Z[i, m - 1]], vY[i])
            else
                delete!(g.__active, g.Z[i, m - 1])
            end

            cluster_keys = keys(g.__active) |> collect
            num_clusters = length(cluster_keys)
            length(clp) < num_clusters + 1 && resize!(clp, 2 * length(clp))

            for idx in eachindex(cluster_keys)
                cluster = g.__active[cluster_keys[idx]]
                clp[idx] = log(cluster.N) + multi_t_logpdf(cluster.td, vY[i])
            end
            clp[num_clusters + 1] = g.__lp_new_cluster[i]

            # Transform and normalize to probability simplex
            softmax!(view(clp, 1:num_clusters + 1))

            # Z_i ~ Categorical(clp_{1:K+1})
            if rand() > clp[num_clusters + 1]
                g.Z[i, m] = wsample(cluster_keys, view(clp, 1:num_clusters))
                add!(g.__active[g.Z[i, m]], vY[i])
            else
                g.Z[i, m] = (max_id += 1)
                g.__active[g.Z[i, m]] = Cluster(vY[i], g.priors)
            end
        end
    end

    g.iterations += M
end

Posterior Prediction

Our ultimate goal remains to predict the correct digit given a PCA transformed vector of gray-scale pixel values. We assign the cluster label for a new observation, \(\tilde{y}\), by finding the model with the “best fit.”

Typically this would involve the posterior predictive distribution.

\[ p(\tilde{y} | \boldsymbol{Y}) = \sum_{k=1}^{K^\star} \frac{n_k}{N + \alpha} p(\tilde{y} | \boldsymbol{Y_k}) + \frac{\alpha}{N + \alpha} p(\tilde{y} | G_0) \]

However, evaluating this over even a small subset of the posterior draws is fairly computationally expensive. Since we are only interested in a point estimate prediction under zero-one loss we’ll form the log-likelihood using the maximum a posteriori (MAP) estimate for the latent labels \(z\).

\[ \begin{align*} z_{\text{MAP}} & = \text{arg} \max_z p(z | \boldsymbol{Y}) \\ & = \text{arg} \max_z \left[ \log p(z) + \sum_{k=1}^{K^\star} \log p\left(\boldsymbol{Y_k} | z_k\right) \right] \end{align*} \tag{8}\]

The second term in Equation 8 is the integrated log-likelihood. Recall, \(Y_i | z_i = k \sim N(\mu_k, \Sigma_k)\). Then,

\[ p\left(\boldsymbol{Y_k} | z_k\right) = \int \int p\left(\boldsymbol{Y_k} | \mu_k, \Sigma_k\right) p\left(\mu_k | \Sigma_k\right) p\left(\Sigma_k\right) d\mu_k d\Sigma_k \]

which you can either solve or simply look up Murphy (2013) Section 5.3.2.3 “Gaussian-Gaussian-Wishart model” to find the following solution.

\[\begin{equation} p\left(\boldsymbol{Y_k} | z_k\right) = \frac{1}{\pi^{\frac{n_k D}{2}}} \left(\frac{\kappa_0}{\kappa_{n_k}}\right)^{\frac{D}{2}} \frac{\left|\Lambda_0\right|^{\frac{\nu_0}{2}}}{\left|\Lambda_{n_k}\right|^{\frac{\nu_{n_k}}{2}}} \frac{\Gamma_D\left(\frac{\nu_{n_k}}{2}\right)}{\Gamma_D\left(\frac{\nu_0}{2}\right)} \end{equation}\]

using StatsFuns: logmvgamma
function integrated_log_lik(Y, priors)
    cluster = Cluster(Y, priors; predictive=false)

    ν, κ, _, Λ = posterior(cluster)
    Λ .= Symmetric(Λ, :L)

    a = 0.5 * cluster.D
    b = 0.5 * priors.ν_0
    c = 0.5 * ν

    a * log(priors.κ_0 / κ) +
      b * logdet(priors.Λ_0) -
      c * logdet(Λ) +
      logmvgamma(cluster.D, c) -
      logmvgamma(cluster.D, b) -
      (cluster.N * a) * log(π)
end

Meanwhile, the first term in Equation 8 is the prior for the latent cluster labels, \(z\), which can be found by expanding the joint probability, \(p(z_1, \ldots, z_N)\), and using Equation 5 and Equation 6,

\[ \begin{align*} p(z_1, \ldots, z_N) & = p(z_1) p(z_2 | z_1) \cdots p(z_N | z_{N_1}, \ldots, z_1) \\ & = \frac{\alpha^{K^\star}}{(\alpha + N - 1)(\alpha + N - 2)\cdots\alpha}\prod_{k=1}^{K^\star} \left( n_k - 1\right)! \end{align*} \]

where \(K^\star\) is the instantiated number of clusters after observing \(N\) observations.

Combining the two terms, we implement the joint log likelihood, which is proportional to \(p(z | \boldsymbol{Y})\).

function joint_log_lik(draw, Y, α, priors)
    N = size(Y, 2)
    z = unique(draw)
    K = length(z)

    # p(y | z)
    lpx = sum(integrated_log_lik(Y[:, draw .== k], priors) for k in z)

    # p(z)
    lpz = K * log(α) + sum(logfactorial(sum(draw .== k) - 1) for k in z) -
          sum(log+ j) for j in 0:(N - 1))

    lpx + lpz
end

The MAP estimate for \(z\) is then the posterior draw which maximizes this joint log-likelihood.

draws(g::Gibbs; burnin=0, thin=1) = g.Z[:, (burnin+1):thin:end]

function optim!(g::Gibbs; kwargs...)
    @info "Finding most likely cluster assignments..."
    Z = draws(g; kwargs...)
    column = argmax(joint_log_lik(draw, g.Y, g.α, g.priors) for draw in eachcol(Z))

    map_labels = Z[:, column]
    g.map_clusters = [Cluster(g.Y[:, map_labels .== k], g.priors)
                      for k in unique(map_labels)]
end

Generating Predictions

To actually generate a prediction, we evaluate the posterior predictive density for a single model using the MAP estimate for \(z\).

using LogExpFunctions: logsumexp
function new_log_lik(g::Gibbs, y; kwargs...)
    !isdefined(g, :map_clusters) && optim!(g; kwargs...)

    lp = [log(cluster.N) + multi_t_logpdf(cluster.td, y) for cluster in g.map_clusters]
    push!(lp, log(g.α) + multi_t_logpdf(g.H, y))

    logsumexp(lp) - log(g.α + g.N)
end

Then, we assign a prediction by selecting the model with the greatest log-likelihood.

\[ \hat{m}\left(Y^{\text{new}}\right) = \underset{m \in \{1, 2, \ldots, M \}}{\operatorname{arg max}} \ell\left(z^{(m)}_{\text{MAP}} | Y^{\text{new}}\right) \]

function predict(Y, models; kwargs...)
    [argmax(new_log_lik(sampler, y; kwargs...) for sampler in models) - 1
     for y in eachcol(Y)]
end

Running the Models

That’s enough math. Let’s actually run the models.

Start by instantiating a Gibbs sampler for each subset of the training dataset corresponding to the different MNIST digits with fairly uninformative priors.

models = map(0:9) do k
    data = Y[:, labels .== k]
    D = size(data, 1)

    μ_0 = zeros(D)
    κ_0 = 1
    ν_0 = D + 1
    Λ_0 = I(D)

    priors = Priors{Float64}(κ_0, ν_0, μ_0, Matrix(Λ_0))

    Gibbs(data, priors, α = 1, init_clusters = 4)
end

We run each sampler in a parallelized fashion for a fixed number of iterations. Warning, this will take a substantial amount of time.

using Base.Threads

niter = 3_000
burnin = 1_500
thin = 3

Threads.@threads :greedy for sampler in models
    sampler(niter)
end

Before we generate the MAP estimates for \(z\) for each model, let’s check the behaviour of our samplers. Figure 1 shows the traceplots for the cluster counts after discarding the first half of the posterior draws.

plts = map(models) do sampler
    n = [length(unique(z)) for z in eachcol(draws(sampler; burnin, thin))]
    plot(1:length(n), n, legend = :none)
end

plot(plts..., layout = (5, 2))
Figure 1: MCMC traceplots of the cluster count per model.

We can also plot the acceptance ratios, i.e. the proportion of observations assigned the same cluster label between sampling iterations, as seen in Figure 2.

function acceptance(Z)
    N, M = size(Z)
    [sum(Z[:, i] .== Z[:, i+1]) / N for i in 1:(M - 1)]
end

ratios = [draws(sampler; burnin, thin) |> acceptance for sampler in models]
plts = [plot(1:length(p), p, legend = :none) for p in ratios]
plot(plts..., layout = (5, 2))
Figure 2: MCMC traceplots for acceptance ratios per model.

Finally, let’s check our prediction accuracy for the training dataset.

check(Z_hat, labels) = round(mean(Z_hat .== labels), digits=3)

Z_hat = predict(Y, models; burnin, thin)
println("Training dataset accuracy: $(check(Z_hat, labels))")
Training dataset accuracy: 0.996

And, likewise, for the test dataset.

test_pixels, test_labels = MNIST(split=:test)[:]
test_input = reshape(test_pixels, (28*28, 10_000))
test_reduced = transform(pca, test_input)

Z_test = predict(test_reduced, models)
println("Test dataset accuracy: $(check(Z_test, test_labels))")
Test dataset accuracy: 0.98

We can also explore how the models learned different variants of the same digit. Looking at the first model that corresponds to the ‘0’ digit, Figure 3 plots the expectation for each cluster.

imgs = [reconstruct(pca, cluster.td.μ) for cluster in models[1].map_clusters]
plts = [plot(Gray.(reshape(1 .- i, 28, 28)')) for i in imgs]
plot(plts..., axis = false, ticks = false)
Figure 3: Predicted images for each variant of the digit ‘0’.

Turns out that ‘0’ is pretty easy to model. This is reflected in the per digit accuracy scores.

for i in 0:9
    aux = check(Z_test[test_labels .== i], fill(i, sum(test_labels .== i)))
    println("Digit '$(i)' accuracy: $aux")
end
Digit '0' accuracy: 0.998
Digit '1' accuracy: 0.989
Digit '2' accuracy: 0.982
Digit '3' accuracy: 0.978
Digit '4' accuracy: 0.979
Digit '5' accuracy: 0.962
Digit '6' accuracy: 0.987
Digit '7' accuracy: 0.967
Digit '8' accuracy: 0.975
Digit '9' accuracy: 0.967

Conclusion

Compared to the ensemble of finite Gaussian mixture models we ended up improving our accuracy on the test dataset by approximately one percentage point at the cost of a significantly increased runtime — on my local machine the infinite mixture models, although memory efficient, take roughly 5x-6x longer to fit. The slow speed, even when generating posterior predictions, undermines the practicality of these models, at least in the form used here.

That said, there is always more that can be done. Similar to the finite mixture models, we did not explore different pre-processing strategies or discuss hyperparameter tuning — in the case of the latter, a common extension is to place a Gamma prior on the concentration parameter \(\alpha\). Furthermore, a production grade solution would actually have to address the question of convergence for the Gibbs samplers.4 And, we would need to design a more robust implementation — checking the dimensions of inputs, asserting that the prior scale matrix is positive-semidefinite, etc etc.5

Maybe, some day, I’ll code up one of the more practical inference algorithms for infinite mixture models.

Addendum

An alternative, and perhaps more canonical, exposition of infinite mixture models from the perspective of Bayesian non-parametrics goes via Dirichlet processes.

A Dirichlet process is a distribution over all possible probability measures defined on some measure space. For our purposes, we denote a random probability measure distributed according to a Dirichlet process as \(G \sim \text{DP}(\alpha G_0)\) defined on measurable space \((\Omega, \mathcal{A})\) with parameters \(\alpha > 0\) and \(G_0\), a baseline probability measure also defined on \((\Omega, \mathcal{A})\). By definition, a realization \(G\) is discrete with probability one and satisfies the property

\[ G(A_1), G(A_2), \ldots, G(A_k) \sim \text{Dirichlet}\left(\alpha G_0(A_1), \alpha G_0(A_2), \ldots, \alpha G_0(A_k)\right) \]

for any finite partition \(A_1, \ldots, A_k\) of \(\Omega\) such that \(\mathbb{E}[G(A)] = G_0(A)\) and \(\text{Var}\left(G(A)\right) = \frac{G_0(A)(1 - G_0(A))}{1 + \alpha}\) for all \(A \in \mathcal{A}\).6

In the context of a mixture model, we let the mixture parameters be distributed according to a Dirichlet process measure. For an exchangeable sequence, \(Y_1, Y_2, \ldots, Y_N\), this leads to the following generative model.

\[ \begin{align*} Y_i | \theta_i & \sim F(\theta_i) \\ \theta_i | G & \sim G \\ G & \sim \text{DP}\left(G_0, \alpha\right) \end{align*} \]

Marginalizing over the measure \(G\) results in the following conditional distribution,

\[ \theta_i | \boldsymbol{\theta_{-i}} \sim \frac{1}{N - 1 + \alpha} \sum_{j\neq i}^{N} \delta_{\theta_j} + \frac{\alpha}{N - 1 + \alpha} G_0 \]

where \(\delta_{\theta_j}\) is the dirac measure assigning probability mass to the single point \(\theta_j\).

Let \(\boldsymbol{\theta^\star} = \{\theta_1^\star, \ldots, \theta_{K}^\star\}\) be the unique values of \(\boldsymbol{\theta}\). Then,

\[ P(\theta_i = \theta_k^\star) = \frac{n_{-i,k}}{\alpha + N - 1} \]

where \(n_{-i,k} = \sum_{j\neq i}^N \mathbb{I}\{\theta_j = \theta_k\}\). Meanwhile,

\[ P(\theta_i \notin \boldsymbol{\theta^\star}) = \frac{\alpha}{\alpha + N - 1} \]

Define \(z_i = k \Leftrightarrow \theta_i = \theta_k\), and we have exactly our conditional probabilities given by Equation 5 and Equation 6, commonly referred to as the Chinese Restaurant Process.

References

Blackwell, David, and James B. MacQueen. 1973. “Ferguson Distributions via Polya Urn Schemes.” The Annals of Statistics 1 (2). https://doi.org/10.1214/aos/1176342372.
Blei, David M., and Michael I. Jordan. 2006. “Variational Inference for Dirichlet Process Mixtures.” Bayesian Analysis 1 (1). https://doi.org/10.1214/06-BA104.
Murphy, Kevin P. 2013. Machine Learning: A Probabilistic Perspective. 4. print. (fixed many typos). Adaptive Computation and Machine Learning Series. Cambridge, Mass.: MIT Press.
Neal, Radford M. 2000. “Markov Chain Sampling Methods for Dirichlet Process Mixture Models.” Journal of Computational and Graphical Statistics 9 (2): 249–65.
Raftery, Adrian E., and Steven M. Lewis. 1992. “[Practical Markov Chain Monte Carlo]: Comment: One Long Run with Diagnostics: Implementation Strategies for Markov Chain Monte Carlo.” Statistical Science 7 (4). https://doi.org/10.1214/ss/1177011143.
Raykov, Yordan P., Alexis Boukouvalas, and Max A. Little. 2016. “Simple Approximate MAP Inference for Dirichlet Processes Mixtures.” Electronic Journal of Statistics 10 (2). https://doi.org/10.1214/16-EJS1196.

  1. Neal (2000) is a canonical reference for Gibbs sampling in the context of conjugate and non-conjugate priors for infinite mixture models. Meanwhile, Blei and Jordan (2006) provides a variational inference approach that is used by scikit-learn and Raykov, Boukouvalas, and Little (2016) offer an approximate MAP estimator.↩︎

  2. For the derivation, refer to your favorite source on conjugate priors. Or, check out section 3.6 “Multivariate normal with unknown mean and variance” in BDA3.↩︎

  3. The code for adding/removing observations from a cluster should set off some alarm bells as it is not numerically stable and will lead to accumulated floating point error. A more robust approach would re-calculate the sufficient statistics from scratch, rather than partially updating. I’ll partially mitigate the issue by using Float64 variables, and leave the rest as an exercise to the reader.↩︎

  4. Previous work discussing infinite mixture models typically use the Raftery and Lewis diagnostic to assess convergence (Raftery and Lewis 1992).↩︎

  5. We should ideally be checking the return values from our BLAS/LAPACK calls. For example, accumulated floating point error when adding/removing observations can lead to the cholesky decomposition failing when creating the posterior predictive distribution for a cluster.↩︎

  6. There are several proofs showing that Dirichlet process distributed random probability measures are discrete almost surely. My favorite is Blackwell and MacQueen (1973) using the Polya urn scheme.↩︎