Variational Inference with Neural Networks

Bayesian neural network classification using ADVI in PyMC.

bayesian
neural-networks
classification
variational-inference
Builds a Bayesian neural network for binary classification on the moons dataset using variational inference (ADVI) in PyMC, demonstrating uncertainty quantification in neural network predictions.
Published

June 18, 2025

%matplotlib inline
import pytensor
import pymc as pm
import arviz as az
import pytensor.tensor as pt
import sklearn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from warnings import filterwarnings
from sklearn import datasets
from sklearn.preprocessing import scale
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_moons
sns.set_style('whitegrid')
X, Y = make_moons(noise=0.2, random_state=0, n_samples=1000)
X = scale(X)
Y = Y.astype('float64')
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=.5)
fig, ax = plt.subplots()
ax.scatter(X[Y==0, 0], X[Y==0, 1], label='Class 0')
ax.scatter(X[Y==1, 0], X[Y==1, 1], color='r', label='Class 1')
sns.despine(); ax.legend()
ax.set(xlabel='X', ylabel='Y', title='Toy binary classification data set');

X_train.shape
(500, 2)
def construct_nn(ann_input, ann_output):
    n_hidden = 5

    # Initialize random weights between each layer
    init_1 = np.random.randn(X.shape[1], n_hidden).astype(np.float64)
    init_2 = np.random.randn(n_hidden, n_hidden).astype(np.float64)
    init_out = np.random.randn(n_hidden).astype(np.float64)

    with pm.Model() as neural_network:
        # Weights from input to hidden layer
        weights_in_1 = pm.Normal('w_in_1', 0, sigma=1,
                                 shape=(X.shape[1], n_hidden),
                                 initval=init_1)

        # Weights from 1st to 2nd layer
        weights_1_2 = pm.Normal('w_1_2', 0, sigma=1,
                                shape=(n_hidden, n_hidden),
                                initval=init_2)

        # Weights from hidden layer to output
        weights_2_out = pm.Normal('w_2_out', 0, sigma=1,
                                  shape=(n_hidden,),
                                  initval=init_out)

        # Build neural-network using tanh activation function
        act_1 = pm.math.tanh(pm.math.dot(ann_input,
                                         weights_in_1))
        act_2 = pm.math.tanh(pm.math.dot(act_1,
                                         weights_1_2))
        act_out = pm.math.sigmoid(pm.math.dot(act_2,
                                              weights_2_out))

        # Binary classification -> Bernoulli likelihood
        out = pm.Bernoulli('out',
                           p=act_out,
                           observed=ann_output,
                           total_size=Y_train.shape[0] # IMPORTANT for minibatches
                          )
    return neural_network

# Trick: Turn inputs and outputs into shared variables.
# It's still the same thing, but we can later change the values of the shared variable
# (to switch in the test-data later) and pymc will just use the new data.
# Kind-of like a pointer we can redirect.
ann_input = pytensor.shared(X_train)
ann_output = pytensor.shared(Y_train)
neural_network = construct_nn(ann_input, ann_output)
with neural_network:
    nutstrace = pm.sample(2000, tune=1000)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [w_in_1, w_1_2, w_2_out]
