added better handling of recursive storage; doc strings and added

This commit is contained in:
marco cammarata 2017-03-03 17:51:47 +01:00
parent e1dfaec2a4
commit 61c07dde55
1 changed files with 21 additions and 10 deletions

View File

@ -8,14 +8,15 @@ import h5py
import collections import collections
import logging import logging
log = logging.getLogger(__name__) # __name__ is "foo.bar" here log = logging.getLogger(__name__)
def unwrapArray(a,recursive=True,readH5pyDataset=True): def unwrapArray(a,recursive=True,readH5pyDataset=True):
""" This function takes an object (like a dictionary) and recursively """ This function takes an object (like a dictionary) and recursively
unwraps it solving many issues like the fact that many objects are unwraps it solving issues like:
packaged as 0d array * the fact that many objects are packaged as 0d array
This funciton has also some specific hack for handling h5py limit to This funciton has also some specific hack for handling h5py limits:
handle for example the None object or the numpy unicode ... * handle the None python object
* numpy unicode ...
""" """
# is h5py dataset convert to array # is h5py dataset convert to array
if isinstance(a,h5py.Dataset) and readH5pyDataset: a = a[...] if isinstance(a,h5py.Dataset) and readH5pyDataset: a = a[...]
@ -140,11 +141,16 @@ class DataStorage(dict):
To initialize it: To initialize it:
data = DataStorage( dict( a=(1,2,3),b="add"),filename='store.npz' ) data = DataStorage( a=(1,2,3),b="add",filename='store.npz' )
data = DataStorage( a=(1,2,3), b="add" ) # recursively by default
# data.a will be a DataStorage instance
data = DataStorage( a=dict( b = 1)) );
reads from file if it exists # data.a will be a dictionary
data = DataStorage( a=dict( b = 1),recursive=False )
# reads from file if it exists
data = DataStorage( 'mysaveddata.npz' ) ; data = DataStorage( 'mysaveddata.npz' ) ;
DOES NOT READ FROM FILE (even if it exists)!! DOES NOT READ FROM FILE (even if it exists)!!
@ -156,6 +162,7 @@ class DataStorage(dict):
def __init__(self,*args,filename='data_storage.npz',recursive=True,**kwargs): def __init__(self,*args,filename='data_storage.npz',recursive=True,**kwargs):
# self.filename = kwargs.pop('filename',"data_storage.npz") # self.filename = kwargs.pop('filename',"data_storage.npz")
self.filename = filename self.filename = filename
self._recursive = recursive
# interpret kwargs as dict if there are # interpret kwargs as dict if there are
if len(kwargs) != 0: if len(kwargs) != 0:
fileOrDict = dict(kwargs) fileOrDict = dict(kwargs)
@ -164,6 +171,7 @@ class DataStorage(dict):
else: else:
fileOrDict = dict() fileOrDict = dict()
d = dict(); # data dictionary d = dict(); # data dictionary
if isinstance(fileOrDict,dict): if isinstance(fileOrDict,dict):
d = fileOrDict d = fileOrDict
@ -194,8 +202,9 @@ class DataStorage(dict):
def __setattr__(self, key, value): def __setattr__(self, key, value):
""" allows to add fields with data.test=4 """ """ allows to add fields with data.test=4 """
#print("__setattr__") # check if attr exists is essential (or it fails when defining an instance)
if isinstance(value,(dict,collections.OrderedDict)): value = DataStorage(value) if hasattr(self,"_recursive") and self._recursive and \
isinstance(value,(dict,collections.OrderedDict)): value = DataStorage(value)
super().__setitem__(key, value) super().__setitem__(key, value)
super().__setattr__(key,value) super().__setattr__(key,value)
@ -216,6 +225,7 @@ class DataStorage(dict):
fmt = "%%%ds %%s" % (nchars) fmt = "%%%ds %%s" % (nchars)
s = ["DataStorage obj containing (sorted): ",] s = ["DataStorage obj containing (sorted): ",]
for k in keys: for k in keys:
if k[0] == "_": continue
obj = self[k] obj = self[k]
if isinstance(obj,np.ndarray): if isinstance(obj,np.ndarray):
value_str = "array, size %s, type %s"% ("x".join(map(str,obj.shape)),obj.dtype) value_str = "array, size %s, type %s"% ("x".join(map(str,obj.shape)),obj.dtype)
@ -236,6 +246,7 @@ class DataStorage(dict):
def keys(self): def keys(self):
keys = list(super().keys()) keys = list(super().keys())
keys = [k for k in keys if k != 'filename' ] keys = [k for k in keys if k != 'filename' ]
keys = [k for k in keys if k[0] != '_' ]
return keys return keys
def save(self,fname=None): def save(self,fname=None):