Skip to content

Commit 245c1ef

Browse files
committed
work on kgcnn.io.file updated graph.base and crystal.base to match indices. Added schedule to training. Added coGN to hyper_mp_e_form.py
1 parent 880223b commit 245c1ef

File tree

7 files changed

+172
-28
lines changed

7 files changed

+172
-28
lines changed

changelog.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ v3.0.1
33
* Removed deprecated molecules.
44
* Fix error in ``kgcnn.data.transform.scaler.serial``
55
* Fix error in ``QMDataset`` for molecular features.
6-
*
6+
* Fix error in ``kgcnn.layers.conv.GraphSageNodeLayer`` .
77

88

99
v3.0.0

kgcnn/crystal/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __call__(self, structure: Structure) -> Union[MultiDiGraph, GraphDict]:
5555
g = GraphDict()
5656
g.from_networkx(
5757
nxg, node_attributes=self.node_attributes, edge_attributes=self.edge_attributes,
58-
graph_attributes=self.graph_attributes)
58+
graph_attributes=self.graph_attributes, reverse_edge_indices=True)
5959
return g
6060
return self.call(structure)
6161

kgcnn/graph/base.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,8 @@ def from_networkx(self, graph: nx.Graph,
179179
node_attributes: Union[str, List[str]] = None,
180180
edge_attributes: Union[str, List[str]] = None,
181181
graph_attributes: Union[str, List[str]] = None,
182-
node_labels: str = None):
182+
node_labels: str = None,
183+
reverse_edge_indices: bool = False):
183184
r"""Convert a networkx graph instance into a dictionary of graph-tensors. The networkx graph is always converted
184185
into integer node labels. The former node IDs can be hold in :obj:`node_labels`. Furthermore, node or edge
185186
data can be cast into attributes via :obj:`node_attributes` and :obj:`edge_attributes`.
@@ -195,6 +196,7 @@ def from_networkx(self, graph: nx.Graph,
195196
graph_attributes (str, list): Name of graph attributes to add from graph data. Can also be a list of names.
196197
Default is None.
197198
node_labels (str): Name of the labels of nodes to store former node IDs into. Default is None.
199+
reverse_edge_indices (bool): Whether to reverse edge indices for notation '(ij, i<-j)'. Default is False.
198200
199201
Returns:
200202
self.
@@ -230,7 +232,10 @@ def _attr_to_list(attr):
230232
edges_attr = _attr_to_list(edge_attributes)
231233
edges_attr_dict = {x: [None] * graph_int.number_of_edges() for x in edges_attr}
232234
for i, x in enumerate(graph_int.edges.data()):
233-
edge_id.append(x[:2])
235+
if reverse_edge_indices:
236+
edge_id.append([x[1], x[0]])
237+
else:
238+
edge_id.append([x[0], x[1]])
234239
for d in edges_attr:
235240
if d not in x[2]:
236241
raise KeyError("Edge does not have property '%s'." % d)

kgcnn/io/file.py

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
import h5py
3+
from typing import List, Union
34

45

56
class RaggedArrayNumpyFile:
@@ -8,13 +9,12 @@ def __init__(self, file_path: str, compressed: bool = False):
89
self.file_path = file_path
910
self.compressed = compressed
1011

11-
def write(self, ragged_array: list):
12-
# Only support ragged one for the moment.
13-
leading_shape = ragged_array[0].shape
14-
assert all([leading_shape[1:] == x.shape[1:] for x in ragged_array]), "Only support ragged rank == 1."
12+
def write(self, ragged_array: List[np.ndarray]):
13+
inner_shape = ragged_array[0].shape
1514
values = np.concatenate([x for x in ragged_array], axis=0)
1615
row_splits = np.cumsum(np.array([len(x) for x in ragged_array], dtype="int64"), dtype="int64")
17-
out = {"values": values, "row_splits": row_splits}
16+
row_splits = np.pad(row_splits, [1, 0])
17+
out = {"values": values, "row_splits": row_splits, "shape": np.array([])}
1818
if self.compressed:
1919
np.savez_compressed(self.file_path, **out)
2020
else:
@@ -24,16 +24,19 @@ def read(self):
2424
data = np.load(self.file_path)
2525
values = data.get("values")
2626
row_splits = data.get("row_splits")
27-
return np.split(values, row_splits[:-1])
27+
return np.split(values, row_splits[1:-1])
28+
29+
def __getitem__(self, item):
30+
raise NotImplementedError("Not implemented for file reference load.")
2831

2932

3033
class RaggedArrayHDFile:
3134

32-
def __init__(self, file_path: str, compressed: bool = False):
35+
def __init__(self, file_path: str, compressed: bool = None):
3336
self.file_path = file_path
3437
self.compressed = compressed
3538

36-
def write(self, ragged_array: list):
39+
def write(self, ragged_array: List[np.ndarray]):
3740
"""Write ragged array to file.
3841
3942
.. code-block:: python
@@ -44,30 +47,57 @@ def write(self, ragged_array: list):
4447
f = RaggedArrayHDFile("test.hdf5")
4548
f.write(data)
4649
print(f.read())
47-
print(f[1])
4850
4951
Args:
50-
ragged_array (list): List of numpy arrays.
52+
ragged_array (list, tf.RaggedTensor): List or list of numpy arrays.
5153
5254
Returns:
5355
None.
5456
"""
57+
inner_shape = ragged_array[0].shape
5558
values = np.concatenate([x for x in ragged_array], axis=0)
5659
row_splits = np.cumsum(np.array([len(x) for x in ragged_array], dtype="int64"), dtype="int64")
57-
with h5py.File(self.file_path, "w") as f:
58-
f.create_dataset("values", data=values)
59-
f.create_dataset("row_splits", data=row_splits)
60+
row_splits = np.pad(row_splits, [1, 0])
61+
with h5py.File(self.file_path, "w") as file:
62+
file.create_dataset("values", data=values, maxshape=[None] + list(inner_shape)[1:])
63+
file.create_dataset("row_splits", data=row_splits, maxshape=(None, ))
64+
file.create_dataset("shape", data=np.array([]))
6065

6166
def read(self):
62-
file = h5py.File(self.file_path)
63-
data = np.split(file["values"][()], file["row_splits"][:1])
64-
file.close()
67+
with h5py.File(self.file_path, "r") as file:
68+
data = np.split(file["values"][()], file["row_splits"][1:-1])
6569
return data
6670

67-
def __getitem__(self, item):
68-
file = h5py.File(self.file_path)
69-
row_splits = file["row_splits"]
70-
row_splits = np.pad(row_splits, [1, 0])
71-
out_data = file["values"][row_splits[item]:row_splits[item+1]]
72-
file.close()
71+
def __getitem__(self, item: int):
72+
with h5py.File(self.file_path, "r") as file:
73+
row_splits = file["row_splits"]
74+
out_data = file["values"][row_splits[item]:row_splits[item+1]]
7375
return out_data
76+
77+
def append(self, item):
78+
with h5py.File(self.file_path, "r+") as file:
79+
file["values"].resize(
80+
file["values"].shape[0] + len(item), axis=0
81+
)
82+
split_last = file["row_splits"][-1]
83+
file["row_splits"].resize(
84+
file["row_splits"].shape[0] + 1, axis=0
85+
)
86+
len_last = len(item)
87+
file["row_splits"][-1] = split_last + len_last
88+
file["values"][split_last:split_last+len_last] = item
89+
90+
def append_multiple(self, items: list):
91+
new_values = np.concatenate(items, axis=0)
92+
new_len = len(items)
93+
new_splits = np.cumsum([len(x) for x in items])
94+
with h5py.File(self.file_path, "r+") as file:
95+
file["values"].resize(
96+
file["values"].shape[0] + new_values.shape[0], axis=0
97+
)
98+
split_last = file["row_splits"][-1]
99+
file["row_splits"].resize(
100+
file["row_splits"].shape[0] + new_len, axis=0
101+
)
102+
file["row_splits"][-new_len:] = split_last + new_splits
103+
file["values"][split_last:+split_last+new_splits[-1]] = new_values

kgcnn/training/schedule.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import numpy as np
22
import tensorflow as tf
33

4+
ks = tf.keras
45

5-
@tf.keras.utils.register_keras_serializable(package='kgcnn', name='LinearWarmupExponentialDecay')
6+
7+
@ks.utils.register_keras_serializable(package='kgcnn', name='LinearWarmupExponentialDecay')
68
class LinearWarmupExponentialDecay(tf.optimizers.schedules.LearningRateSchedule):
79
r"""This schedule combines a linear warmup with an exponential decay.
810
Combines :obj:` tf.optimizers.schedules.PolynomialDecay` with an actual increase during warmup
@@ -63,3 +65,31 @@ def get_config(self):
6365
config = {}
6466
config.update(self._input_config_settings)
6567
return config
68+
69+
70+
@ks.utils.register_keras_serializable(package='kgcnn', name='KerasPolynomialDecaySchedule')
71+
class KerasPolynomialDecaySchedule(ks.optimizers.schedules.PolynomialDecay):
72+
r"""This schedule extends :obj:` tf.optimizers.schedules.PolynomialDecay` ."""
73+
74+
def __init__(self, dataset_size: int, batch_size: int, epochs: int, lr_start: float = 0.0005,
75+
lr_stop: float = 1e-5):
76+
"""Initialize class.
77+
78+
Args:
79+
dataset_size (int): Size of the dataset.
80+
batch_size (int): Batch size for training.
81+
epochs (int): Total epochs to run schedule on.
82+
lr_start (int): Learning rate at the start.
83+
lr_stop (int): Final learning rate.
84+
"""
85+
self._input_config_settings = {"lr_start": lr_start, "lr_stop": lr_stop,
86+
"epochs": epochs, "batch_size": batch_size, "dataset_size": dataset_size}
87+
steps_per_epoch = dataset_size / batch_size
88+
num_steps = epochs * steps_per_epoch
89+
super().__init__(initial_learning_rate=lr_start, decay_steps=num_steps, end_learning_rate=lr_stop)
90+
91+
def get_config(self):
92+
"""Get config for this class."""
93+
config = {}
94+
config.update(self._input_config_settings)
95+
return config

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,5 @@ ase>=3.22.1
1616
click>=7.1.2
1717
visual_graph_datasets>=0.7.1
1818
brotli>=1.0.9
19-
pyxtal>=0.5.5
19+
pyxtal>=0.5.5
20+
h5py

