Add kwargs to the pipeline function
The user can now provide a 'pipeline' function with only '**kwargs' in arguments. It is more flexible and easier to write. The 'sweep_index' keyword is always automatically added, so the 'passindex' option has been removed since it was redundant with the index of the final dataframe object. The user-defined 'pipeline' function can now return anything. It is no longer limited to ([x,],[y,]) format.
This commit is contained in:
parent
db6ee27699
commit
feaaabc9c4
|
@ -17,8 +17,8 @@
|
|||
# along with this msspec. If not, see <http://www.gnu.org/licenses/>.
|
||||
#
|
||||
# Source file : src/msspec/looper.py
|
||||
# Last modified: Mon, 27 Sep 2021 17:49:48 +0200
|
||||
# Committed by : sylvain tricot <sylvain.tricot@univ-rennes1.fr>
|
||||
# Last modified: Wed, 26 Feb 2025 11:15:54 +0100
|
||||
# Committed by : Sylvain Tricot <sylvain.tricot@univ-rennes.fr>
|
||||
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
|
@ -92,9 +92,8 @@ class Sweep:
|
|||
|
||||
|
||||
class SweepRange:
|
||||
def __init__(self, *sweeps, passindex=False):
|
||||
def __init__(self, *sweeps):
|
||||
self.sweeps = sweeps
|
||||
self.passindex = passindex
|
||||
self.index = 0
|
||||
|
||||
# First check that sweeps that are linked to another on are all included
|
||||
|
@ -158,7 +157,6 @@ class SweepRange:
|
|||
for s in [sweep,] + children:
|
||||
key, value = s[idx]
|
||||
row[key] = value
|
||||
if self.passindex:
|
||||
row['sweep_index'] = i
|
||||
return row
|
||||
else:
|
||||
|
@ -166,9 +164,8 @@ class SweepRange:
|
|||
|
||||
@property
|
||||
def columns(self):
|
||||
cols = [sweep.key for sweep in self.sweeps]
|
||||
if self.passindex:
|
||||
cols.append('sweep_index')
|
||||
cols = ['sweep_index']
|
||||
cols += [sweep.key for sweep in self.sweeps]
|
||||
return cols
|
||||
|
||||
@property
|
||||
|
@ -202,31 +199,27 @@ class Looper:
|
|||
logger.debug("Pipeline called with {}".format(x))
|
||||
return self.pipeline(**x)
|
||||
|
||||
def run(self, *sweeps, ncpu=1, passindex=False):
|
||||
def run(self, *sweeps, ncpu=1, **kwargs):
|
||||
logger.info("Loop starts...")
|
||||
# prepare the list of inputs
|
||||
sr = SweepRange(*sweeps, passindex=passindex)
|
||||
sr = SweepRange(*sweeps)
|
||||
items = sr.items
|
||||
|
||||
data = []
|
||||
|
||||
t0 = time.time()
|
||||
|
||||
if ncpu == 1:
|
||||
# serial processing...
|
||||
logger.info("serial processing...")
|
||||
t0 = time.time()
|
||||
|
||||
for i, values in enumerate(items):
|
||||
values.update(kwargs)
|
||||
result = self._wrapper(values)
|
||||
data.append(result)
|
||||
|
||||
t1 = time.time()
|
||||
dt = t1 - t0
|
||||
logger.info("Processed {:d} sets of inputs in {:.3f} seconds".format(
|
||||
len(sr), dt))
|
||||
|
||||
else:
|
||||
# Parallel processing...
|
||||
chunksize = 1 #int(nsets/ncpu)
|
||||
[values.update(kwargs) for values in items]
|
||||
logger.info(("Parallel processing over {:d} cpu (chunksize={:d})..."
|
||||
"").format(ncpu, chunksize))
|
||||
t0 = time.time()
|
||||
|
@ -243,14 +236,16 @@ class Looper:
|
|||
|
||||
# Create the DataFrame
|
||||
dfdata = []
|
||||
columns = sr.columns + ['output',]
|
||||
columns = sr.columns + list(kwargs.keys()) + ['output',]
|
||||
|
||||
for i in range(len(sr)):
|
||||
row = list(items[i].values())
|
||||
row.append(data[i])
|
||||
dfdata.append(row)
|
||||
|
||||
|
||||
df = pd.DataFrame(dfdata, columns=columns)
|
||||
df = df.drop(columns=['sweep_index'])
|
||||
|
||||
self.data = df
|
||||
|
||||
|
@ -259,14 +254,14 @@ class Looper:
|
|||
# of corresponding dict of parameters {'keyA': [val0,...valn],
|
||||
# 'keyB': [val0,...valn], ...}
|
||||
|
||||
all_xy = []
|
||||
for irow, row in df.iterrows():
|
||||
all_xy.append(row.output[0])
|
||||
all_xy.append(row.output[1])
|
||||
parameters = df.to_dict()
|
||||
parameters.pop('output')
|
||||
# all_xy = []
|
||||
# for irow, row in df.iterrows():
|
||||
# all_xy.append(row.output[0])
|
||||
# all_xy.append(row.output[1])
|
||||
# parameters = df.to_dict()
|
||||
# parameters.pop('output')
|
||||
|
||||
return all_xy, parameters
|
||||
return self.data #all_xy, parameters
|
||||
|
||||
|
||||
|
||||
|
@ -276,17 +271,16 @@ class Looper:
|
|||
if __name__ == "__main__":
|
||||
import numpy as np
|
||||
import time
|
||||
import logging
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
def bar(**kwargs):
|
||||
return 0
|
||||
|
||||
def post_process(data):
|
||||
x = data.x.unique()
|
||||
y = data.y.unique()
|
||||
|
||||
i = kwargs.get('sweep_index')
|
||||
return np.linspace(0,i,10)
|
||||
|
||||
theta = Sweep(key='theta', comments="The polar angle",
|
||||
start=-70, stop=70, num=3,
|
||||
|
@ -314,7 +308,16 @@ if __name__ == "__main__":
|
|||
|
||||
looper = Looper()
|
||||
looper.pipeline = bar
|
||||
data = looper.run(emitter, emitter_plane, uij, theta, levels, ncpu=4,
|
||||
passindex=True)
|
||||
other_kws = {'un':1, 'deux':2}
|
||||
data = looper.run(emitter, emitter_plane, uij, theta, levels, ncpu=4, **other_kws)
|
||||
|
||||
# Print the dataframe
|
||||
print(data)
|
||||
#print(data[data.emitter_plane.eq(0)].theta.unique())
|
||||
|
||||
# Accessing the parameters and ouput values for a given sweep (e.g the last one)
|
||||
print(looper.data.iloc[-1])
|
||||
|
||||
# Post-process the output values. For example here, the output is a 1D-array,
|
||||
# make the sum of sweeps with 'Sr' emitter
|
||||
X = np.array([ x for x in data[data.emitter == 'Sr'].output]).sum(axis=0)
|
||||
print(X)
|
||||
|
|
Loading…
Reference in New Issue