Add title to pyplot window

When using pyplot to plot the results, the window
name is based on the dataset title and the view
title.
This commit is contained in:
Sylvain Tricot 2025-06-18 11:46:41 +02:00
parent b15a5424d2
commit c2e1384a5c
1 changed files with 68 additions and 109 deletions

View File

@ -17,7 +17,7 @@
# along with this msspec. If not, see <http://www.gnu.org/licenses/>. # along with this msspec. If not, see <http://www.gnu.org/licenses/>.
# #
# Source file : src/msspec/iodata.py # Source file : src/msspec/iodata.py
# Last modified: Mon, 16 Jun 2025 14:42:03 +0200 # Last modified: Wed, 18 Jun 2025 11:46:41 +0200
# Committed by : Sylvain Tricot <sylvain.tricot@univ-rennes.fr> # Committed by : Sylvain Tricot <sylvain.tricot@univ-rennes.fr>
@ -84,6 +84,7 @@ from lxml import etree
from matplotlib.backends.backend_agg import FigureCanvasAgg from matplotlib.backends.backend_agg import FigureCanvasAgg
#from matplotlib.backends.backend_cairo import FigureCanvasCairo as FigureCanvasAgg #from matplotlib.backends.backend_cairo import FigureCanvasCairo as FigureCanvasAgg
from matplotlib.figure import Figure from matplotlib.figure import Figure
from matplotlib import pyplot as plt
from terminaltables import AsciiTable from terminaltables import AsciiTable
import msspec import msspec
@ -336,6 +337,15 @@ class DataSet(object):
except: except:
pass pass
def get_views(self):
"""Returns all the defined views in the dataset.
:return: A list of view
:rtype: List of :py:class:`iodata._DataSetView`
"""
return self._views
@property
def views(self): def views(self):
"""Returns all the defined views in the dataset. """Returns all the defined views in the dataset.
@ -365,6 +375,12 @@ class DataSet(object):
mydset.add_parameter(name='Spectrometer', group='misc', value='Omicron', unit='') mydset.add_parameter(name='Spectrometer', group='misc', value='Omicron', unit='')
""" """
group = kwargs.get('group')
name = kwargs.get('name')
r = self.get_parameter(group=group, name=name)
if r:
r.update(**kwargs)
else:
self._parameters.append(kwargs) self._parameters.append(kwargs)
def parameters(self): def parameters(self):
@ -398,12 +414,19 @@ class DataSet(object):
p.append(_) p.append(_)
return p[0] if len(p) == 1 else p return p[0] if len(p) == 1 else p
def set_cluster(self, cluster):
clusbuf = StringIO()
cluster.info['absorber'] = cluster.absorber
write_xyz(clusbuf, cluster)
self.add_parameter(group='Cluster', name='cluster', value=clusbuf.getvalue(), hidden="True")
def get_cluster(self): def get_cluster(self):
"""Get all the atoms in the cluster. """Get all the atoms in the cluster.
:return: The cluster :return: The cluster
:rtype: :py:class:`ase.Atoms` :rtype: :py:class:`ase.Atoms`
""" """
try:
p = self.get_parameter(group='Cluster', name='cluster')['value'] p = self.get_parameter(group='Cluster', name='cluster')['value']
s = StringIO() s = StringIO()
s.write(self.get_parameter(group='Cluster', name='cluster')['value']) s.write(self.get_parameter(group='Cluster', name='cluster')['value'])
@ -411,17 +434,19 @@ class DataSet(object):
#return ase.io.read(s, format='xyz') #return ase.io.read(s, format='xyz')
cluster = list(read_xyz(s))[-1] cluster = list(read_xyz(s))[-1]
return cluster return cluster
except:
return None
def select(self, *args, **kwargs): def select(self, *args, **kwargs):
condition = kwargs.get('where', 'True') condition = kwargs.get('where', 'True')
indices = [] indices = []
def export_views(self, folder): def export_views(self, folder, dpi=100):
for view in self.views(): for view in self.get_views():
f = view.get_figure() f = view.get_figure()
fname = os.path.join(folder, view.title) + '.png' fname = os.path.join(folder, view.title) + '.png'
f.savefig(fname) f.savefig(fname, dpi=dpi)
def export(self, filename="", mode="w"): def export(self, filename="", mode="w"):
@ -671,6 +696,7 @@ class Data(object):
return return
else: else:
data_grp = fd.create_group('DATA') data_grp = fd.create_group('DATA')
data_grp.attrs['dset_names'] = titles
meta_grp = fd.create_group('MsSpec viewer metainfo') meta_grp = fd.create_group('MsSpec viewer metainfo')
data_grp.attrs['title'] = self.title data_grp.attrs['title'] = self.title
@ -681,6 +707,7 @@ class Data(object):
continue continue
grp = data_grp.create_group(dset.title) grp = data_grp.create_group(dset.title)
grp.attrs['notes'] = dset.notes grp.attrs['notes'] = dset.notes
grp.attrs['col_names'] = dset.columns()
for col_name in dset.columns(): for col_name in dset.columns():
data = dset[col_name] data = dset[col_name]
grp.create_dataset(col_name, data=data) grp.create_dataset(col_name, data=data)
@ -691,7 +718,7 @@ class Data(object):
# xmlize views # xmlize views
for dset in self._datasets: for dset in self._datasets:
views_node = etree.SubElement(root, 'views', dataset=dset.title) views_node = etree.SubElement(root, 'views', dataset=dset.title)
for view in dset.views(): for view in dset.get_views():
view_el = etree.fromstring(view.to_xml()) view_el = etree.fromstring(view.to_xml())
views_node.append(view_el) views_node.append(view_el)
@ -712,7 +739,7 @@ class Data(object):
self._dirty = False self._dirty = False
LOGGER.info('Data saved in {}'.format(os.path.abspath(filename))) LOGGER.info('Data saved in {}'.format(os.path.abspath(filename)))
def export(self, folder, overwrite=False): def export(self, folder, overwrite=False, dpi=150):
os.makedirs(folder, exist_ok=overwrite) os.makedirs(folder, exist_ok=overwrite)
for dset in self._datasets: for dset in self._datasets:
dset_name = dset.title.replace(' ', '_') dset_name = dset.title.replace(' ', '_')
@ -720,7 +747,7 @@ class Data(object):
os.makedirs(p, exist_ok=overwrite) os.makedirs(p, exist_ok=overwrite)
fname = os.path.join(p, dset_name) + '.txt' fname = os.path.join(p, dset_name) + '.txt'
dset.export(fname) dset.export(fname)
dset.export_views(p) dset.export_views(p, dpi=dpi)
@staticmethod @staticmethod
def load(filename): def load(filename):
@ -737,12 +764,20 @@ class Data(object):
views = {} views = {}
output.title = fd['DATA'].attrs['title'] output.title = fd['DATA'].attrs['title']
for dset_name in fd['DATA'] : try:
dset_names = fd['DATA'].attrs['dset_names']
except:
dset_names = [_ for _ in fd['DATA']]
for dset_name in dset_names:
parameters[dset_name] = [] parameters[dset_name] = []
views[dset_name] = [] views[dset_name] = []
dset = output.add_dset(dset_name) dset = output.add_dset(dset_name)
dset.notes = fd['DATA'][dset_name].attrs['notes'] dset.notes = fd['DATA'][dset_name].attrs['notes']
for h5dset in fd['DATA'][dset_name]: try:
col_names = fd['DATA'][dset_name].attrs['col_names']
except:
col_names = [_ for _ in fd['DATA'][dset_name]]
for h5dset in col_names:
dset.add_columns(**{h5dset: fd['DATA'][dset_name][h5dset][...]}) dset.add_columns(**{h5dset: fd['DATA'][dset_name][h5dset][...]})
try: try:
@ -870,10 +905,19 @@ class _DataSetView(object):
data.append(values) data.append(values)
return data return data
def get_figure(self): def plot(self):
f = self.get_figure(backend='plt')
return f, f.get_axes()[0]
def get_figure(self, backend=None):
opts = self._plotopts opts = self._plotopts
figure = Figure(figsize=(3,2)) if backend is None:
figure = Figure()
else:
figure = plt.figure(num="[{}][{}]".format(self.dataset.title, self.title))
axes = None axes = None
proj = opts['projection'] proj = opts['projection']
scale = opts['scale'] scale = opts['scale']
@ -1113,7 +1157,7 @@ if has_gui:
self.notebooks[dset.title] = nb self.notebooks[dset.title] = nb
#self.GetSizer().Add(nb, 1, wx.ALL|wx.EXPAND) #self.GetSizer().Add(nb, 1, wx.ALL|wx.EXPAND)
self.GetSizer().Add(nb, proportion=1, flag=wx.ALL|wx.EXPAND) self.GetSizer().Add(nb, proportion=1, flag=wx.ALL|wx.EXPAND)
for view in dset.views(): for view in dset.get_views():
self.create_page(nb, view) self.create_page(nb, view)
self.create_menu() self.create_menu()
@ -1291,6 +1335,10 @@ if has_gui:
self.Layout() self.Layout()
self.update_statusbar() self.update_statusbar()
self._current_dset = name self._current_dset = name
has_cluster = True if self.data[self._current_dset].get_cluster() is not None else False
menu_item = self.GetMenuBar().FindItemById(302)
menu_item.Enable(has_cluster)
def create_page(self, nb, view): def create_page(self, nb, view):
# Get the matplotlib figure # Get the matplotlib figure
@ -1324,95 +1372,6 @@ if has_gui:
nb.AddPage(p, view.title) nb.AddPage(p, view.title)
canvas.draw() canvas.draw()
def OLDcreate_page(self, nb, view):
opts = view._plotopts
p = wx.Panel(nb, -1)
figure = Figure()
axes = None
proj = opts['projection']
scale = opts['scale']
if proj == 'rectilinear':
axes = figure.add_subplot(111, projection='rectilinear')
elif proj in ('polar', 'ortho', 'stereo'):
axes = figure.add_subplot(111, projection='polar')
canvas = FigureCanvas(p, -1, figure)
sizer = wx.BoxSizer(wx.VERTICAL)
toolbar = NavigationToolbar2WxAgg(canvas)
toolbar.Realize()
sizer.Add(toolbar, 0, wx.ALL|wx.EXPAND)
toolbar.update()
sizer.Add(canvas, 5, wx.ALL|wx.EXPAND)
p.SetSizer(sizer)
p.Fit()
p.Show()
for values, label in zip(view.get_data(), opts['legend']):
# if we have only one column to plot, select a bar graph
if np.shape(values)[0] == 1:
xvalues = list(range(len(values[0])))
axes.bar(xvalues, values[0], label=label,
picker=5)
axes.set_xticks(xvalues)
else:
if proj in ('ortho', 'stereo'):
theta, phi, Xsec = cols2matrix(*values)
theta_ticks = np.arange(0, 91, 15)
if proj == 'ortho':
R = np.sin(np.radians(theta))
R_ticks = np.sin(np.radians(theta_ticks))
elif proj == 'stereo':
R = 2 * np.tan(np.radians(theta/2.))
R_ticks = 2 * np.tan(np.radians(theta_ticks/2.))
#R = np.tan(np.radians(theta/2.))
X, Y = np.meshgrid(np.radians(phi), R)
im = axes.pcolormesh(X, Y, Xsec)
axes.set_yticks(R_ticks)
axes.set_yticklabels(theta_ticks)
figure.colorbar(im)
elif proj == 'polar':
values[0] = np.radians(values[0])
axes.plot(*values, label=label, picker=5,
marker=opts['marker'])
else:
if scale == 'semilogx':
pltcmd = axes.semilogx
elif scale == 'semilogy':
pltcmd = axes.semilogy
elif scale == 'log':
pltcmd = axes.loglog
else:
pltcmd = axes.plot
pltcmd(*values, label=label, picker=5,
marker=opts['marker'])
axes.grid(opts['grid'])
axes.set_title(opts['title'])
axes.set_xlabel(opts['xlabel'])
axes.set_ylabel(opts['ylabel'])
axes.set_xlim(*opts['xlim'])
axes.set_ylim(*opts['ylim'])
if label:
axes.legend()
axes.autoscale(enable=opts['autoscale'])
# MPL events
figure.canvas.mpl_connect('motion_notify_event', self.on_mpl_motion)
figure.canvas.mpl_connect('pick_event', self.on_mpl_pick)
nb.AddPage(p, view.title)
def update_statusbar(self): def update_statusbar(self):
sb = self.GetStatusBar() sb = self.GetStatusBar()
menu_id = self.GetMenuBar().FindMenu('Datasets') menu_id = self.GetMenuBar().FindMenu('Datasets')