Poisson Regression — Model Comparison and Hierarchical Overdispersion

WAIC model comparison, ensemble averaging, and varying-intercepts for overdispersed counts.

bayesian
regression
hierarchical
models
We return to the oceanic tools dataset to illustrate model comparison using WAIC, model averaging with Akaike weights, counterfactual posterior predictives, and fighting overdispersion with a hierarchical varying-intercepts Poisson regression.
Published

May 28, 2025

We go back to our island tools data set to illustrate

%matplotlib inline
import numpy as np
import scipy as sp
import matplotlib as mpl
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import pandas as pd
pd.set_option('display.width', 500)
pd.set_option('display.max_columns', 100)
pd.set_option('display.notebook_repr_html', True)
import seaborn as sns
sns.set_style("whitegrid")
sns.set_context("poster")
import pymc as pm
import arviz as az
import pytensor.tensor as pt
df=pd.read_csv("data/islands.csv", sep=';')
df
culture population contact total_tools mean_TU
0 Malekula 1100 low 13 3.2
1 Tikopia 1500 low 22 4.7
2 Santa Cruz 3600 low 24 4.0
3 Yap 4791 high 43 5.0
4 Lau Fiji 7400 high 33 5.0
5 Trobriand 8000 high 19 4.0
6 Chuuk 9200 high 40 3.8
7 Manus 13000 low 28 6.6
8 Tonga 17500 high 55 5.4
9 Hawaii 275000 low 71 6.6
df['logpop']=np.log(df.population)
df['clevel']=(df.contact=='high')*1
df
culture population contact total_tools mean_TU logpop clevel
0 Malekula 1100 low 13 3.2 7.003065 0
1 Tikopia 1500 low 22 4.7 7.313220 0
2 Santa Cruz 3600 low 24 4.0 8.188689 0
3 Yap 4791 high 43 5.0 8.474494 1
4 Lau Fiji 7400 high 33 5.0 8.909235 1
5 Trobriand 8000 high 19 4.0 8.987197 1
6 Chuuk 9200 high 40 3.8 9.126959 1
7 Manus 13000 low 28 6.6 9.472705 0
8 Tonga 17500 high 55 5.4 9.769956 1
9 Hawaii 275000 low 71 6.6 12.524526 0
def postscat(idata, thevars):
    d={}
    for v in thevars:
        d[v] = idata.posterior[v].values.flatten()
    df = pd.DataFrame.from_dict(d)
    g = sns.pairplot(df, diag_kind="kde", plot_kws={'s':10})
    for i, j in zip(*np.triu_indices_from(g.axes, 1)):
        g.axes[i, j].set_visible(False)
    return g

Centered Model

As usual, centering the log-population fixes things:

df.logpop_c = df.logpop - df.logpop.mean()
/var/folders/wq/mr3zj9r14dzgjnq9rjx_vqbc0000gn/T/ipykernel_99269/630745569.py:1: UserWarning: Pandas doesn't allow columns to be created via a new attribute name - see https://pandas.pydata.org/pandas-docs/stable/indexing.html#attribute-access
  df.logpop_c = df.logpop - df.logpop.mean()
with pm.Model() as m1c:
    betap = pm.Normal("betap", 0, 1)
    betac = pm.Normal("betac", 0, 1)
    betapc = pm.Normal("betapc", 0, 1)
    alpha = pm.Normal("alpha", 0, 100)
    loglam = alpha + betap*df.logpop_c + betac*df.clevel + betapc*df.clevel*df.logpop_c
    y = pm.Poisson("ntools", mu=pt.exp(loglam), observed=df.total_tools)
with m1c:
    trace1c = pm.sample(5000, tune=1000, idata_kwargs={"log_likelihood": True})
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [betap, betac, betapc, alpha]
/Users/rahul/Library/Caches/uv/archive-v0/aTiHGxSE8gD8G3bEQyxJO/lib/python3.14/site-packages/rich/live.py:260: 
UserWarning: install "ipywidgets" for Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1_000 tune and 5_000 draw iterations (4_000 + 20_000 draws total) took 8 seconds.
az.plot_trace(trace1c);

az.plot_autocorr(trace1c);

az.ess(trace1c)
<xarray.Dataset> Size: 32B
Dimensions:  ()
Data variables:
    betap    float64 8B 1.369e+04
    betac    float64 8B 1.156e+04
    betapc   float64 8B 1.537e+04
    alpha    float64 8B 1.158e+04
Attributes:
    created_at:                 2026-03-07T16:32:55.888467+00:00
    arviz_version:              0.23.4
    inference_library:          pymc
    inference_library_version:  5.28.1
    sampling_time:              8.113940954208374
    tuning_steps:               1000
postscat(trace1c, ["betap", "betac", "betapc", "alpha"]);

az.plot_posterior(trace1c);

Model comparison for interaction significance

This is an example of feature selection, where we want to decide whether we should keep the interaction term or not, that is, whether the interaction is significant or not? We’ll use model comparison to achieve this!

We can see some summary stats from this model:

