Optimization Tutorial

Trey V. Wenger (c) June 2025

Here we demonstrate how to optimize the number of cloud components in a bayes_spec model.

[1]:
# General imports
import matplotlib.pyplot as plt
import arviz as az
import pandas as pd
import numpy as np

import pymc
print("pymc version:", pymc.__version__)

import bayes_spec
print("bayes_spec version:", bayes_spec.__version__)

# Notebook configuration
pd.options.display.max_rows = None
pymc version: 5.22.0
bayes_spec version: 1.7.9-staging+2.gb86248c.dirty

Model Definition and Simulated Data

Like in the basic tutorial for GaussNoiseModel, our model is a Gaussian line profile with the spectral noise as a free parameter. We generate synthetic data as in the basic tutorial, and the data key must be "observation" for this model.

[2]:
from bayes_spec import SpecData
from bayes_spec.models import GaussNoiseModel

# Generate dummy data format for simulation
velocity_axis = np.linspace(-250.0, 250.0, 501) # km/s
noise = 1.0 # K
brightness_data = noise * np.random.randn(len(velocity_axis)) # K

observation = SpecData(
    velocity_axis,
    brightness_data,
    noise,
    xlabel=r"Velocity (km s$^{-1}$)",
    ylabel="Brightness Temperature (K)",
)
dummy_data = {"observation": observation}

# Initialize and define the model
model = GaussNoiseModel(dummy_data, n_clouds=3, baseline_degree=2, seed=1234, verbose=True)
model.add_priors(
    prior_line_area = 500.0, # mode of k=2 gamma distribution prior on line area (K km s-1)
    prior_fwhm = 25.0, # mode of k=2 gamma distribution prior on FWHM line width (km s-1)
    prior_velocity = [0.0, 50.0], # mean and width of normal distribution prior on centroid velocity (km s-1)
    prior_baseline_coeffs = [1.0, 1.0, 1.0], # width of normal distribution prior on normalized baseline coefficients
    prior_rms = 1.0, # width of half-normal distribution prior on spectral rms (K)
)
model.add_likelihood()

# Simulate observation
sim_brightness = model.model.observation.eval({
    "fwhm": [25.0, 40.0, 35.0], # FWHM line width (km/s)
    "line_area": [250.0, 125.0, 175.0], # line area (K km/s)
    "velocity": [-35.0, 10.0, 55.0], # velocity (km/s)
    "baseline_observation_norm": [-0.5, -2.0, 3.0], # normalized baseline coefficients
    "rms_observation": noise, # spectral rms (K)
})

# Plot the simulated data
plt.plot(dummy_data["observation"].spectral, sim_brightness, 'k-')
plt.xlabel(dummy_data["observation"].xlabel)
_ = plt.ylabel(dummy_data["observation"].ylabel)
../_images/notebooks_optimization_3_0.png
[3]:
# Pack simulated data into SpecData
observation = SpecData(
    velocity_axis,
    sim_brightness,
    noise,
    xlabel=r"Velocity (km s$^{-1}$)",
    ylabel="Brightness Temperature (K)",
)
data = {"observation": observation}

Optimize

We use the Optimize class for optimization.

[4]:
from bayes_spec import Optimize

# Initialize optimizer
opt = Optimize(
    GaussNoiseModel,  # model definition
    data,  # data dictionary
    max_n_clouds=8,  # maximum number of clouds
    baseline_degree=2,  # polynomial baseline degree
    seed=1234,  # random seed
    verbose=True,  # verbosity
)

# Define each model
opt.add_priors(
    prior_line_area = 200.0, # mode of k=2 gamma distribution prior on line area (K km s-1)
    prior_fwhm = 30.0, # mode of k=2 gamma distribution prior on FWHM line width (km s-1)
    prior_velocity = [0.0, 50.0], # mean and width of normal distribution prior on centroid velocity (km s-1)
    prior_baseline_coeffs = [1.0, 1.0, 1.0], # width of normal distribution prior on normalized baseline coefficients
    prior_rms = 2.0, # width of half-normal distribution prior on spectral rms (K)
)
opt.add_likelihood()

Optimize has created max_n_clouds models, where opt.models[1] has n_clouds=1, opt.models[2] has n_clouds=2, etc.

[5]:
print(opt.models[4])
print(opt.models[4].n_clouds)
<bayes_spec.models.gauss_noise_model.GaussNoiseModel object at 0x7eade3784e90>
4

