1414 sys .path .insert (0 , nlp_path )
1515
1616
17- from utils_nlp .dataset .cnndm import CNNDMSummarizationDataset
17+ from utils_nlp .dataset .cnndm import CNNDMBertSumProcessedData , CNNDMSummarizationDataset
1818from utils_nlp .models .transformers .extractive_summarization import (
1919 ExtractiveSummarizer ,
2020 ExtSumProcessedData ,
4747parser .add_argument ("--encoder" , type = str .lower , default = 'transformer' ,
4848 choices = ['baseline' , 'classifier' , 'transformer' , 'rnn' ],
4949 help = "Encoder types in the extractive summarizer." )
50+ parser .add_argument (
51+ "--max_pos_length" ,
52+ type = int ,
53+ default = 512 ,
54+ help = "maximum input length in terms of input token numbers in training" ,
55+ )
5056parser .add_argument ("--learning_rate" , type = float , default = 1e-3 ,
5157 help = "Learning rate." )
5258parser .add_argument ("--batch_size" , type = int , default = 3000 ,
@@ -86,8 +92,8 @@ def main():
8692 os .makedirs (args .cache_dir , exist_ok = True )
8793
8894 ngpus_per_node = torch .cuda .device_count ()
89-
90- summarizer = ExtractiveSummarizer (args .model_name , args .encoder , args .cache_dir )
95+ processor = ExtSumProcessor ( model_name = args . model_name )
96+ summarizer = ExtractiveSummarizer (processor , args .model_name , args .encoder , args . max_pos_length , args .cache_dir )
9197
9298 mp .spawn (main_worker , nprocs = ngpus_per_node , args = (ngpus_per_node , summarizer , args ))
9399
@@ -109,9 +115,17 @@ def main_worker(local_rank, ngpus_per_node, summarizer, args):
109115 rank = rank ,
110116 )
111117
112- train_dataset , test_dataset = ExtSumProcessedData ().splits (root = args .data_dir )
118+ if local_rank not in [- 1 , 0 ]:
119+ torch .distributed .barrier ()
120+
121+ download_path = CNNDMBertSumProcessedData .download (local_path = args .data_dir )
122+ train_dataset , test_dataset = ExtSumProcessedData ().splits (root = args .data_dir , train_iterable = True )
123+
124+ if local_rank in [- 1 , 0 ]:
125+ torch .distributed .barrier ()
126+
113127 # total number of steps for training
114- MAX_STEPS = 1e3
128+ MAX_STEPS = 1e2
115129 # number of steps for warm up
116130 WARMUP_STEPS = 5e2
117131 if args .quick_run .lower () == "false" :
@@ -137,17 +151,18 @@ def main_worker(local_rank, ngpus_per_node, summarizer, args):
137151 verbose = True ,
138152 report_every = REPORT_EVERY ,
139153 clip_grad_norm = False ,
140- local_rank = rank ,
154+ local_rank = local_rank ,
141155 save_every = save_every ,
142- world_size = world_size
156+ world_size = world_size ,
157+ rank = rank ,
158+ use_preprocessed_data = True
143159 )
144160
145161 end = time .time ()
146162 print ("rank {0}, duration {1:.6f}s" .format (rank , end - start ))
147- if rank in [ - 1 , 0 ]:
148- summarizer . save_model ( os . path . join ( args . output_dir , args . model_filename ))
163+ torch . distributed . barrier ()
164+ if local_rank in [ - 1 , 0 ]:
149165 prediction = summarizer .predict (test_dataset , num_gpus = ngpus_per_node , batch_size = 128 )
150-
151166 def _write_list_to_file (list_items , filename ):
152167 with open (filename , "w" ) as filehandle :
153168 # for cnt, line in enumerate(filehandle):
0 commit comments