using Colors, MLDatasets, Plots
ENV["DATADEPS_ALWAYS_ACCEPT"] = true
= MNIST(split=:train)[:]
pixels, labels
= [plot(1 .- Gray.(img')) for img in eachslice(pixels[:, :, 1:6], dims=3)]
plts plot(plts..., layout = (2, 3), axis = false, ticks = false)
This note is the second in a three part series on the expectation maximization algorithm. Part 1 gives a cursory overview of the algorithm, Part 2 deals with mixture models, and Part 3 applies the EM algorithm to hidden markov models.
Let’s see the EM algorithm in action by using it on a toy example. We’ll implement several mixture models in Julia to classify handwritten digits from the MNIST dataset.
Mixture models are a form of unsupervised soft clustering that assign probabilities to each observation belonging to a certain unobserved class. In this case, our observed data will be the images from MNIST to which we will try to assign the correct digit label.
The first attempt will be a fairly naive Bernoulli mixture model after which we’ll see if we can improve our predictions using a mixture of Gaussians.
MNIST
The MNIST dataset consists of 70,000 grayscale images of handwritten digits. We can grab the dataset from the MLDatasets
package where each image is represented as a 28x28 matrix of individual pixels normalized to be between 0 and 1.
Bernoulli Mixture Model
For \(i = 1, 2, \ldots, N\) images and \(j = 1, 2, \ldots, D\) pixels, start by creating a binary indicator for each pixel, \(X_{ij} = \mathbb{I}(X_{ij} > 0.5)\).
= collect(pixels .> 0.5)
binned_pixels
# We'll also flatten our nested array into a 784x60,000 matrix
= reshape(binned_pixels, (28*28, 60_000)) model_input
We model each pixel as an independent Bernoulli random variable with a latent variable \(Z_i \in \{ 1, 2, \ldots, K \}\) denoting the unknown classification for image \(i\).
\[\begin{equation} X_{ij} | Z_i = k \sim \text{Bern}(\pi_k) \end{equation}\]
The marginal log-likelihood for \(X\) follows easily.
\[\begin{equation} \ell(\pi, \phi | X) = \sum_{i=1}^N \log \left( \sum_{k=1}^K \pi_k \prod_{j=1}^D \phi_{kj}^{X_{ij}} (1 - \phi_{kj})^{(1 - X_{ij})} \right) \end{equation}\]
Unfortunately, marginalizing over the latent states leaves a summation inside the natural logarithm, which means that there is no closed form solution to this MLE problem.
With the goal of finding an estimate for \(\theta = \{ \pi, \phi \}\) that maximizes \(\ell(\theta | X)\), we turn to the EM algorithm and define our objective function.
\[\begin{aligned} Q(\theta^{(t)}, \theta) & = \mathbb{E}_Z \left[ \log p(X, Z | \theta) | X, \theta^{(t)} \right] \\ & = \sum_{i=1}^N \sum_{k=1}^K p\left(Z_i = k | X, \theta^{(t)} \right) \left( \log \pi_k + \log \prod_{j=1}^D \phi_{kj}^{X_{ij}} (1 - \phi_{kj}) ^{1 - X_{ij}} \right) \end{aligned}\]We can find \(p(Z_i = k | X, \theta^{(t)})\) using Bayes Rule:
\[\begin{equation} p \left(Z_i = k | X, \theta^{(t)} \right) = \cfrac{p(X_i | Z_i = k, \theta^{(t)}) \pi_k}{\sum_{l=1}^K p(X_i | Z_i = l, \theta^{(t)}) \pi_l} \end{equation}\]
where \(p\left( X_i | Z_i = k, \theta^{(t)} \right) = \prod_{j=1}^D \phi_{kj}^{X_{ij}} (1 - \phi_{kj}) ^{1 - X_{ij}}\).
E-Step
Given the previous equations, the E-step is fairly straightforward. Start by defining the marginal log likelihood for \(X_i\). Throughout our implementation of the EM algorithm we’ll be dealing with log probabilities for numeric stability.1
function log_mvbernoulli_pmf(y, lp, lp1m)
sum(@inbounds (pixel == 1) ? lp[i] : lp1m[i] for (i, pixel) in enumerate(y))
end
Next, we define the actual E-step function where we calculate \(p \left(Z_i = k | X_i, \theta^{(t)} \right) \; \forall i \in \{1, 2, \ldots, N \}, k \in \{ 1, 2, \ldots, K \}\). To avoid memory allocation overheads, we directly modify the \(\gamma\) variable to store the posterior estimates.
using LogExpFunctions
function E_step!(Y, γ, ϕ, π)
= log.(π)
lpi = log.(ϕ)
lp = log.(1 .- ϕ)
lp1m
= zeros(Float64, length(π))
log_resps
for i in 1:size(Y, 2)
@views for k in 1:length(lpi)
= lpi[k] + log_mvbernoulli_pmf(Y[:, i], lp[:, k], lp1m[:, k])
log_resps[k] end
:] .= exp.(log_resps .- logsumexp(log_resps))
γ[i, end
end
M-Step
The M-step involves finding \(\theta^{(t+1)} = \text{arg max}_{\theta \in \Theta} Q(\theta^{(t)}, \theta)\). Treating \(p \left(Z_i = k | X, \theta^{(t)} \right)\) as constant, we find \(\theta^{(t+1)}\) by setting the partial derivatives of the objective function, \(\frac{\partial Q(\theta^{(t)}, \theta)}{\partial\theta}\), to zero and solving, which leads to the following updating equations.
\[\begin{equation} \pi_k = \cfrac{\sum_{i=1}^N p \left( Z_i = k | X_i, \theta^{(t)} \right)}{N} \end{equation}\]
\[\begin{equation} \phi_{kj} = \cfrac{\sum_{i=1}^N X_{ij} p \left(Z_i = k| X_i, \theta^{(t)} \right)}{\sum_{i=1}^N p\left( Z_i = k | X_i, \theta^{(t)} \right)} \end{equation}\]
using LinearAlgebra
function M_step!(sY, γ, ϕ, π)
= vec(sum(γ, dims=1))
cluster_sums
π .= cluster_sums ./ size(sY, 2)
mul!(ϕ, sY, γ)
./= cluster_sums'
ϕ clamp!(ϕ, eps(), 1 - eps())
end
Putting it Together
All that remains is to write a function that selects initial values for \(\theta\), and iterates between the E-step and M-step until convergence.
using SparseArrays
function marginal_log_lik(Y, ϕ, π)
= log.(π)
lpi = log.(ϕ)
lp = log.(1 .- ϕ)
lp1m
= 0
aux @views for i in 1:size(Y, 2)
+= logsumexp(lpi[k] + log_mvbernoulli_pmf(Y[:, i], lp[:, k], lp1m[:, k])
aux for k in 1:length(π))
end
return aux
end
function EM(Y, K; max_iter=1_000, tol=1e-5)
π = fill(1/K, K)
= rand(Float64, size(Y, 1), K)
ϕ = zeros(Float64, size(Y, 2), K)
γ
# The M-Step matrix multiplication can be greatly sped up using a
# sparse matrix representation of the binary data
= SparseMatrixCSC(Y)
sY
= log_lik_new = -Inf
log_lik_prev for i in 1:max_iter
E_step!(Y, γ, ϕ, π)
M_step!(sY, γ, ϕ, π)
= marginal_log_lik(Y, ϕ, π)
log_lik_new = abs(log_lik_new - log_lik_prev)
log_lik_diff @info "Iteration $(i): log-likelihood = $(log_lik_new)"
if log_lik_diff < tol
return ϕ, π, γ
end
= log_lik_new
log_lik_prev end
error("Model failed to converge")
end
But, we’re missing one final piece of the puzzle before we can finally run our model. We want to be able to recover the latent classification \(\hat{k}_i\) for each image \(i\) by finding the latent state with the highest posterior probability given the observed data and parameter vector \(\theta\).
\[\begin{equation} \hat{k}_i = \max\left(\left\{ p(Z_i = k | X_i, \theta^{(t)}) : k \in \{ 1, 2, \ldots, K \} \right\} \right) \end{equation}\]
These won’t necessarily map back to the original labels in the dataset. There are multiple different approaches to do this, but we’ll keep it simple and use the mode
.2
using StatsBase
function label_mapping(clusters, labels)
Dict(k => mode(labels[clusters .== k])
for k in 1:maximum(clusters) if k in clusters)
end
Once we have \(\hat{k}_i\) for \(i = 1, 2, \ldots, N\), we can calculate the proportion of observations correctly classified.
function check(Z_hat, labels)
= round(mean(Z_hat .== labels), digits=3)
correct println("Proportion correctly classified: $(correct)")
end
Running the Model
Depending on your computer this may take a significant amount of time.
# Enable logging to see iteration progress
using Logging
disable_logging(Logging.Info)
# Set the number of latent states to the total number of digits we're
# trying to classify
π, γ = EM(model_input, 10)
ϕ,
= map(argmax, eachrow(γ))
clusters = label_mapping(clusters, labels)
mapping = [mapping[i] for i in clusters]
Z_hat
check(Z_hat, labels)
Proportion correctly classified: 0.589
That’s quite underwhelming. To further investigate our model performance, we can plot the predicted probabilities, \(\phi_j\), for each pixel \(j\) belonging to a cluster class which will depict the corresponding “idealized” images.
= [Gray.(reshape(1 .- ϕ[:, i], 28, 28)') |> plot for i in 1:size(ϕ, 2)]
plts plot(plts..., layout = size(ϕ, 2), axis = false, ticks = false)
What’s happened is that clearly some digits are more difficult to model than others. Furthermore, some digits are being overpredicted. It’s important to note that the Bernoulli model does not guarantee that each latent class \(k\) will map one-to-one to the original dataset labels.
How can we improve this? We could for example increase the number of latent states in the model to account for the over predictions, or we switch to a more flexible model.
Gaussian Mixture Model(s)
Let’s model each image instead with a multivariate Gaussian using the normalized values for the pixels. We’re also going to take the opportunity to modify our approach in two different ways.
First, we’ll improve computational efficiency by reducing the dimensionality of our data through principle component analysis. This involves calculating the covariance matrix S given by:
\[\begin{equation} S = \cfrac{1}{N - 1}(X - \bar{X})'(X - \bar{X}) \end{equation}\]
and then performing an eigendecomposition \(S V = V \Lambda\) where \(V\) is a matrix of eigenvectors. Selecting \(D^\star\) eigenvectors where \(D^\star < D\), the new dataset, \(Y \in \mathbb{R}^{N \times D^\star}\), is formed by applying the projection matrix, \(P\), to the centered data.
\[\begin{aligned} P & = [v_1, v_2, \ldots, v_{D^\star}] \; \text{for} \; v_j \in V \\ Y & = (X - \bar{X}) P \end{aligned}\]using MultivariateStats
# Normally we would use a more robust criterion for selecting D*,
# but for simplicity we'll just (arbitrarily) set it at 50.
= fit(PCA, model_input, maxoutdim=50)
pca = transform(pca, model_input) reduced_data
Second, instead of fitting a single mixture model, we’ll fit \(M = 10\) separate models for each digit in the MNIST dataset. Each mixture model will be estimated with \(K\) latent states ideally corresponding to different “variants” of the same digit.3 Let \(L_i\) be the original MNIST label for image \(i\), then model \(m = \{ 1, 2, \ldots, M \}\) is formed as follows:
\[\begin{aligned} Y^{(m)} & = \{ Y_i : L_i = m \} \\ Y^{(m)}_i | Z^{(m)}_i & = k \sim \text{N}\left(\mu_{mk}, \Sigma_{mk} \right) \end{aligned}\]s.t.
\[\begin{equation} Q\left(\theta^{(t)}_m, \theta_m \right) = \sum_{i=1}^{N_m} \sum_{k=1}^{K} p\left(Z^{(m)}_i = k | Y^{(m)}, \theta^{(t)}_m\right) \left( \log \pi_{mk} + \log \text{N}(\mu_{mk}, \Sigma_{mk}) \right) \end{equation}\]
where \(N(\cdot)\) is a slight abuse of notation to denote the Normal probability density function.
E-Step
This time we’ll rely on the Distributions
package for the Normal probability density function.
using Distributions
function gaussians(μ, Σ)
@views [MvNormal(μ[:, i], Σ[:, :, i]) for i in 1:size(μ, 2)]
end
Again, we let \(\gamma\) denote \(P\left( Z^{(m)}_i = k | Y^{(m)}, \theta^{(t)}_m \right)\), updating it in place while working with log probabilities.
function E_step!(Y, γ, π, μ, Σ)
= gaussians(μ, Σ)
dists = log.(π)
lpi
for i in 1:size(Y, 2)
= [lpi[k] + logpdf(dists[k], view(Y, :, i)) for k in 1:length(π)]
lp = logsumexp(lp)
denominator
for k in 1:length(π)
@inbounds γ[i, k] = exp(lp[k] - denominator)
end
end
end
M-Step
We also employ a similar strategy as before for the M-Step. Solving for \(\frac{\partial Q(\theta^{(t)}_m, \theta_m)}{\partial \theta_m} = 0\) we find the updating equation for each parameter.
\[\begin{aligned} \pi_{mk} & = \cfrac{\sum_{i=1}^{N_m} p \left( Z_i^{(m)} = k | Y_i^{(m)}, \theta_m^{(t)} \right)}{N_m} \\ \mu_{mk} & = \cfrac{\sum_{i=1}^{N_m} Y^{(m)}_i p \left(Z^{(m)}_i = k | Y_i^{(m)}, \theta^{(t)}_m\right)}{\sum_{i=1}^{N_m} p \left(Z_i^{(m)} = k| Y_i^{(m)}, \theta_m^{(t)}\right)} \\ \Sigma_{mk} & = \cfrac{\sum_{i=1}^{N_m} p \left( Z_i^{(m)} = k | Y_i^{(m)}, \theta_m^{(t)}\right) \left( Y_i^{(m)} - \mu_{mk} \right) \left(Y_i^{(m)} - \mu_{mk} \right)'}{\sum_{i=1}^{N_m} p \left(Z_i^{(m)} = k | Y_i^{(m)}, \theta_m^{(t)}\right)} \end{aligned}\]function M_step!(Y, γ, π, μ, Σ)
= vec(sum(γ, dims=1))
cluster_sums π .= cluster_sums ./ size(Y, 2)
mul!(μ, Y, γ)
./= cluster_sums'
μ
@views for k in 1:length(π)
= Y .- μ[:, k]
Y_centered = zeros(size(Y, 1), size(Y, 1))
Σ_k for i in 1:size(Y, 2)
+= γ[i, k] * (Y_centered[:, i] * Y_centered[:, i]')
Σ_k end
:, :, k] .= Σ_k / cluster_sums[k]
Σ[end
end
EM Function
Our EM function for the Gaussian mixture model is largely the same as the Bernoulli version. We select initial values for \(\theta_m\) and iterate between the E-Step and M-Step until convergence.
function marginal_log_lik(Y, π, μ, Σ)
= MixtureModel(gaussians(μ, Σ), π)
d sum(logpdf(d, y) for y in eachcol(Y))
end
function EM(Y, K; max_iter=1_000, tol=1e-5)
π = fill(1/K, K)
= Y[:, rand(1:size(Y, 2), K)]
μ = cat([Diagonal(ones(size(Y, 1))) for k in 1:K]..., dims=3)
Σ = zeros(size(Y, 2), K)
γ
= log_lik_new = -Inf
log_lik_prev for i in 1:max_iter
E_step!(Y, γ, π, μ, Σ)
M_step!(Y, γ, π, μ, Σ)
= marginal_log_lik(Y, π, μ, Σ)
log_lik_new @info "Iteration $i: Log Likelihood = $log_lik_new"
if abs(log_lik_new - log_lik_prev) < tol
return π, μ, Σ, γ
end
= log_lik_new
log_lik_prev end
error("Model failed to converge")
end
Once we have fit our models we will need to be able to predict the cluster label given an image. Instead of recovering the latent classifications for the observations, we’ll assign our cluster label based on the model with the highest log-likelihood given the data.
\[\begin{equation} \hat{m}(Y^{\text{new}}) = \underset{m \in \{ 1, 2, \ldots, M \}}{\operatorname{arg max}} \ell\left(\theta_{m} | Y^{\text{new}}\right) \end{equation}\]
function predict(Y, models)
= [MixtureModel(gaussians(μ, Σ), π) for (π, μ, Σ, _) in models]
dists argmax(logpdf(d, y) for d in dists) - 1 for y in eachcol(Y)]
[end
Running the Model
To actually run our models we spawn separate threads to enable parallel processing. Be forewarned that this will probably consume a significant amount of memory.
using Base.Threads
# Again, enable logging to see iteration progress
disable_logging(Logging.Info)
# Fit a model for each digit with K = 4
= fetch.([@spawn EM(reduced_data[:, labels .== i], 4) for i in 0:9])
results
= predict(reduced_data, results)
Z_hat check(Z_hat, labels)
Proportion correctly classified: 0.973
This is a significant improvement over our initial Bernoulli model.
With the fitted models we can also check the prediction rate for the MNIST test dataset.
= MNIST(split=:test)[:]
test_pixels, test_labels = reshape(test_pixels, (28*28, 10_000))
test_input = transform(pca, test_input)
test_reduced
= predict(test_reduced, results)
Z_test check(Z_test, test_labels)
Proportion correctly classified: 0.969
And, plot how a specific model learns the different variants of the same digit. We’ll look at the model run on the subset of images that are supposed to be fives in MNIST and plot the expected value for the Gaussian associated with each latent classification.
= results[6]
digit = map(μ -> reconstruct(pca, μ), eachcol(digit[2]))
img = [plot(Gray.(reshape(1 .- i, 28, 28)')) for i in img]
plts plot(plts..., layout = (1, 4), axis = false, ticks = false)
Finally, we can also plot what the “average” digit looks like according to our models by plotting the expected value of the mixture for each digit.
= map(((π, μ, _, _),) -> reconstruct(pca, μ * π), results)
EY = [Gray.(reshape(1 .- μ, 28, 28)') |> plot for μ in EY]
plts plot(plts..., layout = length(EY), axis = false, ticks = false)
Can we do better? Probably. There’s still more than can be done regarding hyperparameter tuning, exploring different initialization strategies, and just general optimizations to the code, but that will be left as a future exercise.
If we were not to use log probabilities, consider what would happen to \(\prod_{j=1}^D \phi_{kj}^{X_{ij}} (1 - \phi_{kj}) ^{1 - X_{ij}}\) as \(D \to \infty\) and/or \(\phi_{kj} \to 0\) in regards to floating point error.↩︎
In case of ties, the
mode
function fromStatsBase
will return the first element.↩︎For simplicity, we’ll assume a constant number of latent states for each model. Obviously, a more advanced implementation would allow \(K\) to vary since some digits may be inherently more difficult to model although this would be at the cost of additional computational complexity.↩︎