from __future__ import print_function import sys if sys.version_info.major == 2: input=raw_input import logging log = logging.getLogger(__name__) import os import numpy as np import matplotlib.pyplot as plt class MyMask(object): def __init__(self,img=None): self.comp = [] self.img = img self.mask = None def addCircle(self,xcen,ycen,radius): self.comp.append( ["add","circle", [xcen,ycen,radius] ] ) def subtractCircle(self,xcen,ycen,radius): self.comp.append( ["subtract","circle", [xcen,ycen,radius] ] ) def addRectangle(self,x1,y1,x2,y2): if x1>x2: x1,x2=x2,x1 if y1>y2: y1,y2=y2,y1 self.comp.append( ["add","rectangle", [x1,y1,x2,y2] ]) def subtractRectangle(self,x1,y1,x2,y2): if x1>x2: x1,x2=x2,x1 if y1>y2: y1,y2=y2,y1 self.comp.append( ["subtract","rectangle", [x1,y1,x2,y2] ]) def getMask(self,shape=None): if shape is None: shape = self.img.shape m = [] X,Y = np.meshgrid ( range(shape[0]),range(shape[1]) ) for o in self.comp: whattodo = o[0] kind=o[1] pars=o[2] if kind == "circle": (xc,yc,r) = pars d = np.sqrt((X-xc)**2+(Y-yc)**2) #plt.imshow(dx1) & (Xy1) & (Y shape[0]-snapRange: snapped[0] = shape[0] if snapped[1] < snapRange: snapped[1] = 0 if snapped[1] > shape[1]-snapRange: snapped[1] = shape[1] return snapped def getPoint(shape,snapRange): c = plt.ginput()[0] c = snap(c,shape,snapRange=snapRange) return c def makeMaskGui(img,snapRange=60): """ snapRange controls border snapping (in pixels, use <= 0 to disable """ mask = MyMask(img) ans='ok' while (ans != 'done'): plt.imshow(img) plt.clim(np.percentile(img,(2,98))) plt.imshow(mask.getMatplotlibMask()) plt.pause(0.01) ans = input("What's next c/r/done? ") if ans == "c": print("Adding circle, click on center") c = getPoint(img.shape,snapRange) print("Adding circle, click on another point to define radius") p = getPoint(img.shape,snapRange) r = np.sqrt( (p[0]-c[0])**2 + (p[1]-c[1])**2 ) mask.addCircle(c[0],c[1],r) if ans == "r": print("Adding rectangle, click on one corner") c1 = getPoint(img.shape,snapRange) print("Adding rectangle, click on opposite corner") c2 = getPoint(img.shape,snapRange) mask.addRectangle(c1[0],c1[1],c2[0],c2[1]) plt.imshow(mask.getMatplotlibMask()) plt.pause(0.01) fname = input("Enter a valid filename (ext .edf or .npy) to save the mask") try: if fname != '': ext = os.path.splitext(fname)[1] if ext == '.edf': mask.save(fname) elif ext == '.npy': np.save(fname,mask.getMask()) except Exception as e: log.error("Error in saving mask") log.error(e) finally: return mask if __name__ == "__main__": test() plt.show() ans=input("Enter to finish")