dfsum=az.summary(trace1c)
dfsum
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
betap 0.263 0.035 0.196 0.329 0.000 0.000 13689.0 14213.0 1.0
betac 0.286 0.119 0.074 0.517 0.001 0.001 11563.0 12625.0 1.0
betapc 0.065 0.170 -0.249 0.395 0.001 0.001 15371.0 13990.0 1.0
alpha 3.311 0.091 3.143 3.483 0.001 0.001 11577.0 12893.0 1.0
# pm.dic is removed in modern pymc; use WAIC or LOO instead
az.waic(trace1c)
/Users/rahul/Library/Caches/uv/archive-v0/aTiHGxSE8gD8G3bEQyxJO/lib/python3.14/site-packages/arviz/stats/stats.py:1652: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. 
See http://arxiv.org/abs/1507.04544 for details
  warnings.warn(
Computed from 20000 posterior samples and 10 observations log-likelihood matrix.

          Estimate       SE
elpd_waic   -41.98     6.06
p_waic        6.98        -

There has been a warning during the calculation. Please check the results.
az.waic(trace1c)
/Users/rahul/Library/Caches/uv/archive-v0/aTiHGxSE8gD8G3bEQyxJO/lib/python3.14/site-packages/arviz/stats/stats.py:1652: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. 
See http://arxiv.org/abs/1507.04544 for details
  warnings.warn(
Computed from 20000 posterior samples and 10 observations log-likelihood matrix.

          Estimate       SE
elpd_waic   -41.98     6.06
p_waic        6.98        -

There has been a warning during the calculation. Please check the results.

Sampling from multiple different centered models

(A) Our complete model

(B) A model with no interaction

with pm.Model() as m2c_nopc:
    betap = pm.Normal("betap", 0, 1)
    betac = pm.Normal("betac", 0, 1)
    alpha = pm.Normal("alpha", 0, 100)
    loglam = alpha + betap*df.logpop_c + betac*df.clevel
    y = pm.Poisson("ntools", mu=pt.exp(loglam), observed=df.total_tools)
    trace2c_nopc = pm.sample(5000, tune=1000, idata_kwargs={"log_likelihood": True})
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [betap, betac, alpha]
/Users/rahul/Library/Caches/uv/archive-v0/aTiHGxSE8gD8G3bEQyxJO/lib/python3.14/site-packages/rich/live.py:260: 
UserWarning: install "ipywidgets" for Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1_000 tune and 5_000 draw iterations (4_000 + 20_000 draws total) took 6 seconds.

(C) A model with no contact term

with pm.Model() as m2c_onlyp:
    betap = pm.Normal("betap", 0, 1)
    alpha = pm.Normal("alpha", 0, 100)
    loglam = alpha + betap*df.logpop_c
    y = pm.Poisson("ntools", mu=pt.exp(loglam), observed=df.total_tools)
    trace2c_onlyp = pm.sample(5000, tune=1000, idata_kwargs={"log_likelihood": True})
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [betap, alpha]
/Users/rahul/Library/Caches/uv/archive-v0/aTiHGxSE8gD8G3bEQyxJO/lib/python3.14/site-packages/rich/live.py:260: 
UserWarning: install "ipywidgets" for Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1_000 tune and 5_000 draw iterations (4_000 + 20_000 draws total) took 5 seconds.

(D) A model with only the contact term

with pm.Model() as m2c_onlyc:
    betac = pm.Normal("betac", 0, 1)
    alpha = pm.Normal("alpha", 0, 100)
    loglam = alpha +  betac*df.clevel
    y = pm.Poisson("ntools", mu=pt.exp(loglam), observed=df.total_tools)
    trace2c_onlyc = pm.sample(5000, tune=1000, idata_kwargs={"log_likelihood": True})
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [betac, alpha]
/Users/rahul/Library/Caches/uv/archive-v0/aTiHGxSE8gD8G3bEQyxJO/lib/python3.14/site-packages/rich/live.py:260: 
UserWarning: install "ipywidgets" for Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1_000 tune and 5_000 draw iterations (4_000 + 20_000 draws total) took 6 seconds.

(E) A model with only the intercept.

with pm.Model() as m2c_onlyic:
    alpha = pm.Normal("alpha", 0, 100)
    loglam = alpha
    y = pm.Poisson("ntools", mu=pt.exp(loglam), observed=df.total_tools)
    trace2c_onlyic = pm.sample(5000, tune=1000, idata_kwargs={"log_likelihood": True})
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [alpha]
/Users/rahul/Library/Caches/uv/archive-v0/aTiHGxSE8gD8G3bEQyxJO/lib/python3.14/site-packages/rich/live.py:260: 
UserWarning: install "ipywidgets" for Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1_000 tune and 5_000 draw iterations (4_000 + 20_000 draws total) took 3 seconds.

We create a dictionary from these models and their traces, so that we can track the names as well

modeldict = {
    "m1c": {"idata": trace1c, "model": m1c},
    "m2c_nopc": {"idata": trace2c_nopc, "model": m2c_nopc},
    "m2c_onlyp": {"idata": trace2c_onlyp, "model": m2c_onlyp},
    "m2c_onlyc": {"idata": trace2c_onlyc, "model": m2c_onlyc},
    "m2c_onlyic": {"idata": trace2c_onlyic, "model": m2c_onlyic},
}
# Compute log_likelihood for each model so az.compare can use it
compare_dict = {}
for name, d in modeldict.items():
    idata = d["idata"]
    model = d["model"]
    if not hasattr(idata, "log_likelihood"):
        pm.compute_log_likelihood(idata, model=model)
    compare_dict[name] = idata

Comparing the models using WAIC

Finally we use az.compare to create a dataframe of comparisons.

comparedf = az.compare(compare_dict, ic="waic", method="pseudo-BMA")
comparedf.head()
/Users/rahul/Library/Caches/uv/archive-v0/aTiHGxSE8gD8G3bEQyxJO/lib/python3.14/site-packages/arviz/stats/stats.py:1652: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. 
See http://arxiv.org/abs/1507.04544 for details
  warnings.warn(
/Users/rahul/Library/Caches/uv/archive-v0/aTiHGxSE8gD8G3bEQyxJO/lib/python3.14/site-packages/arviz/stats/stats.py:1652: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. 
See http://arxiv.org/abs/1507.04544 for details
  warnings.warn(
/Users/rahul/Library/Caches/uv/archive-v0/aTiHGxSE8gD8G3bEQyxJO/lib/python3.14/site-packages/arviz/stats/stats.py:1652: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. 
See http://arxiv.org/abs/1507.04544 for details
  warnings.warn(
/Users/rahul/Library/Caches/uv/archive-v0/aTiHGxSE8gD8G3bEQyxJO/lib/python3.14/site-packages/arviz/stats/stats.py:1652: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. 
See http://arxiv.org/abs/1507.04544 for details
  warnings.warn(
/Users/rahul/Library/Caches/uv/archive-v0/aTiHGxSE8gD8G3bEQyxJO/lib/python3.14/site-packages/arviz/stats/stats.py:1652: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. 
See http://arxiv.org/abs/1507.04544 for details
  warnings.warn(
rank elpd_waic p_waic elpd_diff weight se dse warning scale
m2c_nopc 0 -39.493347 4.195777 0.000000 8.694607e-01 5.528580 0.000000 True log
m1c 1 -41.976041 6.978613 2.482694 7.261555e-02 6.064680 1.834442 True log
m2c_onlyp 2 -42.202093 3.730087 2.708746 5.792372e-02 4.469968 3.961737 True log
m2c_onlyic 3 -70.740050 8.273443 31.246703 2.338725e-14 15.812725 16.368695 True log
m2c_onlyc 4 -75.175479 16.730025 35.682132 2.771385e-16 22.393098 22.251583 True log
# comparedf already has model names as index from az.compare
comparedf
rank elpd_waic p_waic elpd_diff weight se dse warning scale
m2c_nopc 0 -39.493347 4.195777 0.000000 8.694607e-01 5.528580 0.000000 True log
m1c 1 -41.976041 6.978613 2.482694 7.261555e-02 6.064680 1.834442 True log
m2c_onlyp 2 -42.202093 3.730087 2.708746 5.792372e-02 4.469968 3.961737 True log
m2c_onlyic 3 -70.740050 8.273443 31.246703 2.338725e-14 15.812725 16.368695 True log
m2c_onlyc 4 -75.175479 16.730025 35.682132 2.771385e-16 22.393098 22.251583 True log

From McElreath, here is how to read this table:

  1. WAIC is obviously WAIC for each model. Smaller WAIC indicates better estimated out-of-sample deviance.
  1. pWAIC is the estimated effective number of parameters. This provides a clue as to how flexible each model is in fitting the sample.
  1. dWAIC is the difference between each WAIC and the lowest WAIC. Since only relative deviance matters, this column shows the differences in relative fashion.
  1. weight is the AKAIKE WEIGHT for each model. These values are transformed information criterion values. I’ll explain them below.
  1. SE is the standard error of the WAIC estimate. WAIC is an estimate, and provided the sample size N is large enough, its uncertainty will be well approximated by its standard error. So this SE value isn’t necessarily very precise, but it does provide a check against overconfidence in differences between WAIC values.
  1. dSE is the standard error of the difference in WAIC between each model and the top-ranked model. So it is missing for the top model.

The weight for a model i in a set of m models is given by:

\[w_i = \frac{exp(-\frac{1}{2}dWAIC_i)}{\sum_j exp(-\frac{1}{2}dWAIC_j)}\]

The Akaike weight formula might look rather odd, but really all it is doing is putting WAIC on a probability scale, so it just undoes the multiplication by −2 and then exponentiates to reverse the log transformation. Then it standardizes by dividing by the total. So each weight will be a number from 0 to 1, and the weights together always sum to 1. Now larger values are better.

But what do these weights mean?

Akaike’s interpretation:

A model’s weight is an estimate of the probability that the model will make the best predictions on new data, conditional on the set of models considered…the Akaike weights are analogous to posterior probabilities of models, conditional on expected future data.

So you can heuristically read each weight as an estimated probability that each model will perform best on future data. In simulation at least, interpreting weights in this way turns out to be appropriate. (McElreath 199-200)

We can make visual comparison plots in the style of McElreath’s book. We can see that all the weight is in the no-interaction, full, and only log(population) models.

az.plot_compare(comparedf)

Comparing for non-centered models

We can redo the coparison for non-centered models

with pm.Model() as m1:
    betap = pm.Normal("betap", 0, 1)
    betac = pm.Normal("betac", 0, 1)
    betapc = pm.Normal("betapc", 0, 1)
    alpha = pm.Normal("alpha", 0, 100)
    loglam = alpha + betap*df.logpop + betac*df.clevel + betapc*df.clevel*df.logpop
    y = pm.Poisson("ntools", mu=pt.exp(loglam), observed=df.total_tools)
    trace1 = pm.sample(10000, tune=2000, idata_kwargs={"log_likelihood": True})
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [betap, betac, betapc, alpha]
/Users/rahul/Library/Caches/uv/archive-v0/aTiHGxSE8gD8G3bEQyxJO/lib/python3.14/site-packages/rich/live.py:260: 
UserWarning: install "ipywidgets" for Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 2_000 tune and 10_000 draw iterations (8_000 + 40_000 draws total) took 17 seconds.
with pm.Model() as m2_onlyp:
    betap = pm.Normal("betap", 0, 1)
    alpha = pm.Normal("alpha", 0, 100)
    loglam = alpha + betap*df.logpop
    y = pm.Poisson("ntools", mu=pt.exp(loglam), observed=df.total_tools)
    trace2_onlyp = pm.sample(10000, tune=2000, idata_kwargs={"log_likelihood": True})
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [betap, alpha]
/Users/rahul/Library/Caches/uv/archive-v0/aTiHGxSE8gD8G3bEQyxJO/lib/python3.14/site-packages/rich/live.py:260: 
UserWarning: install "ipywidgets" for Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 2_000 tune and 10_000 draw iterations (8_000 + 40_000 draws total) took 6 seconds.
with pm.Model() as m2_nopc:
    betap = pm.Normal("betap", 0, 1)
    betac = pm.Normal("betac", 0, 1)
    alpha = pm.Normal("alpha", 0, 100)
    loglam = alpha + betap*df.logpop + betac*df.clevel
    y = pm.Poisson("ntools", mu=pt.exp(loglam), observed=df.total_tools)
    trace2_nopc = pm.sample(10000, tune=2000, idata_kwargs={"log_likelihood": True})
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [betap, betac, alpha]
/Users/rahul/Library/Caches/uv/archive-v0/aTiHGxSE8gD8G3bEQyxJO/lib/python3.14/site-packages/rich/live.py:260: 
UserWarning: install "ipywidgets" for Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 2_000 tune and 10_000 draw iterations (8_000 + 40_000 draws total) took 9 seconds.
modeldict2 = {
    "m1": {"idata": trace1, "model": m1},
    "m2_nopc": {"idata": trace2_nopc, "model": m2_nopc},
    "m2_onlyp": {"idata": trace2_onlyp, "model": m2_onlyp},
    "m2_onlyc": {"idata": trace2c_onlyc, "model": m2c_onlyc},
    "m2_onlyic": {"idata": trace2c_onlyic, "model": m2c_onlyic},
}
compare_dict2 = {}
for name, d in modeldict2.items():
    idata = d["idata"]
    model = d["model"]
    if not hasattr(idata, "log_likelihood"):
        pm.compute_log_likelihood(idata, model=model)
    compare_dict2[name] = idata
comparedf2 = az.compare(compare_dict2, ic="waic", method="pseudo-BMA")
comparedf2
/Users/rahul/Library/Caches/uv/archive-v0/aTiHGxSE8gD8G3bEQyxJO/lib/python3.14/site-packages/arviz/stats/stats.py:1652: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. 
See http://arxiv.org/abs/1507.04544 for details
  warnings.warn(
/Users/rahul/Library/Caches/uv/archive-v0/aTiHGxSE8gD8G3bEQyxJO/lib/python3.14/site-packages/arviz/stats/stats.py:1652: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. 
See http://arxiv.org/abs/1507.04544 for details
  warnings.warn(
/Users/rahul/Library/Caches/uv/archive-v0/aTiHGxSE8gD8G3bEQyxJO/lib/python3.14/site-packages/arviz/stats/stats.py:1652: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. 
See http://arxiv.org/abs/1507.04544 for details
  warnings.warn(
/Users/rahul/Library/Caches/uv/archive-v0/aTiHGxSE8gD8G3bEQyxJO/lib/python3.14/site-packages/arviz/stats/stats.py:1652: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. 
See http://arxiv.org/abs/1507.04544 for details
  warnings.warn(
/Users/rahul/Library/Caches/uv/archive-v0/aTiHGxSE8gD8G3bEQyxJO/lib/python3.14/site-packages/arviz/stats/stats.py:1652: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. 
See http://arxiv.org/abs/1507.04544 for details
  warnings.warn(
rank elpd_waic p_waic elpd_diff weight se dse warning scale
m2_nopc 0 -39.574344 4.268525 0.000000 6.002391e-01 5.533746 0.000000 True log
m1 1 -40.096776 4.906833 0.522432 3.559876e-01 5.627265 0.579995 True log
m2_onlyp 2 -42.192647 3.725374 2.618303 4.377334e-02 4.460087 3.965167 True log
m2_onlyic 3 -70.740050 8.273443 31.165706 1.750773e-14 15.812725 16.357640 True log
m2_onlyc 4 -75.175479 16.730025 35.601135 2.074663e-16 22.393098 22.238972 True log

What we find now is that the full-model has much more weight.

az.plot_compare(comparedf2)

In either the centered or non-centered case, our top model excludes the interaction, but the second top model includes it. In the centered case, the non-interacting model has most of the weight, while in the non-centered model, the weights were more equally shared.

In a situation where the interaction model has so much weight, we can say its probably overfit. So in a sense, centering even helps us with our overfitting issues by clearly preferring the sans-interaction model, as it removes correlation and thus spurious weight being borrowed.

Computing the (counterfactual) posterior predictive for checking

We now write some code to compute the posterior predictive at arbitrary points without having to use pytensor shared variables and sample_posterior_predictive, in two different counterfactual situations of low contact and high contact. Since some of our models omit certain terms, we use traces with 0s in them to construct a general function to do this.

def trace_or_zero(idata, name):
    if name in idata.posterior:
        return idata.posterior[name].values.flatten()
    else:
        nsamples = idata.posterior.sizes["chain"] * idata.posterior.sizes["draw"]
        return np.zeros(nsamples)
# Number of total samples = chains * draws
nsamples = trace1c.posterior.sizes["chain"] * trace1c.posterior.sizes["draw"]
nsamples, trace1c.posterior['alpha'].values.flatten().shape[0]
(20000, 20000)
from scipy.stats import poisson
def compute_pp(lpgrid, idata, contact=0):
    alphatrace = trace_or_zero(idata, 'alpha')
    betaptrace = trace_or_zero(idata, 'betap')
    betactrace = trace_or_zero(idata, 'betac')
    betapctrace = trace_or_zero(idata, 'betapc')
    tl = len(alphatrace)
    gl = lpgrid.shape[0]
    lam = np.empty((gl, tl))
    lpgrid_c = lpgrid - lpgrid.mean()
    for i, v in enumerate(lpgrid):
        temp = alphatrace + betaptrace*lpgrid_c[i] + betactrace*contact + betapctrace*contact*lpgrid_c[i]
        lam[i,:] = poisson.rvs(np.exp(temp))
    return lam

We compute the posterior predictive in the counterfactual cases: remember what we are doing here is turning on and off a feature.

lpgrid = np.linspace(6,13,30)
pplow = compute_pp(lpgrid, trace1c)
pphigh = compute_pp(lpgrid, trace1c, contact=1)

We compute the medians and the hpds, and plot these against the data

pplowmed = np.median(pplow, axis=1)
pplowhpd = az.hdi(pplow.T)
pphighmed = np.median(pphigh, axis=1)
pphighhpd = az.hdi(pphigh.T)
/var/folders/wq/mr3zj9r14dzgjnq9rjx_vqbc0000gn/T/ipykernel_99269/375403103.py:2: FutureWarning: hdi currently interprets 2d data as (draw, shape) but this will change in a future release to (chain, draw) for coherence with other functions
  pplowhpd = az.hdi(pplow.T)
/var/folders/wq/mr3zj9r14dzgjnq9rjx_vqbc0000gn/T/ipykernel_99269/375403103.py:4: FutureWarning: hdi currently interprets 2d data as (draw, shape) but this will change in a future release to (chain, draw) for coherence with other functions
  pphighhpd = az.hdi(pphigh.T)
with sns.plotting_context('poster'):
    plt.plot(df[df['clevel']==1].logpop, df[df['clevel']==1].total_tools,'.', color="g")
    plt.plot(df[df['clevel']==0].logpop, df[df['clevel']==0].total_tools,'.', color="r")
    plt.plot(lpgrid, pphighmed, color="g", label="c=1")
    plt.fill_between(lpgrid, pphighhpd[:,0], pphighhpd[:,1], color="g", alpha=0.2)
    plt.plot(lpgrid, pplowmed, color="r", label="c=0")
    plt.fill_between(lpgrid, pplowhpd[:,0], pplowhpd[:,1], color="r", alpha=0.2)

This is for the full centered model. The high contact predictive and data is in green. We undertake this exercise as a prelude to ensembling the models with high Akaike weights

Ensembling

Ensembles are a good way to combine models where one model may be good at something and the other at something else. Ensembles also help with overfitting if the variance cancels out between the ensemble members: they would all probably overfit in slightly different ways. Lets write a function to do our ensembling for us.

def ensemble(grid, modeldict, comparedf, modelnames, contact=0):
    accum_pp=0
    accum_weight=0
    for m in modelnames:
        weight = comparedf.loc[m]['weight']
        pp = compute_pp(grid, modeldict[m]["idata"], contact)
        print(m, weight, np.median(pp))
        accum_pp += pp*weight
        accum_weight +=weight
    return accum_pp/accum_weight
ens_pp_low = ensemble(lpgrid, modeldict, comparedf, ['m1c', 'm2c_nopc', 'm2c_onlyp'])
m1c 0.07261554626595455 28.0
m2c_nopc 0.8694607349682668 28.0
m2c_onlyp 0.057923718765755104 33.0
ens_pp_high = ensemble(lpgrid, modeldict, comparedf, ['m1c', 'm2c_nopc', 'm2c_onlyp'], contact=1)
m1c 0.07261554626595455 37.0
m2c_nopc 0.8694607349682668 37.0
m2c_onlyp 0.057923718765755104 32.0
with sns.plotting_context('poster'):
    pplowmed = np.median(ens_pp_low, axis=1)
    pplowhpd = az.hdi(ens_pp_low.T)
    pphighmed = np.median(ens_pp_high, axis=1)
    pphighhpd = az.hdi(ens_pp_high.T)
    plt.plot(df[df['clevel']==1].logpop, df[df['clevel']==1].total_tools,'o', color="g")
    plt.plot(df[df['clevel']==0].logpop, df[df['clevel']==0].total_tools,'o', color="r")
    plt.plot(lpgrid, pphighmed, color="g")
    plt.fill_between(lpgrid, pphighhpd[:,0], pphighhpd[:,1], color="g", alpha=0.2)
    plt.plot(lpgrid, pplowmed, color="r")
    plt.fill_between(lpgrid, pplowhpd[:,0], pplowhpd[:,1], color="r", alpha=0.2)
/var/folders/wq/mr3zj9r14dzgjnq9rjx_vqbc0000gn/T/ipykernel_99269/628422166.py:3: FutureWarning: hdi currently interprets 2d data as (draw, shape) but this will change in a future release to (chain, draw) for coherence with other functions
  pplowhpd = az.hdi(ens_pp_low.T)
/var/folders/wq/mr3zj9r14dzgjnq9rjx_vqbc0000gn/T/ipykernel_99269/628422166.py:5: FutureWarning: hdi currently interprets 2d data as (draw, shape) but this will change in a future release to (chain, draw) for coherence with other functions
  pphighhpd = az.hdi(ens_pp_high.T)

The ensemble gives sensible limits and even regularizes down the green band at high population by giving more weight to the no-interaction model.

Hierarchical Modelling

Overdispersion is a problem one finds in most poisson models where the variance of the data is larger than the mean, which is the constraint the poisson distribution imposes.

To simplify things, let us consider here, only the model with log(population). Since there is no contact variable, there are no counterfactual plots and we can view the posterior predictive.

ppsamps = compute_pp(lpgrid, trace2c_onlyp)
ppmed = np.median(ppsamps, axis=1)
pphpd = az.hdi(ppsamps.T)
plt.plot(df[df['clevel']==1].logpop, df[df['clevel']==1].total_tools,'o', color="g")
plt.plot(df[df['clevel']==0].logpop, df[df['clevel']==0].total_tools,'o', color="r")
plt.plot(lpgrid, ppmed, color="b")
plt.fill_between(lpgrid, pphpd[:,0], pphpd[:,1], color="b", alpha=0.1)
#plt.ylim([0, 300])
/var/folders/wq/mr3zj9r14dzgjnq9rjx_vqbc0000gn/T/ipykernel_99269/478959609.py:3: FutureWarning: hdi currently interprets 2d data as (draw, shape) but this will change in a future release to (chain, draw) for coherence with other functions
  pphpd = az.hdi(ppsamps.T)

By taking the ratio of the posterior-predictive variance to the posterior-predictive mean, we see that the model is overdispersed.

ppvar=np.var(ppsamps, axis=1)
ppmean=np.mean(ppsamps, axis=1)
ppvar/ppmean
array([1.28609537, 1.26022013, 1.25911084, 1.22957116, 1.21939143,
       1.21164235, 1.20977116, 1.17637416, 1.1547616 , 1.15159053,
       1.14548628, 1.11847757, 1.10974092, 1.11537954, 1.09401532,
       1.11399381, 1.12332672, 1.13285241, 1.11783819, 1.13130799,
       1.15628753, 1.19329506, 1.22502034, 1.2738189 , 1.31209513,
       1.39818014, 1.45712839, 1.56125456, 1.68234374, 1.80818061])

Overdispersion can be fixed by considering a mixture model. We shall see this next week. But hierarchical modelling is also a great way to do this.

Varying Intercepts hierarchical model

What we are basically doing is splitting the intercept into a value constant across the societies and a residual which is society dependent. It is this residual that we will assume is drawn from a gaussian with 0 mean and sigmasoc (\(\sigma_{society}\)) standard deviation. Since there is a varying intercept for every observation, \(\sigma_{society}\) lands up as an estimate of overdispersion amongst societies.

with pm.Model() as m3c:
    betap = pm.Normal("betap", 0, 1)
    alpha = pm.Normal("alpha", 0, 100)
    sigmasoc = pm.HalfCauchy("sigmasoc", 1)
    alphasoc = pm.Normal("alphasoc", 0, sigmasoc, shape=df.shape[0])
    loglam = alpha + alphasoc + betap*df.logpop_c 
    y = pm.Poisson("ntools", mu=pt.exp(loglam), observed=df.total_tools)
with m3c:
    trace3 = pm.sample(5000, tune=1000, idata_kwargs={"log_likelihood": True})
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [betap, alpha, sigmasoc, alphasoc]
/Users/rahul/Library/Caches/uv/archive-v0/aTiHGxSE8gD8G3bEQyxJO/lib/python3.14/site-packages/rich/live.py:260: 
UserWarning: install "ipywidgets" for Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1_000 tune and 5_000 draw iterations (4_000 + 20_000 draws total) took 4 seconds.

Notice that we are fitting 13 parameters to 10 points. Ordinarily this would scream overfitting, but thefocus of our parameters is at different levels, and in the hierarchial set up, 10 of these parameters are really pooled together from one sigma. So the effective number of parameters is something lower.

az.plot_trace(trace3)
array([[<Axes: title={'center': 'betap'}>,
        <Axes: title={'center': 'betap'}>],
       [<Axes: title={'center': 'alpha'}>,
        <Axes: title={'center': 'alpha'}>],
       [<Axes: title={'center': 'alphasoc'}>,
        <Axes: title={'center': 'alphasoc'}>],
       [<Axes: title={'center': 'sigmasoc'}>,
        <Axes: title={'center': 'sigmasoc'}>]], dtype=object)

np.mean(trace3.sample_stats['diverging'].values)
np.float64(0.0)
az.summary(trace3)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
betap 0.260 0.082 0.104 0.416 0.001 0.001 8767.0 7947.0 1.0
alpha 3.442 0.126 3.207 3.679 0.002 0.002 7387.0 7646.0 1.0
alphasoc[0] -0.204 0.245 -0.670 0.244 0.002 0.003 11983.0 10493.0 1.0
alphasoc[1] 0.046 0.221 -0.355 0.486 0.002 0.002 10082.0 10852.0 1.0
alphasoc[2] -0.046 0.197 -0.419 0.327 0.002 0.002 12636.0 10971.0 1.0
alphasoc[3] 0.331 0.193 -0.026 0.692 0.002 0.002 8312.0 10502.0 1.0
alphasoc[4] 0.047 0.180 -0.293 0.387 0.002 0.002 11440.0 11203.0 1.0
alphasoc[5] -0.321 0.211 -0.703 0.074 0.002 0.002 10461.0 10778.0 1.0
alphasoc[6] 0.147 0.177 -0.187 0.479 0.002 0.002 10068.0 11227.0 1.0
alphasoc[7] -0.170 0.186 -0.526 0.181 0.002 0.002 11203.0 8724.0 1.0
alphasoc[8] 0.278 0.179 -0.056 0.614 0.002 0.002 8798.0 9830.0 1.0
alphasoc[9] -0.092 0.297 -0.656 0.483 0.003 0.004 8845.0 7786.0 1.0
sigmasoc 0.313 0.130 0.097 0.551 0.002 0.002 4570.0 5225.0 1.0

We can ask the WAIC how many effective parameters it has, and it tells us roughly 5. Thus you really care about the number of hyper-parameters you have, and not so much about the lower level parameters.

az.waic(trace3)
/Users/rahul/Library/Caches/uv/archive-v0/aTiHGxSE8gD8G3bEQyxJO/lib/python3.14/site-packages/arviz/stats/stats.py:1652: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. 
See http://arxiv.org/abs/1507.04544 for details
  warnings.warn(
Computed from 20000 posterior samples and 10 observations log-likelihood matrix.

          Estimate       SE
elpd_waic   -34.94     1.28
p_waic        4.94        -

There has been a warning during the calculation. Please check the results.

We now write code where now we use sampling from the normal corresponding to \(\sigma_{society}\) to simulate our societies. Again, we dont use theano’s shareds, opting simply to generate samples for the residual intercepts for multiple societies. How many? As many as the traces. You might have thought you only need to generate as many as there are grid points, ie 30, but at the end the posterior predictive must marginalize over the traces at all these points, and thus marginalizing over the full trace at each point suffices!

def compute_pp2(lpgrid, idata, contact=0):
    alphatrace = trace_or_zero(idata, 'alpha')
    betaptrace = trace_or_zero(idata, 'betap')
    sigmasoctrace = trace_or_zero(idata, 'sigmasoc')
    tl = len(alphatrace)
    gl = lpgrid.shape[0]
    lam = np.empty((gl, tl))
    lpgrid_c = lpgrid - lpgrid.mean()
    #simulate. alphasocs generated here
    alphasoctrace = np.random.normal(0, sigmasoctrace)
    for i, v in enumerate(lpgrid):
        temp = alphatrace + betaptrace*lpgrid_c[i] + alphasoctrace
        lam[i,:] = poisson.rvs(np.exp(temp))
    return lam
ppsamps = compute_pp2(lpgrid, trace3)
ppmed = np.median(ppsamps, axis=1)
pphpd = az.hdi(ppsamps.T)
plt.plot(df[df['clevel']==1].logpop, df[df['clevel']==1].total_tools,'o', color="g")
plt.plot(df[df['clevel']==0].logpop, df[df['clevel']==0].total_tools,'o', color="r")
plt.plot(lpgrid, ppmed, color="b")
plt.fill_between(lpgrid, pphpd[:,0], pphpd[:,1], color="b", alpha=0.1)
/var/folders/wq/mr3zj9r14dzgjnq9rjx_vqbc0000gn/T/ipykernel_99269/3714602409.py:2: FutureWarning: hdi currently interprets 2d data as (draw, shape) but this will change in a future release to (chain, draw) for coherence with other functions
  pphpd = az.hdi(ppsamps.T)

The envelope of predictions is much wider here, but overlaps all the points! This is because of the varying intercepts, and it reflects the fact that there is much more variation in the data than is expected from a pure poisson model.

Cross Validation and stacking BMA in pymc

comparedf = az.compare(compare_dict, ic="loo", method="pseudo-BMA")
comparedf.head()
/Users/rahul/Library/Caches/uv/archive-v0/aTiHGxSE8gD8G3bEQyxJO/lib/python3.14/site-packages/arviz/stats/stats.py:782: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
  warnings.warn(
/Users/rahul/Library/Caches/uv/archive-v0/aTiHGxSE8gD8G3bEQyxJO/lib/python3.14/site-packages/arviz/stats/stats.py:782: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
  warnings.warn(
/Users/rahul/Library/Caches/uv/archive-v0/aTiHGxSE8gD8G3bEQyxJO/lib/python3.14/site-packages/arviz/stats/stats.py:782: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
  warnings.warn(
/Users/rahul/Library/Caches/uv/archive-v0/aTiHGxSE8gD8G3bEQyxJO/lib/python3.14/site-packages/arviz/stats/stats.py:782: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
  warnings.warn(
/Users/rahul/Library/Caches/uv/archive-v0/aTiHGxSE8gD8G3bEQyxJO/lib/python3.14/site-packages/arviz/stats/stats.py:782: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
  warnings.warn(
rank elpd_loo p_loo elpd_diff weight se dse warning scale
m2c_nopc 0 -39.885648 4.588078 0.000000 9.398024e-01 5.537213 0.000000 True log
m2c_onlyp 1 -42.853990 4.381984 2.968343 4.829495e-02 4.495307 3.978302 True log
m1c 2 -44.254554 9.257126 4.368906 1.190268e-02 6.535852 2.657803 True log
m2c_onlyic 3 -70.966581 8.499973 31.080933 2.983725e-14 15.939870 16.284487 True log
m2c_onlyc 4 -75.538210 17.092756 35.652563 3.085498e-16 22.478023 22.088001 True log
# az.compare already uses model names as index
comparedf
rank elpd_loo p_loo elpd_diff weight se dse warning scale
m2c_nopc 0 -39.885648 4.588078 0.000000 9.398024e-01 5.537213 0.000000 True log
m2c_onlyp 1 -42.853990 4.381984 2.968343 4.829495e-02 4.495307 3.978302 True log
m1c 2 -44.254554 9.257126 4.368906 1.190268e-02 6.535852 2.657803 True log
m2c_onlyic 3 -70.966581 8.499973 31.080933 2.983725e-14 15.939870 16.284487 True log
m2c_onlyc 4 -75.538210 17.092756 35.652563 3.085498e-16 22.478023 22.088001 True log
az.plot_compare(comparedf)

comparedf_s = az.compare(compare_dict, ic="waic", method="stacking")
comparedf_s.head()
/Users/rahul/Library/Caches/uv/archive-v0/aTiHGxSE8gD8G3bEQyxJO/lib/python3.14/site-packages/arviz/stats/stats.py:1652: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. 
See http://arxiv.org/abs/1507.04544 for details
  warnings.warn(
/Users/rahul/Library/Caches/uv/archive-v0/aTiHGxSE8gD8G3bEQyxJO/lib/python3.14/site-packages/arviz/stats/stats.py:1652: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. 
See http://arxiv.org/abs/1507.04544 for details
  warnings.warn(
/Users/rahul/Library/Caches/uv/archive-v0/aTiHGxSE8gD8G3bEQyxJO/lib/python3.14/site-packages/arviz/stats/stats.py:1652: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. 
See http://arxiv.org/abs/1507.04544 for details
  warnings.warn(
/Users/rahul/Library/Caches/uv/archive-v0/aTiHGxSE8gD8G3bEQyxJO/lib/python3.14/site-packages/arviz/stats/stats.py:1652: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. 
See http://arxiv.org/abs/1507.04544 for details
  warnings.warn(
/Users/rahul/Library/Caches/uv/archive-v0/aTiHGxSE8gD8G3bEQyxJO/lib/python3.14/site-packages/arviz/stats/stats.py:1652: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. 
See http://arxiv.org/abs/1507.04544 for details
  warnings.warn(
rank elpd_waic p_waic elpd_diff weight se dse warning scale
m2c_nopc 0 -39.493347 4.195777 0.000000 0.760466 5.528580 0.000000 True log
m1c 1 -41.976041 6.978613 2.482694 0.000000 6.064680 1.834442 True log
m2c_onlyp 2 -42.202093 3.730087 2.708746 0.239534 4.469968 3.961737 True log
m2c_onlyic 3 -70.740050 8.273443 31.246703 0.000000 15.812725 16.368695 True log
m2c_onlyc 4 -75.175479 16.730025 35.682132 0.000000 22.393098 22.251583 True log
# az.compare already uses model names as index
comparedf_s
rank elpd_waic p_waic elpd_diff weight se dse warning scale
m2c_nopc 0 -39.493347 4.195777 0.000000 0.760466 5.528580 0.000000 True log
m1c 1 -41.976041 6.978613 2.482694 0.000000 6.064680 1.834442 True log
m2c_onlyp 2 -42.202093 3.730087 2.708746 0.239534 4.469968 3.961737 True log
m2c_onlyic 3 -70.740050 8.273443 31.246703 0.000000 15.812725 16.368695 True log
m2c_onlyc 4 -75.175479 16.730025 35.682132 0.000000 22.393098 22.251583 True log