Skip to content

Commit 41e05f5

Browse files
committed
Typing for python=3.8
1 parent eace3f3 commit 41e05f5

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

kgcnn/data/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import tensorflow as tf
44
import pandas as pd
55
import os
6-
from typing import Union
6+
from typing import Union, List
77
from collections.abc import MutableMapping
88
from kgcnn.data.utils import save_pickle_file, load_pickle_file, ragged_tensor_from_nested_numpy
99
from kgcnn.graph.base import GraphNumpyContainer, GraphDict
@@ -493,7 +493,7 @@ def message_error(msg):
493493
message_error("Can not check shape for '%s'." % x["name"])
494494
return
495495

496-
def set_methods(self, method_list: list[dict]) -> None:
496+
def set_methods(self, method_list: List[dict]) -> None:
497497
r"""Apply a list of serialized class-methods on the dataset.
498498
499499
This can extend the config-serialization scheme in :obj:`kgcnn.utils.serial`.

kgcnn/io/loader.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22
import tensorflow as tf
3-
from typing import Union
3+
from typing import Union, List
44
from kgcnn.data.utils import ragged_tensor_from_nested_numpy
55
from kgcnn.data.base import MemoryGraphDataset
66
ks = tf.keras
@@ -10,9 +10,9 @@ class GraphBatchLoader(ks.utils.Sequence):
1010
r"""Example (minimal) implementation of a graph batch loader based on :obj:`ks.utils.Sequence`."""
1111

1212
def __init__(self,
13-
data: Union[list[dict], MemoryGraphDataset],
14-
inputs: Union[dict, list[dict]],
15-
outputs: Union[dict, list[dict]],
13+
data: Union[List[dict], MemoryGraphDataset],
14+
inputs: Union[dict, List[dict]],
15+
outputs: Union[dict, List[dict]],
1616
batch_size: int = 32,
1717
shuffle: bool = False):
1818
"""Initialization with data and input information.
@@ -65,7 +65,7 @@ def on_epoch_end(self):
6565
self._shuffle_indices()
6666

6767
@staticmethod
68-
def _to_tensor(item, is_ragged):
68+
def _to_tensor(item: Union[np.ndarray, list], is_ragged: bool):
6969
if is_ragged:
7070
return ragged_tensor_from_nested_numpy(item)
7171
else:

0 commit comments

Comments
 (0)