improved plots for online analysis + new little helper functions

This commit is contained in:
Marco Cammarata 2017-02-09 17:12:48 +01:00
parent e78d5fc2ce
commit 02186d220f
1 changed files with 117 additions and 23 deletions

View File

@ -97,63 +97,88 @@ def removeBackground(x,data,xlims=None,max_iter=100,background_regions=[],**kw):
return x,np.squeeze(data) return x,np.squeeze(data)
def plotdata(q,data,x=None,plot=True,showTrend=True,title=None,clim='auto'): def plotdata(*args,x=None,plot=True,showTrend=True,title=None,clim='auto',fig=None):
if isinstance(args[0],storage.DataStorage):
q = args[0].q; data=args[0].data;
if title is None: title = args[0].folder
else:
q,data = args
if not (plot or showTrend): return if not (plot or showTrend): return
if x is None: x = np.arange(data.shape[0]) if x is None: x = np.arange(data.shape[0])
if clim == 'auto': clim = np.nanpercentile(data,(1.5,98.5)) if clim == 'auto': clim = np.nanpercentile(data,(1.5,98.5))
one_plot = showTrend or plot one_plot = showTrend or plot
two_plot = showTrend and plot two_plot = showTrend and plot
if one_plot and not two_plot: if one_plot and not two_plot:
if fig is None:
fig,ax = plt.subplots(1,1) fig,ax = plt.subplots(1,1)
else:
fig.clear()
ax = fig.axes
if two_plot: if two_plot:
fig,ax = plt.subplots(1,2,sharey=True) if fig is None:
fig,ax = plt.subplots(2,1,sharex=True)
else:
ax = fig.axes
ax = np.atleast_1d(ax) ax = np.atleast_1d(ax)
if showTrend: if showTrend:
plt.sca(ax[0]) plt.sca(ax[1])
plt.pcolormesh(x,q,data.T) plt.pcolormesh(q,x,data)
plt.xlabel("image number, 0 being older") plt.ylabel("image number, 0 being older")
plt.ylabel(r"q ($\AA^{-1}$)") plt.xlabel(r"q ($\AA^{-1}$)")
plt.clim( *clim ) plt.clim( *clim )
if plot: if plot:
if showTrend: ax[0].plot(q,np.nanmean(data,axis=0))
ax[1].plot(data.mean(axis=0),q)
else:
ax[0].plot(q,data.mean(axis=0))
if (plot or showTrend) and title is not None: if (plot or showTrend) and title is not None:
plt.title(title) plt.title(title)
def plotdiffs(q,diffs,t,select=None,err=None,absSignal=None,absSignalScale=10, def plotdiffs(*args,select=None,err=None,absSignal=None,absSignalScale=10,
showErr=False,cmap=plt.cm.jet): showErr=False,cmap=plt.cm.jet,fig=None,title=None):
# this selection trick done in this way allows to keep the same colors when # this selection trick done in this way allows to keep the same colors when
# subselecting (because I do not change the size of diffs) # subselecting (because I do not change the size of diffs)
if isinstance(args[0],storage.DataStorage):
q = args[0].q; t = args[0].scan; err = args[0].err
diffs = args[0].data
diffs_abs = args[0].dataAbsAvScanPoint
else:
q,diffs,t = args
diffs_abs = None
if select is not None: if select is not None:
indices = range(*select.indices(t.shape[0])) indices = range(*select.indices(t.shape[0]))
else: else:
indices = range(len(t)) indices = range(len(t))
lines = []
if fig is None: fig = plt.gcf()
# fig.clear()
lines_diff = []
lines_abs = []
if absSignal is not None: if absSignal is not None:
line = plt.plot(q,absSignal/absSignalScale,lw=3, line = plt.plot(q,absSignal/absSignalScale,lw=3,
color='k',label="absSignal/%s"%str(absSignalScale))[0] color='k',label="absSignal/%s"%str(absSignalScale))[0]
lines.append(line) lines.append(line)
for idiff in indices: for linenum,idiff in enumerate(indices):
color = cmap(idiff/(len(diffs)-1)) color = cmap(idiff/(len(diffs)-1))
label = timeToStr(t[idiff]) label = timeToStr(t[idiff])
kw = dict( color = color, label = label ) kw = dict( color = color, label = label )
if err is not None and showErr: if err is not None and showErr:
line = plt.errorbar(q,diffs[idiff],err[idiff],**kw)[0] line = plt.errorbar(q,diffs[idiff],err[idiff],**kw)[0]
lines_diff.append(line)
else: else:
line = plt.plot(q,diffs[idiff],**kw)[0] line = plt.plot(q,diffs[idiff],**kw)[0]
lines.append(line) lines_diff.append(line)
if diffs_abs is not None:
fig = plt.gcf() line = plt.plot(q,diffs_abs[idiff],color=color)[0]
legend = plt.legend() lines_abs.append(line)
if title is not None: fig.axes[0].set_title(title)
legend = plt.legend(loc=4)
plt.grid() plt.grid()
plt.xlabel(r"q ($\AA^{-1}$)") plt.xlabel(r"q ($\AA^{-1}$)")
# we will set up a dict mapping legend line to orig line, and enable # we will set up a dict mapping legend line to orig line, and enable
# picking on the legend line # picking on the legend line
lined = dict() lined = dict()
for legline, origline in zip(legend.get_lines(), lines): for legline, origline in zip(legend.get_lines(), lines_diff):
legline.set_picker(5) # 5 pts tolerance legline.set_picker(5) # 5 pts tolerance
lined[legline] = origline lined[legline] = origline
@ -174,7 +199,14 @@ def plotdiffs(q,diffs,t,select=None,err=None,absSignal=None,absSignalScale=10,
fig.canvas.draw() fig.canvas.draw()
fig.canvas.mpl_connect('pick_event', onpick) fig.canvas.mpl_connect('pick_event', onpick)
return lines_diff,lines_abs
def updateLines(lines,data):
for l,d in zip(lines,data):
l.set_ydata(d)
#def getScan
def saveTxt(fname,q,data,headerv=None,info=None,overwrite=True,columns=''): def saveTxt(fname,q,data,headerv=None,info=None,overwrite=True,columns=''):
""" Write data to file 'fname' in text format. """ Write data to file 'fname' in text format.
@ -212,6 +244,7 @@ def reshapeToBroadcast(what,ref):
multidimentional array 'ref'. The two arrays have to same the same multidimentional array 'ref'. The two arrays have to same the same
dimensions along the first axis dimensions along the first axis
""" """
if what.shape == ref.shape: return what
assert what.shape[0] == ref.shape[0] assert what.shape[0] == ref.shape[0]
shape = [ref.shape[0],] + [1,]*(ref.ndim-1) shape = [ref.shape[0],] + [1,]*(ref.ndim-1)
return what.reshape(shape) return what.reshape(shape)
@ -232,11 +265,72 @@ def degToQ(theta,**kw):
return radToQ(theta,**kw) return radToQ(theta,**kw)
degToQ.__doc__ = radToQ.__doc__ degToQ.__doc__ = radToQ.__doc__
def qToTheta(q,**kw): def qToTheta(q,asDeg=False,**kw):
""" Return scattering angle from q (given E or wavelength) """ """ Return scattering angle from q (given E or wavelength) """
# Energy or wavelength should be in kw # Energy or wavelength should be in kw
assert "E" in kw or "wavelength" in kw assert "E" in kw or "wavelength" in kw
# but not both # but not both
assert not ("E" in kw and "wavelength" in kw) assert not ("E" in kw and "wavelength" in kw)
if "E" in kw: kw["wavelength"] = 12.398/kw["E"] if "E" in kw: kw["wavelength"] = 12.398/kw["E"]
return np.arcsin(q*kw["wavelength"]/4/np.pi) theta = np.arcsin(q*kw["wavelength"]/4/np.pi)
if asDeg: theta = np.rad2deg(theta)
return theta
def attenuation_length(compound, density=None, natural_density=None,energy=None, wavelength=None):
""" extend periodictable.xsf capabilities """
import periodictable.xsf
if energy is not None: wavelength = periodictable.xsf.xray_wavelength(energy)
assert wavelength is not None, "scattering calculation needs energy or wavelength"
if (np.isscalar(wavelength)): wavelength=np.array( [wavelength] )
n = periodictable.xsf.index_of_refraction(compound=compound,
density=density, natural_density=natural_density,
wavelength=wavelength)
attenuation_length = (wavelength*1e-10)/ (4*np.pi*np.imag(n))
return np.abs(attenuation_length)
def transmission(material='Si',thickness=100e-6, density=None, natural_density=None,energy=None, wavelength=None):
""" extend periodictable.xsf capabilities """
att_len = attenuation_length(compound,density=density,
natural_density=natural_density,energy=energy,wavelength=wavelength)
return np.exp(-thickness/att_len)
def chargeToPhoton(chargeOrCurrent,material="Si",thickness=100e-6,energy=10,e_hole_pair=3.6):
"""
Function to convert charge (or current to number of photons (or number
of photons per second)
Parameters
----------
chargeOrCurrent: float or array
material : str
Used to calculate
"""
# calculate absortption
A = 1-transmission(material=material,energy=energy)
chargeOrCurrent = chargeOrCurrent/A
e_hole_pair_energy = 3.6e-3
n_charge_per_photon = energy/e_hole_pair_energy
# convert to Q
charge_per_photon = n_charge_per_photon*1.60217662e-19
nphoton = chargeOrCurrent/charge_per_photon
if len(nphoton) == 1: nphoton = float(nphoton)
return nphoton
def logToScreen():
""" It allows printing to terminal on top of logfile """
# define a Handler which writes INFO messages or higher to the sys.stderr
console = logging.StreamHandler()
console.setLevel(logging.INFO)
# set a format which is simpler for console use
formatter = logging.Formatter('%(message)s')
# tell the handler to use this format
console.setFormatter(formatter)
# add the handler to the root logger (if needed)
if len(logging.getLogger('').handlers)==1:
logging.getLogger('').addHandler(console)