/Users/rahul/Library/Caches/uv/archive-v0/wV-uT_3pb4u247-POgKgx/lib/python3.14/site-packages/rich/live.py:260: 
UserWarning: install "ipywidgets" for Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 86 seconds.
There were 1324 divergences after tuning. Increase `target_accept` or reparameterize.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
az.summary(nutstrace)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
w_in_1[0, 0] -0.252 1.517 -3.140 2.490 0.235 0.076 48.0 243.0 1.09
w_in_1[0, 1] 0.491 1.602 -2.438 3.289 0.337 0.083 25.0 203.0 1.11
w_in_1[0, 2] -0.230 1.559 -3.060 2.722 0.174 0.136 85.0 167.0 1.03
w_in_1[0, 3] -0.006 1.327 -2.855 2.509 0.134 0.083 105.0 90.0 1.03
w_in_1[0, 4] -0.067 1.507 -2.828 2.824 0.152 0.088 107.0 273.0 1.05
w_in_1[1, 0] 0.095 0.673 -1.151 1.740 0.063 0.043 112.0 223.0 1.03
w_in_1[1, 1] 0.109 0.670 -1.344 1.489 0.088 0.045 51.0 192.0 1.07
w_in_1[1, 2] -0.025 0.575 -1.255 1.280 0.042 0.035 182.0 280.0 1.02
w_in_1[1, 3] -0.047 0.607 -1.451 1.246 0.035 0.037 295.0 348.0 1.03
w_in_1[1, 4] -0.020 0.623 -1.367 1.339 0.037 0.046 268.0 387.0 1.03
w_1_2[0, 0] 0.018 1.249 -2.313 2.341 0.046 0.021 759.0 1966.0 1.01
w_1_2[0, 1] -0.109 1.266 -2.463 2.251 0.043 0.025 878.0 1319.0 1.01
w_1_2[0, 2] 0.041 1.244 -2.286 2.308 0.039 0.021 1024.0 1661.0 1.00
w_1_2[0, 3] 0.012 1.259 -2.227 2.436 0.039 0.022 1076.0 2152.0 1.01
w_1_2[0, 4] 0.018 1.249 -2.236 2.403 0.038 0.023 1075.0 1865.0 1.00
w_1_2[1, 0] -0.088 1.237 -2.289 2.282 0.046 0.020 741.0 1930.0 1.01
w_1_2[1, 1] -0.037 1.273 -2.374 2.390 0.067 0.037 361.0 428.0 1.01
w_1_2[1, 2] 0.033 1.223 -2.330 2.216 0.039 0.019 1007.0 1920.0 1.00
w_1_2[1, 3] 0.085 1.266 -2.222 2.406 0.043 0.020 899.0 2093.0 1.00
w_1_2[1, 4] 0.012 1.232 -2.252 2.269 0.040 0.021 948.0 2108.0 1.01
w_1_2[2, 0] -0.054 1.300 -2.407 2.345 0.061 0.025 457.0 1600.0 1.00
w_1_2[2, 1] -0.064 1.285 -2.365 2.383 0.065 0.022 400.0 1204.0 1.01
w_1_2[2, 2] 0.030 1.248 -2.287 2.306 0.041 0.018 950.0 2124.0 1.00
w_1_2[2, 3] -0.022 1.273 -2.355 2.349 0.046 0.024 771.0 1048.0 1.00
w_1_2[2, 4] -0.042 1.278 -2.354 2.400 0.060 0.046 456.0 196.0 1.01
w_1_2[3, 0] 0.083 1.269 -2.274 2.491 0.047 0.023 719.0 1048.0 1.01
w_1_2[3, 1] 0.176 1.352 -2.329 2.724 0.076 0.042 321.0 497.0 1.01
w_1_2[3, 2] 0.020 1.235 -2.304 2.305 0.043 0.021 852.0 1649.0 1.01
w_1_2[3, 3] -0.020 1.257 -2.279 2.358 0.049 0.026 670.0 1065.0 1.01
w_1_2[3, 4] 0.043 1.293 -2.244 2.467 0.051 0.027 647.0 1191.0 1.00
w_1_2[4, 0] 0.042 1.268 -2.253 2.453 0.052 0.024 601.0 1333.0 1.01
w_1_2[4, 1] 0.090 1.310 -2.536 2.390 0.074 0.041 314.0 482.0 1.02
w_1_2[4, 2] -0.010 1.245 -2.368 2.297 0.041 0.022 924.0 1141.0 1.00
w_1_2[4, 3] -0.042 1.271 -2.310 2.371 0.046 0.024 771.0 1208.0 1.00
w_1_2[4, 4] -0.013 1.274 -2.448 2.240 0.051 0.022 637.0 775.0 1.00
w_2_out[0] 0.254 2.365 -4.302 4.399 0.125 0.049 374.0 807.0 1.01
w_2_out[1] 0.161 2.468 -4.633 4.412 0.176 0.069 196.0 627.0 1.01
w_2_out[2] -0.089 2.228 -4.237 4.191 0.083 0.045 736.0 1540.0 1.00
w_2_out[3] 0.177 2.365 -4.232 4.607 0.128 0.058 353.0 1147.0 1.01
w_2_out[4] -0.101 2.346 -4.494 4.114 0.113 0.050 448.0 1317.0 1.00
with neural_network:
    inference = pm.ADVI()
    approx = pm.fit(n=30000, method=inference)
