File tree Expand file tree Collapse file tree 2 files changed +7
-7
lines changed
Expand file tree Collapse file tree 2 files changed +7
-7
lines changed Original file line number Diff line number Diff line change 33import tensorflow as tf
44import pandas as pd
55import os
6- from typing import Union
6+ from typing import Union , List
77from collections .abc import MutableMapping
88from kgcnn .data .utils import save_pickle_file , load_pickle_file , ragged_tensor_from_nested_numpy
99from kgcnn .graph .base import GraphNumpyContainer , GraphDict
@@ -493,7 +493,7 @@ def message_error(msg):
493493 message_error ("Can not check shape for '%s'." % x ["name" ])
494494 return
495495
496- def set_methods (self , method_list : list [dict ]) -> None :
496+ def set_methods (self , method_list : List [dict ]) -> None :
497497 r"""Apply a list of serialized class-methods on the dataset.
498498
499499 This can extend the config-serialization scheme in :obj:`kgcnn.utils.serial`.
Original file line number Diff line number Diff line change 11import numpy as np
22import tensorflow as tf
3- from typing import Union
3+ from typing import Union , List
44from kgcnn .data .utils import ragged_tensor_from_nested_numpy
55from kgcnn .data .base import MemoryGraphDataset
66ks = tf .keras
@@ -10,9 +10,9 @@ class GraphBatchLoader(ks.utils.Sequence):
1010 r"""Example (minimal) implementation of a graph batch loader based on :obj:`ks.utils.Sequence`."""
1111
1212 def __init__ (self ,
13- data : Union [list [dict ], MemoryGraphDataset ],
14- inputs : Union [dict , list [dict ]],
15- outputs : Union [dict , list [dict ]],
13+ data : Union [List [dict ], MemoryGraphDataset ],
14+ inputs : Union [dict , List [dict ]],
15+ outputs : Union [dict , List [dict ]],
1616 batch_size : int = 32 ,
1717 shuffle : bool = False ):
1818 """Initialization with data and input information.
@@ -65,7 +65,7 @@ def on_epoch_end(self):
6565 self ._shuffle_indices ()
6666
6767 @staticmethod
68- def _to_tensor (item , is_ragged ):
68+ def _to_tensor (item : Union [ np . ndarray , list ], is_ragged : bool ):
6969 if is_ragged :
7070 return ragged_tensor_from_nested_numpy (item )
7171 else :
You can’t perform that action at this time.
0 commit comments