Mixture Models and MCMC

Gaussian mixtures, label-switching, and identifiability in MCMC sampling.

bayesian
mcmc
sampling
models
Explores Gaussian mixture models with MCMC sampling using Old Faithful data and synthetic 3-component mixtures. Demonstrates label-switching problems, identifiability issues, and mitigations via ordered transforms and potentials.
Published

June 25, 2025

We now do a study of learning mixture models with MCMC. We have already done this in the case of the Zero-Inflated Poisson Model, and will stick to Gaussian Mixture models for now.

%matplotlib inline
import numpy as np
import scipy as sp
import matplotlib as mpl
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import pandas as pd
pd.set_option('display.width', 500)
pd.set_option('display.max_columns', 100)
pd.set_option('display.notebook_repr_html', True)
import seaborn as sns
sns.set_style('whitegrid')
sns.set_context('poster')
import pymc as pm
import pytensor.tensor as pt
import arviz as az

Mixture of 2 Gaussians, the old faithful data

We start by considering waiting times from the Old-Faithful Geyser at Yellowstone National Park.

ofdata=pd.read_csv("data/oldfaithful.csv")
ofdata.head()
eruptions waiting
0 3.600 79
1 1.800 54
2 3.333 74
3 2.283 62
4 4.533 85
sns.histplot(ofdata.waiting, kde=True);

Visually, there seem to be two components to the waiting time, so let us model this using a mixture of two gaussians. Remember that this is a unsupervized model, and all we are doing is modelling \(p(x)\) , with the assumption that there are two clusters and a hidden variable \(z\) that indexes them.

Notice that these gaussians seem well separated. The separation of gaussians impacts how your sampler will perform.

with pm.Model() as ofmodel:
    p1 = pm.Uniform('p', 0, 1)
    p2 = 1 - p1
    p = pt.stack([p1, p2])
    assignment = pm.Categorical("assignment", p, 
                                shape=ofdata.shape[0])
    sds = pm.Uniform("sds", 0, 40, shape=2)
    centers = pm.Normal("centers", 
                        mu=np.array([50, 80]), 
                        sigma=np.array([20, 20]), 
                        shape=2)
    
    # and to combine it with the observations:
    observations = pm.Normal("obs", mu=centers[assignment], sigma=sds[assignment], observed=ofdata.waiting)
with ofmodel:
    oftrace = pm.sample(10000, target_accept=0.95)
Multiprocess sampling (4 chains in 4 jobs)
CompoundStep
>NUTS: [p, sds, centers]
>BinaryGibbsMetropolis: [assignment]
/Users/rahul/Library/Caches/uv/archive-v0/dxNg8MqRfrVXkBgi9B0vB/lib/python3.14/site-packages/rich/live.py:260: 
UserWarning: install "ipywidgets" for Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1_000 tune and 10_000 draw iterations (4_000 + 40_000 draws total) took 77 seconds.
/Users/rahul/Library/Caches/uv/archive-v0/dxNg8MqRfrVXkBgi9B0vB/lib/python3.14/site-packages/arviz/stats/diagnostics.py:596: RuntimeWarning: invalid value encountered in scalar divide
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
pm.model_to_graphviz(ofmodel)

az.plot_trace(oftrace, var_names=["p", "centers", "sds"]);

az.summary(oftrace, var_names=["p", "centers", "sds"])
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
p 0.362 0.031 0.302 0.421 0.000 0.000 23423.0 27215.0 1.0
centers[0] 54.639 0.744 53.222 56.014 0.006 0.004 14494.0 19863.0 1.0
centers[1] 80.080 0.522 79.104 81.074 0.004 0.003 18198.0 21226.0 1.0
sds[0] 6.025 0.591 4.919 7.112 0.005 0.004 13170.0 17820.0 1.0
sds[1] 5.951 0.421 5.172 6.737 0.004 0.002 14442.0 20604.0 1.0
az.plot_autocorr(oftrace, var_names=['p', 'centers', 'sds']);

oftrace.posterior["centers"].mean(dim=["chain", "draw"]).values
array([54.63938979, 80.08034908])

We can visualize the two clusters, suitably scales by the category-belonging probability by taking the posterior means. Note that this misses any smearing that might go into making the posterior predictive

from scipy.stats import norm
x = np.linspace(20, 120, 500)
# for pretty colors later in the book.
colors = ["#348ABD", "#A60628"] if oftrace.posterior["centers"].values[0, -1, 0] > oftrace.posterior["centers"].values[0, -1, 1] \
    else ["#A60628", "#348ABD"]

posterior_center_means = oftrace.posterior["centers"].mean(dim=["chain", "draw"]).values
posterior_std_means = oftrace.posterior["sds"].mean(dim=["chain", "draw"]).values
posterior_p_mean = oftrace.posterior["p"].mean(dim=["chain", "draw"]).values.item()

plt.hist(ofdata.waiting, bins=20, histtype="step", density=True, color="k",
     lw=2, label="histogram of data")
y = posterior_p_mean * norm.pdf(x, loc=posterior_center_means[0],
                                scale=posterior_std_means[0])
plt.plot(x, y, label="Cluster 0 (using posterior-mean parameters)", lw=3)
plt.fill_between(x, y, color=colors[1], alpha=0.3)

y = (1 - posterior_p_mean) * norm.pdf(x, loc=posterior_center_means[1],
                                      scale=posterior_std_means[1])
plt.plot(x, y, label="Cluster 1 (using posterior-mean parameters)", lw=3)
plt.fill_between(x, y, color=colors[0], alpha=0.3)

plt.legend(loc="upper left")
plt.title("Visualizing Clusters using posterior-mean parameters");

A tetchy 3 Gaussian Model

Let us set up our data. Our analysis here follows that of https://colindcarroll.com/2018/07/20/why-im-excited-about-pymc3-v3.5.0/ , and we have chosen 3 gaussians reasonably close to each other to show the problems that arise!

mu_true = np.array([-2, 0, 2])
sigma_true = np.array([1, 1, 1])
lambda_true = np.array([1/3, 1/3, 1/3])
n = 100
from scipy.stats import multinomial
# Simulate from each distribution according to mixing proportion psi
z = multinomial.rvs(1, lambda_true, size=n)
data=np.array([np.random.normal(mu_true[i.astype('bool')][0], sigma_true[i.astype('bool')][0]) for i in z])
sns.histplot(data, bins=50, kde=True);

np.savetxt("data/3gv2.dat", data)
with pm.Model() as mof:
    #p = pm.Dirichlet('p', a=np.array([1., 1., 1.]), shape=3)
    p=[1/3, 1/3, 1/3]

    # cluster centers
    means = pm.Normal('means', mu=0, sigma=10, shape=3)


    #sds = pm.HalfCauchy('sds', 5, shape=3)
    sds = np.array([1., 1., 1.])
    
    # latent cluster of each observation
    category = pm.Categorical('category',
                              p=p,
                              shape=data.shape[0])

    # likelihood for each observed value
    points = pm.Normal('obs',
                       mu=means[category],
                       sigma=1., #sds[category],
                       observed=data)
with mof:
    trace_mof = pm.sample(10000)
Multiprocess sampling (4 chains in 4 jobs)
CompoundStep
>NUTS: [means]
>CategoricalGibbsMetropolis: [category]
/Users/rahul/Library/Caches/uv/archive-v0/dxNg8MqRfrVXkBgi9B0vB/lib/python3.14/site-packages/rich/live.py:260: 
UserWarning: install "ipywidgets" for Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1_000 tune and 10_000 draw iterations (4_000 + 40_000 draws total) took 27 seconds.
There were 4 divergences after tuning. Increase `target_accept` or reparameterize.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
az.plot_trace(trace_mof, var_names=["means"], combined=True);

az.plot_autocorr(trace_mof, var_names=['means']);

Problems with clusters and sampling

Some of the traces seem ok, but the autocorrelation is quite bad. And there is label-switching .This is because there are major problems with using MCMC for clustering.

AND THIS IS WITHOUT MODELING \(p\) OR \(\sigma\). It gets much worse otherwise! (it would be better if the gaussians were quite widely separated out).

These are firstly, the lack of parameter identifiability (the so called label-switching problem) and secondly, the multimodality of the posteriors.

