Add kwargs to the pipeline function
The user can now provide a 'pipeline' function with only '**kwargs' in arguments. It is more flexible and easier to write. The 'sweep_index' keyword is always automatically added, so the 'passindex' option has been removed since it was redundant with the index of the final dataframe object. The user-defined 'pipeline' function can now return anything. It is no longer limited to ([x,],[y,]) format.
This commit is contained in:
		
							parent
							
								
									db6ee27699
								
							
						
					
					
						commit
						feaaabc9c4
					
				|  | @ -17,8 +17,8 @@ | ||||||
| # along with this msspec.  If not, see <http://www.gnu.org/licenses/>. | # along with this msspec.  If not, see <http://www.gnu.org/licenses/>. | ||||||
| # | # | ||||||
| # Source file  : src/msspec/looper.py | # Source file  : src/msspec/looper.py | ||||||
| # Last modified: Mon, 27 Sep 2021 17:49:48 +0200 | # Last modified: Wed, 26 Feb 2025 11:15:54 +0100 | ||||||
| # Committed by : sylvain tricot <sylvain.tricot@univ-rennes1.fr> | # Committed by : Sylvain Tricot <sylvain.tricot@univ-rennes.fr> | ||||||
| 
 | 
 | ||||||
| from collections import OrderedDict | from collections import OrderedDict | ||||||
| from functools import partial | from functools import partial | ||||||
|  | @ -92,9 +92,8 @@ class Sweep: | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class SweepRange: | class SweepRange: | ||||||
|     def __init__(self, *sweeps, passindex=False): |     def __init__(self, *sweeps): | ||||||
|         self.sweeps = sweeps |         self.sweeps = sweeps | ||||||
|         self.passindex = passindex |  | ||||||
|         self.index = 0 |         self.index = 0 | ||||||
| 
 | 
 | ||||||
|         # First check that sweeps that are linked to another on are all included |         # First check that sweeps that are linked to another on are all included | ||||||
|  | @ -158,17 +157,15 @@ class SweepRange: | ||||||
|                     for s in [sweep,] + children: |                     for s in [sweep,] + children: | ||||||
|                         key, value = s[idx] |                         key, value = s[idx] | ||||||
|                         row[key] = value |                         row[key] = value | ||||||
|                 if self.passindex: |                 row['sweep_index'] = i | ||||||
|                     row['sweep_index'] = i |  | ||||||
|             return row |             return row | ||||||
|         else: |         else: | ||||||
|             raise StopIteration |             raise StopIteration | ||||||
| 
 | 
 | ||||||
|     @property |     @property | ||||||
|     def columns(self): |     def columns(self): | ||||||
|         cols = [sweep.key for sweep in self.sweeps] |         cols  = ['sweep_index'] | ||||||
|         if self.passindex: |         cols += [sweep.key for sweep in self.sweeps] | ||||||
|             cols.append('sweep_index') |  | ||||||
|         return cols |         return cols | ||||||
| 
 | 
 | ||||||
|     @property |     @property | ||||||
|  | @ -202,31 +199,27 @@ class Looper: | ||||||
|         logger.debug("Pipeline called with {}".format(x)) |         logger.debug("Pipeline called with {}".format(x)) | ||||||
|         return self.pipeline(**x) |         return self.pipeline(**x) | ||||||
| 
 | 
 | ||||||
|     def run(self, *sweeps, ncpu=1, passindex=False): |     def run(self, *sweeps, ncpu=1, **kwargs): | ||||||
|         logger.info("Loop starts...") |         logger.info("Loop starts...") | ||||||
|         # prepare the list of inputs |         # prepare the list of inputs | ||||||
|         sr = SweepRange(*sweeps, passindex=passindex) |         sr = SweepRange(*sweeps) | ||||||
|         items = sr.items |         items = sr.items | ||||||
| 
 | 
 | ||||||
|         data = [] |         data = [] | ||||||
| 
 | 
 | ||||||
|  |         t0 = time.time() | ||||||
|  | 
 | ||||||
|         if ncpu == 1: |         if ncpu == 1: | ||||||
|             # serial processing... |             # serial processing... | ||||||
|             logger.info("serial processing...") |             logger.info("serial processing...") | ||||||
|             t0 = time.time() |  | ||||||
| 
 |  | ||||||
|             for i, values in enumerate(items): |             for i, values in enumerate(items): | ||||||
|  |                 values.update(kwargs) | ||||||
|                 result = self._wrapper(values) |                 result = self._wrapper(values) | ||||||
|                 data.append(result) |                 data.append(result) | ||||||
| 
 |  | ||||||
|             t1 = time.time() |  | ||||||
|             dt = t1 - t0 |  | ||||||
|             logger.info("Processed {:d} sets of inputs in {:.3f} seconds".format( |  | ||||||
|                 len(sr), dt)) |  | ||||||
| 
 |  | ||||||
