# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
#IMPORTS
import os,re
import sys
from scipy import ndimage
import astropy.io.ascii as asciitable
from astropy.table import Table
import math
import numpy as np
np.random.seed(seed=1234)
# LOGGING CONFIGURATION
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('GalPaK: Plots')
from .galaxy_parameters import GalaxyParameters
from .hyperspectral_cube import HyperspectralCube as HyperCube
# OPTIONAL IMPORTS
try:
import bottleneck as bn
except ImportError:
logger.info("bottleneck (optional) not installed, performances will be degraded")
import numpy as bn
#matplotlib
import matplotlib as mpl
from mpl_toolkits.axes_grid import inset_locator as inl
# from matplotlib.ticker import MaxNLocator
from matplotlib import pyplot as plot
class Plots:
def plot_mcmc(self, filepath=None, plot_likelihood=False, adapt_range='5stdev', fontsize=10):
"""
Plot the MCMC chain details, and then either show it or save it to a file.
filepath: string
If specified, will write the plot to a file instead of showing it.
The file will be created at the provided filepath, be it absolute or relative.
The extension of the file must be either png or pdf.
plot_likelihood: bool
True to plot -log[chi2] instead
adapt_range: string
'boundaries' to adapt the range to boundaries
'minmax' [default] to adapt the range to min/max values
'3stdev' to adapt the range to 3 x stdev
'5stdev' to adapt the range to 5 x stdev
fontsize: int
to change the fontsize
"""
if self.chain is None:
raise RuntimeError(self.NO_CHAIN_ERROR)
if filepath is not None:
name, extension = os.path.splitext(filepath)
supported_extensions = ['.png', '.pdf']
if not extension in supported_extensions:
raise ValueError("Extension '%s' is not supported, use %s.",
extension, ', '.join(supported_extensions))
chain = self.chain.copy()
xmin = self.chain.xmin
xmax = self.chain.xmax
if self.method == 'chi_sorted':
chain.sort('reduced_chi')
chain_size = np.size(chain, 0)
xmax = int(chain_size * self.chain_fraction / 100.) # number of samples (last_fraction(%) of total)
xmin = 0
elif self.method == 'last':
chain_size = np.size(self.chain, 0)
xmin = chain_size - int(chain_size * self.chain_fraction / 100.) # number of samples (last_fraction(%) of total)
xmax = chain_size
elif self.method == 'chi_min':
min_chi_index = self._get_min_chi_index()
chain_size = np.size(self.chain, 0)
n = int(chain_size * self.chain_fraction / 100.)
xmin = np.max([0, min_chi_index - n / 2])
xmax = np.min([chain_size, min_chi_index + n / 2])
if fontsize is None:
fontsize = plot.rcParams['font.size']
fig = plot.figure(1, figsize=(16, 9))
names = chain.dtype.names
short_dict = self.galaxy.short_dict()
plot.clf() # clear current figure
plot.subplots_adjust(wspace=0.32, hspace=0.32,
bottom=0.05, top=0.95, left=0.05, right=0.95)
n = np.size(names)
rows = 3
cols = int(math.ceil((n-1) / (rows - 1.)))
for i, par in enumerate(names):
if i < n - 1:
# Parameters
plot.subplot2grid((rows, cols), (int(math.floor(i / cols)), i % cols))
plot.plot(chain[par].data, zorder=0)
plot.hlines(y=self.galaxy[i], xmin=xmin, xmax=xmax, color='r', label=r'$\hat P$', zorder=10)
plot.hlines(y=self.galaxy[i]+self.galaxy.stdev[i],xmin=xmin, xmax=xmax,color='k',lw=2, label=r'$\sigma$', zorder=10)
plot.hlines(y=self.galaxy[i]-self.galaxy.stdev[i], xmin=xmin, xmax=xmax, color='k',lw=2, zorder=10)
if self.galaxy.lower is not None:
plot.hlines(y=self.galaxy.lower[i], xmin=xmin, xmax=xmax, color='k', linestyles='dotted',lw=2, zorder=10, label='%d ' % (self.percentile) + '% CI')
if self.galaxy.upper is not None:
plot.hlines(y=self.galaxy.upper[i], xmin=xmin, xmax=xmax, color='k', linestyles='dotted',lw=2, zorder=10)
if i == n-2:
#plot.legend(loc=0, fontsize=13)
#plot.legend(loc=0, fontsize=13, bbox_to_anchor=(0.02, 0.3))
plot.legend(loc='lower center', fontsize=fontsize, bbox_to_anchor=(0.5, 0.0), ncol=3, fancybox=True, shadow=True)
if (adapt_range == '5stdev') and (self.galaxy.stdev is not None):
plot.ylim(self.galaxy[i]-5.*self.galaxy.stdev[i], self.galaxy[i]+5.*self.galaxy.stdev[i])
elif (adapt_range == '3stdev') and (self.galaxy.stdev is not None):
plot.ylim(self.galaxy[i]-3.*self.galaxy.stdev[i], self.galaxy[i]+3.*self.galaxy.stdev[i])
elif (adapt_range =='boundaries') and (self.min_boundaries is not None and self.max_boundaries is not None):
plot.ylim(self.min_boundaries[i], self.max_boundaries[i])
elif (adapt_range == 'minmax') or (adapt_range is None):
plot.ylim(np.min(chain[par]), np.max(chain[par]) )
title = short_dict[par]
title += "(%s)" % (self.galaxy.unit_dict()[par])
plot.title(title, fontsize=fontsize)
else:
# Last row, reduced ki
plot.subplot2grid((rows, cols), (rows - 1, 0), colspan=cols)
if (plot_likelihood):
plot.plot(np.exp(-chain['reduced_chi'] ))
plot.ylim([0,1])
plot.title(r'${\cal L}=\exp$[-$\chi^2$]')
else:
plot.plot(np.log10(chain['reduced_chi'] - np.min(chain['reduced_chi'])))
plot.title(r'$\log$ [$\chi^2 - \chi^2_{min}$]')
plot.plot(0,-1,color='grey',lw=2,label=r'$1\sigma$')
if self.percentile is not None:
plot.plot(0,-2,color='grey',ls='-.',lw=2,label='%.2f' % (self.percentile) +'% CI' )
plot.xticks(np.arange(2.) / 2 * np.size(chain['flux']))
fig.subplots_adjust(wspace=0.3)
if filepath is None:
plot.show()
else:
plot.savefig(filepath)
def plot_geweke(self, filepath=None, fontsize=10, Nsigma=3, Nintervals=25, full_chain=False):
"""
Plot the geweke score for each parameter, and then either show it or save it to a file.
https://pymc-devs.github.io/pymc/modelchecking.html
filepath: string
If specified, will write the plot to a file instead of showing it.
The file will be created at the provided filepath, be it absolute or relative.
The extension of the file must be either png or pdf.
fontsize: int [None default]
to change the fontsize for plot
if None will use rcParams
Nsigma: int [3 default]
sigma range for convergence test
Nintervals
Number of intervals for geweke statistics
full_chain: boolean
True if score for full chain
"""
if self.chain is None:
raise RuntimeError(self.NO_CHAIN_ERROR)
if filepath is not None:
name, extension = os.path.splitext(filepath)
supported_extensions = ['.png', '.pdf']
if not extension in supported_extensions:
raise ValueError("Extension '%s' is not supported, use %s.",
extension, ', '.join(supported_extensions))
chain = self.chain.copy()
#plot geweke has no meaning with chi_sorted
if self.method is not None and self.method != 'chi_sorted':
if fontsize is None:
fontsize = plot.rcParams['font.size']
fig = plot.figure(2, figsize=(16, 9))
names = chain.dtype.names
short_dict = self.galaxy.short_dict()
plot.clf() # clear current figure
plot.subplots_adjust(wspace=0.32, hspace=0.32,
bottom=0.05, top=0.95, left=0.05, right=0.95)
n = np.size(names)
score = np.zeros( len(self.galaxy.names), dtype='<f8')
rows = 3
cols = int(math.ceil((n-1) / (rows - 1.)))
for i, par in enumerate(self.galaxy.names):
#if i<n-1:
# Parameters
plot.subplot2grid((rows, cols), (int(math.floor(i / cols)), i % cols))
if i==0:
plot.ylabel('$\sigma$')
if full_chain is False:
zscore = self._geweke_score(self.sub_chain[par], intervals=Nintervals)
else:
zscore = self._geweke_score(self.chain[par], intervals=Nintervals)
_x = zscore[:,0]
_y = zscore[:,1]
if full_chain is False:
plot.plot(_x, _y, color='k', label='Sub-chain')
else:
plot.plot(_x, _y, color='k', label='Full-chain')
plot.plot(_x[0], _y[0], 'ks',ms=4,label=' Initial at 100\%')
plot.plot(_x[-1],_y[-1],'ko',ms=7,label=' Final at $>50$\%')
#in red masked array
_mask = np.ma.masked_inside(_y, -Nsigma, Nsigma)
score[i] = _mask.mask.sum()/Nintervals #should be 1 if True
if score[i]==1.0:
self.logger.info('Parameter %s has converged ? %s' % (par, score[i]))
else:
self.logger.warning('Parameter %s has not converged ? %s' % (par, score[i]))
plot.plot(_x,_mask, color='r')
plot.plot(_x[-1],_mask[-1],'ro',ms=7)
plot.ylim([-5,5])
plot.axhline(-Nsigma,ls='-.',lw=2.5)
plot.axhline(+Nsigma,ls='-.',lw=2.5, label=r'%s$\sigma$' % (Nsigma) )
plot.axhline(-1,ls='-.',lw=1)
plot.axhline(+1,ls='-.',lw=1,label=r'$1\sigma$')
title = short_dict[par]
title += "(%s)" % (self.galaxy.unit_dict()[par])
plot.title(title, fontsize=fontsize)
#elif i==n-1:
plot.xlabel('Start index')
plot.legend(loc='lower center', fontsize=fontsize, bbox_to_anchor=(0.0, -0.8), ncol=2, fancybox=True, shadow=True)
self.convergence=Table(score,names=self.galaxy.names)
#fig.subplots_adjust(wspace=0.3)
if filepath is None:
plot.show()
else:
plot.savefig(filepath)
else:
logger.warning('Plot geweke not allowed')
def plot_corner(self, filepath=None, smooth=None, nsigma=4, fontsize=10):
'''
using corner package to plot chain
filepath: string [default:None]
filename for saving the plot
smooth: [default: None]
smooth option from corner
nsigma: int [default:4]
N-sigma (nsigma x stdev) to determine the ranges for each plot
using the range option from corner
'''
if np.size(self.chain)>200:
try:
import corner
corner_true = True
except:
corner_true = False
self.logger.warning('plot_corner(): Need corner package for this')
if corner_true:
shorts_dict = self.galaxy.short_dict()
shorts_dict.pop('x') #remove x
shorts_dict.pop('y') #remove y
shorts_dict.pop('z') #remove z
params = shorts_dict.keys()
show = list(params)
s=show[0]
ls=[shorts_dict[s]] # also
#n sigma
r = [(self.galaxy[s] - self.galaxy.stdev[s] * nsigma, \
self.galaxy[s] + self.galaxy.stdev[s] * nsigma)]
best = list([self.galaxy[s]])
for s in show[1:]:
r.append((self.galaxy[s] - self.galaxy.stdev[s] * nsigma, \
self.galaxy[s] + self.galaxy.stdev[s] * nsigma))
best.append(self.galaxy[s])
ls.append(shorts_dict[s])
if fontsize is None:
fontsize = plot.rcParams['font.size']
#@fixme how to make onto fig=plot.figure(3)
corner.corner(self.sub_chain[show].as_array().tolist(),\
labels=ls, truths=best, range=r, \
smooth=smooth, label_kwargs={'fontname':'sans-serif', 'fontsize': fontsize})
if filepath is None:
plot.show()
else:
plot.savefig(filepath)
else:
raise Warning('package corner not installed')
else:
corner_true = False
raise self.logger.warning('plot_corner skipped, chain too small (<200)')
return corner_true
def plot_correlations(self, filepath=None, nbins=15, fontsize=10):
"""
Plot the correlations between parameters.
filepath: string
If specified, will write the plot to a file instead of showing it,
at the provided absolute or relative filepath.
The extension of the file must be either png or pdf.
Eg: filepath='/home/me/my_galaxy/galpak/run_correlations.png'
fontsize: int
to change the fontsize
"""
if self.chain is None or self.sub_chain is None:
raise RuntimeError(self.NO_CHAIN_ERROR)
if filepath is not None:
name, extension = os.path.splitext(filepath)
supported_extensions = ['.png', '.pdf']
if not extension in supported_extensions:
raise ValueError("Extension '%s' is not supported, use %s.",
extension, ', '.join(supported_extensions))
if fontsize is None:
fontsize = plot.rcParams['font.size']
fig=plot.figure(2, figsize=(16, 9))
plot.clf() # clear current figure
names = self.chain.dtype.names
short_dict = self.galaxy.short_dict()
# 'x', 'y', 'z', 'flux', 'radius', 'inclination', 'pa', 'rv',
# 'maximum_velocity', 'velocity_dispersion','reduced_chi'
idx_selection = np.arange(4,np.size(names)-1) #4:-1]
#plot.subplots_adjust(wspace=0.25, hspace=0.25, bottom=0.1, top=0.85, left=0.1, right=0.85)
rows = np.shape(idx_selection)[0] - 1
cols = np.shape(idx_selection)[0] - 1
for i in np.arange(rows):
for j in np.arange(i, cols):
jj = names[idx_selection[i] ]
ii = names[idx_selection[j + 1]]
xdat = self.sub_chain[ii].data
ydat = self.sub_chain[jj].data # chain[names[jj]]
x_edges = np.array([self.galaxy[ii] - self.stdev[ii] * 5, self.galaxy[ii] + self.stdev[ii] * 5.])
y_edges = np.array([self.galaxy[jj] - self.stdev[jj] * 5, self.galaxy[jj] + self.stdev[jj] * 5.])
extend = [x_edges[0], x_edges[1], y_edges[0], y_edges[1]]
#plot parameters
#plot.plot(xdat,ydat,lw=1)
# from top,left (y,x)
plot.subplot2grid((rows, cols), (i, j))
#plot.plot(xdat,ydat,'k.',lw=1,alpha=0.5)
#plot.axis(extend)
plot.axvline(self.galaxy[ii], ls='-', c='k')
plot.axvline(self.galaxy[ii] + self.galaxy.stdev[ii], ls='-.', c='grey')
plot.axvline(self.galaxy[ii] - self.galaxy.stdev[ii], ls='-.', c='grey')
plot.axhline(self.galaxy[jj], ls='-', c='k')
plot.axhline(self.galaxy[jj] + self.galaxy.stdev[jj], ls='-.', c='grey')
plot.axhline(self.galaxy[jj] - self.galaxy.stdev[jj], ls='-.', c='grey')
image2d, _edg = np.histogramdd((ydat, xdat), bins=nbins,
range=((y_edges[0], y_edges[1]), (x_edges[0], x_edges[1])))
plot.contour(image2d, vmin=np.min(image2d), vmax=bn.nanmax(image2d), extent=extend, color='k')
#plot.plot(xdat,ydat,'k.',alpha=0.5,lw=1)
if (x_edges[1] - x_edges[0])>0:
xticks = np.rint(2 * np.arange(4) / 4. * (x_edges[1] - x_edges[0]) + x_edges[0] * 2) / 2.0
plot.xticks(np.unique(xticks))
plot.xlim(x_edges)
if (y_edges[1] - y_edges[0])>0:
yticks = np.rint(2 * np.arange(4) / 4. * (y_edges[1] - y_edges[0]) + y_edges[0] * 2) / 2.0
plot.yticks(np.unique(yticks))
plot.ylim(y_edges)
if i == j:
plot.xlabel(short_dict[ii])
plot.ylabel(short_dict[jj])
fig.subplots_adjust(wspace=0.3)
if filepath is None:
plot.show()
else:
plot.savefig(filepath)
def plot_true_vfield(self, filepath=None, mask=None, contours=None,
fontsize=10, slitwidth=3):
"""
plot the 2d maps and the 1d along the major axis
mask: [optional] 2d nd-array
mask for display purposes
None: [default] default mask is flux>max(flux)/20.
1: not apply any mask
nd array: mask to be used
contours: [optional] 2d nd-array
external map to overlay
fontsize: int
to change the fontsize
slitwidth: 3 [default] the slit width used to extract 1d profile
"""
fmap = self.true_flux_map.data
vmap = self.true_velocity_map.data
smap = self.true_disp_map.data
try:
vmax = self.galaxy.maximum_velocity
except:
try:
vmax = self.galaxy.virial_velocity
except:
try:
vmax = self.galaxy.halo_velocity
except:
vmax = None
self.logger.warning("Galaxy has no Vmax Vvir halo_velocity")
if mask is None:
# default mask is flux>max(flux)/30.
mask = (fmap > bn.nanmax(fmap) / 50.) * (fmap != 0)
xc = self.galaxy.x
yc = self.galaxy.y
pixscale = self.instrument.xy_step
ny, nx = np.shape(fmap)
#matplotlib display tickmarks on half pixels
x0 = pixscale * (-xc -0.5)
x1 = pixscale * (nx - xc -0.5)
y0 = pixscale * (-yc -0.5)
y1 = pixscale * (ny - yc -0.5)
extent = [x0, x1, y0, y1]
if fontsize is None:
fontsize = plot.rcParams['font.size']
plot.figure(1)
plot.clf()
plot.subplot(2, 3, 1)
ax1=self._plot2dimage(fmap / mask,
vmin=0, vmax=np.nanmax(fmap),
xlabel=r'$\Delta \alpha$(")',
ylabel=r'$\Delta \delta$(")', interpolation='nearest',
extent=extent, contour=contours, title='True Flux map')
plot.subplot(2, 3, 2)
ax2=self._plot2dimage(vmap / mask,
vmin=np.nanmin(vmap), vmax=bn.nanmax(vmap),
xlabel=r'$\Delta \alpha$(")',
ylabel=r'$\Delta \delta$(")', interpolation='nearest',
extent=extent, contour=contours, title='True Vel. map')
plot.subplot(2, 3, 3)
s1 = bn.nanmax(smap)
try:
s0 = bn.nanmedian(smap) - 2 * bn.nanstd(smap)
except:
s0 = np.nanmean(smap) - 2 * np.nanstd(smap)
ax3=self._plot2dimage(smap / mask, vmin=s0, vmax=s1,
xlabel=r'$\Delta \alpha$(")',
ylabel=r'$\Delta \delta$(")',
extent=extent, contour=contours,
interpolation='nearest', title='True Disp. map')
xx, slice_v = self._slit(fmap, slitwidth, ax=ax1)
good = (slice_v != 0)
flux_profile = slice_v[good]
plot.subplot(2, 3, 4)
if self.error_maps:
fmap_err = self.true_flux_map_error.data
ax4=self._plot2dimage(fmap/fmap_err/mask,
vmin=0, vmax=np.nanmax(fmap/fmap_err),
xlabel=r'$\Delta \alpha$(")',
ylabel=r'$\Delta \delta$(")',
extent=extent, contour=fmap,
interpolation='nearest', title='True Fmap SNR')
else:
plot.plot(pixscale * xx[good], np.log10(slice_v[good]), 'k-')
plot.xlim([-3, 3])
plot.ylim([np.min(np.log10(slice_v[good])), bn.nanmax(np.log10(1.3 * slice_v[good]))])
plot.xlabel(r'$\delta x$')
plot.ylabel(r'$\log$ I(r)')
plot.subplot(2, 3, 5)
xx, slice_v = self._slit(vmap, slitwidth, ax=ax2)
rotation_curve=[pixscale * xx[good], slice_v[good], flux_profile]
if self.error_maps:
vmap_err = self.true_velocity_map_error.data
ax5=self._plot2dimage(abs(vmap/vmap_err)/mask,
vmin=0,
vmax=np.nanmax(abs(vmap/vmap_err)),
xlabel=r'$\Delta \alpha$(")',
extent=extent, contour=fmap,
interpolation='nearest', title='True Vmap SNR')
else:
plot.plot(pixscale * xx[good], slice_v[good], 'k-')
plot.xlim([-3, 3])
if vmax is not None:
plot.ylim([-vmax, vmax])
plot.axhline(vmax * np.sin(np.radians(self.galaxy.inclination)), color='k', ls='--')
plot.axhline(-vmax * np.sin(np.radians(self.galaxy.inclination)), color='k', ls='--')
plot.xlabel(r'$\delta x$')
plot.ylabel(r'$V_{los}(x)$')
plot.subplot(2, 3, 6)
xx, slice_v = self._slit(smap, slitwidth, ax=ax3)
if self.error_maps:
smap_err = self.true_disp_map_error.data
ax6=self._plot2dimage(smap / smap_err / mask,
vmin=0,
vmax=np.nanmax(smap / smap_err),
xlabel=r'$\Delta \alpha$(")',
extent=extent, contour=fmap,
interpolation='nearest', title='True Disp SNR')
else:
plot.plot(pixscale * xx[good], slice_v[good], 'k-')
if 'velocity_dispersion' in self.galaxy.short_dict().keys():
plot.axhline(self.galaxy.velocity_dispersion, ls='-.', label='Vturb')
plot.xlim([-3, 3])
plot.ylim([s0, s1])
plot.xlabel(r'$\delta x$')
plot.legend()
if filepath is None:
plot.show()
else:
plot.savefig(filepath)
rotation_name=re.sub('true_maps','true_Vrot',filepath[:-4])+'.dat'
asciitable.write(rotation_curve, \
output=rotation_name,Writer=asciitable.FixedWidth, \
names=['dx_arcsec','v_kms', 'flux_slit'])
def plot_obs_vfield(self, filepath=None, mask=None, contours=None,
fontsize=10, slitwidth=3):
"""
plot the observed 2d maps and the 1d along the major axis
mask: [optional] 2d nd-array
mask for display purposes
None: [default] default mask is flux>max(flux)/20.
1: not apply any mask
nd array: mask to be used
contours: [optional] 2d nd-array
external map to overlay
fontsize: int
to change the fontsize
"""
#Fmap, Vmap, Smap = self._make_moment_maps(self.convolved_cube, mask=True)
Fmap = self.obs_flux_map
Vmap = self.obs_velocity_map
Smap = self.obs_disp_map
fmap = Fmap.data
vmap = Vmap.data
smap = Smap.data
try:
vmax = self.galaxy.maximum_velocity
except:
try:
vmax = self.galaxy.virial_velocity
except:
try:
vmax = self.galaxy.halo_velocity
except:
vmax = None
self.logger.warning("Galaxy has no Vmax Vvir halo_velocity")
if mask is None:
# default mask is flux>max(flux)/30.
mask = (fmap > bn.nanmax(fmap) / 50.) * (fmap != 0)
xc = self.galaxy.x
yc = self.galaxy.y
pixscale = self.instrument.xy_step
## Matplotlib uses a grid with tickmarks on the middle of the pixels
## hence the -0.5
ny, nx = np.shape(fmap)
x0 = pixscale * (-xc -0.5)
x1 = pixscale * (nx - xc -0.5)
y0 = pixscale * (-yc -0.5)
y1 = pixscale * (ny - yc -0.5)
extent = [x0, x1, y0, y1]
if fontsize is None:
fontsize = plot.rcParams['font.size']
plot.figure(1)
plot.clf()
plot.subplot(2, 3, 1)
ax1=self._plot2dimage(fmap / mask,
vmin=0, vmax=np.nanmax(fmap),
xlabel=r'$\Delta \alpha$(")',
ylabel=r'$\Delta \delta$(")', interpolation='nearest',
extent=extent, contour=contours, title='Obs. Flux map')
plot.subplot(2, 3, 2)
ax2=self._plot2dimage(vmap / mask,
vmin=np.nanmin(vmap), vmax=bn.nanmax(vmap),
xlabel=r'$\Delta \alpha$(")',
ylabel=r'$\Delta \delta$(")', interpolation='nearest',
extent=extent, contour=contours, title='Obs. Vel. map')
plot.subplot(2, 3, 3)
s1 = bn.nanmax(smap)
try:
s0 = bn.nanmedian(smap) - 2 * bn.nanstd(smap)
except:
s0 = np.nanmean(smap) - 2 * np.nanstd(smap)
ax3=self._plot2dimage(smap / mask, vmin=s0, vmax=s1,
xlabel=r'$\Delta \alpha$(")',
ylabel=r'$\Delta \delta$(")',
extent=extent, contour=contours,
interpolation='nearest', title='Obs. Disp. map')
plot.subplot(2, 3, 4)
xx, slice_v = self._slit(fmap, slitwidth, ax=ax1)
good = (slice_v != 0)
plot.plot(pixscale * xx[good], np.log10(slice_v[good]), 'k-')
plot.xlim([-3, 3])
plot.ylim([np.min(np.log10(slice_v[good])), bn.nanmax(np.log10(1.3 * slice_v[good]))])
plot.xlabel(r'$\delta x$')
plot.ylabel(r'$\log$ I(r)')
plot.subplot(2, 3, 5)
xx, slice_v = self._slit(vmap, slitwidth, ax=ax2)
plot.plot(pixscale * xx[good], slice_v[good], 'k-')
plot.xlim([-3, 3])
if vmax is not None:
plot.ylim([-vmax, vmax])
plot.axhline(vmax * np.sin(np.radians(self.galaxy.inclination)), color='k', ls='--')
plot.axhline(-vmax * np.sin(np.radians(self.galaxy.inclination)), color='k', ls='--')
plot.xlabel(r'$\delta x$')
#p.ylabel(r'$V_{los}(x)$')
plot.subplot(2, 3, 6)
xx, slice_v = self._slit(smap, slitwidth, ax=ax3)
plot.plot(pixscale * xx[good], slice_v[good], 'k-')
if 'velocity_dispersion' in self.galaxy.short_dict().keys():
plot.axhline(self.galaxy.velocity_dispersion, ls='-.', label='Vturb')
plot.xlim([-3, 3])
plot.ylim([s0, s1])
plot.xlabel(r'$\delta x$')
plot.legend()
if filepath is None:
plot.show()
else:
plot.savefig(filepath)
def plot_images(self, filepath=None, z_crop=None):
"""
Plot a mosaic of images of the cropped (along z) cubes,
and then either show it or save it to a file.
filepath: string
If specified, will write the plot to a file instead of showing it.
The file will be created at the provided absolute or relative filepath.
The extension of the file must be either png or pdf.
z_crop: None|int
The maximum and total length of the crop (in pixels) along z,
centered on the galaxy's z position.
If you provide zero or an even value (2n),
the closest bigger odd value will be used (2n+1).
By default, will not crop.
"""
if self.chain is None:
raise RuntimeError(self.NO_CHAIN_ERROR)
if filepath is not None:
name, extension = os.path.splitext(filepath)
supported_extensions = ['.png', '.pdf']
if not extension in supported_extensions:
raise ValueError("Extension '%s' is not supported, "
"you may use one of %s",
extension, ', '.join(supported_extensions))
self._plot_images(self.cube, self.galaxy, self.convolved_cube,
self.deconvolved_cube, self.residuals_cube,
z_crop=z_crop)
if filepath is None:
plot.show()
else:
plot.savefig(filepath)
def _plot_images(self, cube, galaxy, convolved_cube, deconvolved_cube,
residuals_cube, z_crop=None):
"""
Plot a mosaic of images of the cropped (along z) cubes.
z_crop: None|int
The maximum and total length of the crop (in pixels) along z,
centered on the galaxy's z position.
If you provide zero or an even value (2n),
the closest bigger odd value will be used (2n+1).
By default, will not crop.
"""
if z_crop is None:
zmin = 0
zmax = cube.shape[0] - 1
else:
if not (z_crop & 1):
z_crop += 1
z0 = galaxy.z
zd = (z_crop - 1) / 2
zmin = max(0, z0 - zd)
zmax = z0 + zd + 1
fig = plot.figure(1, figsize=(16, 9))
plot.clf()
plot.subplots_adjust(wspace=0.25, hspace=0.25, bottom=0.05, top=0.95, left=0.05, right=0.95)
# MEASURE
sub = fig.add_subplot(2, 2, 1)
sub.set_title('Measured')
measured_cube_cropped = cube.data[zmin:zmax, :, :]
image_measure = (measured_cube_cropped.sum(0) / measured_cube_cropped.shape[0])
plot.imshow(image_measure, interpolation='nearest', origin='lower')
plot.xticks(fontsize=8)
plot.yticks(fontsize=8)
colorbar = plot.colorbar()
colorbar.ax.tick_params(labelsize=8)
# CONVOLVED
sub = fig.add_subplot(2, 2, 2)
sub.set_title('Convolved')
convolved_cube_cropped = convolved_cube.data[zmin:zmax, :, :]
image_convolved = (convolved_cube_cropped.sum(0) / convolved_cube_cropped.shape[0])
plot.imshow(image_convolved, interpolation='nearest', origin='lower')
plot.xticks(fontsize=8)
plot.yticks(fontsize=8)
colorbar = plot.colorbar()
colorbar.ax.tick_params(labelsize=8)
# DECONVOLVED
sub = fig.add_subplot(2, 2, 3)
sub.set_title('Deconvolved')
deconvolved_cube_cropped = deconvolved_cube.data[zmin:zmax, :, :]
image_deconvolved = (deconvolved_cube_cropped.sum(0) / deconvolved_cube_cropped.shape[0])
plot.imshow(image_deconvolved, interpolation='nearest', origin='lower')
plot.xticks(fontsize=8)
plot.yticks(fontsize=8)
colorbar = plot.colorbar()
colorbar.ax.tick_params(labelsize=8)
# ERROR
sub = fig.add_subplot(2, 2, 4)
sub.set_title('Error')
square_error_cube_cropped = residuals_cube.data[zmin:zmax, :, :]
nz = square_error_cube_cropped.shape[0]
#normalized error image in sigmas:
image_error = (square_error_cube_cropped.sum(0) / nz) * np.sqrt(nz)
#vmin = np.amin(image_measure)
#vmax = np.amax(image_measure)
#plot.imshow(image_error, interpolation='nearest', origin='lower', vmin=vmin, vmax=vmax)
plot.imshow(image_error, vmin=-2.5, vmax=2.5, interpolation='nearest', origin='lower')
plot.xticks(fontsize=8)
plot.yticks(fontsize=8)
colorbar = plot.colorbar()
colorbar.ax.tick_params(labelsize=8)
return fig
def _plot2dimage(self, image, vmin=-250., vmax=250., pos=None,
xlabel=None, ylabel=None, contour=None,
extent=None, title=None, interpolation=None):
"""
Plot a 2D image with colorbars.
"""
# Matplotlib's default origin is not ours, we need to use 'lower'
origin = 'lower'
if extent is not None:
plot.imshow(image, aspect='equal', vmin=vmin, vmax=vmax,
extent=extent, origin=origin,
interpolation=interpolation)
plot.axhline(0,color='k',lw=1,alpha=0.5)
plot.axvline(0,color='k',lw=1,alpha=0.5)
if contour is not None:
cmax = bn.nanmax(contour)
plot.contour(contour, extent=extent, color='k',
levels=[cmax / 10., cmax / 5., cmax / 2.],
origin=origin)
else:
plot.imshow(image, aspect='equal', vmin=vmin, vmax=vmax,
origin=origin)
plot.axhline(0,color='k',lw=1,alpha=0.5)
plot.axvline(0,color='k',lw=1,alpha=0.5)
if contour is not None:
cmax = bn.nanmax(contour)
plot.contour(contour, levels=[cmax / 10., cmax / 5., cmax / 2.],
color='k', origin=origin)
if title is not None:
plot.title(title)
ax = plot.gca()
v = ax.axis()
if pos is not None:
plot.plot(pos[0], pos[1], 'k+', ms=25, mew=3)
plot.axis(v)
if xlabel is not None:
plot.xlabel(xlabel)
if ylabel is not None:
plot.ylabel(ylabel)
#ax=p.gca()
# Add colorbar
cax = inl.inset_axes(
ax,
width="80%", # width = 10% of parent_bbox width
height="10%", # height : 50%
#borderpad=1,
# bbox_to_anchor=(1,0,1,2),
# bbox_transform=ax.transAxes,
loc=1)
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
if vmax > 1:
tick_range = np.rint(np.r_[vmin:vmax:5j])
else:
tick_range = (np.r_[vmin:vmax:5j])
# matplotlib will crash and burn if ticks are not unique
tick_range = np.unique(tick_range)
cb = mpl.colorbar.ColorbarBase(cax, cmap=plot.cm.jet,
norm=norm, orientation='horizontal',
ticks=tick_range[1:-1])
#cax.xaxis.set_major_locator( MaxNLocator(nbins = 6) )
return ax
def _slit(self, mapdata, slit_width, reshape=False, ax=None):
"""
Create 1d profile along the major axis.
Note that PA = 0 is vertical, anti-clockwise from y-axis.
"""
#orig code
#ang=np.radians(PA)
#prof=np.where(s<=(slit_width/2.0+0.5),data,0.)
#profile=ndimage.rotate(prof,PA,mode='constant',order=1).sum(1)
#profile=profile/np.max(profile)
y, x = np.indices(mapdata.shape)
ang = -np.radians(self.galaxy.pa - 90)
dx = x - self.galaxy.x
dy = y - self.galaxy.y
dx_p = dx * np.cos(ang) - dy * np.sin(ang)
dy_p = dx * np.sin(ang) + dy * np.cos(ang)
if (ax!=None):
#plot slit
#axis labels in arcsec
xlim=ax.get_xlim()
ylim=ax.get_ylim()
pixscale = self.instrument.xy_step
xar= (x[0] - self.galaxy.x) * pixscale
ang = np.radians(self.galaxy.pa + 90.)
xx= xar * np.cos(ang) - (slit_width/2.) * pixscale * np.sin(ang)
yy= xar * np.sin(ang) + (slit_width/2.) * pixscale * np.cos(ang)
ax.plot(xx,yy,'-',lw=2,c='grey',alpha=0.6)
xx= xar * np.cos(ang) + (slit_width/2.) * pixscale * np.sin(ang)
yy= xar * np.sin(ang) - (slit_width/2.) * pixscale * np.cos(ang)
ax.plot(xx,yy,'-',lw=2,c='grey',alpha=0.6)
ax.set_xlim(xlim)
ax.set_ylim(ylim)
slit_vert = np.abs(dy_p) # used for PA=0
slit_horiz = dx_p # distance along the slit
slit_data = np.where(slit_vert <= (slit_width / 2.0 ), mapdata, 0.)
slit_distance = np.where(slit_vert <= (slit_width / 2.0 ), slit_horiz,0.)
# PA is anti-clockwise, rotate is clockwise
rot_ang = self.galaxy.pa + 90 # to have horizontal slit image
slit_rot = ndimage.rotate(slit_data, rot_ang, mode='constant', order=0,reshape=reshape) # 0 because spline introduces errors!
dist_rot = ndimage.rotate(slit_distance, rot_ang, mode='constant', order=0,reshape=reshape)
a = np.sum(slit_rot != 0,axis=(0)) #nansum returns a bool with bottelneck; not with numpy!!
profile = np.where(a!=0, bn.nansum(slit_rot, axis=(0)) / a, 0.) #mean of the slit
b = np.sum(dist_rot != 0,axis=(0)) # pixels not zero and not nan
xaxis = np.where(b!=0, bn.nansum(dist_rot, axis=(0)) / b, 0.)
return xaxis, profile
def _make_moment_maps(self, cube=None, mask=False, cut_level=50, parameters=None, instrument=None, remove_LSF_from_disp=False):
"""
make moment maps from a noiseless cube
assumes no continuum to be removed
:param mask: boolean [default: None] to apply masking
:param cube: HyperspectralCube
:return:
"""
if (cube is None):
cube = self.convolved_cube
if (instrument is None):
vstep = self.instrument.z_step_kms
lsf_vector = self.instrument.lsf.as_vector(cube)
if remove_LSF_from_disp:
lsf_fwhm = self.instrument.lsf.fwhm / self.instrument.z_step #in pixel !
else:
vstep = instrument.z_step_kms
lsf_vector = instrument.lsf.as_vector(cube)
if remove_LSF_from_disp:
x=np.arange(np.size(lsf_vector))
sigma = np.sqrt(np.sum(x**2*lsf_vector)-np.sum(x*lsf_vector)**2) #in pixel !
lsf_fwhm = sigma * 2.35
zgrid, _, _ = np.indices(cube.shape)
if parameters is None:
zo = self.galaxy.z
else:
zo = parameters.z
if mask is True:
# default mask is flux>max(flux)/30.
mask = (cube.data > bn.nanmax(cube.data) / cut_level )
else:
mask = np.ones_like(cube.data)
#if the line is a doublet use a weighted average
if self.model.line is not None:
line = self.model.line
delta = 3e5 * (line['wave'][1]-line['wave'][0])/(line['wave'][0]+line['wave'][1])*2
r1 = line['ratio'][0]
r2 = line['ratio'][1]
p = r1 / (r1+r2) #weight of first/blue line
# weighted averaged zo
#zo_avg = ((zo-delta/vstep) * r1 + zo * r2) / (r1 + r2)
#print zo,zo-delta/vstep,zo_avg
#zo = zo_avg
#in km/s
# M1 = p * mu1 + (1-p) * mu2 ; m1 = mu2-Delta
# mu2 = (M1 - p mu1 ) / (1-p)
# mu2 = M1 + p delta
Fmap = bn.nansum(cube.data * mask,axis=0)
cube_norm = np.where(np.isfinite(cube.data / Fmap), cube.data / Fmap, 0)
M1map = bn.nansum(cube_norm * mask * (zgrid - zo), axis=0) * vstep
Vmap = (M1map + p * delta) # Vmap
# M2 = p * (mu1^2+sig^2) + (1-p) * (mu2^2+sig^2); m1 = mu2-Delta
# M2 = sig^2 + (p mu1^2 + (1-p) mu2^2)
# M2 - p (mu2-delta)**2 - (1-p) mu2^2) = sig^2
M2map = bn.nansum(cube.data / Fmap * mask * (zgrid -zo)**2, axis=0) * vstep**2
S2map = M2map - (Vmap-delta)**2 * p - (Vmap)**2 * (1-p)
Smap = np.sqrt(S2map)
else:
Fmap = bn.nansum(cube.data * mask,axis=0)
cube_norm = np.where(np.isfinite(cube.data / Fmap), cube.data / Fmap, 0)
Vmap = bn.nansum(cube_norm * mask * (zgrid-zo), axis=0) * vstep # mu
S2map = bn.nansum(cube_norm * mask *(zgrid-zo)**2, axis=0) * vstep**2 # mu^2 + sigma^2
Smap = np.sqrt(S2map-Vmap**2)
#Remove instrument LSF
if (remove_LSF_from_disp):
self.logger.info('removing LSF from dispersion map in quadrature; with FWHM %.5f pix' % (lsf_fwhm) )
Smap=np.sqrt(Smap**2 - (lsf_fwhm/2.35 * vstep)**2)
self.obs_flux_map = HyperCube(Fmap)
self.obs_velocity_map = HyperCube(Vmap)
self.obs_disp_map = HyperCube(Smap)
return self.obs_flux_map, self.obs_velocity_map, self.obs_disp_map
def _geweke_score(self, x, first=.1, last=.5, intervals=20):
"""Return z-scores for convergence diagnostics.
Compare the mean of the first % of series with the mean of the last % of
series. x is divided into a number of segments for which this difference is
computed. If the series is converged, this score should oscillate between
-1 and 1.
Parameters
----------
x : array-like
The trace of some stochastic parameter.
first : float
The fraction of series at the beginning of the trace.
last : float
The fraction of series at the end to be compared with the section
at the beginning.
intervals : int
The number of segments.
maxlag : int
Maximum autocorrelation lag for estimation of spectral variance
Returns
-------
scores : list [[]]
Return a list of [i, score], where i is the starting index for each
interval and score the Geweke score on the interval.
Notes
-----
The Geweke score on some series x is computed by:
.. math:: \frac{E[x_s] - E[x_e]}{\sqrt{V[x_s] + V[x_e]}}
where :math:`E` stands for the mean, :math:`V` the variance,
:math:`x_s` a section at the start of the series and
:math:`x_e` a section at the end of the series.
References
----------
Geweke (1992)
"""
if np.ndim(x) > 1:
return [self._geweke_score(y, first, last, intervals) for y in np.transpose(x)]
# Filter out invalid intervals
if first + last >= 1:
raise ValueError(
"Invalid intervals for Geweke convergence analysis",
(first, last))
# Initialize list of z-scores
zscores = [None] * intervals
# Starting points for calculations
starts = np.linspace(0, int(len(x)*(1.-last)), intervals).astype(int)
# Loop over start indices
for i,s in enumerate(starts):
# Size of remaining array
x_trunc = x[s:]
n = len(x_trunc)
# Calculate slices
first_slice = x_trunc[:int(first * n)]
last_slice = x_trunc[int(last * n):]
z = (first_slice.mean() - last_slice.mean())
#to avoid numerical errors
if np.var(last_slice)>1e-6*last_slice.mean()**2:
z /= np.sqrt(np.var(first_slice) +
np.var(last_slice))
else:
z = 0
zscores[i] = len(x) - n, z
return np.array(zscores)
#@fixme unsupported
def film_images(self, name, frames_skipped=0, fps=25):
"""
This is still a WiP.
TODO:
- print => logging
- heavily document how to solve ffmpeg issue (even if it's not our job)
Generates an video of the evolution of plot_images() through the chain,
skipping [frames_skipped] between each draw, at [fps] frames per second.
.. warning::
The generated file may take up a lot of disk space.
Known issue with ffmpeg :
https://stackoverflow.com/questions/17887117/python-matplotlib-basemap-animation-with-ffmpegwriter-stops-after-820-frames
Fix :
sudo apt-add-repository ppa:jon-severinsson/ffmpeg
sudo apt-get update
sudo apt-get install ffmpeg
"""
import matplotlib.animation as animation
# Initial plot
fig = self._plot_images(self.cube, self.galaxy, self.convolved_cube,
self.deconvolved_cube, self.residuals_cube)
# Convert the chain to a list of parameters
chain_list = np.array(self.chain.tolist())
params = chain_list[:, :-1] # remove the reduced chi (last element)
# Sanity check (otherwise, divisions by 0 will happen)
if len(params) < 2:
raise RuntimeError("Chain is not long enough for video.")
self.logger.info("Encoding video...")
def animate(i, skipped, count, me, chain):
"""
animation.FuncAnimation requires a function to generate the figure
on each frame. This is the bottleneck of our film generation.
i: int
0 0 1 2 3 ... up to [count]
skipped: int
Between each frame, skip [skipped] parameters in the chain
count: int
The total [count] of frames to show.
me: GalPaK3D
Our local self.
chain: ndarray
The list of parameters, as long as the MCMC chain.
"""
# Get the parameters from the chain
index = int(math.floor(i / (1 + skipped)))
galaxy = GalaxyParameters.from_ndarray(chain[index])
# Print the progression in %
sys.stdout.write("\r%d: %2.2f%%" % (i, 100. * i / (count - 1)))
sys.stdout.flush()
# Compute the cubes
deconvolved_cube = me.create_clean_cube(galaxy, me.cube.shape)
convolved_cube = me.instrument.convolve(deconvolved_cube.copy())
residuals_cube = (me.cube - convolved_cube) / me.error_cube
# Plot and return the plot
# (this is the expensive step, and it can be optimized further)
return me._plot_images(me.cube, galaxy, convolved_cube,
deconvolved_cube, residuals_cube)
if frames_skipped < 0:
frames_skipped *= -1
frames_count = int(math.floor(len(params) / (1. + frames_skipped)))
ani = animation.FuncAnimation(fig, animate, frames_count,
fargs=(int(frames_skipped), frames_count,
self, params),
repeat=False)
# Still unsure if that metadata is really written in the file
metadata = {
'title': 'GalPaK\'s cubes timelapse',
'author': 'galpak',
}
writer = animation.FFMpegWriter(fps=fps)
ani.save(name + '_images.avi', writer=writer, metadata=metadata)