By default (approx=True), the optimization algorithm first loops over every model and approximates the posterior distribution using variational inference. We can supply arguments to fit and sample via dictionaries. Whichever model is the first to have a BIC within bic_threshold of the minimum BIC is the “best” model, and is then sampled with MCMC. The algorithm will stop early if two successive models fail to converge to a unique solution with an improving BIC. Note the start_spread parameter, which is like passing start = {"velocity_norm": np.linspace(-3.0, 3.0, n_clouds)} to model.fit or to model.sample within init_kwargs.

[6]:
fit_kwargs = {
    "rel_tolerance": 0.01,
    "abs_tolerance": 0.01,
    "learning_rate": 0.001,
}
sample_kwargs = {
    "chains": 4,
    "cores": 4,
    "init_kwargs": fit_kwargs,
    "nuts_kwargs": {"target_accept": 0.8},
}
opt.optimize(
    bic_threshold=10.0,
    kl_div_threshold=0.1,
    sample_kwargs=sample_kwargs,
    fit_kwargs=fit_kwargs,
    start_spread = {"velocity_norm": [-3.0, 3.0]},
)
Null hypothesis BIC = 2.895e+03
Approximating n_cloud = 1 posterior...
Convergence achieved at 23000
Interrupted at 22,999 [2%]: Average Loss = 1,139.8
Adding log-likelihood to trace
n_cloud = 1 BIC = 1.965e+03

Approximating n_cloud = 2 posterior...
Convergence achieved at 28600
Interrupted at 28,599 [2%]: Average Loss = 1,075.5
Adding log-likelihood to trace
n_cloud = 2 BIC = 1.533e+03

Approximating n_cloud = 3 posterior...
Convergence achieved at 37200
Interrupted at 37,199 [3%]: Average Loss = 962.58
Adding log-likelihood to trace
n_cloud = 3 BIC = 1.507e+03

Approximating n_cloud = 4 posterior...
Convergence achieved at 29700
Interrupted at 29,699 [2%]: Average Loss = 1,026
Adding log-likelihood to trace
n_cloud = 4 BIC = 1.522e+03

Stopping criteria met.
Approximating n_cloud = 5 posterior...
Convergence achieved at 21200
Interrupted at 21,199 [2%]: Average Loss = 1,241.4
Adding log-likelihood to trace
n_cloud = 5 BIC = 1.546e+03

Stopping criteria met.
Stopping early.
Sampling best model (n_cloud = 3)...
Initializing NUTS using custom advi+adapt_diag strategy
Convergence achieved at 37200
Interrupted at 37,199 [3%]: Average Loss = 962.58
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [baseline_observation_norm, line_area_norm, fwhm_norm, velocity_norm, rms_observation_norm]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 3 seconds.
Adding log-likelihood to trace
GMM converged to unique solution

The “best” model is saved in opt.best_model.

