Running Multiple Chains (Sequentially or in Parallel) in StochTree¶
Mixing of an MCMC sampler is a perennial concern for complex Bayesian models. BART and BCF are no exception. On common way to address such concerns is to run multiple independent "chains" of an MCMC sampler, so that if each chain gets stuck in a different region of the posterior, their combined samples attain better coverage of the full posterior.
This idea works with the classic "root-initialized" MCMC sampler of Chipman et al (2010), but a key insight of He and Hahn (2023) and Krantsevich et al (2023) is that the GFR algorithm may be used to warm-start initialize multiple chains of the BART / BCF MCMC sampler.
Operationally, the above two approaches have the same implementation (setting num_gfr > 0 if warm-start initialization is desired), so this vignette will demonstrate how to run a multi-chain sampler sequentially.
To begin, load stochtree and other relevant libraries
import numpy as np
import matplotlib.pyplot as plt
import arviz as az
from sklearn.model_selection import train_test_split
from stochtree import BARTModel, BCFModel
Demo 1: Supervised Learning¶
Data Simulation¶
Simulate a simple partitioned linear model
# Generate the data
random_seed = 1111
rng = np.random.default_rng(random_seed)
n = 500
p_x = 10
p_w = 1
snr = 3
X = rng.uniform(size=(n, p_x))
leaf_basis = rng.uniform(size=(n, p_w))
f_XW = (((0 <= X[:, 0]) & (0.25 > X[:, 0])) *
(-7.5 * leaf_basis[:, 0]) +
((0.25 <= X[:, 0]) & (0.5 > X[:, 0])) * (-2.5 * leaf_basis[:, 0]) +
((0.5 <= X[:, 0]) & (0.75 > X[:, 0])) * (2.5 * leaf_basis[:, 0]) +
((0.75 <= X[:, 0]) & (1 > X[:, 0])) * (7.5 * leaf_basis[:, 0]))
noise_sd = np.std(f_XW) / snr
y = f_XW + rng.normal(0, noise_sd, size=n)
# Split data into test and train sets
test_set_pct = 0.2
train_inds, test_inds = train_test_split(np.arange(n), test_size=test_set_pct, random_state=random_seed)
n_train = len(train_inds)
n_test = len(test_inds)
X_train = X[train_inds]
X_test = X[test_inds]
leaf_basis_train = leaf_basis[train_inds]
leaf_basis_test = leaf_basis[test_inds]
y_train = y[train_inds]
y_test = y[test_inds]
Sampling Multiple Chains Sequentially from Scratch¶
The simplest way to sample multiple chains of a stochtree model is to do so "sequentially," that is, after chain 1 is sampled, chain 2 is sampled from a different starting state, and similarly for each of the requested chains. This is supported internally in both the bart() and bcf() functions, with the num_chains parameter in the general_params list.
Define some high-level parameters, including number of chains to run and number of samples per chain. Here we run 4 independent chains with 5000 MCMC iterations, each of which is initialized by a different "grow-from-root" sample (the last 4 of 5 GFR samples) and burned in for 2000 iterations after warm-start.
num_chains = 4
num_gfr = 0
num_burnin = 1000
num_mcmc = 2000
Run the sampler
bart_model = BARTModel()
bart_model.sample(
X_train = X_train,
leaf_basis_train = leaf_basis_train,
y_train = y_train,
num_gfr = num_gfr,
num_burnin = num_burnin,
num_mcmc = num_mcmc,
general_params = {'num_chains' : num_chains}
)
Now we have a BARTModel object with num_chains * num_mcmc samples stored internally. These samples are arranged sequentially, with the first num_mcmc samples corresponding to chain 1, the next num_mcmc samples to chain 2, etc...
Since each chain is a set of samples of the same model, we can analyze the samples collectively, for example, by looking at out-of-sample predictions.
y_hat_test = bart_model.predict(
X = X_test,
leaf_basis = leaf_basis_test,
type = "mean",
terms = "y_hat"
)
plt.scatter(y_hat_test, y_test)
plt.xlabel("Estimated conditional mean")
plt.ylabel("Actual outcome")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
<matplotlib.lines.AxLine at 0x7faaef000410>
Now, suppose we want to analyze each of the chains separately to assess mixing / convergence.
We can use our knowledge of the internal arrangement of the chain samples to construct a an mcmc.list in the coda package, from which we can perform various diagnostics.
sigma2_samples = bart_model.global_var_samples
sigma2_samples_by_chain = {"sigma2": np.reshape(sigma2_samples, (num_chains, num_mcmc))}
az.plot_trace(sigma2_samples_by_chain)
array([[<Axes: title={'center': 'sigma2'}>,
<Axes: title={'center': 'sigma2'}>]], dtype=object)
az.ess(sigma2_samples_by_chain)
<xarray.Dataset> Size: 8B
Dimensions: ()
Data variables:
sigma2 float64 8B 666.9
Attributes:
created_at: 2025-11-27T01:58:03.539083+00:00
arviz_version: 0.22.0az.rhat(sigma2_samples_by_chain)
<xarray.Dataset> Size: 8B
Dimensions: ()
Data variables:
sigma2 float64 8B 1.009
Attributes:
created_at: 2025-11-27T01:58:03.723022+00:00
arviz_version: 0.22.0az.plot_autocorr(sigma2_samples_by_chain)
array([<Axes: title={'center': 'sigma2\n0'}>,
<Axes: title={'center': 'sigma2\n1'}>,
<Axes: title={'center': 'sigma2\n2'}>,
<Axes: title={'center': 'sigma2\n3'}>], dtype=object)
Sampling Multiple Chains Sequentially from XBART Forests¶
In the example above, each chain was initialized from "root", meaning each tree in a forest was a single root node and all parameter values were set to a "default" starting point. If we sample a model using a small number of 'grow-from-root' iterations, we can use these forests to initialize MCMC chains.
num_chains = 4
num_gfr = 5
num_burnin = 1000
num_mcmc = 2000
Run the initial GFR sampler
xbart_model = BARTModel()
xbart_model.sample(
X_train = X_train,
leaf_basis_train = leaf_basis_train,
y_train = y_train,
num_gfr = num_gfr,
num_burnin = 0,
num_mcmc = 0
)
xbart_model_json = xbart_model.to_json()
Run the multi-chain BART sampler, with each chain initialized from a different GFR forest
bart_model = BARTModel()
bart_model.sample(
X_train = X_train,
leaf_basis_train = leaf_basis_train,
y_train = y_train,
num_gfr = 0,
num_burnin = num_burnin,
num_mcmc = num_mcmc,
general_params = {'num_chains' : num_chains},
previous_model_json = xbart_model_json,
previous_model_warmstart_sample_num = num_gfr - 1
)
Now we have a BARTModel object with num_chains * num_mcmc samples stored internally. These samples are arranged sequentially, with the first num_mcmc samples corresponding to chain 1, the next num_mcmc samples to chain 2, etc...
Since each chain is a set of samples of the same model, we can analyze the samples collectively, for example, by looking at out-of-sample predictions.
y_hat_test = bart_model.predict(
X = X_test,
leaf_basis = leaf_basis_test,
type = "mean",
terms = "y_hat"
)
plt.scatter(y_hat_test, y_test)
plt.xlabel("Estimated conditional mean")
plt.ylabel("Actual outcome")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
<matplotlib.lines.AxLine at 0x7faaedf06bd0>
Now, suppose we want to analyze each of the chains separately to assess mixing / convergence.
We can use our knowledge of the internal arrangement of the chain samples to construct a an mcmc.list in the coda package, from which we can perform various diagnostics.
sigma2_samples = bart_model.global_var_samples
sigma2_samples_by_chain = {"sigma2": np.reshape(sigma2_samples, (num_chains, num_mcmc))}
az.plot_trace(sigma2_samples_by_chain)
array([[<Axes: title={'center': 'sigma2'}>,
<Axes: title={'center': 'sigma2'}>]], dtype=object)
az.ess(sigma2_samples_by_chain)
<xarray.Dataset> Size: 8B
Dimensions: ()
Data variables:
sigma2 float64 8B 350.2
Attributes:
created_at: 2025-11-27T01:58:34.424478+00:00
arviz_version: 0.22.0az.rhat(sigma2_samples_by_chain)
<xarray.Dataset> Size: 8B
Dimensions: ()
Data variables:
sigma2 float64 8B 1.01
Attributes:
created_at: 2025-11-27T01:58:34.435651+00:00
arviz_version: 0.22.0az.plot_autocorr(sigma2_samples_by_chain)
array([<Axes: title={'center': 'sigma2\n0'}>,
<Axes: title={'center': 'sigma2\n1'}>,
<Axes: title={'center': 'sigma2\n2'}>,
<Axes: title={'center': 'sigma2\n3'}>], dtype=object)
Demo 2: Causal Inference¶
Data Simulation¶
Simulate a simple partitioned linear model
# Generate the data
n = 1000
p_X = 5
X = rng.uniform(0, 1, (n, p_X))
pi_X = 0.25 + 0.5 * X[:, 0]
Z = rng.binomial(1, pi_X, n).astype(float)
mu_X = pi_X * 5 + 2 * X[:, 2]
tau_X = X[:, 1] * 2 - 1
epsilon = rng.normal(0, 1, n)
E_Y_XZ = mu_X + tau_X * Z
y = E_Y_XZ + epsilon
# Split data into test and train sets
test_set_pct = 0.2
train_inds, test_inds = train_test_split(np.arange(n), test_size=test_set_pct, random_state=random_seed)
n_train = len(train_inds)
n_test = len(test_inds)
X_train = X[train_inds]
X_test = X[test_inds]
Z_train = Z[train_inds]
Z_test = Z[test_inds]
E_Y_XZ_train = E_Y_XZ[train_inds]
E_Y_XZ_test = E_Y_XZ[test_inds]
propensity_train = pi_X[train_inds]
propensity_test = pi_X[test_inds]
y_train = y[train_inds]
y_test = y[test_inds]
Sampling Multiple Chains Sequentially from Scratch¶
The simplest way to sample multiple chains of a stochtree model is to do so "sequentially," that is, after chain 1 is sampled, chain 2 is sampled from a different starting state, and similarly for each of the requested chains. This is supported internally in both the bart() and bcf() functions, with the num_chains parameter in the general_params list.
Define some high-level parameters, including number of chains to run and number of samples per chain. Here we run 4 independent chains with 5000 MCMC iterations, each of which is initialized by a different "grow-from-root" sample (the last 4 of 5 GFR samples) and burned in for 2000 iterations after warm-start.
num_chains = 4
num_gfr = 0
num_burnin = 1000
num_mcmc = 2000
Run the sampler
bcf_model = BCFModel()
bcf_model.sample(
X_train = X_train,
Z_train = Z_train,
propensity_train = propensity_train,
y_train = y_train,
num_gfr = num_gfr,
num_burnin = num_burnin,
num_mcmc = num_mcmc,
general_params = {'num_chains' : num_chains}
)
Now we have a BARTModel object with num_chains * num_mcmc samples stored internally. These samples are arranged sequentially, with the first num_mcmc samples corresponding to chain 1, the next num_mcmc samples to chain 2, etc...
Since each chain is a set of samples of the same model, we can analyze the samples collectively, for example, by looking at out-of-sample predictions.
y_hat_test = bcf_model.predict(
X = X_test,
Z = Z_test,
propensity = propensity_test,
type = "mean",
terms = "y_hat"
)
plt.scatter(y_hat_test, y_test)
plt.xlabel("Estimated conditional mean")
plt.ylabel("Actual outcome")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
<matplotlib.lines.AxLine at 0x7faaedc58cb0>
Now, suppose we want to analyze each of the chains separately to assess mixing / convergence.
We can use our knowledge of the internal arrangement of the chain samples to construct a an mcmc.list in the coda package, from which we can perform various diagnostics.
sigma2_samples = bcf_model.global_var_samples
sigma2_samples_by_chain = {"sigma2": np.reshape(sigma2_samples, (num_chains, num_mcmc))}
az.plot_trace(sigma2_samples_by_chain)
array([[<Axes: title={'center': 'sigma2'}>,
<Axes: title={'center': 'sigma2'}>]], dtype=object)
az.ess(sigma2_samples_by_chain)
<xarray.Dataset> Size: 8B
Dimensions: ()
Data variables:
sigma2 float64 8B 1.07e+03
Attributes:
created_at: 2025-11-27T02:00:31.448953+00:00
arviz_version: 0.22.0az.rhat(sigma2_samples_by_chain)
<xarray.Dataset> Size: 8B
Dimensions: ()
Data variables:
sigma2 float64 8B 1.009
Attributes:
created_at: 2025-11-27T02:00:31.459872+00:00
arviz_version: 0.22.0az.plot_autocorr(sigma2_samples_by_chain)
array([<Axes: title={'center': 'sigma2\n0'}>,
<Axes: title={'center': 'sigma2\n1'}>,
<Axes: title={'center': 'sigma2\n2'}>,
<Axes: title={'center': 'sigma2\n3'}>], dtype=object)
Sampling Multiple Chains Sequentially from XBCF Forests¶
In the example above, each chain was initialized from "root", meaning each tree in a forest was a single root node and all parameter values were set to a "default" starting point. If we sample a model using a small number of 'grow-from-root' iterations, we can use these forests to initialize MCMC chains.
num_chains = 4
num_gfr = 5
num_burnin = 1000
num_mcmc = 2000
Run the initial GFR sampler
xbcf_model = BCFModel()
xbcf_model.sample(
X_train = X_train,
Z_train = Z_train,
propensity_train = propensity_train,
y_train = y_train,
num_gfr = num_gfr,
num_burnin = 0,
num_mcmc = 0
)
xbcf_model_json = xbcf_model.to_json()
Run the multi-chain BCF sampler, with each chain initialized from a different GFR forest
bcf_model = BCFModel()
bcf_model.sample(
X_train = X_train,
Z_train = Z_train,
propensity_train = propensity_train,
y_train = y_train,
num_gfr = 0,
num_burnin = num_burnin,
num_mcmc = num_mcmc,
general_params = {'num_chains' : num_chains},
previous_model_json = xbcf_model_json,
previous_model_warmstart_sample_num = num_gfr - 1
)
Now we have a BARTModel object with num_chains * num_mcmc samples stored internally. These samples are arranged sequentially, with the first num_mcmc samples corresponding to chain 1, the next num_mcmc samples to chain 2, etc...
Since each chain is a set of samples of the same model, we can analyze the samples collectively, for example, by looking at out-of-sample predictions.
y_hat_test = bcf_model.predict(
X = X_test,
Z = Z_test,
propensity = propensity_test,
type = "mean",
terms = "y_hat"
)
plt.scatter(y_hat_test, y_test)
plt.xlabel("Estimated conditional mean")
plt.ylabel("Actual outcome")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
<matplotlib.lines.AxLine at 0x7faaeb4d52e0>
Now, suppose we want to analyze each of the chains separately to assess mixing / convergence.
We can use our knowledge of the internal arrangement of the chain samples to construct a an mcmc.list in the coda package, from which we can perform various diagnostics.
sigma2_samples = bcf_model.global_var_samples
sigma2_samples_by_chain = {"sigma2": np.reshape(sigma2_samples, (num_chains, num_mcmc))}
az.plot_trace(sigma2_samples_by_chain)
array([[<Axes: title={'center': 'sigma2'}>,
<Axes: title={'center': 'sigma2'}>]], dtype=object)
az.ess(sigma2_samples_by_chain)
<xarray.Dataset> Size: 8B
Dimensions: ()
Data variables:
sigma2 float64 8B 3.366e+03
Attributes:
created_at: 2025-11-27T02:02:29.284646+00:00
arviz_version: 0.22.0az.rhat(sigma2_samples_by_chain)
<xarray.Dataset> Size: 8B
Dimensions: ()
Data variables:
sigma2 float64 8B 1.003
Attributes:
created_at: 2025-11-27T02:02:29.294459+00:00
arviz_version: 0.22.0az.plot_autocorr(sigma2_samples_by_chain)
array([<Axes: title={'center': 'sigma2\n0'}>,
<Axes: title={'center': 'sigma2\n1'}>,
<Axes: title={'center': 'sigma2\n2'}>,
<Axes: title={'center': 'sigma2\n3'}>], dtype=object)