mcutils/xray/mask.py

140 lines
4.0 KiB
Python

from __future__ import print_function
import sys
if sys.version_info.major == 2: input=raw_input
import logging
log = logging.getLogger(__name__)
import os
import numpy as np
import matplotlib.pyplot as plt
class MyMask(object):
def __init__(self,img=None):
self.comp = []
self.img = img
self.mask = None
def addCircle(self,xcen,ycen,radius):
self.comp.append( ["add","circle", [xcen,ycen,radius] ] )
def subtractCircle(self,xcen,ycen,radius):
self.comp.append( ["subtract","circle", [xcen,ycen,radius] ] )
def addRectangle(self,x1,y1,x2,y2):
if x1>x2: x1,x2=x2,x1
if y1>y2: y1,y2=y2,y1
self.comp.append( ["add","rectangle", [x1,y1,x2,y2] ])
def subtractRectangle(self,x1,y1,x2,y2):
if x1>x2: x1,x2=x2,x1
if y1>y2: y1,y2=y2,y1
self.comp.append( ["subtract","rectangle", [x1,y1,x2,y2] ])
def getMask(self,shape=None):
if shape is None: shape = self.img.shape
m = []
X,Y = np.meshgrid ( range(shape[0]),range(shape[1]) )
for o in self.comp:
whattodo = o[0]
kind=o[1]
pars=o[2]
if kind == "circle":
(xc,yc,r) = pars
d = np.sqrt((X-xc)**2+(Y-yc)**2)
#plt.imshow(d<r)
#raw_input()
m.append( d<r )
if kind == "rectangle":
(x1,y1,x2,y2) = pars
temp = (X>x1) & (X<x2) & ( Y>y1) & (Y<y2)
m.append( temp )
mask = np.zeros(shape,dtype=np.bool)
for i in range(len(m)):
whattodo = self.comp[i][0]
if (whattodo == "add"):
mask[m[i]] = True
else:
mask[m[i]] = False
self.mask = mask
return mask
def getMatplotlibMask(self,shape=None):
mask = self.getMask(shape=shape)
# convert
mpl_mask = np.zeros( (mask.shape[0],mask.shape[1],4) )
mpl_mask[:,:,:3] = 0.5; # gray color
mpl_mask[:,:,3] = mask/2; # give some transparency
return mpl_mask
def save(self,fname,inverted=False):
import fabio
mask = self.mask
if (inverted): mask = ~mask
i=fabio.edfimage.edfimage(mask.astype(np.uint8)); # edf does not support bool
i.save(fname)
def test():
mask = MyMask()
mask.addCircle(400,300,250)
mask.subtractCircle(400,300,150)
mask.addRectangle(350,250,1500,700)
mask.show()
return mask
def snap(point,shape,snapRange=20):
snapped = list(point)
if snapped[0] < snapRange: snapped[0] = 0
if snapped[0] > shape[0]-snapRange: snapped[0] = shape[0]
if snapped[1] < snapRange: snapped[1] = 0
if snapped[1] > shape[1]-snapRange: snapped[1] = shape[1]
return snapped
def getPoint(shape,snapRange):
c = plt.ginput()[0]
c = snap(c,shape,snapRange=snapRange)
return c
def makeMaskGui(img,snapRange=60):
""" snapRange controls border snapping (in pixels, use <= 0 to disable """
mask = MyMask(img)
ans='ok'
while (ans != 'done'):
plt.imshow(img)
plt.clim(np.percentile(img,(2,98)))
plt.imshow(mask.getMatplotlibMask())
plt.pause(0.01)
ans = input("What's next c/r/done? ")
if ans == "c":
print("Adding circle, click on center")
c = getPoint(img.shape,snapRange)
print("Adding circle, click on another point to define radius")
p = getPoint(img.shape,snapRange)
r = np.sqrt( (p[0]-c[0])**2 + (p[1]-c[1])**2 )
mask.addCircle(c[0],c[1],r)
if ans == "r":
print("Adding rectangle, click on one corner")
c1 = getPoint(img.shape,snapRange)
print("Adding rectangle, click on opposite corner")
c2 = getPoint(img.shape,snapRange)
mask.addRectangle(c1[0],c1[1],c2[0],c2[1])
plt.imshow(mask.getMatplotlibMask())
plt.pause(0.01)
fname = input("Enter a valid filename (ext .edf or .npy) to save the mask")
try:
if fname != '':
ext = os.path.splitext(fname)[1]
if ext == '.edf':
mask.save(fname)
elif ext == '.npy':
np.save(fname,mask.getMask())
except Exception as e:
log.error("Error in saving mask")
log.error(e)
finally:
return mask
if __name__ == "__main__":
test()
plt.show()
ans=input("Enter to finish")