|         else: |         else: | ||||||
|             # Parallel processing... |             # Parallel processing... | ||||||
|             chunksize = 1 #int(nsets/ncpu) |             chunksize = 1 #int(nsets/ncpu) | ||||||
|  |             [values.update(kwargs) for values in items] | ||||||
|             logger.info(("Parallel processing over {:d} cpu (chunksize={:d})..." |             logger.info(("Parallel processing over {:d} cpu (chunksize={:d})..." | ||||||
|                          "").format(ncpu, chunksize)) |                          "").format(ncpu, chunksize)) | ||||||
|             t0 = time.time() |             t0 = time.time() | ||||||
|  | @ -236,21 +229,23 @@ class Looper: | ||||||
|             pool.close() |             pool.close() | ||||||
|             pool.join() |             pool.join() | ||||||
| 
 | 
 | ||||||
|             t1 = time.time() |         t1 = time.time() | ||||||
|             dt = t1 - t0 |         dt = t1 - t0 | ||||||
|             logger.info(("Processed {:d} sets of inputs in {:.3f} seconds" |         logger.info(("Processed {:d} sets of inputs in {:.3f} seconds" | ||||||
|                          "").format(len(sr), dt)) |                         "").format(len(sr), dt)) | ||||||
| 
 | 
 | ||||||
|         # Create the DataFrame |         # Create the DataFrame | ||||||
|         dfdata = [] |         dfdata = [] | ||||||
|         columns = sr.columns + ['output',] |         columns = sr.columns + list(kwargs.keys()) + ['output',] | ||||||
| 
 | 
 | ||||||
|         for i in range(len(sr)): |         for i in range(len(sr)): | ||||||
|             row = list(items[i].values()) |             row = list(items[i].values()) | ||||||
|             row.append(data[i]) |             row.append(data[i]) | ||||||
|             dfdata.append(row) |             dfdata.append(row) | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
|         df = pd.DataFrame(dfdata, columns=columns) |         df = pd.DataFrame(dfdata, columns=columns) | ||||||
|  |         df = df.drop(columns=['sweep_index']) | ||||||
| 
 | 
 | ||||||
|         self.data = df |         self.data = df | ||||||
| 
 | 
 | ||||||
|  | @ -259,14 +254,14 @@ class Looper: | ||||||
|         # of corresponding dict of parameters {'keyA': [val0,...valn], |         # of corresponding dict of parameters {'keyA': [val0,...valn], | ||||||
|         # 'keyB': [val0,...valn], ...} |         # 'keyB': [val0,...valn], ...} | ||||||
| 
 | 
 | ||||||
|         all_xy = [] |         # all_xy = [] | ||||||
|         for irow, row in df.iterrows(): |         # for irow, row in df.iterrows(): | ||||||
|             all_xy.append(row.output[0]) |             # all_xy.append(row.output[0]) | ||||||
|             all_xy.append(row.output[1]) |             # all_xy.append(row.output[1]) | ||||||
|         parameters = df.to_dict() |         # parameters = df.to_dict() | ||||||
|         parameters.pop('output') |         # parameters.pop('output') | ||||||
| 
 | 
 | ||||||
|         return all_xy, parameters |         return self.data #all_xy, parameters | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -276,17 +271,16 @@ class Looper: | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     import numpy as np |     import numpy as np | ||||||
|     import time |     import time | ||||||
|  |     import logging | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     logging.basicConfig(level=logging.DEBUG) | ||||||
| 
 | 
 | ||||||
|     logger.setLevel(logging.DEBUG) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|     def bar(**kwargs): |     def bar(**kwargs): | ||||||
|         return 0 |         i = kwargs.get('sweep_index') | ||||||
| 
 |         return np.linspace(0,i,10) | ||||||
|     def post_process(data): |  | ||||||
|         x = data.x.unique() |  | ||||||
|         y = data.y.unique() |  | ||||||
| 
 |  | ||||||
| 
 | 
 | ||||||
|     theta = Sweep(key='theta', comments="The polar angle", |     theta = Sweep(key='theta', comments="The polar angle", | ||||||
|                   start=-70, stop=70, num=3, |                   start=-70, stop=70, num=3, | ||||||
|  | @ -314,7 +308,16 @@ if __name__ == "__main__": | ||||||
| 
 | 
 | ||||||
|     looper = Looper() |     looper = Looper() | ||||||
|     looper.pipeline = bar |     looper.pipeline = bar | ||||||
|     data = looper.run(emitter, emitter_plane, uij, theta, levels, ncpu=4, |     other_kws = {'un':1, 'deux':2} | ||||||
|                       passindex=True) |     data = looper.run(emitter, emitter_plane, uij, theta, levels, ncpu=4, **other_kws) | ||||||
|  | 
 | ||||||
|  |     # Print the dataframe | ||||||
|     print(data) |     print(data) | ||||||
|     #print(data[data.emitter_plane.eq(0)].theta.unique()) | 
 | ||||||
|  |     # Accessing the parameters and ouput values for a given sweep (e.g the last one) | ||||||
|  |     print(looper.data.iloc[-1]) | ||||||
|  | 
 | ||||||
|  |     # Post-process the output values. For example here, the output is a 1D-array, | ||||||
|  |     # make the sum of sweeps with 'Sr' emitter | ||||||
|  |     X = np.array([ x for x in data[data.emitter == 'Sr'].output]).sum(axis=0) | ||||||
|  |     print(X) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue