# Copyright (c) 2015-2023 Patricio Cubillos and contributors.
# mc3 is open-source software under the MIT license (see LICENSE).
__all__ = [
'subplot',
'_histogram',
'_pairwise',
'hist_2D',
'Marginal',
'Figure',
'Posterior',
]
import copy
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import _pylab_helpers
from matplotlib.colors import is_color_like, to_rgba
from matplotlib.lines import Line2D
import scipy.interpolate as si
from .. import stats as ms
from .. import utils as u
from . import colors
def is_open(fig):
"""Check if a figure has been closed."""
current_figs = [
manager.canvas.figure
for manager in _pylab_helpers.Gcf.figs.values()
]
return fig in current_figs
[docs]def subplot(rect, margin, pos, nx, ny=None, ymargin=None, dry=False):
"""
Create an axis instance for one panel (with index pos) of a grid
of npanels, where the grid located inside rect (xleft, ybottom,
xright, ytop).
Parameters
----------
rect: 1D List/ndarray
Rectangle with xlo, ylo, xhi, yhi positions of the grid boundaries.
margin: Float
Width of margin between panels.
pos: Integer
Index of panel to create (as in plt.subplots).
nx: Integer
Number of panels along the x axis.
ny: Integer
Number of panels along the y axis. If None, assume ny=nx.
ymargin: Float
Width of margin between panels along y axes (if None, adopt margin).
Returns
-------
axes: Matplotlib.axes.Axes
An Axes instance at the specified position.
"""
if ny is None:
ny = nx
if ymargin is None:
ymargin = margin
# Size of a panel:
Dx = rect[2] - rect[0]
Dy = rect[3] - rect[1]
dx = Dx/nx - (nx-1.0)* margin/nx
dy = Dy/ny - (ny-1.0)*ymargin/ny
# Position of panel pos:
# Follow plt's scheme, where panel 1 is at the top left panel,
# panel 2 is to the right of panel 1, and so on:
xloc = (pos-1) % nx
yloc = (ny-1) - ((pos-1) // nx)
# Bottom-left corner of panel:
xpanel = rect[0] + xloc*(dx+ margin)
ypanel = rect[1] + yloc*(dy+ymargin)
if dry:
return [xpanel, ypanel, dx, dy]
return plt.axes([xpanel, ypanel, dx, dy])
[docs]def hist_2D(posterior, ranges, nbins):
"""Construct 2D histograms."""
nsamples, npars = np.shape(posterior)
# Column index matches par index, row index matches par index + 1
hist_xran = np.zeros((npars, nbins+1))
hist = np.zeros((npars-1, npars-1, nbins, nbins))
lmax = np.zeros((npars-1, npars-1))
for icol in range(npars-1):
for irow in range(icol, npars-1):
ran = None
if ranges[icol] is not None:
ran = [ranges[irow+1], ranges[icol]]
h, y, x = np.histogram2d(
posterior[:,irow+1],
posterior[:,icol],
bins=nbins, range=ran, density=False,
)
hist[irow, icol] = h
if icol == 0:
hist_xran[irow+1] = y
if irow == 0 and icol == 0:
hist_xran[irow] = x
lmax[irow, icol] = np.amax(h) + 1
return hist_xran, hist, lmax
[docs]def _histogram(
posterior, estimates, ranges, axes,
nbins, pdf, xpdf, hpd_min, low_bounds, high_bounds,
linewidth, theme, orientation, alpha=0.6,
top_pad=1.05,
clear=True,
):
"""
Lowest-lever routine to plot marginal posterior distributions.
"""
nsamples, npars = np.shape(posterior)
has_credible_interval = (
hpd_min is not None
or low_bounds is not None
)
hist_kw = {
'bins': nbins,
'linewidth': linewidth,
'orientation': orientation,
'facecolor': to_rgba(theme.light_color, alpha=alpha),
'edgecolor': theme.color,
'histtype': 'stepfilled',
'density': True,
}
if has_credible_interval:
hist_kw['facecolor'] = 'none'
for i in range(npars):
ax = axes[i]
if clear:
ax.clear()
if orientation == 'vertical':
xax, yax = ax.xaxis, ax.yaxis
fill_between = ax.fill_between
axline = ax.axvline
else:
xax, yax = ax.yaxis, ax.xaxis
fill_between = ax.fill_betweenx
axline = ax.axhline
vals, bins, h = ax.hist(posterior[:,i], range=ranges[i], **hist_kw)
# Plot the credible intervals as shaded areas:
if has_credible_interval:
vals = np.r_[0, vals, 0]
bins = np.r_[bins[0] - (bins[1]-bins[0]), bins]
f = si.interp1d(bins+0.5*(bins[1]-bins[0]), vals, kind='nearest')
xran = (xpdf[i]>ranges[i][0]) & (xpdf[i]<ranges[i][1])
if hpd_min is not None:
x_shade = pdf[i][xran] >= hpd_min[i]
elif low_bounds is not None:
x_shade = (
(xpdf[i][xran] >= low_bounds[i]) &
(xpdf[i][xran] <= high_bounds[i])
)
fill_between(
xpdf[i][xran], 0.0, f(xpdf[i][xran]),
where=x_shade,
facecolor=theme.light_color,
edgecolor='none',
interpolate=False,
alpha=alpha,
)
if estimates[i] is not None:
axline(
estimates[i],
dashes=(9,2),
lw=linewidth,
color=theme.dark_color,
)
ytop = top_pad * np.amax(vals)
if ytop > yax.get_view_interval()[1]:
yax.set_view_interval(0, ytop, ignore=True)
xax.set_view_interval(*ranges[i], ignore=True)
[docs]def _pairwise(
hist, hist_xran, axes, ranges, estimates,
palette, nlevels, absolute_dens, lmax,
linewidth, theme, alpha=0.8, clear=True,
):
"""
Lowest-lever routine to plot pair-wise posterior distributions.
(Everything happening inside the axes)
"""
npars = len(ranges)
# Reset upper boundary to absolute maximum value if requested:
if absolute_dens:
lmax[:] = np.amax(lmax)
for icol in range(npars-1):
for irow in range(icol, npars-1):
ax = axes[irow,icol]
if clear:
ax.clear()
extent = (
hist_xran[icol,0],
hist_xran[icol,-1],
hist_xran[irow+1,0],
hist_xran[irow+1,-1],
)
levels = np.zeros(nlevels+1)
levels[1:] = np.linspace(1.0, lmax[irow,icol], nlevels)
colors = palette(levels/lmax[irow,icol], alpha=alpha)
colors[0,3] = 0.0
colors[1,3] = 0.75*alpha
cont = ax.contourf(
hist[irow,icol],
colors=colors, levels=levels,
origin='lower', extent=extent,
)
edge_color = to_rgba(theme.color, alpha=0.65)
for c in cont.collections:
c.set_edgecolor(edge_color)
c.set_linewidth(0.1)
cont.collections[0].set_edgecolor((1,1,1,0))
if estimates[icol] is not None:
ax.axvline(
estimates[icol],
dashes=(9,2), lw=linewidth, color=theme.dark_color,
)
if estimates[irow+1] is not None:
ax.axhline(
estimates[irow+1],
dashes=(9,2), lw=linewidth, color=theme.dark_color,
)
if ranges[icol] is not None:
ax.set_xlim(ranges[icol])
if ranges[irow] is not None:
ax.set_ylim(ranges[irow+1])
def _plot_marginal(obj):
"""Re-draw everything except the data inside the axes."""
npars = obj.npars
ax = obj.hist_axes[0]
fig = ax.get_figure()
for text in obj.stats_texts:
text.set_visible(False)
obj.stats_texts = []
for i in range(npars):
ax = obj.hist_axes[i]
if obj.orientation == 'vertical':
xax, yax = ax.xaxis, ax.yaxis
else:
xax, yax = ax.yaxis, ax.xaxis
ax.tick_params(
labelsize=obj.fontsize-1, direction='in', left=False, top=True,
)
xax.set_label_text(obj.pnames[i], fontsize=obj.fontsize)
yax.set_ticklabels([])
if obj.show_texts:
texts = [rf'{obj.source.tex_estimates[i]}']
obj.stats_texts += colors.rainbow_text(
ax, texts, obj.fontsize-0.25, loc='inside',
)
if not obj.auto_axes:
continue
ax_position = subplot(
obj.rect, obj.margin, i+1, obj.nx, obj.ny, obj.ymargin, dry=True,
)
ax.set_position(ax_position)
if i == 0:
pt_to_pix = fig.canvas.get_renderer().points_to_pixels(72.0)
axes_size_pix = np.amin(ax.get_window_extent().size)
axes_size_pt = axes_size_pix / pt_to_pix * 72.0
tick_size = np.amin([3.5, axes_size_pt/15.0])
ax.tick_params(length=tick_size)
if i%obj.nx == 0:
yax.set_label_text('Posterior', fontsize=obj.fontsize)
def _plot_pairwise(obj):
"""Re-draw everything except the data inside the axes."""
npars = obj.npars
# Estimate size of axes (to later set the length of the ticks)
ax = obj.pair_axes[0,0]
fig = ax.get_figure()
nx = npars + int(obj.plot_marginal) - 1
for icol in range(npars-1):
for irow in range(icol, npars-1):
ax = obj.pair_axes[irow,icol]
h = nx*irow + icol + 1 + npars*int(obj.plot_marginal)
ax_position = subplot(
obj.rect, obj.margin, h, nx, nx, obj.ymargin, dry=True,
)
ax.set_position(ax_position)
if icol==0 and irow==0:
pt_to_pix = fig.canvas.get_renderer().points_to_pixels(72.0)
axes_size_pix = np.amin(ax.get_window_extent().size)
axes_size_pt = axes_size_pix / pt_to_pix * 72.0
tick_size = np.amin([3.5, axes_size_pt/15.0])
# Labels:
ax.tick_params(
labelsize=obj.fontsize-1,
length=tick_size,
direction='in',
)
if icol == 0:
ax.set_ylabel(obj.pnames[irow+1], size=obj.fontsize)
else:
ax.set_yticklabels([])
if irow == npars-2:
ax.set_xlabel(obj.pnames[icol], size=obj.fontsize)
plt.setp(ax.xaxis.get_majorticklabels(), rotation=90)
else:
ax.set_xticklabels([])
# Re-draw the colorbar:
colorbar = obj.colorbar
dx = (obj.rect[2]-obj.rect[0])*0.03
dy = (obj.rect[3]-obj.rect[1])*0.45
colorbar.ax.set_position([obj.rect[2]-dx, obj.rect[3]-dy, dx, dy])
colorbar.ax.clear()
boundaries = np.linspace(0.0, 1.0, obj.nlevels)
norm = mpl.colors.BoundaryNorm(boundaries, obj.nlevels)
cmap = mpl.colors.ListedColormap(obj.palette(boundaries, alpha=0.8))
mappable = mpl.cm.ScalarMappable(cmap=cmap, norm=norm)
obj.colorbar = mpl.colorbar.Colorbar(
ax=colorbar.ax,
mappable=mappable,
boundaries=boundaries,
ticks=np.linspace(0.0, 1.0, 6),
ticklocation='left',
)
colorbar.set_label('Posterior density', fontsize=obj.fontsize)
colorbar.ax.tick_params(labelsize=obj.fontsize-1, direction='in')
colorbar.ax.minorticks_off()
for col in colorbar.ax.collections:
col.set_edgecolor('face')
colorbar.ax.set_visible(obj.show_colorbar)
# Histogram:
for text in obj.stats_texts:
text.set_visible(False)
obj.stats_texts = []
nx = npars
for i in range(npars):
ax = obj.hist_axes[i]
ax.set_visible(obj.plot_marginal)
if not obj.plot_marginal:
continue
if obj.orientation == 'vertical':
xax, yax = ax.xaxis, ax.yaxis
plt.setp(ax.xaxis.get_majorticklabels(), rotation=90)
else:
xax, yax = ax.yaxis, ax.xaxis
h = (npars+1)*i + 1
ax_position = subplot(
obj.rect, obj.margin, h, nx, nx, obj.ymargin, dry=True)
ax.set_position(ax_position)
ax.tick_params(
labelsize=obj.fontsize-1,
length=tick_size,
direction='in', left=False, top=True,
)
if i == npars-1:
xax.set_label_text(obj.pnames[i], fontsize=obj.fontsize)
else:
xax.set_label_text('', fontsize=obj.fontsize)
xax.set_ticklabels([])
yax.set_ticklabels([])
if not obj.show_texts:
continue
elif i < npars-1:
stats_text = rf'{obj.pnames[i]} = {obj.source.tex_estimates[i]}'
else:
stats_text = rf'{obj.source.tex_estimates[i]}'
texts = [stats_text]
obj.stats_texts += colors.rainbow_text(ax, texts, obj.fontsize)
class SoftUpdate:
""" https://docs.python.org/3/howto/descriptor.html """
def __set_name__(self, obj, name):
self.private_name = '_' + name
def __get__(self, obj, objtype=None):
value = getattr(obj, self.private_name)
return value
def __set__(self, obj, value):
var_name = self.private_name[1:]
setattr(obj, self.private_name, value)
setattr(obj.source, var_name, value)
# TBD: Inspect this clause
if hasattr(obj, 'pair_axes'):
if obj.pair_axes is not None:
_plot_pairwise(obj)
plt.draw()
else:
if obj.hist_axes is not None:
_plot_marginal(obj)
plt.draw()
def raise_array_size_error(self, obj, value):
raise ValueError(
f"Invalid {self.private_name[1:]} input. Array size "
f"({len(value)}) does not match number of parameters "
f"({obj.npars})"
)
class SizeUpdate(SoftUpdate):
def __set__(self, obj, value):
setattr(obj, self.private_name, tuple(value))
if obj.fig is not None:
obj.fig.set_size_inches(*list(value))
class ThemeUpdate(SoftUpdate):
def __set__(self, obj, value):
var_name = self.private_name[1:]
# TBD: add checks
if isinstance(value, colors.Theme):
pass
elif isinstance(value, str) and value in colors.THEMES:
value = colors.THEMES[value]
elif is_color_like(value):
value = colors.Theme(value)
setattr(obj, self.private_name, value)
setattr(obj.source, var_name, value)
class BestpUpdate(SoftUpdate):
def __set__(self, obj, value):
var_name = self.private_name[1:]
if not hasattr(obj, 'npars'):
return
if value is None:
value = [None for _ in range(obj.npars)]
if len(value) != obj.npars:
self.raise_array_size_error(obj, value)
setattr(obj, self.private_name, value)
setattr(obj.source, var_name, value)
class StatsUpdate(SoftUpdate):
def __set__(self, obj, value):
var_name = self.private_name[1:]
setattr(obj, self.private_name, value)
setattr(obj.source, var_name, value)
class QuantileUpdate(SoftUpdate):
def __set__(self, obj, value):
var_name = self.private_name[1:]
setattr(obj, self.private_name, value)
setattr(obj.source, var_name, value)
class RangeUpdate(SoftUpdate):
def __set__(self, obj, value):
var_name = self.private_name[1:]
if not hasattr(obj, 'npars'):
return
pmins = np.nanmin(obj.posterior, axis=0)
pmaxs = np.nanmax(obj.posterior, axis=0)
# Defaults:
min_max = [(pmin, pmax) for pmin,pmax in zip(pmins, pmaxs)]
if value is None:
value = min_max
if len(value) != obj.npars:
self.raise_array_size_error(obj, value)
for i in range(obj.npars):
if value[i] is None:
value[i] = min_max[i]
setattr(obj, self.private_name, value)
setattr(obj.source, var_name, value)
[docs]class Marginal(object):
"""A mostly-interactive marginal posterior plotting object."""
# Soft-update properties:
pnames = SoftUpdate()
rect = SoftUpdate()
margin = SoftUpdate()
ymargin = SoftUpdate()
fontsize = SoftUpdate()
figsize = SizeUpdate()
show_texts = SoftUpdate()
show_estimates = SoftUpdate()
# Properties that require re-drawing:
bestp = BestpUpdate()
ranges = RangeUpdate()
theme = ThemeUpdate()
quantile = QuantileUpdate()
statistics = StatsUpdate()
def __init__(
self, source, posterior, pnames, bestp, ranges, theme,
nx=None, ny=None,
statistics='med_central', quantile=0.683,
bins=25, fontsize=11, linewidth=1.5,
axes=None,
show_texts=True, show_estimates=True,
):
self.source = source
self.fig = None
self.hist_axes = None
self.posterior = posterior
nsamples, self.npars = np.shape(posterior)
self.pnames = pnames
self.bestp = bestp
self.ranges = ranges
self.theme = theme
if nx is None or ny is None:
# Default layout:
npars = self.npars
if npars < 6: # Single row, N columns
nx = npars
elif npars < 13: # Two rows, up to 6 columns
nx = (npars+1) // 2
elif npars < 25: # Six columns, up to 4 rows
nx = 6
elif npars < 56: # 7 columns, up to 8 rows
nx = 7
else:
nx = 8 # Stick with 8 columns from now on
ny = 1 + (npars-1) // nx
self.nx = nx
self.ny = ny
# Default layout sizes:
dx0 = 0.4
self.figsize = size = [
dx0 + 1.45*self.nx,
1.75*self.ny + 0.1
]
self.margin = 0.04 / self.nx
self.ymargin = 0.275 / self.ny
self.rect = [
dx0/size[0], self.ymargin, 1.0 - 0.2/size[0], 1.0 - 0.1/size[1],
]
self.statistics = statistics
self.quantile = quantile
self.bins = bins
self.fontsize = fontsize
self.linewidth = linewidth
self.orientation = 'vertical'
self.show_texts = show_texts
self.show_estimates = show_estimates
self.stats_texts = []
def update(self):
# TBD: Need to erase previous axes
self.plot(
fignum=self.fignum,
)
[docs] def plot(
self, fignum=None, axes=None, quantile=None,
savefile=None,
):
"""Marginal histogram plot."""
# Create new figure unless explicitly point to existing axes:
if axes is not None:
self.hist_axes = axes
self.fig = axes[0].get_figure()
self.auto_axes = False
else:
self.fig = plt.figure(fignum)
self.fig.set_size_inches(self.figsize)
self.fig.clear()
self.auto_axes = True # False when user inputs custom axes
self.fignum = self.fig.number
self.figsize = self.fig.get_size_inches()
if axes is None:
self.hist_axes = np.tile(None, self.npars)
for i in range(self.npars):
self.hist_axes[i] = subplot(
self.rect, self.margin, i+1, self.nx, self.ny, self.ymargin,
)
if '_like' in self.statistics:
hpd_min = self.source.hpd_min
else:
hpd_min = None
estimates = self.source.estimates
if not self.show_estimates:
estimates = [None for _ in estimates]
_histogram(
self.posterior, estimates, self.ranges,
self.hist_axes, self.bins,
self.source.pdf, self.source.xpdf,
hpd_min, self.source.low_bounds, self.source.high_bounds,
self.linewidth, self.theme,
self.orientation,
top_pad=1.25,
)
_plot_marginal(self)
if savefile is not None:
self.fig.savefig(savefile, dpi=300)
class ShareUpdate:
""" https://docs.python.org/3/howto/descriptor.html """
def __set_name__(self, obj, name):
self.private_name = '_' + name
def __get__(self, obj, objtype=None):
value = getattr(obj, self.private_name)
return value
def __set__(self, obj, value):
priv_name = self.private_name
var_name = self.private_name[1:]
if hasattr(obj, priv_name) and value is getattr(obj, priv_name):
return
setattr(obj, priv_name, value)
for i in reversed(range(len(obj.figures))):
fig = obj.figures[i]
if not is_open(fig.fig):
obj.figures.pop(i)
else:
setattr(fig, var_name, value)
class ShareTheme(ShareUpdate):
def __set__(self, obj, value):
priv_name = self.private_name
var_name = self.private_name[1:]
if isinstance(value, colors.Theme):
pass
elif isinstance(value, str) and value in colors.THEMES:
value = colors.THEMES[value]
elif is_color_like(value):
value = colors.Theme(value)
if hasattr(obj, priv_name) and value == getattr(obj, priv_name):
return
setattr(obj, priv_name, value)
for i in reversed(range(len(obj.figures))):
fig = obj.figures[i]
if not is_open(fig.fig):
obj.figures.pop(i)
else:
setattr(fig, var_name, value)
class StatisticsUpdate(ShareUpdate):
def __set__(self, obj, value):
var_name = self.private_name[1:]
priv_name = self.private_name
if hasattr(obj, priv_name) and value is getattr(obj, priv_name):
return
setattr(obj, priv_name, value)
for i in reversed(range(len(obj.figures))):
fig = obj.figures[i]
if not is_open(fig.fig):
obj.figures.pop(i)
else:
setattr(fig, var_name, value)
has_all_attributes = (
hasattr(obj, 'bestp') and
hasattr(obj, 'statistics') and
hasattr(obj, 'quantile')
)
if has_all_attributes:
for i in range(obj.npars):
_, _, obj.hpd_min[i] = ms.cred_region(
obj.posterior[:,i],
quantile=obj.quantile,
)
estimates, low_bounds, high_bounds = ms.marginal_statistics(
obj.posterior, obj.statistics, obj.quantile,
pdf=obj.pdf, xpdf=obj.xpdf,
)
if obj.statistics.startswith('global_'):
obj.estimates = obj.bestp
else:
obj.estimates = estimates
obj.low_bounds = low_bounds
obj.high_bounds = high_bounds
obj.tex_estimates = u.tex_parameters(
obj.estimates,
obj.low_bounds,
obj.high_bounds,
significant_digits=2,
)
[docs]class Posterior(object):
"""
Classification of posterior plotting tools.
statistics: String
Statistics to use for parameter estimates and uncertainties:
global_* use global best-fit (bestp) estimate.
max_*: Marginal maximum-likelihood (mode) estimate.
med_*: Marginal median estimate.
*_like: HPD credible interval.
*_central: Central quantile interval.
Examples
--------
>>> import mc3
>>> mcmc = np.load('MCMC_HD209458b_sing_0.29-2.0um_MM2017.npz')
>>> posterior, zchain, zmask = mc3.utils.burn(mcmc)
>>> pnames = mcmc['texnames']
>>> bestp = mcmc['bestp']
>>> p = mc3.plots.Posterior(posterior, pnames, bestp)
>>> f1 = p.plot(savefile=f'pairwise_{6:02d}pars.png')
>>> f2 = p.plot_histogram(savefile=f'histogram_{6:02d}pars.png')
"""
# Soft-update properties:
pnames = ShareUpdate()
ranges = ShareUpdate()
theme = ShareTheme()
bestp = StatisticsUpdate()
statistics = StatisticsUpdate()
quantile = StatisticsUpdate()
show_texts = ShareUpdate()
show_estimates = ShareUpdate()
show_colorbar = ShareUpdate()
def __init__(
self, posterior, pnames=None, bestp=None, ranges=None,
statistics='med_central', quantile=0.683,
sample_size=20000,
theme='blue', orientation='vertical',
show_texts=True, show_estimates=True,
show_colorbar=True,
seed=314159,
):
self.figures = []
nsamples, self.npars = np.shape(posterior)
rng = np.random.default_rng(seed)
if sample_size < nsamples:
sample = rng.choice(nsamples, sample_size, replace=False)
sampled_posterior = posterior[sample]
else:
sampled_posterior = np.copy(posterior)
# TBD: enforce posterior as 2D
self.posterior = sampled_posterior
# Defaults:
if pnames is None:
pnames = [f'p{i:02d}' for i in range(self.npars)]
self.pnames = pnames
self.ranges = ranges
self.theme = theme
self.orientation = orientation
self.show_texts = show_texts
self.show_estimates = show_estimates
self.show_colorbar = show_colorbar
self.pdf = [None for _ in range(self.npars)]
self.xpdf = [None for _ in range(self.npars)]
self.hpd_min = [None for _ in range(self.npars)]
for i in range(self.npars):
pdf, xpdf, hpd = ms.cred_region(
self.posterior[:,i], quantile=quantile,
)
self.pdf[i] = pdf
self.xpdf[i] = xpdf
# These will trigger the param estimate calcs in StatisticsUpdate():
if bestp is None:
self.bestp = [None for _ in range(self.npars)]
else:
self.bestp = bestp
self.statistics = statistics
self.quantile = quantile
[docs] def plot(
self, plot_marginal=True,
fignum=None, figure=None,
quantile=None,
linewidth=None, fontsize=None,
figsize=None, rect=None,
margin=None, ymargin=None,
show_texts=None, show_estimates=None,
show_colorbar=None,
savefile=None,
):
"""
Plot marginal histograms and pairwise posteriors.
"""
# Defaults:
if quantile is None:
quantile = self.quantile
if show_estimates is None:
show_estimates = self.show_estimates
if show_texts is None:
show_texts = self.show_texts
if show_colorbar is None:
show_colorbar = self.show_colorbar
fig = Figure(
self,
self.posterior, self.pnames, self.bestp,
self.ranges, self.theme,
rect=rect,
margin=margin,
ymargin=ymargin,
statistics=self.statistics,
quantile=quantile,
plot_marginal=plot_marginal,
linewidth=linewidth,
fontsize=fontsize,
figsize=figsize,
fignum=fignum,
show_texts=show_texts,
show_estimates=show_estimates,
show_colorbar=show_colorbar,
# bins=25, nlevels=20,
)
self.figures.append(fig)
fig.plot(figure=figure, savefile=savefile)
return fig
[docs] def plot_histogram(
self, fignum=None, axes=None, quantile=None,
nx=None, ny=None,
savefile=None,
show_texts=None, show_estimates=None,
):
"""
Plot histogram of marginal posteriors.
"""
if show_estimates is None:
show_estimates = self.show_estimates
if show_texts is None:
show_texts = self.show_texts
fig = Marginal(
self,
self.posterior, self.pnames, self.bestp,
self.ranges, self.theme,
statistics=self.statistics,
nx=nx, ny=ny,
show_texts=show_texts,
show_estimates=show_estimates,
)
self.figures.append(fig)
fig.plot(savefile=savefile, fignum=fignum, axes=axes)
return fig
[docs] def add():
"""TBD: Do not call this method."""
pass
def update(self, **kwargs):
for key, value in kwargs.items():
# If key in valid keys:
print(f'{key} = {value}')
# Else: throw warning
replot = False
# if hard-replotting, update parameter in kwargs.keys():
# 'bestp' in kwargs.keys():
# 'ranges' in kwargs.keys():
# 'linewidth' in kwargs.keys():
# 'theme' in kwargs.keys()
# 'quantile'
# 'bins'
# 'nlevels'
replot = True
if len(kwargs) == 0:
replot = True
for key, value in kwargs.items():
setattr(self, key, value)
#if replot:
for fig in self.figures:
fig.update()
def plot_trace():
pass