[7]:
print(f"Best model has n_clouds = {opt.best_model.n_clouds}")
az.summary(opt.best_model.trace.solution_0)
Best model has n_clouds = 3
[7]:
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
baseline_observation_norm[0] -0.380 0.071 -0.514 -0.249 0.001 0.001 2770.0 2446.0 1.00
baseline_observation_norm[1] -1.870 0.156 -2.177 -1.594 0.002 0.003 4185.0 3016.0 1.00
baseline_observation_norm[2] 1.309 0.884 -0.289 3.045 0.017 0.013 2826.0 2878.0 1.00
velocity_norm[0] 0.229 0.142 -0.025 0.509 0.004 0.002 1076.0 1820.0 1.00
velocity_norm[1] -0.706 0.010 -0.724 -0.689 0.000 0.000 2519.0 2386.0 1.00
velocity_norm[2] 1.129 0.041 1.049 1.203 0.001 0.001 1302.0 1911.0 1.00
line_area_norm[0] 0.825 0.221 0.418 1.225 0.008 0.003 855.0 1528.0 1.01
line_area_norm[1] 1.116 0.081 0.972 1.277 0.002 0.001 1518.0 1885.0 1.00
line_area_norm[2] 0.724 0.177 0.398 1.040 0.006 0.003 811.0 1767.0 1.00
fwhm_norm[0] 2.093 0.512 1.153 3.020 0.017 0.008 907.0 1559.0 1.01
fwhm_norm[1] 0.793 0.040 0.719 0.868 0.001 0.001 2154.0 2492.0 1.00
fwhm_norm[2] 1.188 0.172 0.851 1.505 0.005 0.003 1065.0 2049.0 1.00
rms_observation_norm 0.501 0.016 0.471 0.531 0.000 0.000 4012.0 2812.0 1.00
line_area[0] 164.955 44.245 83.563 245.096 1.509 0.690 855.0 1528.0 1.01
line_area[1] 223.113 16.157 194.479 255.482 0.414 0.278 1518.0 1885.0 1.00
line_area[2] 144.807 35.410 79.565 207.919 1.254 0.548 811.0 1767.0 1.00
fwhm[0] 62.782 15.372 34.597 90.608 0.510 0.236 907.0 1559.0 1.01
fwhm[1] 23.777 1.197 21.566 26.031 0.026 0.018 2154.0 2492.0 1.00
fwhm[2] 35.653 5.162 25.516 45.154 0.158 0.089 1065.0 2049.0 1.00
velocity[0] 11.450 7.103 -1.238 25.439 0.220 0.122 1076.0 1820.0 1.00
velocity[1] -35.296 0.478 -36.211 -34.451 0.010 0.008 2519.0 2386.0 1.00
velocity[2] 56.427 2.063 52.458 60.155 0.058 0.036 1302.0 1911.0 1.00
amplitude[0] 2.467 0.240 2.036 2.923 0.004 0.003 3168.0 3381.0 1.00
amplitude[1] 8.812 0.404 8.084 9.619 0.009 0.006 2156.0 2960.0 1.00
amplitude[2] 3.774 0.545 2.766 4.680 0.018 0.007 1009.0 2302.0 1.00
rms_observation 1.002 0.032 0.942 1.063 0.001 0.001 4012.0 2812.0 1.00

With approx=False, the optimization algorithm samples each model with MCMC (or SMC if smc=True) and determines which is the first model with a BIC within bic_threshold of the minimum across all models. This is more accurate, but slower.

[8]:
opt.optimize(
    bic_threshold=10.0,
    kl_div_threshold=0.1,
    sample_kwargs=sample_kwargs,
    fit_kwargs=fit_kwargs,
    start_spread = {"velocity_norm": [-3.0, 3.0]},
    approx=False,
)
Null hypothesis BIC = 2.895e+03
Sampling n_cloud = 1 posterior...
Initializing NUTS using custom advi+adapt_diag strategy
Convergence achieved at 23000
Interrupted at 22,999 [2%]: Average Loss = 1,139.8
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [baseline_observation_norm, line_area_norm, fwhm_norm, velocity_norm, rms_observation_norm]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 1 seconds.
Adding log-likelihood to trace
GMM converged to unique solution
n_cloud = 1 solution = 0 BIC = 1.964e+03

Sampling n_cloud = 2 posterior...
Initializing NUTS using custom advi+adapt_diag strategy
Convergence achieved at 28600
Interrupted at 28,599 [2%]: Average Loss = 1,075.5
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [baseline_observation_norm, line_area_norm, fwhm_norm, velocity_norm, rms_observation_norm]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 2 seconds.
Adding log-likelihood to trace
GMM converged to unique solution
n_cloud = 2 solution = 0 BIC = 1.531e+03

Sampling n_cloud = 3 posterior...
Initializing NUTS using custom advi+adapt_diag strategy
Convergence achieved at 37200
Interrupted at 37,199 [3%]: Average Loss = 962.58
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [baseline_observation_norm, line_area_norm, fwhm_norm, velocity_norm, rms_observation_norm]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 3 seconds.
Adding log-likelihood to trace
GMM converged to unique solution
n_cloud = 3 solution = 0 BIC = 1.504e+03

Sampling n_cloud = 4 posterior...
Initializing NUTS using custom advi+adapt_diag strategy
Convergence achieved at 29700
Interrupted at 29,699 [2%]: Average Loss = 1,026
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [baseline_observation_norm, line_area_norm, fwhm_norm, velocity_norm, rms_observation_norm]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 9 seconds.
Adding log-likelihood to trace
There were 668 divergences in converged chains.
No solution found!
0 of 4 chains appear converged.

Stopping criteria met.
Sampling n_cloud = 5 posterior...
Initializing NUTS using custom advi+adapt_diag strategy
Convergence achieved at 21200
Interrupted at 21,199 [2%]: Average Loss = 1,241.4
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [baseline_observation_norm, line_area_norm, fwhm_norm, velocity_norm, rms_observation_norm]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 12 seconds.
Adding log-likelihood to trace
There were 2 divergences in converged chains.
No solution found!
0 of 4 chains appear converged.

