1- # Copyright (c) 2024, NVIDIA CORPORATION.
1+ # Copyright (c) 2024-2025 , NVIDIA CORPORATION.
22# Licensed under the Apache License, Version 2.0 (the "License");
33# you may not use this file except in compliance with the License.
44# You may obtain a copy of the License at
@@ -79,9 +79,15 @@ def get_reader(
7979 return DistSampleReader (self ._directory , format = self ._format , rank = rank )
8080
8181 def __write_minibatches_coo (self , minibatch_dict ):
82- has_edge_ids = minibatch_dict ["edge_id" ] is not None
83- has_edge_types = minibatch_dict ["edge_type" ] is not None
84- has_weights = minibatch_dict ["weight" ] is not None
82+ has_edge_ids = (
83+ "edge_id" in minibatch_dict and minibatch_dict ["edge_id" ] is not None
84+ )
85+ has_edge_types = (
86+ "edge_type" in minibatch_dict and minibatch_dict ["edge_type" ] is not None
87+ )
88+ has_weights = (
89+ "weight" in minibatch_dict and minibatch_dict ["weight" ] is not None
90+ )
8591
8692 if minibatch_dict ["renumber_map" ] is None :
8793 raise ValueError (
@@ -92,22 +98,22 @@ def __write_minibatches_coo(self, minibatch_dict):
9298 if len (minibatch_dict ["batch_id" ]) == 0 :
9399 return
94100
95- fanout_length = (len (minibatch_dict ["label_hop_offsets" ]) - 1 ) // len (
96- minibatch_dict ["batch_id" ]
97- )
101+ fanout_length = len (minibatch_dict ["fanout" ])
102+ total_num_batches = (
103+ len (minibatch_dict ["label_hop_offsets" ]) - 1
104+ ) / fanout_length
98105
99- for p in range (
100- 0 , int (ceil (len (minibatch_dict ["batch_id" ]) / self .__batches_per_partition ))
101- ):
106+ for p in range (0 , int (ceil (total_num_batches / self .__batches_per_partition ))):
102107 partition_start = p * (self .__batches_per_partition )
103108 partition_end = (p + 1 ) * (self .__batches_per_partition )
104109
105110 label_hop_offsets_array_p = minibatch_dict ["label_hop_offsets" ][
106111 partition_start * fanout_length : partition_end * fanout_length + 1
107112 ]
108113
109- batch_id_array_p = minibatch_dict ["batch_id" ][partition_start :partition_end ]
110- start_batch_id = batch_id_array_p [0 ]
114+ num_batches_p = len (label_hop_offsets_array_p ) - 1
115+
116+ start_batch_id = minibatch_dict ["batch_start" ]
111117
112118 input_offsets_p = minibatch_dict ["input_offsets" ][
113119 partition_start : (partition_end + 1 )
@@ -171,7 +177,7 @@ def __write_minibatches_coo(self, minibatch_dict):
171177 }
172178 )
173179
174- end_batch_id = start_batch_id + len ( batch_id_array_p ) - 1
180+ end_batch_id = start_batch_id + num_batches_p - 1
175181 rank = minibatch_dict ["rank" ] if "rank" in minibatch_dict else 0
176182
177183 full_output_path = os .path .join (
@@ -188,9 +194,15 @@ def __write_minibatches_coo(self, minibatch_dict):
188194 )
189195
190196 def __write_minibatches_csr (self , minibatch_dict ):
191- has_edge_ids = minibatch_dict ["edge_id" ] is not None
192- has_edge_types = minibatch_dict ["edge_type" ] is not None
193- has_weights = minibatch_dict ["weight" ] is not None
197+ has_edge_ids = (
198+ "edge_id" in minibatch_dict and minibatch_dict ["edge_id" ] is not None
199+ )
200+ has_edge_types = (
201+ "edge_type" in minibatch_dict and minibatch_dict ["edge_type" ] is not None
202+ )
203+ has_weights = (
204+ "weight" in minibatch_dict and minibatch_dict ["weight" ] is not None
205+ )
194206
195207 if minibatch_dict ["renumber_map" ] is None :
196208 raise ValueError (
@@ -201,22 +213,22 @@ def __write_minibatches_csr(self, minibatch_dict):
201213 if len (minibatch_dict ["batch_id" ]) == 0 :
202214 return
203215
204- fanout_length = (len (minibatch_dict ["label_hop_offsets" ]) - 1 ) // len (
205- minibatch_dict ["batch_id" ]
206- )
216+ fanout_length = len (minibatch_dict ["fanout" ])
217+ total_num_batches = (
218+ len (minibatch_dict ["label_hop_offsets" ]) - 1
219+ ) / fanout_length
207220
208- for p in range (
209- 0 , int (ceil (len (minibatch_dict ["batch_id" ]) / self .__batches_per_partition ))
210- ):
221+ for p in range (0 , int (ceil (total_num_batches / self .__batches_per_partition ))):
211222 partition_start = p * (self .__batches_per_partition )
212223 partition_end = (p + 1 ) * (self .__batches_per_partition )
213224
214225 label_hop_offsets_array_p = minibatch_dict ["label_hop_offsets" ][
215226 partition_start * fanout_length : partition_end * fanout_length + 1
216227 ]
217228
218- batch_id_array_p = minibatch_dict ["batch_id" ][partition_start :partition_end ]
219- start_batch_id = batch_id_array_p [0 ]
229+ num_batches_p = len (label_hop_offsets_array_p ) - 1
230+
231+ start_batch_id = minibatch_dict ["batch_start" ]
220232
221233 input_offsets_p = minibatch_dict ["input_offsets" ][
222234 partition_start : (partition_end + 1 )
@@ -292,7 +304,7 @@ def __write_minibatches_csr(self, minibatch_dict):
292304 }
293305 )
294306
295- end_batch_id = start_batch_id + len ( batch_id_array_p ) - 1
307+ end_batch_id = start_batch_id + num_batches_p - 1
296308 rank = minibatch_dict ["rank" ] if "rank" in minibatch_dict else 0
297309
298310 full_output_path = os .path .join (
@@ -309,12 +321,19 @@ def __write_minibatches_csr(self, minibatch_dict):
309321 )
310322
311323 def write_minibatches (self , minibatch_dict ):
312- if (minibatch_dict ["majors" ] is not None ) and (
313- minibatch_dict ["minors" ] is not None
314- ):
324+ if "minors" not in minibatch_dict :
325+ raise ValueError ("invalid columns" )
326+
327+ # PLC API specifies this behavior for empty input
328+ # This needs to be handled here to avoid causing a hang
329+ if len (minibatch_dict ["minors" ]) == 0 :
330+ return
331+
332+ if "majors" in minibatch_dict and minibatch_dict ["majors" ] is not None :
315333 self .__write_minibatches_coo (minibatch_dict )
316- elif (minibatch_dict ["major_offsets" ] is not None ) and (
317- minibatch_dict ["minors" ] is not None
334+ elif (
335+ "major_offsets" in minibatch_dict
336+ and minibatch_dict ["major_offsets" ] is not None
318337 ):
319338 self .__write_minibatches_csr (minibatch_dict )
320339 else :
0 commit comments