/Users/rahul/Library/Caches/uv/archive-v0/wV-uT_3pb4u247-POgKgx/lib/python3.14/site-packages/rich/live.py:260: 
UserWarning: install "ipywidgets" for Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Finished [100%]: Average Loss = 170.33
advitrace = approx.sample(draws=5000)
az.summary(advitrace)
arviz - WARNING - Shape validation failed: input_shape: (1, 5000), minimum_shape: (chains=2, draws=4)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
w_in_1[0, 0] 0.205 0.392 -0.539 0.904 0.006 0.004 4488.0 5023.0 NaN
w_in_1[0, 1] 0.545 0.136 0.285 0.791 0.002 0.001 4718.0 4809.0 NaN
w_in_1[0, 2] -2.102 0.556 -3.149 -1.092 0.008 0.006 4495.0 4604.0 NaN
w_in_1[0, 3] 0.560 0.132 0.304 0.804 0.002 0.001 5250.0 4973.0 NaN
w_in_1[0, 4] -0.208 0.527 -1.169 0.809 0.007 0.005 4996.0 4938.0 NaN
w_in_1[1, 0] -0.246 0.531 -1.229 0.789 0.008 0.005 4596.0 4611.0 NaN
w_in_1[1, 1] -0.478 0.255 -0.939 0.005 0.004 0.003 4909.0 4759.0 NaN
w_in_1[1, 2] -1.276 0.527 -2.258 -0.263 0.008 0.005 4894.0 4982.0 NaN
w_in_1[1, 3] -0.490 0.214 -0.889 -0.081 0.003 0.002 4709.0 4972.0 NaN
w_in_1[1, 4] 0.390 0.637 -0.821 1.552 0.009 0.007 4960.0 4601.0 NaN
w_1_2[0, 0] 0.007 0.995 -1.957 1.790 0.014 0.010 4931.0 4778.0 NaN
w_1_2[0, 1] 0.174 1.041 -1.708 2.144 0.015 0.011 4788.0 4808.0 NaN
w_1_2[0, 2] -0.080 0.602 -1.180 1.070 0.009 0.006 4700.0 4935.0 NaN
w_1_2[0, 3] -0.134 0.603 -1.220 1.006 0.009 0.006 4679.0 4726.0 NaN
w_1_2[0, 4] 0.209 0.569 -0.830 1.321 0.009 0.006 4435.0 4906.0 NaN
w_1_2[1, 0] -0.315 0.943 -2.119 1.369 0.014 0.009 4703.0 4713.0 NaN
w_1_2[1, 1] 0.330 0.972 -1.509 2.145 0.014 0.010 5166.0 4631.0 NaN
w_1_2[1, 2] -1.088 0.537 -2.103 -0.095 0.008 0.005 4404.0 4605.0 NaN
w_1_2[1, 3] -0.888 0.550 -1.935 0.138 0.008 0.005 5042.0 4912.0 NaN
w_1_2[1, 4] 1.284 0.496 0.358 2.215 0.007 0.005 5161.0 4813.0 NaN
w_1_2[2, 0] -0.485 0.844 -2.117 1.039 0.012 0.009 4709.0 4754.0 NaN
w_1_2[2, 1] 0.444 0.871 -1.180 2.125 0.012 0.009 4919.0 4782.0 NaN
w_1_2[2, 2] -0.730 0.359 -1.387 -0.034 0.005 0.004 4661.0 4599.0 NaN
w_1_2[2, 3] -0.785 0.357 -1.455 -0.101 0.005 0.004 4994.0 5021.0 NaN
w_1_2[2, 4] 1.000 0.348 0.317 1.619 0.005 0.003 4949.0 4941.0 NaN
w_1_2[3, 0] -0.295 0.944 -2.071 1.439 0.014 0.010 4820.0 4909.0 NaN
w_1_2[3, 1] 0.045 0.979 -1.757 1.873 0.014 0.010 4854.0 4590.0 NaN
w_1_2[3, 2] -1.151 0.533 -2.129 -0.128 0.008 0.005 5021.0 5017.0 NaN
w_1_2[3, 3] -1.384 0.528 -2.420 -0.405 0.008 0.005 4921.0 4901.0 NaN
w_1_2[3, 4] 1.139 0.485 0.201 2.029 0.007 0.005 4924.0 4938.0 NaN
w_1_2[4, 0] 0.247 0.962 -1.606 2.041 0.014 0.010 4792.0 4866.0 NaN
w_1_2[4, 1] -0.207 0.973 -2.013 1.607 0.014 0.010 4962.0 4785.0 NaN
w_1_2[4, 2] 0.176 0.531 -0.811 1.189 0.008 0.005 4972.0 4812.0 NaN
w_1_2[4, 3] 0.261 0.554 -0.815 1.270 0.008 0.005 4788.0 4494.0 NaN
w_1_2[4, 4] 0.154 0.501 -0.755 1.104 0.007 0.005 5008.0 4458.0 NaN
w_2_out[0] -0.193 0.244 -0.649 0.258 0.003 0.003 5080.0 4896.0 NaN
w_2_out[1] 0.147 0.245 -0.308 0.606 0.003 0.002 5096.0 4971.0 NaN
w_2_out[2] -1.140 0.231 -1.580 -0.706 0.003 0.002 5087.0 4785.0 NaN
w_2_out[3] -1.148 0.236 -1.589 -0.709 0.003 0.002 5020.0 5066.0 NaN
w_2_out[4] 1.211 0.232 0.778 1.631 0.003 0.002 5120.0 4901.0 NaN
plt.plot(-inference.hist, alpha=.3)

