Skip to content
This repository was archived by the owner on Nov 16, 2023. It is now read-only.

Commit fedb34f

Browse files
authored
Merge pull request #573 from microsoft/daden/bertsumext
Daden/bertsumext improvement of bertsum extractive summarization
2 parents 4372ba7 + 03948df commit fedb34f

15 files changed

+2257
-669
lines changed

examples/text_summarization/extractive_summarization_cnndm_distributed_train.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
sys.path.insert(0, nlp_path)
1515

1616

17-
from utils_nlp.dataset.cnndm import CNNDMSummarizationDataset
17+
from utils_nlp.dataset.cnndm import CNNDMBertSumProcessedData, CNNDMSummarizationDataset
1818
from utils_nlp.models.transformers.extractive_summarization import (
1919
ExtractiveSummarizer,
2020
ExtSumProcessedData,
@@ -47,6 +47,12 @@
4747
parser.add_argument("--encoder", type=str.lower, default='transformer',
4848
choices=['baseline', 'classifier', 'transformer', 'rnn'],
4949
help="Encoder types in the extractive summarizer.")
50+
parser.add_argument(
51+
"--max_pos_length",
52+
type=int,
53+
default=512,
54+
help="maximum input length in terms of input token numbers in training",
55+
)
5056
parser.add_argument("--learning_rate", type=float, default=1e-3,
5157
help="Learning rate.")
5258
parser.add_argument("--batch_size", type=int, default=3000,
@@ -86,8 +92,8 @@ def main():
8692
os.makedirs(args.cache_dir, exist_ok=True)
8793

8894
ngpus_per_node = torch.cuda.device_count()
89-
90-
summarizer = ExtractiveSummarizer(args.model_name, args.encoder, args.cache_dir)
95+
processor = ExtSumProcessor(model_name=args.model_name)
96+
summarizer = ExtractiveSummarizer(processor, args.model_name, args.encoder, args.max_pos_length, args.cache_dir)
9197

9298
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, summarizer, args))
9399

@@ -109,9 +115,17 @@ def main_worker(local_rank, ngpus_per_node, summarizer, args):
109115
rank=rank,
110116
)
111117

112-
train_dataset, test_dataset = ExtSumProcessedData().splits(root=args.data_dir)
118+
if local_rank not in [-1, 0]:
119+
torch.distributed.barrier()
120+
121+
download_path = CNNDMBertSumProcessedData.download(local_path=args.data_dir)
122+
train_dataset, test_dataset = ExtSumProcessedData().splits(root=args.data_dir, train_iterable=True)
123+
124+
if local_rank in [-1, 0]:
125+
torch.distributed.barrier()
126+
113127
# total number of steps for training
114-
MAX_STEPS = 1e3
128+
MAX_STEPS = 1e2
115129
# number of steps for warm up
116130
WARMUP_STEPS = 5e2
117131
if args.quick_run.lower() == "false":
@@ -137,17 +151,18 @@ def main_worker(local_rank, ngpus_per_node, summarizer, args):
137151
verbose=True,
138152
report_every=REPORT_EVERY,
139153
clip_grad_norm=False,
140-
local_rank=rank,
154+
local_rank=local_rank,
141155
save_every=save_every,
142-
world_size=world_size
156+
world_size=world_size,
157+
rank=rank,
158+
use_preprocessed_data=True
143159
)
144160

145161
end = time.time()
146162
print("rank {0}, duration {1:.6f}s".format(rank, end - start))
147-
if rank in [-1, 0]:
148-
summarizer.save_model(os.path.join(args.output_dir, args.model_filename))
163+
torch.distributed.barrier()
164+
if local_rank in [-1, 0]:
149165
prediction = summarizer.predict(test_dataset, num_gpus=ngpus_per_node, batch_size=128)
150-
151166
def _write_list_to_file(list_items, filename):
152167
with open(filename, "w") as filehandle:
153168
# for cnt, line in enumerate(filehandle):

0 commit comments

Comments
 (0)