EM Algorithm: Part 2 - MNIST

2024-02-18

This note is the second in a three part series on the expectation maximization algorithm. Part 1 gives a cursory overview of the algorithm, Part 2 deals with mixture models, and Part 3 applies the EM algorithm to hidden markov models.


Let’s see the EM algorithm in action by using it on a toy example. We’ll implement several mixture models in Julia to classify handwritten digits from the MNIST dataset.

Mixture models are a form of unsupervised soft clustering that assign probabilities to each observation belonging to a certain unobserved class. In this case, our observed data will be the images from MNIST to which we will try to assign the correct digit label.

The first attempt will be a fairly naive Bernoulli mixture model after which we’ll see if we can improve our predictions using a mixture of Gaussians.

MNIST

The MNIST dataset consists of 70,000 grayscale images of handwritten digits. We can grab the dataset from the MLDatasets package where each image is represented as a 28x28 matrix of individual pixels normalized to be between 0 and 1.

using Colors, MLDatasets, Plots

ENV["DATADEPS_ALWAYS_ACCEPT"] = true
pixels, labels = MNIST(split=:train)[:]

plts = [plot(1 .- Gray.(img')) for img in eachslice(pixels[:, :, 1:6], dims=3)]
plot(plts..., layout = (2, 3), axis = false, ticks = false)
Figure 1: First six digits from the MNIST training dataset.

Bernoulli Mixture Model

For \(i = 1, 2, \ldots, N\) images and \(j = 1, 2, \ldots, D\) pixels, start by creating a binary indicator for each pixel, \(X_{ij} = \mathbb{I}(X_{ij} > 0.5)\).

binned_pixels = collect(pixels .> 0.5)

# We'll also flatten our nested array into a 784x60,000 matrix
model_input = reshape(binned_pixels, (28*28, 60_000))

We model each pixel as an independent Bernoulli random variable with a latent variable \(Z_i \in \{ 1, 2, \ldots, K \}\) denoting the unknown classification for image \(i\).

\[\begin{equation} X_{ij} | Z_i = k \sim \text{Bern}(\pi_k) \end{equation}\]

The marginal log-likelihood for \(X\) follows easily.

\[\begin{equation} \ell(\pi, \phi | X) = \sum_{i=1}^N \log \left( \sum_{k=1}^K \pi_k \prod_{j=1}^D \phi_{kj}^{X_{ij}} (1 - \phi_{kj})^{(1 - X_{ij})} \right) \end{equation}\]

Unfortunately, marginalizing over the latent states leaves a summation inside the natural logarithm, which means that there is no closed form solution to this MLE problem.

With the goal of finding an estimate for \(\theta = \{ \pi, \phi \}\) that maximizes \(\ell(\theta | X)\), we turn to the EM algorithm and define our objective function.

\[\begin{aligned} Q(\theta^{(t)}, \theta) & = \mathbb{E}_Z \left[ \log p(X, Z | \theta) | X, \theta^{(t)} \right] \\ & = \sum_{i=1}^N \sum_{k=1}^K p\left(Z_i = k | X, \theta^{(t)} \right) \left( \log \pi_k + \log \prod_{j=1}^D \phi_{kj}^{X_{ij}} (1 - \phi_{kj}) ^{1 - X_{ij}} \right) \end{aligned}\]

We can find \(p(Z_i = k | X, \theta^{(t)})\) using Bayes Rule:

\[\begin{equation} p \left(Z_i = k | X, \theta^{(t)} \right) = \cfrac{p(X_i | Z_i = k, \theta^{(t)}) \pi_k}{\sum_{l=1}^K p(X_i | Z_i = l, \theta^{(t)}) \pi_l} \end{equation}\]

where \(p\left( X_i | Z_i = k, \theta^{(t)} \right) = \prod_{j=1}^D \phi_{kj}^{X_{ij}} (1 - \phi_{kj}) ^{1 - X_{ij}}\).

E-Step

Given the previous equations, the E-step is fairly straightforward. Start by defining the marginal log likelihood for \(X_i\). Throughout our implementation of the EM algorithm we’ll be dealing with log probabilities for numeric stability.1

function log_mvbernoulli_pmf(y, lp, lp1m)
    sum(@inbounds (pixel == 1) ? lp[i] : lp1m[i] for (i, pixel) in enumerate(y))
end

Next, we define the actual E-step function where we calculate \(p \left(Z_i = k | X_i, \theta^{(t)} \right) \; \forall i \in \{1, 2, \ldots, N \}, k \in \{ 1, 2, \ldots, K \}\). To avoid memory allocation overheads, we directly modify the \(\gamma\) variable to store the posterior estimates.

using LogExpFunctions

function E_step!(Y, γ, ϕ, π)
    lpi = log.(π)
    lp = log.(ϕ)
    lp1m = log.(1 .- ϕ)

    log_resps = zeros(Float64, length(π))

    for i in 1:size(Y, 2)
        @views for k in 1:length(lpi)
            log_resps[k] = lpi[k] + log_mvbernoulli_pmf(Y[:, i], lp[:, k], lp1m[:, k])
        end

        γ[i, :] .= exp.(log_resps .- logsumexp(log_resps))
    end
end

M-Step

The M-step involves finding \(\theta^{(t+1)} = \text{arg max}_{\theta \in \Theta} Q(\theta^{(t)}, \theta)\). Treating \(p \left(Z_i = k | X, \theta^{(t)} \right)\) as constant, we find \(\theta^{(t+1)}\) by setting the partial derivatives of the objective function, \(\frac{\partial Q(\theta^{(t)}, \theta)}{\partial\theta}\), to zero and solving, which leads to the following updating equations.

\[\begin{equation} \pi_k = \cfrac{\sum_{i=1}^N p \left( Z_i = k | X_i, \theta^{(t)} \right)}{N} \end{equation}\]

\[\begin{equation} \phi_{kj} = \cfrac{\sum_{i=1}^N X_{ij} p \left(Z_i = k| X_i, \theta^{(t)} \right)}{\sum_{i=1}^N p\left( Z_i = k | X_i, \theta^{(t)} \right)} \end{equation}\]

using LinearAlgebra

function M_step!(sY, γ, ϕ, π)
    cluster_sums = vec(sum(γ, dims=1))

    π .= cluster_sums ./ size(sY, 2)

    mul!(ϕ, sY, γ)
    ϕ ./= cluster_sums'
    clamp!(ϕ, eps(), 1 - eps())
end

Putting it Together

All that remains is to write a function that selects initial values for \(\theta\), and iterates between the E-step and M-step until convergence.

using SparseArrays

function marginal_log_lik(Y, ϕ, π)
    lpi = log.(π)
    lp = log.(ϕ)
    lp1m = log.(1 .- ϕ)

    aux = 0
    @views for i in 1:size(Y, 2)
        aux += logsumexp(lpi[k] + log_mvbernoulli_pmf(Y[:, i], lp[:, k], lp1m[:, k])
                         for k in 1:length(π))
    end

    return aux
end

function EM(Y, K; max_iter=1_000, tol=1e-5)
    π = fill(1/K, K)
    ϕ = rand(Float64, size(Y, 1), K)
    γ = zeros(Float64, size(Y, 2), K)

    # The M-Step matrix multiplication can be greatly sped up using a
    # sparse matrix representation of the binary data
    sY = SparseMatrixCSC(Y)

    log_lik_prev = log_lik_new = -Inf
    for i in 1:max_iter
        E_step!(Y, γ, ϕ, π)
        M_step!(sY, γ, ϕ, π)

        log_lik_new = marginal_log_lik(Y, ϕ, π)
        log_lik_diff = abs(log_lik_new - log_lik_prev)
        @info "Iteration $(i): log-likelihood = $(log_lik_new)"
        if log_lik_diff < tol
            return ϕ, π, γ
        end

        log_lik_prev = log_lik_new
    end

    error("Model failed to converge")
end

But, we’re missing one final piece of the puzzle before we can finally run our model. We want to be able to recover the latent classification \(\hat{k}_i\) for each image \(i\) by finding the latent state with the highest posterior probability given the observed data and parameter vector \(\theta\).

\[\begin{equation} \hat{k}_i = \max\left(\left\{ p(Z_i = k | X_i, \theta^{(t)}) : k \in \{ 1, 2, \ldots, K \} \right\} \right) \end{equation}\]

These won’t necessarily map back to the original labels in the dataset. There are multiple different approaches to do this, but we’ll keep it simple and use the mode.2

using StatsBase

function label_mapping(clusters, labels)
    Dict(k => mode(labels[clusters .== k])
         for k in 1:maximum(clusters) if k in clusters)
end

Once we have \(\hat{k}_i\) for \(i = 1, 2, \ldots, N\), we can calculate the proportion of observations correctly classified.

function check(Z_hat, labels)
    correct = round(mean(Z_hat .== labels), digits=3)
    println("Proportion correctly classified: $(correct)")
end

Running the Model

Depending on your computer this may take a significant amount of time.

# Enable logging to see iteration progress
using Logging
disable_logging(Logging.Info)

# Set the number of latent states to the total number of digits we're
# trying to classify
ϕ, π, γ = EM(model_input, 10)

clusters = map(argmax, eachrow(γ))
mapping = label_mapping(clusters, labels)
Z_hat = [mapping[i] for i in clusters]

check(Z_hat, labels)
Proportion correctly classified: 0.589

That’s quite underwhelming. To further investigate our model performance, we can plot the predicted probabilities, \(\phi_j\), for each pixel \(j\) belonging to a cluster class which will depict the corresponding “idealized” images.

plts = [Gray.(reshape(1 .- ϕ[:, i], 28, 28)') |> plot for i in 1:size(ϕ, 2)]
plot(plts..., layout = size(ϕ, 2), axis = false, ticks = false)
Figure 2: Ideal digit predictions from our Bernoulli mixture model.

What’s happened is that clearly some digits are more difficult to model than others. Furthermore, some digits are being overpredicted. It’s important to note that the Bernoulli model does not guarantee that each latent class \(k\) will map one-to-one to the original dataset labels.

How can we improve this? We could for example increase the number of latent states in the model to account for the over predictions, or we switch to a more flexible model.

Gaussian Mixture Model(s)

Let’s model each image instead with a multivariate Gaussian using the normalized values for the pixels. We’re also going to take the opportunity to modify our approach in two different ways.

First, we’ll improve computational efficiency by reducing the dimensionality of our data through principle component analysis. This involves calculating the covariance matrix S given by:

\[\begin{equation} S = \cfrac{1}{N - 1}(X - \bar{X})'(X - \bar{X}) \end{equation}\]

and then performing an eigendecomposition \(S V = V \Lambda\) where \(V\) is a matrix of eigenvectors. Selecting \(D^\star\) eigenvectors where \(D^\star < D\), the new dataset, \(Y \in \mathbb{R}^{N \times D^\star}\), is formed by applying the projection matrix, \(P\), to the centered data.

\[\begin{aligned} P & = [v_1, v_2, \ldots, v_{D^\star}] \; \text{for} \; v_j \in V \\ Y & = (X - \bar{X}) P \end{aligned}\]
using MultivariateStats

# Normally we would use a more robust criterion for selecting D*,
# but for simplicity we'll just (arbitrarily) set it at 50.
pca = fit(PCA, model_input, maxoutdim=50)
reduced_data = transform(pca, model_input)

Second, instead of fitting a single mixture model, we’ll fit \(M = 10\) separate models for each digit in the MNIST dataset. Each mixture model will be estimated with \(K\) latent states ideally corresponding to different “variants” of the same digit.3 Let \(L_i\) be the original MNIST label for image \(i\), then model \(m = \{ 1, 2, \ldots, M \}\) is formed as follows:

\[\begin{aligned} Y^{(m)} & = \{ Y_i : L_i = m \} \\ Y^{(m)}_i | Z^{(m)}_i & = k \sim \text{N}\left(\mu_{mk}, \Sigma_{mk} \right) \end{aligned}\]

s.t.

\[\begin{equation} Q\left(\theta^{(t)}_m, \theta_m \right) = \sum_{i=1}^{N_m} \sum_{k=1}^{K} p\left(Z^{(m)}_i = k | Y^{(m)}, \theta^{(t)}_m\right) \left( \log \pi_{mk} + \log \text{N}(\mu_{mk}, \Sigma_{mk}) \right) \end{equation}\]

where \(N(\cdot)\) is a slight abuse of notation to denote the Normal probability density function.

E-Step

This time we’ll rely on the Distributions package for the Normal probability density function.

using Distributions

function gaussians(μ, Σ)
    @views [MvNormal(μ[:, i], Σ[:, :, i]) for i in 1:size(μ, 2)]
end

Again, we let \(\gamma\) denote \(P\left( Z^{(m)}_i = k | Y^{(m)}, \theta^{(t)}_m \right)\), updating it in place while working with log probabilities.

function E_step!(Y, γ, π, μ, Σ)
    dists = gaussians(μ, Σ)
    lpi = log.(π)

    for i in 1:size(Y, 2)
        lp = [lpi[k] + logpdf(dists[k], view(Y, :, i)) for k in 1:length(π)]
        denominator = logsumexp(lp)

        for k in 1:length(π)
            @inbounds γ[i, k] = exp(lp[k] - denominator)
        end
    end
end

M-Step

We also employ a similar strategy as before for the M-Step. Solving for \(\frac{\partial Q(\theta^{(t)}_m, \theta_m)}{\partial \theta_m} = 0\) we find the updating equation for each parameter.

\[\begin{aligned} \pi_{mk} & = \cfrac{\sum_{i=1}^{N_m} p \left( Z_i^{(m)} = k | Y_i^{(m)}, \theta_m^{(t)} \right)}{N_m} \\ \mu_{mk} & = \cfrac{\sum_{i=1}^{N_m} Y^{(m)}_i p \left(Z^{(m)}_i = k | Y_i^{(m)}, \theta^{(t)}_m\right)}{\sum_{i=1}^{N_m} p \left(Z_i^{(m)} = k| Y_i^{(m)}, \theta_m^{(t)}\right)} \\ \Sigma_{mk} & = \cfrac{\sum_{i=1}^{N_m} p \left( Z_i^{(m)} = k | Y_i^{(m)}, \theta_m^{(t)}\right) \left( Y_i^{(m)} - \mu_{mk} \right) \left(Y_i^{(m)} - \mu_{mk} \right)'}{\sum_{i=1}^{N_m} p \left(Z_i^{(m)} = k | Y_i^{(m)}, \theta_m^{(t)}\right)} \end{aligned}\]
function M_step!(Y, γ, π, μ, Σ)
    cluster_sums = vec(sum(γ, dims=1))
    π .= cluster_sums ./ size(Y, 2)

    mul!(μ, Y, γ)
    μ ./= cluster_sums'

    @views for k in 1:length(π)
        Y_centered = Y .- μ[:, k]
        Σ_k = zeros(size(Y, 1), size(Y, 1))
        for i in 1:size(Y, 2)
            Σ_k += γ[i, k] * (Y_centered[:, i] * Y_centered[:, i]')
        end
        Σ[:, :, k] .= Σ_k / cluster_sums[k]
    end
end

EM Function

Our EM function for the Gaussian mixture model is largely the same as the Bernoulli version. We select initial values for \(\theta_m\) and iterate between the E-Step and M-Step until convergence.

function marginal_log_lik(Y, π, μ, Σ)
    d = MixtureModel(gaussians(μ, Σ), π)
    sum(logpdf(d, y) for y in eachcol(Y))
end

function EM(Y, K; max_iter=1_000, tol=1e-5)
    π = fill(1/K, K)
    μ = Y[:, rand(1:size(Y, 2), K)]
    Σ = cat([Diagonal(ones(size(Y, 1))) for k in 1:K]..., dims=3)
    γ = zeros(size(Y, 2), K)

    log_lik_prev = log_lik_new = -Inf
    for i in 1:max_iter
        E_step!(Y, γ, π, μ, Σ)
        M_step!(Y, γ, π, μ, Σ)

        log_lik_new = marginal_log_lik(Y, π, μ, Σ)
        @info "Iteration $i: Log Likelihood = $log_lik_new"

        if abs(log_lik_new - log_lik_prev) < tol
            return π, μ, Σ, γ
        end

        log_lik_prev = log_lik_new
    end

    error("Model failed to converge")
end

Once we have fit our models we will need to be able to predict the cluster label given an image. Instead of recovering the latent classifications for the observations, we’ll assign our cluster label based on the model with the highest log-likelihood given the data.

\[\begin{equation} \hat{m}(Y^{\text{new}}) = \underset{m \in \{ 1, 2, \ldots, M \}}{\operatorname{arg max}} \ell\left(\theta_{m} | Y^{\text{new}}\right) \end{equation}\]

function predict(Y, models)
    dists = [MixtureModel(gaussians(μ, Σ), π) for (π, μ, Σ, _) in models]
    [argmax(logpdf(d, y) for d in dists) - 1 for y in eachcol(Y)]
end

Running the Model

To actually run our models we spawn separate threads to enable parallel processing. Be forewarned that this will probably consume a significant amount of memory.

using Base.Threads

# Again, enable logging to see iteration progress
disable_logging(Logging.Info)

# Fit a model for each digit with K = 4
results = fetch.([@spawn EM(reduced_data[:, labels .== i], 4) for i in 0:9])

Z_hat = predict(reduced_data, results)
check(Z_hat, labels)
Proportion correctly classified: 0.973

This is a significant improvement over our initial Bernoulli model.

With the fitted models we can also check the prediction rate for the MNIST test dataset.

test_pixels, test_labels = MNIST(split=:test)[:]
test_input = reshape(test_pixels, (28*28, 10_000))
test_reduced = transform(pca, test_input)

Z_test = predict(test_reduced, results)
check(Z_test, test_labels)
Proportion correctly classified: 0.969

And, plot how a specific model learns the different variants of the same digit. We’ll look at the model run on the subset of images that are supposed to be fives in MNIST and plot the expected value for the Gaussian associated with each latent classification.

digit = results[6]
img = map-> reconstruct(pca, μ), eachcol(digit[2]))
plts = [plot(Gray.(reshape(1 .- i, 28, 28)')) for i in img]
plot(plts..., layout = (1, 4), axis = false, ticks = false)
Figure 3: Predicted images for each variant of the digit ‘5’.

Finally, we can also plot what the “average” digit looks like according to our models by plotting the expected value of the mixture for each digit.

EY = map(((π, μ, _, _),) -> reconstruct(pca, μ * π), results)
plts = [Gray.(reshape(1 .- μ, 28, 28)') |> plot for μ in EY]
plot(plts..., layout = length(EY), axis = false, ticks = false)
Figure 4: “Average” predicted image for each digit

Can we do better? Probably. There’s still more than can be done regarding hyperparameter tuning, exploring different initialization strategies, and just general optimizations to the code, but that will be left as a future exercise.


  1. If we were not to use log probabilities, consider what would happen to \(\prod_{j=1}^D \phi_{kj}^{X_{ij}} (1 - \phi_{kj}) ^{1 - X_{ij}}\) as \(D \to \infty\) and/or \(\phi_{kj} \to 0\) in regards to floating point error.↩︎

  2. In case of ties, the mode function from StatsBase will return the first element.↩︎

  3. For simplicity, we’ll assume a constant number of latent states for each model. Obviously, a more advanced implementation would allow \(K\) to vary since some digits may be inherently more difficult to model although this would be at the cost of additional computational complexity.↩︎