Skip to content

Commit f99a518

Browse files
committed
Set seed
1 parent 36c82fb commit f99a518

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

benchmarks/cugraph-dgl/python-script/dgl_dataloading_benchmark/dgl_benchmark.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import os
1919
import time
2020
import json
21+
import random
22+
import numpy as np
2123
from argparse import ArgumentParser
2224

2325

@@ -218,6 +220,11 @@ def dataloading_benchmark(g, train_idx, fanouts, batch_sizes, use_uva):
218220
print("==============================================")
219221
return time_ls
220222

223+
def set_seed(seed):
224+
random.seed(seed)
225+
np.random.seed(seed)
226+
torch.manual_seed(seed)
227+
torch.cuda.manual_seed_all(seed)
221228

222229
if __name__ == "__main__":
223230
parser = ArgumentParser()
@@ -230,13 +237,14 @@ def dataloading_benchmark(g, train_idx, fanouts, batch_sizes, use_uva):
230237
)
231238
parser.add_argument("--batch_sizes", type=str, default="512,1024")
232239
parser.add_argument("--do_not_use_uva", action="store_true")
240+
parser.add_argument("--seed", type=int, default=42)
233241
args = parser.parse_args()
234242

235243
if args.do_not_use_uva:
236244
use_uva = False
237245
else:
238246
use_uva = True
239-
247+
set_seed(args.seed)
240248
replication_factors = [int(x) for x in args.replication_factors.split(",")]
241249
fanouts = [[int(y) for y in x.split("_")] for x in args.fanouts.split(",")]
242250
batch_sizes = [int(x) for x in args.batch_sizes.split(",")]

0 commit comments

Comments
 (0)