dispersiveXanes/online/sMhelper_upd.py

297 lines
9.3 KiB
Python

import sys
import zmq
import numpy as np
import time
import pickle
import alignment
import matplotlib.pyplot as plt
import threading
import datetime
import copy
plt.rcParams['image.cmap'] = 'viridis'
def histVec(v,oversample=1):
v = np.atleast_1d(v)
v = np.unique(v)
vd = np.diff(v)
vd = np.hstack([vd[0],vd])
#vv = np.hstack([v-vd/2,v[-1]+vd[-1]/2])
vv = np.hstack([v-vd/2.,v[-1]+vd[-1]/2.])
if oversample>1:
vvo = []
for i in range(len(vv)-1):
vvo.append(np.linspace(vv[i],vv[i+1],oversample+1)[:-1])
vvo.append(vv[-1])
vv = np.array(np.hstack(vvo))
return vv
def subtractBkg(imgs,nPix=100,dKtype='corners'):
""" Opals tend to have different backgroud for every quadrant """
if dKtype is 'corners':
if imgs.ndim == 2: imgs = imgs[np.newaxis,:]
imgs = imgs.astype(np.float)
q1 = imgs[:,:nPix,:nPix].mean(-1).mean(-1)
imgs[:,:512,:512]-=q1[:,np.newaxis,np.newaxis]
q2 = imgs[:,:nPix,-nPix:].mean(-1).mean(-1)
imgs[:,:512,-512:]-=q2[:,np.newaxis,np.newaxis]
q3 = imgs[:,-nPix:,-nPix:].mean(-1).mean(-1)
imgs[:,-512:,-512:]-=q3[:,np.newaxis,np.newaxis]
q4 = imgs[:,-nPix:,:nPix].mean(-1).mean(-1)
imgs[:,-512:,:512]-=q4[:,np.newaxis,np.newaxis]
elif dKtype is 'stripes':
if imgs.ndim == 2: imgs = imgs[np.newaxis,:]
imgs = imgs.astype(np.float)
s1 = imgs[:,:nPix,:].mean(-2)
return np.squeeze(imgs)
def getData():
t0 = time.time()
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.connect('tcp://daq-xpp-mon06:12322')
#socket.setsockopt(zmq.SUBSCRIBE, b'')
while True:
socket.send(b"Request")
ret = socket.recv()
ret = pickle.loads(ret, encoding='latin1')
print('received',ret.keys(),time.time()-t0)
t0 = time.time()
if __name__ == "__main__":
getData()
class SpecHandler(object):
def __init__(self,connectString='tcp://daq-xpp-mon06:12322',spec1name='opal0',spec2name='opal1',roi=[0,1024,0,1024]):
self.connectString = connectString
self.spec1name = spec1name
self.spec2name = spec2name
self.dataCollector = []
#self.surveyplot = Surveyplot(spec1name=spec1name,spec2name=spec2name)
self.roi = roi
self.projsimple = ProjSimple(spec1name=spec1name,spec2name=spec2name,dataCollector=self.dataCollector)
self._rawContinuousTime = None
self.lastDat = None
self.openSocket()
#self.runningPlot = RunningPlot(self.dataCollector)
def openSocket(self):
context = zmq.Context()
self.context = context
self.socket = context.socket(zmq.REQ)
self.socket.connect(self.connectString)
#self.socket.setsockopt(zmq.SUBSCRIBE, b'')
def closeSocket(self):
del self.context
del self.socket
def getData(self):
self.socket.send(b"Request")
ret = self.socket.recv()
ret = pickle.loads(ret, encoding='latin1')
for sn in [self.spec1name,self.spec2name]:
ret[sn] = np.squeeze(alignment.subtractBkg(ret[sn], nPix=100, bkg_type='line'))
self.lastDat = ret
return ret
def getRaw(self,repoenConnection=False,doAlign=False,show=False,doFit=False,updateImg=True,updateProj=True,updateFom=True,flipit=False):
if doFit is True: doFit='iminuit'
if repoenConnection:
self.closeSocket()
self.openSocket()
dat = self.getData()
im1 = dat[self.spec1name]; im2 = dat[self.spec2name]
if doAlign:
#t = np.load("gui_align_transform_xppl3716.npy").item()
#if hasattr(self,'transformer'):
#algn = self.transformer
#else:
algn = alignment.loadAlignment('last_trafo.npy')
t = algn['transform']
roi1 = algn['roi1']
roi2 = algn['roi2']
im1 = im1[roi1];
if flipit:
im1 = im1[::-1]
im2 = im2[roi2]
r = alignment.doShot( im1,im2, t, show=show, doFit=doFit)
self.transformer = dict(transform=r.final_transform,roi1=roi1,roi2=roi2)
alignment.saveAlignment('last_trafo.npy',r.final_transform,roi1,roi2)
im1 = r.im1; im2 = r.im2
showDiff = True
showRatio = True
else:
showDiff = False
showRatio = False
if doAlign:
thres = 0.05
#im1[im1<thres*np.max(im1.ravel())] = np.nan
#im2[im2<thres*np.max(im2.ravel())] = np.nan
self.dataCollector.append(\
dict( time=datetime.datetime.now(),
fom=r.fom,
ratProj = np.nansum(im2/im1,axis=0),
im1Proj = np.nansum(im1,axis=0),
im2Proj = np.nansum(im2,axis=0)))
#if updateFom:
#self.runningPlot.updatePlot()
#if updateImg:
#self.surveyplot.plotImages(im1,im2,showDiff=showDiff,showRatio=showRatio)
#if updateProj:
self.projsimple.plotProfiles(im1,im2)
def alignFeatures(self):
im1 = copy.copy(self.lastDat[self.spec1name])
im2 = copy.copy(self.lastDat[self.spec2name])
roi1 = alignment.findRoi(im1)
roi2 = alignment.findRoi(im2)
tra = alignment.GuiAlignment(im1[roi1,:],im2[roi2,:],autostart=False)
self.transformer = dict(transform=tra.start(),roi1=roi1,roi2=roi2)
def getRawContinuuous(self,sleepTime,**kwargs):
self._rawContinuousTime = sleepTime
if not hasattr(self,'_rawContinuousThread'):
def upd():
while not self._rawContinuousTime is None:
self.getRaw(**kwargs)
plt.draw()
time.sleep(self._rawContinuousTime)
self._rawContinuousThread = threading.Thread(target=upd)
self._rawContinuousThread.start()
class Surveyplot(object):
def __init__(self,spec1name='spec1',spec2name='spec2'):
self.fig,self.axs = plt.subplots(4,1,sharex=True,sharey=True)
self.axs[0].set_title(spec1name)
self.axs[1].set_title(spec2name)
self.axs[2].set_title("Difference")
self.axs[3].set_title("Ratio")
def plotImages(self,img1,img2,showDiff=False,showRatio=False):
if hasattr(self,'i1'):
self.i1.set_data(img1)
else:
self.i1 = self.axs[0].imshow(img1,origin='lower',interpolation='none')
if hasattr(self,'i2'):
self.i2.set_data(img2)
else:
self.i2 = self.axs[1].imshow(img2,origin='lower',interpolation='none')
if showDiff:
tdiff = img2-img1
if hasattr(self,'idiff'):
self.idiff.set_data(tdiff)
else:
self.idiff = self.axs[2].imshow(tdiff,interpolation='none',origin='lower')
lms = np.percentile(tdiff,[30,70])
self.idiff.set_clim(lms)
if showRatio:
tratio = img2/img1
if hasattr(self,'iratio'):
self.iratio.set_data(tratio)
else:
self.iratio = self.axs[3].imshow(tratio,interpolation='none',origin='lower')
lms = np.percentile(tratio,[30,70])
self.iratio.set_clim(lms)
class ProjSimple(object):
def __init__(self,spec1name='spec1',spec2name='spec2',roi1=[0,1024,0,1024],roi2=[0,1024,0,1024],dataCollector=[]):
self.fig,self.axs = plt.subplots(2,1,sharex=True)
self.roi1 = roi1
self.roi2 = roi2
self.spec1name=spec1name
self.spec2name=spec2name
self.dataCollector = dataCollector
#def getROI(self,specNo):
def _roiit(self,img,roi):
return img[roi[0]:roi[1],roi[2]:roi[3]]
def plotProfiles(self,img1,img2):
prof1 = np.nansum(self._roiit(img1,self.roi1),0)
prof2 = np.nansum(self._roiit(img2,self.roi2),0)
if hasattr(self,'l1'):
self.l1.set_ydata(prof1)
else:
self.l1 = self.axs[0].plot(prof1,label=self.spec1name)[0]
if hasattr(self,'l2'):
self.l2.set_ydata(prof2)
else:
self.l2 = self.axs[0].plot(prof2,label=self.spec2name)[0]
plt.legend()
if hasattr(self,'lrat'):
self.lrat.set_ydata(prof2/prof1)
else:
self.lrat = self.axs[1].plot(prof2/prof1,'k',label='ratio')[0]
self.axs[1].set_ylim(0,2)
if len(self.dataCollector) > 0 :
im1Proj = np.asarray([i['im1Proj'] for i in self.dataCollector])
im2Proj = np.asarray([i['im2Proj'] for i in self.dataCollector])
#print(im1Proj.shape,im2Proj.shape)
ratAv = np.median(im2Proj[-10:,:],0)/np.median(im1Proj[-10:,:],0)
if hasattr(self,'lratAv'):
self.lratAv.set_ydata(ratAv)
else:
self.lratAv = self.axs[1].plot(ratAv,'r',label='ratio Avg')[0]
self.axs[1].set_ylim(0,2)
class RunningPlot(object):
def __init__(self,dataCollector):
self.dataCollector = dataCollector
self.fig,self.axs = plt.subplots(1,1)
def updatePlot(self):
if len(self.dataCollector)>0:
times = np.asarray(([i['time'] for i in self.dataCollector]))
foms = np.asarray(([i['fom'] for i in self.dataCollector]))
im1Proj = np.asarray(([i['im1Proj'] for i in self.dataCollector]))
im2Proj = np.asarray(([i['im2Proj'] for i in self.dataCollector]))
if hasattr(self,'fomline'):
self.fomline.set_ydata(foms)
self.fomline.set_xdata(times)
self.axs.set_xlim(np.min(times)-datetime.timedelta(0,10),np.max(times)+datetime.timedelta(0,10))
#self.axs[0].autoscale(enable=True,axis='x')
else:
self.fomline = self.axs.plot(times,foms,'o-')[0]
self.axs.autoscale(enable=True,axis='x')
#if hasattr(self,'ratioimg'):
#self.ratioimg.set_data()
#self.ratioimg.set_ydata(foms)
#else:
#self.axs[0].plot(times,foms,'o-')
#def plotOrUpdate(img1,img2):
#if hasattr(self,i1):
#self.i1.set_data(img1)
#else:
#self.i1 = self.axs.imshow(img1,interpolate='none',origin='bottom')
#if hasattr(self,i1):
#self.i1.set_data(img1)
#else:
#self.i1 = self.axs.imshow(img1,interpolate='none',origin='bottom')