# Copyright (c) 2015-2025 Patricio Cubillos and contributors.
# mc3 is open-source software under the MIT license (see LICENSE).
__all__ = [
'subplot',
'_histogram',
'_pairwise',
'hist_2D',
'Marginal',
'Figure',
'Posterior',
]
import copy
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import _pylab_helpers
from matplotlib.colors import is_color_like, to_rgba
from matplotlib.lines import Line2D
import scipy.interpolate as si
from .. import stats as ms
from .. import utils as u
from . import colors
def is_open(fig):
"""Check if a figure has been closed."""
current_figs = [
manager.canvas.figure
for manager in _pylab_helpers.Gcf.figs.values()
]
return fig in current_figs
[docs]
def subplot(rect, margin, pos, nx, ny=None, ymargin=None, dry=False):
"""
Create an axis instance for one panel (with index pos) of a grid
of npanels, where the grid located inside rect (xleft, ybottom,
xright, ytop).
Parameters
----------
rect: 1D List/ndarray
Rectangle with xlo, ylo, xhi, yhi positions of the grid boundaries.
margin: Float
Width of margin between panels.
pos: Integer
Index of panel to create (as in plt.subplots).
nx: Integer
Number of panels along the x axis.
ny: Integer
Number of panels along the y axis. If None, assume ny=nx.
ymargin: Float
Width of margin between panels along y axes (if None, adopt margin).
Returns
-------
axes: Matplotlib.axes.Axes
An Axes instance at the specified position.
"""
if ny is None:
ny = nx
if ymargin is None:
ymargin = margin
# Size of a panel:
Dx = rect[2] - rect[0]
Dy = rect[3] - rect[1]
dx = Dx/nx - (nx-1.0)* margin/nx
dy = Dy/ny - (ny-1.0)*ymargin/ny
# Position of panel pos:
# Follow plt's scheme, where panel 1 is at the top left panel,
# panel 2 is to the right of panel 1, and so on:
xloc = (pos-1) % nx
yloc = (ny-1) - ((pos-1) // nx)
# Bottom-left corner of panel:
xpanel = rect[0] + xloc*(dx+ margin)
ypanel = rect[1] + yloc*(dy+ymargin)
if dry:
return [xpanel, ypanel, dx, dy]
return plt.axes([xpanel, ypanel, dx, dy])
[docs]
def hist_2D(posterior, ranges, nbins):
"""Construct 2D histograms."""
nsamples, npars = np.shape(posterior)
# Column index matches par index, row index matches par index + 1
hist_xran = np.zeros((npars, nbins+1))
hist = np.zeros((npars-1, npars-1, nbins, nbins))
lmax = np.zeros((npars-1, npars-1))
for icol in range(npars-1):
for irow in range(icol, npars-1):
ran = None
if ranges[icol] is not None:
ran = [ranges[irow+1], ranges[icol]]
h, y, x = np.histogram2d(
posterior[:,irow+1],
posterior[:,icol],
bins=nbins, range=ran, density=False,
)
hist[irow, icol] = h
if icol == 0:
hist_xran[irow+1] = y
if irow == 0 and icol == 0:
hist_xran[irow] = x
lmax[irow, icol] = np.amax(h) + 1
return hist_xran, hist, lmax
[docs]
def _histogram(
posterior, estimates, ranges, axes,
nbins, pdf, xpdf, hpd_min, low_bounds, high_bounds,
linewidth, theme, orientation, alpha=0.6,
top_pad=1.05,
clear=True,
):
"""
Lowest-lever routine to plot marginal posterior distributions.
"""
nsamples, npars = np.shape(posterior)
has_credible_interval = (
hpd_min is not None
or low_bounds is not None
)
hist_kw = {
'bins': nbins,
'linewidth': linewidth,
'orientation': orientation,
'facecolor': to_rgba(theme.light_color, alpha=alpha),
'edgecolor': theme.color,
'histtype': 'stepfilled',
'density': True,
}
if has_credible_interval:
hist_kw['facecolor'] = 'none'
for i in range(npars):
ax = axes[i]
if clear:
ax.clear()
if orientation == 'vertical':
xax, yax = ax.xaxis, ax.yaxis
fill_between = ax.fill_between
axline = ax.axvline
else:
xax, yax = ax.yaxis, ax.xaxis
fill_between = ax.fill_betweenx
axline = ax.axhline
vals, bins, h = ax.hist(posterior[:,i], range=ranges[i], **hist_kw)
# Plot the credible intervals as shaded areas:
if has_credible_interval:
vals = np.r_[0, vals, 0]
bins = np.r_[bins[0] - (bins[1]-bins[0]), bins]
f = si.interp1d(bins+0.5*(bins[1]-bins[0]), vals, kind='nearest')
xran = (xpdf[i]>ranges[i][0]) & (xpdf[i]<ranges[i][1])
if hpd_min is not None:
x_shade = pdf[i][xran] >= hpd_min[i]
elif low_bounds is not None:
x_shade = (
(xpdf[i][xran] >= low_bounds[i]) &
(xpdf[i][xran] <= high_bounds[i])
)
fill_between(
xpdf[i][xran], 0.0, f(xpdf[i][xran]),
where=x_shade,
facecolor=theme.light_color,
edgecolor='none',
interpolate=False,
alpha=alpha,
)
if estimates[i] is not None:
axline(
estimates[i],
dashes=(9,2),
lw=linewidth,
color=theme.dark_color,
)
ytop = top_pad * np.amax(vals)
if ytop > yax.get_view_interval()[1]:
yax.set_view_interval(0, ytop, ignore=True)
xax.set_view_interval(*ranges[i], ignore=True)
[docs]
def _pairwise(
hist, hist_xran, axes, ranges, estimates,
palette, nlevels, absolute_dens, lmax,
linewidth, theme, alpha=0.8, clear=True,
swap_axes=False,
):
"""
Lowest-lever routine to plot pair-wise posterior distributions.
(Everything happening inside the axes)
"""
npars = len(ranges)
# Reset upper boundary to absolute maximum value if requested:
if absolute_dens:
lmax[:] = np.amax(lmax)
for icol in range(npars-1):
for irow in range(icol, npars-1):
ax = axes[irow,icol]
if swap_axes:
row, col = icol, irow+1
histo = hist[col-1,row].T
else:
row, col = irow+1, icol
histo = hist[row-1,col]
if clear:
ax.clear()
extent = (
hist_xran[col, 0],
hist_xran[col, -1],
hist_xran[row, 0],
hist_xran[row, -1],
)
levels = np.zeros(nlevels+1)
levels[1:] = np.linspace(1.0, lmax[irow,icol], nlevels)
colors = palette(levels/lmax[irow,icol], alpha=alpha)
colors[0,3] = 0.0
colors[1,3] = 0.75*alpha
cont = ax.contourf(
histo,
colors=colors, levels=levels,
origin='lower', extent=extent,
)
edge_color = to_rgba(theme.color, alpha=0.25)
edge_colors = [edge_color for _ in cont.get_paths()]
edge_colors[0] = (1,1,1,0)
cont.set_edgecolor(edge_colors)
cont.set_linewidth(0.1)
if estimates[col] is not None:
ax.axvline(
estimates[col],
dashes=(9,2), lw=linewidth, color=theme.dark_color,
)
if estimates[row] is not None:
ax.axhline(
estimates[row],
dashes=(9,2), lw=linewidth, color=theme.dark_color,
)
if ranges[col] is not None:
ax.set_xlim(ranges[col])
if ranges[row] is not None:
ax.set_ylim(ranges[row])
def _plot_marginal(obj):
"""Re-draw everything except the data inside the axes."""
npars = obj.npars
ax = obj.hist_axes[0]
fig = ax.get_figure()
for text in obj.stats_texts:
text.set_visible(False)
obj.stats_texts = []
for i in range(npars):
ax = obj.hist_axes[i]
if obj.orientation == 'vertical':
xax, yax = ax.xaxis, ax.yaxis
else:
xax, yax = ax.yaxis, ax.xaxis
ax.tick_params(
labelsize=obj.fontsize-1, direction='in', left=False, top=True,
)
xax.set_label_text(obj.pnames[i], fontsize=obj.fontsize)
yax.set_ticklabels([])
if obj.show_texts:
texts = [rf'{obj.source.tex_estimates[i]}']
obj.stats_texts += colors.rainbow_text(
ax, texts, obj.fontsize-0.25, loc='inside',
)
if not obj.auto_axes:
continue
ax_position = subplot(
obj.rect, obj.margin, i+1, obj.nx, obj.ny, obj.ymargin, dry=True,
)
ax.set_position(ax_position)
if i == 0:
pt_to_pix = fig.canvas.get_renderer().points_to_pixels(72.0)
axes_size_pix = np.amin(ax.get_window_extent().size)
axes_size_pt = axes_size_pix / pt_to_pix * 72.0
tick_size = np.amin([3.5, axes_size_pt/15.0])
ax.tick_params(length=tick_size)
if i%obj.nx == 0:
yax.set_label_text('Posterior', fontsize=obj.fontsize)
def _plot_pairwise(obj):
"""Re-draw everything except the data inside the axes."""
npars = obj.npars
# Estimate size of axes (to later set the length of the ticks)
ax = obj.pair_axes[0,0]
fig = ax.get_figure()
nx = npars + int(obj.plot_marginal) - 1
for icol in range(npars-1):
for irow in range(icol, npars-1):
ax = obj.pair_axes[irow,icol]
if obj.orientation == 'vertical':
h = nx*irow + icol + 1 + npars*int(obj.plot_marginal)
else:
h = 2 + (nx+int(obj.plot_marginal))*icol + irow - icol
ax_position = subplot(
obj.rect, obj.margin, h, nx, nx, obj.ymargin, dry=True,
)
ax.set_position(ax_position)
if icol==0 and irow==0:
pt_to_pix = fig.canvas.get_renderer().points_to_pixels(72.0)
axes_size_pix = np.amin(ax.get_window_extent().size)
axes_size_pt = axes_size_pix / pt_to_pix * 72.0
tick_size = np.amin([3.5, axes_size_pt/15.0])
# Labels:
ax.tick_params(
labelsize=obj.fontsize-1,
length=tick_size,
direction='in',
)
if icol == 0:
ax.set_ylabel(obj.pnames[irow+1], size=obj.fontsize)
else:
ax.set_yticklabels([])
if irow == npars-2:
ax.set_xlabel(obj.pnames[icol], size=obj.fontsize)
plt.setp(ax.xaxis.get_majorticklabels(), rotation=90)
else:
ax.set_xticklabels([])
# Re-draw the colorbar:
colorbar = obj.colorbar
dx = (obj.rect[2]-obj.rect[0])*0.03
dy = (obj.rect[3]-obj.rect[1])*0.45
colorbar.ax.set_position([obj.rect[2]-dx, obj.rect[3]-dy, dx, dy])
colorbar.ax.clear()
boundaries = np.linspace(0.0, 1.0, obj.nlevels)
norm = mpl.colors.BoundaryNorm(boundaries, obj.nlevels)
cmap = mpl.colors.ListedColormap(obj.palette(boundaries, alpha=0.8))
mappable = mpl.cm.ScalarMappable(cmap=cmap, norm=norm)
obj.colorbar = mpl.colorbar.Colorbar(
ax=colorbar.ax,
mappable=mappable,
boundaries=boundaries,
ticks=np.linspace(0.0, 1.0, 6),
ticklocation='left',
)
colorbar.set_label('Posterior density', fontsize=obj.fontsize)
colorbar.ax.tick_params(labelsize=obj.fontsize-1, direction='in')
colorbar.ax.minorticks_off()
for col in colorbar.ax.collections:
col.set_edgecolor('face')
colorbar.ax.set_visible(obj.show_colorbar)
# Histogram:
for text in obj.stats_texts:
text.set_visible(False)
obj.stats_texts = []
nx = npars
for i in range(npars):
ax = obj.hist_axes[i]
ax.set_visible(obj.plot_marginal)
if not obj.plot_marginal:
continue
if obj.orientation == 'vertical':
xax, yax = ax.xaxis, ax.yaxis
plt.setp(ax.xaxis.get_majorticklabels(), rotation=90)
else:
xax, yax = ax.yaxis, ax.xaxis
h = (npars+1)*i + 1
ax_position = subplot(
obj.rect, obj.margin, h, nx, nx, obj.ymargin, dry=True)
ax.set_position(ax_position)
ax.tick_params(
labelsize=obj.fontsize-1,
length=tick_size,
direction='in', left=False, top=True,
)
if i == npars-1:
xax.set_label_text(obj.pnames[i], fontsize=obj.fontsize)
else:
xax.set_label_text('', fontsize=obj.fontsize)
xax.set_ticklabels([])
yax.set_ticklabels([])
if not obj.show_texts:
continue
elif i < npars-1:
stats_text = rf'{obj.pnames[i]} = {obj.source.tex_estimates[i]}'
else:
stats_text = rf'{obj.source.tex_estimates[i]}'
texts = [stats_text]
obj.stats_texts += colors.rainbow_text(ax, texts, obj.fontsize)
class SoftUpdate:
""" https://docs.python.org/3/howto/descriptor.html """
def __set_name__(self, obj, name):
self.private_name = '_' + name
def __get__(self, obj, objtype=None):
value = getattr(obj, self.private_name)
return value
def __set__(self, obj, value):
var_name = self.private_name[1:]
setattr(obj, self.private_name, value)
setattr(obj.source, var_name, value)
# TBD: Inspect this clause
if hasattr(obj, 'pair_axes'):
if obj.pair_axes is not None:
_plot_pairwise(obj)
plt.draw()
else:
if obj.hist_axes is not None:
_plot_marginal(obj)
plt.draw()
def raise_array_size_error(self, obj, value):
raise ValueError(
f"Invalid {self.private_name[1:]} input. Array size "
f"({len(value)}) does not match number of parameters "
f"({obj.npars})"
)
class SizeUpdate(SoftUpdate):
def __set__(self, obj, value):
setattr(obj, self.private_name, tuple(value))
if obj.fig is not None:
obj.fig.set_size_inches(*list(value))
class ThemeUpdate(SoftUpdate):
def __set__(self, obj, value):
var_name = self.private_name[1:]
# TBD: add checks
if isinstance(value, colors.Theme):
pass
elif isinstance(value, str) and value in colors.THEMES:
value = colors.THEMES[value]
elif is_color_like(value):
value = colors.Theme(value)
setattr(obj, self.private_name, value)
setattr(obj.source, var_name, value)
class BestpUpdate(SoftUpdate):
def __set__(self, obj, value):
var_name = self.private_name[1:]
if not hasattr(obj, 'npars'):
return
if value is None:
value = [None for _ in range(obj.npars)]
if len(value) != obj.npars:
self.raise_array_size_error(obj, value)
setattr(obj, self.private_name, value)
setattr(obj.source, var_name, value)
class StatsUpdate(SoftUpdate):
def __set__(self, obj, value):
var_name = self.private_name[1:]
setattr(obj, self.private_name, value)
setattr(obj.source, var_name, value)
class QuantileUpdate(SoftUpdate):
def __set__(self, obj, value):
var_name = self.private_name[1:]
setattr(obj, self.private_name, value)
setattr(obj.source, var_name, value)
class RangeUpdate(SoftUpdate):
def __set__(self, obj, value):
var_name = self.private_name[1:]
if not hasattr(obj, 'npars'):
return
pmins = np.nanmin(obj.posterior, axis=0)
pmaxs = np.nanmax(obj.posterior, axis=0)
# Defaults:
min_max = [(pmin, pmax) for pmin,pmax in zip(pmins, pmaxs)]
if value is None:
value = min_max
if len(value) != obj.npars:
self.raise_array_size_error(obj, value)
for i in range(obj.npars):
if value[i] is None:
value[i] = min_max[i]
setattr(obj, self.private_name, value)
setattr(obj.source, var_name, value)
[docs]
class Marginal():
"""A mostly-interactive marginal posterior plotting object."""
# Soft-update properties:
pnames = SoftUpdate()
rect = SoftUpdate()
margin = SoftUpdate()
ymargin = SoftUpdate()
fontsize = SoftUpdate()
figsize = SizeUpdate()
show_texts = SoftUpdate()
show_estimates = SoftUpdate()
# Properties that require re-drawing:
bestp = BestpUpdate()
ranges = RangeUpdate()
theme = ThemeUpdate()
quantile = QuantileUpdate()
statistics = StatsUpdate()
def __init__(
self, source, posterior, pnames, bestp, ranges, theme,
nx=None, ny=None,
statistics='med_central', quantile=0.683,
bins=25, fontsize=11, linewidth=1.5,
axes=None,
show_texts=True, show_estimates=True,
):
self.source = source
self.fig = None
self.hist_axes = None
self.posterior = posterior
nsamples, self.npars = np.shape(posterior)
self.pnames = pnames
self.bestp = bestp
self.ranges = ranges
self.theme = theme
if nx is None or ny is None:
# Default layout:
npars = self.npars
if npars < 6: # Single row, N columns
nx = npars
elif npars < 13: # Two rows, up to 6 columns
nx = (npars+1) // 2
elif npars < 25: # Six columns, up to 4 rows
nx = 6
elif npars < 56: # 7 columns, up to 8 rows
nx = 7
else:
nx = 8 # Stick with 8 columns from now on
ny = 1 + (npars-1) // nx
self.nx = nx
self.ny = ny
# Default layout sizes:
dx0 = 0.4
self.figsize = size = [
dx0 + 1.45*self.nx,
1.75*self.ny + 0.1
]
self.margin = 0.04 / self.nx
self.ymargin = 0.275 / self.ny
self.rect = [
dx0/size[0], self.ymargin, 1.0 - 0.2/size[0], 1.0 - 0.1/size[1],
]
self.statistics = statistics
self.quantile = quantile
self.bins = bins
self.fontsize = fontsize
self.linewidth = linewidth
self.orientation = 'vertical'
self.show_texts = show_texts
self.show_estimates = show_estimates
self.stats_texts = []
def update(self):
# TBD: Need to erase previous axes
self.plot(
fignum=self.fignum,
)
[docs]
def plot(
self, fignum=None, axes=None, quantile=None,
savefile=None,
):
"""Marginal histogram plot."""
# Create new figure unless explicitly point to existing axes:
if axes is not None:
self.hist_axes = axes
self.fig = axes[0].get_figure()
self.auto_axes = False
else:
self.fig = plt.figure(fignum)
self.fig.set_size_inches(self.figsize)
self.fig.clear()
self.auto_axes = True # False when user inputs custom axes
self.fignum = self.fig.number
self.figsize = self.fig.get_size_inches()
if axes is None:
self.hist_axes = np.tile(None, self.npars)
for i in range(self.npars):
self.hist_axes[i] = subplot(
self.rect, self.margin, i+1, self.nx, self.ny, self.ymargin,
)
if '_like' in self.statistics:
hpd_min = self.source.hpd_min
else:
hpd_min = None
estimates = self.source.estimates
if not self.show_estimates:
estimates = [None for _ in estimates]
_histogram(
self.posterior, estimates, self.ranges,
self.hist_axes, self.bins,
self.source.pdf, self.source.xpdf,
hpd_min, self.source.low_bounds, self.source.high_bounds,
self.linewidth, self.theme,
self.orientation,
top_pad=1.25,
)
_plot_marginal(self)
if savefile is not None:
self.fig.savefig(savefile, dpi=300)
def overplot(self, post, labels=None, nlevels=4, alpha=0.4):
"""
Overplot additional posteriors in the same figure.
This method is still work in progress!
Note that a call to self.update() or even soft updates
will remove all/some of the overplot data. In such case
the user would need to make a new call to self.overplot().
It is also recommended to set show_estimates=False to
prevent over-crowding the figures.
Parameters
----------
posts: 1D iterable of Posterior objects
Currently there are no checks that these new posteriors
have the same parameters (nor same statistics) as self.
The user needs to make sure they are all compartible.
labels: 1D iterable of strings
Labels for each posterior. Note that if provided, the
length of labels has to be one more than posts, because
it also contains the label for self.
"""
if '_like' in post.statistics:
hpd_min = post.hpd_min
else:
hpd_min = None
estimates = post.estimates
if not self.show_estimates:
estimates = [None for _ in estimates]
_histogram(
post.posterior, estimates, post.ranges,
self.hist_axes, self.bins,
post.pdf, post.xpdf,
hpd_min, post.low_bounds, post.high_bounds,
self.linewidth, post.theme,
post.orientation,
top_pad=1.25,
clear=False,
)
class ShareUpdate:
""" https://docs.python.org/3/howto/descriptor.html """
def __set_name__(self, obj, name):
self.private_name = '_' + name
def __get__(self, obj, objtype=None):
value = getattr(obj, self.private_name)
return value
def __set__(self, obj, value):
priv_name = self.private_name
var_name = self.private_name[1:]
if hasattr(obj, priv_name) and value is getattr(obj, priv_name):
return
setattr(obj, priv_name, value)
for i in reversed(range(len(obj.figures))):
fig = obj.figures[i]
if not is_open(fig.fig):
obj.figures.pop(i)
else:
setattr(fig, var_name, value)
class ShareTheme(ShareUpdate):
def __set__(self, obj, value):
priv_name = self.private_name
var_name = self.private_name[1:]
if isinstance(value, colors.Theme):
pass
elif isinstance(value, str) and value in colors.THEMES:
value = colors.THEMES[value]
elif is_color_like(value):
value = colors.Theme(value)
if hasattr(obj, priv_name) and value == getattr(obj, priv_name):
return
setattr(obj, priv_name, value)
for i in reversed(range(len(obj.figures))):
fig = obj.figures[i]
if not is_open(fig.fig):
obj.figures.pop(i)
else:
setattr(fig, var_name, value)
class StatisticsUpdate(ShareUpdate):
def __set__(self, obj, value):
var_name = self.private_name[1:]
priv_name = self.private_name
if hasattr(obj, priv_name) and value is getattr(obj, priv_name):
return
setattr(obj, priv_name, value)
for i in reversed(range(len(obj.figures))):
fig = obj.figures[i]
if not is_open(fig.fig):
obj.figures.pop(i)
else:
setattr(fig, var_name, value)
has_all_attributes = (
hasattr(obj, 'bestp') and
hasattr(obj, 'statistics') and
hasattr(obj, 'quantile')
)
if has_all_attributes:
for i in range(obj.npars):
_, _, obj.hpd_min[i] = ms.cred_region(
obj.posterior[:,i],
quantile=obj.quantile,
)
estimates, low_bounds, high_bounds = ms.marginal_statistics(
obj.posterior, obj.statistics, obj.quantile,
pdf=obj.pdf, xpdf=obj.xpdf,
)
if obj.statistics.startswith('global_'):
obj.estimates = obj.bestp
else:
obj.estimates = estimates
obj.low_bounds = low_bounds
obj.high_bounds = high_bounds
obj.tex_estimates = u.tex_parameters(
obj.estimates,
obj.low_bounds,
obj.high_bounds,
significant_digits=2,
)
[docs]
class Posterior(object):
"""
Classification of posterior plotting tools.
statistics: String
Statistics to use for parameter estimates and uncertainties:
global_* use global best-fit (bestp) estimate.
max_*: Marginal maximum-likelihood (mode) estimate.
med_*: Marginal median estimate.
*_like: HPD credible interval.
*_central: Central quantile interval.
Examples
--------
>>> import mc3
>>> mcmc = np.load('MCMC_HD209458b_sing_0.29-2.0um_MM2017.npz')
>>> posterior, zchain, zmask = mc3.utils.burn(mcmc)
>>> pnames = mcmc['texnames']
>>> bestp = mcmc['bestp']
>>> p = mc3.plots.Posterior(posterior, pnames, bestp)
>>> f1 = p.plot(savefile=f'pairwise_{6:02d}pars.png')
>>> f2 = p.plot_histogram(savefile=f'histogram_{6:02d}pars.png')
"""
# Soft-update properties:
pnames = ShareUpdate()
ranges = ShareUpdate()
theme = ShareTheme()
bestp = StatisticsUpdate()
statistics = StatisticsUpdate()
quantile = StatisticsUpdate()
show_texts = ShareUpdate()
show_estimates = ShareUpdate()
show_colorbar = ShareUpdate()
def __init__(
self, posterior, pnames=None, bestp=None, ranges=None,
statistics='med_central', quantile=0.683,
sample_size=20000,
theme='blue', orientation='vertical',
show_texts=True, show_estimates=True,
show_colorbar=True,
seed=314159,
):
self.figures = []
nsamples, self.npars = np.shape(posterior)
rng = np.random.default_rng(seed)
if sample_size < nsamples:
sample = rng.choice(nsamples, sample_size, replace=False)
sampled_posterior = posterior[sample]
else:
sampled_posterior = np.copy(posterior)
# TBD: enforce posterior as 2D
self.posterior = sampled_posterior
# Defaults:
if pnames is None:
pnames = [f'p{i:02d}' for i in range(self.npars)]
self.pnames = pnames
self.ranges = ranges
self.theme = theme
self.orientation = orientation
self.show_texts = show_texts
self.show_estimates = show_estimates
self.show_colorbar = show_colorbar
self.pdf = [None for _ in range(self.npars)]
self.xpdf = [None for _ in range(self.npars)]
self.hpd_min = [None for _ in range(self.npars)]
for i in range(self.npars):
pdf, xpdf, hpd = ms.cred_region(
self.posterior[:,i], quantile=quantile,
)
self.pdf[i] = pdf
self.xpdf[i] = xpdf
# These will trigger the param estimate calcs in StatisticsUpdate():
if bestp is None:
self.bestp = [None for _ in range(self.npars)]
else:
self.bestp = bestp
self.statistics = statistics
self.quantile = quantile
[docs]
def plot(
self, plot_marginal=True,
fignum=None, figure=None,
quantile=None,
linewidth=None, fontsize=None,
figsize=None, rect=None,
margin=None, ymargin=None,
show_texts=None, show_estimates=None,
show_colorbar=None,
savefile=None,
):
"""
Plot marginal histograms and pairwise posteriors.
"""
# Defaults:
if quantile is None:
quantile = self.quantile
if show_estimates is None:
show_estimates = self.show_estimates
if show_texts is None:
show_texts = self.show_texts
if show_colorbar is None:
show_colorbar = self.show_colorbar
fig = Figure(
self,
self.posterior, self.pnames, self.bestp,
self.ranges, self.theme,
rect=rect,
margin=margin,
ymargin=ymargin,
statistics=self.statistics,
quantile=quantile,
plot_marginal=plot_marginal,
linewidth=linewidth,
fontsize=fontsize,
figsize=figsize,
fignum=fignum,
show_texts=show_texts,
show_estimates=show_estimates,
show_colorbar=show_colorbar,
# bins=25, nlevels=20,
)
self.figures.append(fig)
fig.plot(figure=figure, savefile=savefile)
return fig
[docs]
def plot_histogram(
self, fignum=None, axes=None, quantile=None,
nx=None, ny=None,
savefile=None,
show_texts=None, show_estimates=None,
):
"""
Plot histogram of marginal posteriors.
"""
if show_estimates is None:
show_estimates = self.show_estimates
if show_texts is None:
show_texts = self.show_texts
fig = Marginal(
self,
self.posterior, self.pnames, self.bestp,
self.ranges, self.theme,
statistics=self.statistics,
nx=nx, ny=ny,
show_texts=show_texts,
show_estimates=show_estimates,
)
self.figures.append(fig)
fig.plot(savefile=savefile, fignum=fignum, axes=axes)
return fig
[docs]
def add():
"""TBD: Do not call this method."""
pass
def update(self, **kwargs):
for key, value in kwargs.items():
# If key in valid keys:
print(f'{key} = {value}')
# Else: throw warning
replot = False
# if hard-replotting, update parameter in kwargs.keys():
# 'bestp' in kwargs.keys():
# 'ranges' in kwargs.keys():
# 'linewidth' in kwargs.keys():
# 'theme' in kwargs.keys()
# 'quantile'
# 'bins'
# 'nlevels'
replot = True
if len(kwargs) == 0:
replot = True
for key, value in kwargs.items():
setattr(self, key, value)
#if replot:
for fig in self.figures:
fig.update()
def plot_trace():
pass