33
44import argparse
55import os
6- import pickle
76import sys
87import time
98import torch
109import torch .distributed as dist
1110import torch .multiprocessing as mp
1211
1312# torch.set_printoptions(threshold=5000)
14- from tempfile import TemporaryDirectory
1513
1614nlp_path = os .path .abspath ("../../" )
1715if nlp_path not in sys .path :
1816 sys .path .insert (0 , nlp_path )
1917
18+ sys .path .insert (0 , "./" )
2019
2120from utils_nlp .models .transformers .abstractive_summarization_bertsum import (
2221 BertSumAbs ,
4039parser .add_argument (
4140 "--dist_url" ,
4241 type = str ,
43- default = "tcp://127.0.0.1:29500 " ,
42+ default = "tcp://127.0.0.1:29507 " ,
4443 help = "URL specifying how to initialize the process groupi." ,
4544)
4645parser .add_argument (
5655parser .add_argument (
5756 "--data_dir" ,
5857 type = str ,
59- default = "./" ,
58+ default = "./abstemp " ,
6059 help = "Directory to download the preprocessed data." ,
6160)
6261parser .add_argument (
101100 "--max_steps" ,
102101 type = int ,
103102 default = 5e4 ,
104- help = "Maximum number of training steps run in training. \
105- If quick_run is set , it 's not used .",
103+ help = """ Maximum number of training steps run in training.
104+ If quick_run is set, it's not used.""" ,
106105)
107106parser .add_argument (
108107 "--warmup_steps_bert" ,
@@ -176,8 +175,11 @@ def main():
176175 print ("data_dir is {}" .format (args .data_dir ))
177176 print ("cache_dir is {}" .format (args .cache_dir ))
178177
178+ TOP_N = - 1
179+ if args .quick_run .lower () == "false" :
180+ TOP_N = 10
179181 train_dataset , test_dataset = CNNDMSummarizationDataset (
180- top_n = - 1 , local_cache_path = args .data_dir , prepare_extractive = False
182+ top_n = TOP_N , local_cache_path = args .data_dir , prepare_extractive = False
181183 )
182184
183185 ngpus_per_node = torch .cuda .device_count ()
@@ -212,6 +214,7 @@ def main_worker(
212214 checkpoint = os .path .join (args .cache_dir , args .checkpoint_filename )
213215 else :
214216 checkpoint = None
217+
215218 # train_sum_dataset, test_sum_dataset = load_processed_cnndm_abs(args.data_dir)
216219 def this_validate (class_obj ):
217220 return validate (class_obj , test_dataset )
@@ -225,8 +228,8 @@ def this_validate(class_obj):
225228 fp16 = args .fp16 .lower () == "true"
226229 print ("fp16 is {}" .format (fp16 ))
227230 # total number of steps for training
228- MAX_STEPS = 50
229- SAVE_EVERY = 50
231+ MAX_STEPS = 10
232+ SAVE_EVERY = 10
230233 REPORT_EVERY = 10
231234 # number of steps for warm up
232235 WARMUP_STEPS_BERT = MAX_STEPS
@@ -235,7 +238,7 @@ def this_validate(class_obj):
235238 MAX_STEPS = args .max_steps
236239 WARMUP_STEPS_BERT = args .warmup_steps_bert
237240 WARMUP_STEPS_DEC = args .warmup_steps_dec
238- SAVE_EVERY = args . save_every
241+ SAVE_EVERY = save_every
239242 REPORT_EVERY = args .report_every
240243
241244 print ("max steps is {}" .format (MAX_STEPS ))
@@ -266,14 +269,16 @@ def this_validate(class_obj):
266269
267270 end = time .time ()
268271 print ("rank {0}, duration {1:.6f}s" .format (rank , end - start ))
269- if rank == 0 or local_rank == - 1 :
272+ if local_rank in [0 , - 1 ] and args .rank == 0 :
273+ TOP_N = - 1
274+ if args .quick_run .lower () == "false" :
275+ TOP_N = ngpus_per_node
270276 saved_model_path = os .path .join (
271277 args .output_dir , "{}_step{}" .format (args .model_filename , MAX_STEPS )
272278 )
273279 summarizer .save_model (MAX_STEPS , saved_model_path )
274- top_n = 8
275280 prediction = summarizer .predict (
276- test_dataset .shorten (top_n = top_n ), batch_size = 4 , num_gpus = ngpus_per_node
281+ test_dataset .shorten (top_n = TOP_N ), batch_size = ngpus_per_node , num_gpus = ngpus_per_node
277282 )
278283 print (prediction [0 ])
279284
0 commit comments