ann_input.set_value(X_test)
ann_output.set_value(Y_test)
with neural_network:
    ppc = pm.sample_posterior_predictive(advitrace)
Sampling: [out]
/Users/rahul/Library/Caches/uv/archive-v0/wV-uT_3pb4u247-POgKgx/lib/python3.14/site-packages/rich/live.py:260: 
UserWarning: install "ipywidgets" for Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

pred = ppc.posterior_predictive['out'].values.reshape(-1, X_test.shape[0]).mean(axis=0) > 0.5
fig, ax = plt.subplots()
ax.scatter(X_test[pred==0, 0], X_test[pred==0, 1])
ax.scatter(X_test[pred==1, 0], X_test[pred==1, 1], color='r')
sns.despine()
ax.set(title='Predicted labels in testing set', xlabel='X', ylabel='Y');

print('Accuracy = {}%'.format((Y_test == pred).mean() * 100))
Accuracy = 88.8%
grid = pm.floatX(np.mgrid[-3:3:100j,-3:3:100j])
grid_2d = grid.reshape(2, -1).T
dummy_out = np.ones(grid_2d.shape[0], dtype=np.float64)
ann_input.set_value(grid_2d)
ann_output.set_value(dummy_out)
with neural_network:
    ppc_grid = pm.sample_posterior_predictive(advitrace, var_names=['out'])
Sampling: [out]
/Users/rahul/Library/Caches/uv/archive-v0/wV-uT_3pb4u247-POgKgx/lib/python3.14/site-packages/rich/live.py:260: 
UserWarning: install "ipywidgets" for Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

cmap = sns.diverging_palette(250, 12, s=85, l=25, as_cmap=True)
fig, ax = plt.subplots(figsize=(16, 9))
ppc_out_mean = ppc_grid.posterior_predictive['out'].values.reshape(-1, grid_2d.shape[0]).mean(axis=0)
contour = ax.contourf(grid[0], grid[1], ppc_out_mean.reshape(100, 100), cmap=cmap)
ax.scatter(X_test[pred==0, 0], X_test[pred==0, 1])
ax.scatter(X_test[pred==1, 0], X_test[pred==1, 1], color='r')
cbar = plt.colorbar(contour, ax=ax)
_ = ax.set(xlim=(-3, 3), ylim=(-3, 3), xlabel='X', ylabel='Y');
cbar.ax.set_ylabel('Posterior predictive mean probability of class label = 0');

cmap = sns.cubehelix_palette(light=1, as_cmap=True)
fig, ax = plt.subplots(figsize=(16, 9))
ppc_out_std = ppc_grid.posterior_predictive['out'].values.reshape(-1, grid_2d.shape[0]).std(axis=0)
contour = ax.contourf(grid[0], grid[1], ppc_out_std.reshape(100, 100), cmap=cmap)
ax.scatter(X_test[pred==0, 0], X_test[pred==0, 1])
ax.scatter(X_test[pred==1, 0], X_test[pred==1, 1], color='r')
cbar = plt.colorbar(contour, ax=ax)
_ = ax.set(xlim=(-3, 3), ylim=(-3, 3), xlabel='X', ylabel='Y');
cbar.ax.set_ylabel('Uncertainty (posterior predictive standard deviation)');

minibatch_x = pm.Minibatch(X_train, batch_size=50)
minibatch_y = pm.Minibatch(Y_train, batch_size=50)
neural_network_minibatch = construct_nn(minibatch_x, minibatch_y)
with neural_network_minibatch:
    approx_mb = pm.fit(40000, method=pm.ADVI())
/Users/rahul/Library/Caches/uv/archive-v0/wV-uT_3pb4u247-POgKgx/lib/python3.14/site-packages/rich/live.py:260: 
UserWarning: install "ipywidgets" for Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Finished [100%]: Average Loss = 35.717
plt.plot(-approx_mb.hist)