We have seen non-identifiability before. Switching labels on the means and z’s, for example, does not change the likelihoods. The problem with this is that cluster parameters cannot be compared across chains: what might be a cluster parameter in one chain could well belong to the other cluster in the second chain. Even within a single chain, indices might swap leading to a telltale back and forth in the traces for long chains or data not cleanly separated.

Also, the (joint) posteriors can be highly multimodal. One form of multimodality is the non-identifiability, though even without identifiability issues the posteriors are highly multimodal.

To quote the Stan manual: >Bayesian inference fails in cases of high multimodality because there is no way to visit all of the modes in the posterior in appropriate proportions and thus no way to evaluate integrals involved in posterior predictive inference. In light of these two problems, the advice often given in fitting clustering models is to try many different initializations and select the sample with the highest overall probability. It is also popular to use optimization-based point estimators such as expectation maximization or variational Bayes, which can be much more efficient than sampling-based approaches.

Some mitigation via ordering in pymc3

But this is not a panacea. Sampling is still very hard.

with pm.Model() as mof2:
    
    p = [1/3, 1/3, 1/3]

    # cluster centers
    means = pm.Normal('means', mu=0, sigma=10, shape=3,
                  transform=pm.distributions.transforms.ordered,
                  initval=np.array([-1, 0, 1]))


                                         
    # measurement error
    #sds = pm.Uniform('sds', lower=0, upper=20, shape=3)

    # latent cluster of each observation
    category = pm.Categorical('category',
                              p=p,
                              shape=data.shape[0])

    # likelihood for each observed value
    points = pm.Normal('obs',
                       mu=means[category],
                       sigma=1., #sds[category],
                       observed=data)
with mof2:
    trace_mof2 = pm.sample(10000)
Multiprocess sampling (4 chains in 4 jobs)
CompoundStep
>NUTS: [means]
>CategoricalGibbsMetropolis: [category]
/Users/rahul/Library/Caches/uv/archive-v0/dxNg8MqRfrVXkBgi9B0vB/lib/python3.14/site-packages/rich/live.py:260: 
UserWarning: install "ipywidgets" for Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1_000 tune and 10_000 draw iterations (4_000 + 40_000 draws total) took 33 seconds.
There were 906 divergences after tuning. Increase `target_accept` or reparameterize.
az.plot_trace(trace_mof2, var_names=["means"], combined=True);

az.plot_autocorr(trace_mof2, var_names=['means']);

Full sampling is horrible, even with potentials

Now lets put Dirichlet based (and this is a strongly centering Dirichlet) prior on the probabilities

from scipy.stats import dirichlet
ds = dirichlet(alpha=[10,10,10]).rvs(1000)
"""
Visualize points on the 3-simplex (eg, the parameters of a
3-dimensional multinomial distributions) as a scatter plot 
contained within a 2D triangle.
David Andrzejewski (david.andrzej@gmail.com)
"""
import numpy as NP
import matplotlib.pyplot as P
import matplotlib.ticker as MT
import matplotlib.lines as L
import matplotlib.cm as CM
import matplotlib.colors as C
import matplotlib.patches as PA

def plotSimplex(points, fig=None, 
                vertexlabels=['1','2','3'],
                **kwargs):
    """
    Plot Nx3 points array on the 3-simplex 
    (with optionally labeled vertices) 
    
    kwargs will be passed along directly to matplotlib.pyplot.scatter    
    Returns Figure, caller must .show()
    """
    if(fig == None):        
        fig = P.figure()
    # Draw the triangle
    l1 = L.Line2D([0, 0.5, 1.0, 0], # xcoords
                  [0, NP.sqrt(3) / 2, 0, 0], # ycoords
                  color='k')
    fig.gca().add_line(l1)
    fig.gca().xaxis.set_major_locator(MT.NullLocator())
    fig.gca().yaxis.set_major_locator(MT.NullLocator())
    # Draw vertex labels
    fig.gca().text(-0.05, -0.05, vertexlabels[0])
    fig.gca().text(1.05, -0.05, vertexlabels[1])
    fig.gca().text(0.5, NP.sqrt(3) / 2 + 0.05, vertexlabels[2])
    # Project and draw the actual points
    projected = projectSimplex(points)
    P.scatter(projected[:,0], projected[:,1], **kwargs)              
    # Leave some buffer around the triangle for vertex labels
    fig.gca().set_xlim(-0.2, 1.2)
    fig.gca().set_ylim(-0.2, 1.2)

    return fig    