Stopping criteria met.
Stopping early.

Notice that models with more complexity than is present in the data (i.e., n_clouds > the true number of clouds) tend to have more divergences and difficulty converging. Eventually the stopping criteria terminate the algorithm early.

[9]:
print(f"Best model has n_clouds = {opt.best_model.n_clouds}")
az.summary(opt.best_model.trace.solution_0)
Best model has n_clouds = 3
[9]:
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
baseline_observation_norm[0] -0.380 0.071 -0.514 -0.249 0.001 0.001 2770.0 2446.0 1.00
baseline_observation_norm[1] -1.870 0.156 -2.177 -1.594 0.002 0.003 4185.0 3016.0 1.00
baseline_observation_norm[2] 1.309 0.884 -0.289 3.045 0.017 0.013 2826.0 2878.0 1.00
velocity_norm[0] 0.229 0.142 -0.025 0.509 0.004 0.002 1076.0 1820.0 1.00
velocity_norm[1] -0.706 0.010 -0.724 -0.689 0.000 0.000 2519.0 2386.0 1.00
velocity_norm[2] 1.129 0.041 1.049 1.203 0.001 0.001 1302.0 1911.0 1.00
line_area_norm[0] 0.825 0.221 0.418 1.225 0.008 0.003 855.0 1528.0 1.01
line_area_norm[1] 1.116 0.081 0.972 1.277 0.002 0.001 1518.0 1885.0 1.00
line_area_norm[2] 0.724 0.177 0.398 1.040 0.006 0.003 811.0 1767.0 1.00
fwhm_norm[0] 2.093 0.512 1.153 3.020 0.017 0.008 907.0 1559.0 1.01
fwhm_norm[1] 0.793 0.040 0.719 0.868 0.001 0.001 2154.0 2492.0 1.00
fwhm_norm[2] 1.188 0.172 0.851 1.505 0.005 0.003 1065.0 2049.0 1.00
rms_observation_norm 0.501 0.016 0.471 0.531 0.000 0.000 4012.0 2812.0 1.00
line_area[0] 164.955 44.245 83.563 245.096 1.509 0.690 855.0 1528.0 1.01
line_area[1] 223.113 16.157 194.479 255.482 0.414 0.278 1518.0 1885.0 1.00
line_area[2] 144.807 35.410 79.565 207.919 1.254 0.548 811.0 1767.0 1.00
fwhm[0] 62.782 15.372 34.597 90.608 0.510 0.236 907.0 1559.0 1.01
fwhm[1] 23.777 1.197 21.566 26.031 0.026 0.018 2154.0 2492.0 1.00
fwhm[2] 35.653 5.162 25.516 45.154 0.158 0.089 1065.0 2049.0 1.00
velocity[0] 11.450 7.103 -1.238 25.439 0.220 0.122 1076.0 1820.0 1.00
velocity[1] -35.296 0.478 -36.211 -34.451 0.010 0.008 2519.0 2386.0 1.00
velocity[2] 56.427 2.063 52.458 60.155 0.058 0.036 1302.0 1911.0 1.00
amplitude[0] 2.467 0.240 2.036 2.923 0.004 0.003 3168.0 3381.0 1.00
amplitude[1] 8.812 0.404 8.084 9.619 0.009 0.006 2156.0 2960.0 1.00
amplitude[2] 3.774 0.545 2.766 4.680 0.018 0.007 1009.0 2302.0 1.00
rms_observation 1.002 0.032 0.942 1.063 0.001 0.001 4012.0 2812.0 1.00

With posteriors sampled with MCMC, we can also use leave-one-out cross-validation to perform model comparison.

[10]:
# leave-one-out cross validation
loo = az.compare({
    n_gauss: model.trace for n_gauss, model in opt.models.items()
    if model.solutions # ignore models with no valid solutions
})
loo
[10]:
rank elpd_loo p_loo elpd_diff weight se dse warning scale
3 0 -717.831101 11.490170 0.000000 9.386881e-01 14.599520 0.000000 False log
2 1 -739.711946 9.757259 21.880846 6.131194e-02 14.725176 7.180182 False log
1 2 -964.569764 8.267954 246.738664 2.671086e-12 22.670923 23.981153 False log

The model with the largest elpd_loo is preferred.