training/hyper/hyper_mp_e_form.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,4 +419,82 @@
419419
"kgcnn_version": "2.1.1"
420420
}
421421
},
422+
"coGN": {
423+
"model": {
424+
"module_name": "kgcnn.literature.coGN",
425+
"class_name": "make_model",
426+
"config": {
427+
"name": "coGN",
428+
"inputs": {
429+
"offset": {"shape": (None, 3), "name": "offset", "dtype": "float32", "ragged": True},
430+
"cell_translation": None,
431+
"affine_matrix": None,
432+
"voronoi_ridge_area": None,
433+
"atomic_number": {"shape": (None,), "name": "atomic_number", "dtype": "int32", "ragged": True},
434+
"frac_coords": None,
435+
"coords": None,
436+
"multiplicity": {"shape": (None,), "name": "multiplicity", "dtype": "int32", "ragged": True},
437+
"lattice_matrix": None,
438+
"edge_indices": {"shape": (None, 2), "name": "edge_indices", "dtype": "int32", "ragged": True},
439+
"line_graph_edge_indices": None,
440+
},
441+
# All default.
442+
}
443+
},
444+
"training": {
445+
"cross_validation": {"class_name": "KFold",
446+
"config": {"n_splits": 5, "random_state": 42, "shuffle": True}},
447+
"fit": {
448+
"batch_size": 64, "epochs": 800, "validation_freq": 10, "verbose": 2,
449+
"callbacks": [
450+
# {"class_name": "kgcnn>LinearLearningRateScheduler", "config": {
451+
# "learning_rate_start": 0.0005, "learning_rate_stop": 0.5e-05, "epo_min": 0, "epo": 800,
452+
# "verbose": 0}
453+
# }
454+
]
455+
},
456+
"compile": {
457+
"optimizer": {
458+
"class_name": "Adam",
459+
"config": {
460+
"learning_rate": {
461+
"class_name": "kgcnn>KerasPolynomialDecaySchedule",
462+
"config": {
463+
"dataset_size": 106.201, "batch_size": 64, "epochs": 800,
464+
"lr_start": 0.0005, "lr_stop": 1.0e-05
465+
}
466+
}
467+
}
468+
},
469+
"loss": "mean_absolute_error"
470+
},
471+
"scaler": {
472+
"class_name": "StandardScaler",
473+
"module_name": "kgcnn.data.transform.scaler.standard",
474+
"config": {"with_std": True, "with_mean": True, "copy": True}
475+
},
476+
"multi_target_indices": None
477+
},
478+
"data": {
479+
"dataset": {
480+
"class_name": "MatProjectEFormDataset",
481+
"module_name": "kgcnn.data.datasets.MatProjectEFormDataset",
482+
"config": {},
483+
"methods": [
484+
{"set_representation": {
485+
"pre_processor": {"class_name": "KNNAsymmetricUnitCell",
486+
"module_name": "kgcnn.crystal.preprocessor",
487+
"config": {"k": 24}
488+
},
489+
"reset_graphs": False}}
490+
]
491+
},
492+
"data_unit": "eV/atom"
493+
},
494+
"info": {
495+
"postfix": "",
496+
"postfix_file": "",
497+
"kgcnn_version": "3.0.1"
498+
}
499+
},
422500
}

0 commit comments

Comments
 (0)