Source code for mc3.stats.gelman

# Copyright (c) 2015-2023 Patricio Cubillos and contributors.
# mc3 is open-source software under the MIT license (see LICENSE).

__all__ = [
    'gelman_rubin',
]

import sys
import numpy as np


[docs]def gelman_rubin(Z, Zchain, burnin): """ Gelman--Rubin convergence test on a MCMC chain of parameters (Gelman & Rubin, 1992). Parameters ---------- Z: 2D float ndarray A 2D array of shape (nsamples, npars) containing the parameter MCMC chains. Zchain: 1D integer ndarray A 1D array of length nsamples indicating the chain for each sample. burnin: Integer Number of iterations to remove. Returns ------- GRfactor: 1D float ndarray The potential scale reduction factors of the chain for each parameter. If they are much greater than 1, the chain is not converging. """ # Number of chains: nchains = np.amax(Zchain) + 1 # Number of free parameters: npars = np.shape(Z)[1] # Count number of samples in each chain: unique, nsamples = np.unique(Zchain, return_counts=True) # Remove pre-MCMC samples, and subtract burnin: nsamples = nsamples[unique >= 0] - burnin # Number of iterations (chain length): niter = np.amin(nsamples) if niter < 1: print("Not enough samples for Gelman-Rubin test.") return np.zeros(npars) # Reshape the Z array into a 3D array: data = np.zeros((nchains, niter, npars)) for c in range(nchains): good = np.where(Zchain == c)[0][burnin:burnin+niter] data[c] = Z[good] # Allocate placeholder for results: GRfactor = np.zeros(npars) # Calculate psrf for each parameter: for i in range(npars): GRfactor[i] = psrf(data[:,:,i]) return GRfactor
def psrf(chains): """ Calculate the potential scale reduction factor (PSRF) of the Gelman and Rubin convergence test on a fitting parameter. Parameters ---------- chains: 2D ndarray Array containing the chains for a single parameter. Shape must be (nchains, chainlen). """ # Get length of each chain and reshape: nchains, chainlen = np.shape(chains) # Calculate W (within-chain variance): W = np.mean(np.var(chains, axis=1)) # Calculate B (between-chain variance): means = np.mean(chains, axis=1) mmean = np.mean(means) B = (chainlen/(nchains-1.0)) * np.sum((means-mmean)**2) # Calculate V (posterior marginal variance): V = W*((chainlen - 1.0)/chainlen) + B*((nchains + 1.0)/(chainlen*nchains)) # Calculate potential scale reduction factor (PSRF): rf = np.sqrt(V/W) return rf