def projectSimplex(points):
    """ 
    Project probabilities on the 3-simplex to a 2D triangle
    
    N points are given as N x 3 array
    """
    # Convert points one at a time
    tripts = NP.zeros((points.shape[0],2))
    for idx in range(points.shape[0]):
        # Init to triangle centroid
        x = 1.0 / 2
        y = 1.0 / (2 * NP.sqrt(3))
        # Vector 1 - bisect out of lower left vertex 
        p1 = points[idx, 0]
        x = x - (1.0 / NP.sqrt(3)) * p1 * NP.cos(NP.pi / 6)
        y = y - (1.0 / NP.sqrt(3)) * p1 * NP.sin(NP.pi / 6)
        # Vector 2 - bisect out of lower right vertex  
        p2 = points[idx, 1]  
        x = x + (1.0 / NP.sqrt(3)) * p2 * NP.cos(NP.pi / 6)
        y = y - (1.0 / NP.sqrt(3)) * p2 * NP.sin(NP.pi / 6)        
        # Vector 3 - bisect out of top vertex
        p3 = points[idx, 2]
        y = y + (1.0 / NP.sqrt(3) * p3)
      
        tripts[idx,:] = (x,y)

    return tripts
plotSimplex(ds, s=20);

The idea behind a Potential is something that is not part of the likelihood, but enforces a constraint by setting the probability to 0 if the constraint is violated. We use it here to give each cluster some membership and to order the means to remove the non-identifiability problem. See below for how its used.

The sampler below has a lot of problems.

with pm.Model() as mofb:
    p = pm.Dirichlet('p', a=np.array([10., 10., 10.]), shape=3)
    # ensure all clusters have some points
    p_min_potential = pm.Potential('p_min_potential', pt.switch(pt.min(p) < .1, -np.inf, 0))
    # cluster centers
    means = pm.Normal('means', mu=0, sigma=10, shape=3, transform=pm.distributions.transforms.ordered,
                  initval=np.array([-1, 0, 1]))

    category = pm.Categorical('category',
                              p=p,
                              shape=data.shape[0])

    # likelihood for each observed value
    points = pm.Normal('obs',
                       mu=means[category],
                       sigma=1., #sds[category],
                       observed=data)
with mofb:
    trace_mofb = pm.sample(10000, target_accept=0.95)
Multiprocess sampling (4 chains in 4 jobs)
CompoundStep
>NUTS: [p, means]
>CategoricalGibbsMetropolis: [category]
/Users/rahul/Library/Caches/uv/archive-v0/dxNg8MqRfrVXkBgi9B0vB/lib/python3.14/site-packages/rich/live.py:260: 
UserWarning: install "ipywidgets" for Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1_000 tune and 10_000 draw iterations (4_000 + 40_000 draws total) took 60 seconds.
There were 33 divergences after tuning. Increase `target_accept` or reparameterize.
az.plot_trace(trace_mofb, var_names=["means", "p"], combined=True);

az.summary(trace_mofb, var_names=["means", "p"])
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
means[0] -1.713 0.288 -2.268 -1.195 0.005 0.003 2970.0 6546.0 1.0
means[1] -0.248 0.435 -1.105 0.552 0.011 0.006 1694.0 2841.0 1.0
means[2] 1.967 0.320 1.373 2.562 0.005 0.002 4139.0 10717.0 1.0
p[0] 0.351 0.079 0.204 0.501 0.002 0.001 2244.0 5286.0 1.0
p[1] 0.355 0.078 0.215 0.507 0.001 0.001 2946.0 7071.0 1.0
p[2] 0.293 0.061 0.180 0.408 0.001 0.000 3595.0 8348.0 1.0

Making Problems go away

A lot will go away when identifiability improves through separated gaussians. But that changes the data. If we want any further improvement on this data, we are going to have to stop sampling so many discrete categoricals. And for that we will need a marginalization trick.