Skip to content
This repository was archived by the owner on Nov 16, 2023. It is now read-only.

Commit 806d5fb

Browse files
authored
Merge pull request #574 from microsoft/daden/bertsumext
Daden/bertsumext Add test for distributed training for BertSum
2 parents fedb34f + 6117e7b commit 806d5fb

17 files changed

+877
-2432
lines changed

examples/text_summarization/abstractive_summarization_bertsum_cnndm_distributed_train.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,19 @@
33

44
import argparse
55
import os
6-
import pickle
76
import sys
87
import time
98
import torch
109
import torch.distributed as dist
1110
import torch.multiprocessing as mp
1211

1312
# torch.set_printoptions(threshold=5000)
14-
from tempfile import TemporaryDirectory
1513

1614
nlp_path = os.path.abspath("../../")
1715
if nlp_path not in sys.path:
1816
sys.path.insert(0, nlp_path)
1917

18+
sys.path.insert(0, "./")
2019

2120
from utils_nlp.models.transformers.abstractive_summarization_bertsum import (
2221
BertSumAbs,
@@ -40,7 +39,7 @@
4039
parser.add_argument(
4140
"--dist_url",
4241
type=str,
43-
default="tcp://127.0.0.1:29500",
42+
default="tcp://127.0.0.1:29507",
4443
help="URL specifying how to initialize the process groupi.",
4544
)
4645
parser.add_argument(
@@ -56,7 +55,7 @@
5655
parser.add_argument(
5756
"--data_dir",
5857
type=str,
59-
default="./",
58+
default="./abstemp",
6059
help="Directory to download the preprocessed data.",
6160
)
6261
parser.add_argument(
@@ -101,8 +100,8 @@
101100
"--max_steps",
102101
type=int,
103102
default=5e4,
104-
help="Maximum number of training steps run in training. \
105-
If quick_run is set, it's not used.",
103+
help="""Maximum number of training steps run in training.
104+
If quick_run is set, it's not used.""",
106105
)
107106
parser.add_argument(
108107
"--warmup_steps_bert",
@@ -176,8 +175,11 @@ def main():
176175
print("data_dir is {}".format(args.data_dir))
177176
print("cache_dir is {}".format(args.cache_dir))
178177

178+
TOP_N = -1
179+
if args.quick_run.lower() == "false":
180+
TOP_N = 10
179181
train_dataset, test_dataset = CNNDMSummarizationDataset(
180-
top_n=-1, local_cache_path=args.data_dir, prepare_extractive=False
182+
top_n=TOP_N, local_cache_path=args.data_dir, prepare_extractive=False
181183
)
182184

183185
ngpus_per_node = torch.cuda.device_count()
@@ -212,6 +214,7 @@ def main_worker(
212214
checkpoint = os.path.join(args.cache_dir, args.checkpoint_filename)
213215
else:
214216
checkpoint = None
217+
215218
# train_sum_dataset, test_sum_dataset = load_processed_cnndm_abs(args.data_dir)
216219
def this_validate(class_obj):
217220
return validate(class_obj, test_dataset)
@@ -225,8 +228,8 @@ def this_validate(class_obj):
225228
fp16 = args.fp16.lower() == "true"
226229
print("fp16 is {}".format(fp16))
227230
# total number of steps for training
228-
MAX_STEPS = 50
229-
SAVE_EVERY = 50
231+
MAX_STEPS = 10
232+
SAVE_EVERY = 10
230233
REPORT_EVERY = 10
231234
# number of steps for warm up
232235
WARMUP_STEPS_BERT = MAX_STEPS
@@ -235,7 +238,7 @@ def this_validate(class_obj):
235238
MAX_STEPS = args.max_steps
236239
WARMUP_STEPS_BERT = args.warmup_steps_bert
237240
WARMUP_STEPS_DEC = args.warmup_steps_dec
238-
SAVE_EVERY = args.save_every
241+
SAVE_EVERY = save_every
239242
REPORT_EVERY = args.report_every
240243

241244
print("max steps is {}".format(MAX_STEPS))
@@ -266,14 +269,16 @@ def this_validate(class_obj):
266269

267270
end = time.time()
268271
print("rank {0}, duration {1:.6f}s".format(rank, end - start))
269-
if rank == 0 or local_rank == -1:
272+
if local_rank in [0, -1] and args.rank == 0:
273+
TOP_N = -1
274+
if args.quick_run.lower() == "false":
275+
TOP_N = ngpus_per_node
270276
saved_model_path = os.path.join(
271277
args.output_dir, "{}_step{}".format(args.model_filename, MAX_STEPS)
272278
)
273279
summarizer.save_model(MAX_STEPS, saved_model_path)
274-
top_n = 8
275280
prediction = summarizer.predict(
276-
test_dataset.shorten(top_n=top_n), batch_size=4, num_gpus=ngpus_per_node
281+
test_dataset.shorten(top_n=TOP_N), batch_size=ngpus_per_node, num_gpus=ngpus_per_node
277282
)
278283
print(prediction[0])
279284

0 commit comments

Comments
 (0)