11import numpy as np
22import h5py
3+ from typing import List , Union
34
45
56class RaggedArrayNumpyFile :
@@ -8,13 +9,12 @@ def __init__(self, file_path: str, compressed: bool = False):
89 self .file_path = file_path
910 self .compressed = compressed
1011
11- def write (self , ragged_array : list ):
12- # Only support ragged one for the moment.
13- leading_shape = ragged_array [0 ].shape
14- assert all ([leading_shape [1 :] == x .shape [1 :] for x in ragged_array ]), "Only support ragged rank == 1."
12+ def write (self , ragged_array : List [np .ndarray ]):
13+ inner_shape = ragged_array [0 ].shape
1514 values = np .concatenate ([x for x in ragged_array ], axis = 0 )
1615 row_splits = np .cumsum (np .array ([len (x ) for x in ragged_array ], dtype = "int64" ), dtype = "int64" )
17- out = {"values" : values , "row_splits" : row_splits }
16+ row_splits = np .pad (row_splits , [1 , 0 ])
17+ out = {"values" : values , "row_splits" : row_splits , "shape" : np .array ([])}
1818 if self .compressed :
1919 np .savez_compressed (self .file_path , ** out )
2020 else :
@@ -24,16 +24,19 @@ def read(self):
2424 data = np .load (self .file_path )
2525 values = data .get ("values" )
2626 row_splits = data .get ("row_splits" )
27- return np .split (values , row_splits [:- 1 ])
27+ return np .split (values , row_splits [1 :- 1 ])
28+
29+ def __getitem__ (self , item ):
30+ raise NotImplementedError ("Not implemented for file reference load." )
2831
2932
3033class RaggedArrayHDFile :
3134
32- def __init__ (self , file_path : str , compressed : bool = False ):
35+ def __init__ (self , file_path : str , compressed : bool = None ):
3336 self .file_path = file_path
3437 self .compressed = compressed
3538
36- def write (self , ragged_array : list ):
39+ def write (self , ragged_array : List [ np . ndarray ] ):
3740 """Write ragged array to file.
3841
3942 .. code-block:: python
@@ -44,30 +47,57 @@ def write(self, ragged_array: list):
4447 f = RaggedArrayHDFile("test.hdf5")
4548 f.write(data)
4649 print(f.read())
47- print(f[1])
4850
4951 Args:
50- ragged_array (list): List of numpy arrays.
52+ ragged_array (list, tf.RaggedTensor ): List or list of numpy arrays.
5153
5254 Returns:
5355 None.
5456 """
57+ inner_shape = ragged_array [0 ].shape
5558 values = np .concatenate ([x for x in ragged_array ], axis = 0 )
5659 row_splits = np .cumsum (np .array ([len (x ) for x in ragged_array ], dtype = "int64" ), dtype = "int64" )
57- with h5py .File (self .file_path , "w" ) as f :
58- f .create_dataset ("values" , data = values )
59- f .create_dataset ("row_splits" , data = row_splits )
60+ row_splits = np .pad (row_splits , [1 , 0 ])
61+ with h5py .File (self .file_path , "w" ) as file :
62+ file .create_dataset ("values" , data = values , maxshape = [None ] + list (inner_shape )[1 :])
63+ file .create_dataset ("row_splits" , data = row_splits , maxshape = (None , ))
64+ file .create_dataset ("shape" , data = np .array ([]))
6065
6166 def read (self ):
62- file = h5py .File (self .file_path )
63- data = np .split (file ["values" ][()], file ["row_splits" ][:1 ])
64- file .close ()
67+ with h5py .File (self .file_path , "r" ) as file :
68+ data = np .split (file ["values" ][()], file ["row_splits" ][1 :- 1 ])
6569 return data
6670
67- def __getitem__ (self , item ):
68- file = h5py .File (self .file_path )
69- row_splits = file ["row_splits" ]
70- row_splits = np .pad (row_splits , [1 , 0 ])
71- out_data = file ["values" ][row_splits [item ]:row_splits [item + 1 ]]
72- file .close ()
71+ def __getitem__ (self , item : int ):
72+ with h5py .File (self .file_path , "r" ) as file :
73+ row_splits = file ["row_splits" ]
74+ out_data = file ["values" ][row_splits [item ]:row_splits [item + 1 ]]
7375 return out_data
76+
77+ def append (self , item ):
78+ with h5py .File (self .file_path , "r+" ) as file :
79+ file ["values" ].resize (
80+ file ["values" ].shape [0 ] + len (item ), axis = 0
81+ )
82+ split_last = file ["row_splits" ][- 1 ]
83+ file ["row_splits" ].resize (
84+ file ["row_splits" ].shape [0 ] + 1 , axis = 0
85+ )
86+ len_last = len (item )
87+ file ["row_splits" ][- 1 ] = split_last + len_last
88+ file ["values" ][split_last :split_last + len_last ] = item
89+
90+ def append_multiple (self , items : list ):
91+ new_values = np .concatenate (items , axis = 0 )
92+ new_len = len (items )
93+ new_splits = np .cumsum ([len (x ) for x in items ])
94+ with h5py .File (self .file_path , "r+" ) as file :
95+ file ["values" ].resize (
96+ file ["values" ].shape [0 ] + new_values .shape [0 ], axis = 0
97+ )
98+ split_last = file ["row_splits" ][- 1 ]
99+ file ["row_splits" ].resize (
100+ file ["row_splits" ].shape [0 ] + new_len , axis = 0
101+ )
102+ file ["row_splits" ][- new_len :] = split_last + new_splits
103+ file ["values" ][split_last :+ split_last + new_splits [- 1 ]] = new_values
0 commit comments