Skip to content
This repository was archived by the owner on Nov 16, 2023. It is now read-only.

Commit e0a2d6a

Browse files
author
Emmanuel Awa
committed
remove bert initialization into MTDNNModel
1 parent b09a25a commit e0a2d6a

File tree

4 files changed

+83
-34
lines changed

4 files changed

+83
-34
lines changed

examples/text_classification/test.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
1+
from utils_nlp.models.mtdnn.common.types import EncoderModelType
12
from utils_nlp.models.mtdnn.configuration_mtdnn import MTDNNConfig
23
from utils_nlp.models.mtdnn.modeling_mtdnn import MTDNNModel
34

45
if __name__ == "__main__":
56
config = MTDNNConfig()
67
b = MTDNNModel(config)
8+
print("Network: ", b.network)
79
print("Config Class: ", b.config_class)
810
print("Config: ", b.config)
9-
print("Embeddings: ", b.embeddings)
10-
print("Encoding: ", b.encoder)
1111
print("Pooler: ", b.pooler)
12+
13+
if config.encoder_type == EncoderModelType.BERT:
14+
print("Encoding: ", b.encoder)
15+
print("Embeddings: ", b.embeddings)
16+
print("Bert Config: ", b.bert_config)
17+
1218
print("Archive Map: ", b.pretrained_model_archive_map)
1319
print("Base Model Prefix: ", b.base_model_prefix)
14-
print("Bert Config: ", b.bert_config)

utils_nlp/models/mtdnn/common/san.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -132,22 +132,12 @@ class SANNetwork(nn.Module):
132132
https://arxiv.org/abs/1804.07888
133133
"""
134134

135-
def __init__(self, config: MTDNNConfig):
135+
def __init__(self, config: MTDNNConfig, pooler):
136136
super(SANNetwork, self).__init__()
137137
self.config = config
138138
self.dropout_list = nn.ModuleList()
139139
self.encoder_type = config.encoder_type
140-
141-
# Setup the baseline network
142-
# Define the encoder based on config options
143-
self.bert_config = BertConfig.from_dict(self.config.to_dict())
144-
self.bert = BertModel(self.bert_config)
145-
self.hidden_size = self.bert_config.hidden_size
146-
147-
if self.encoder_type == EncoderModelType.ROBERTA:
148-
self.bert = FairseqRobertModel.from_pretrained(config.init_checkpoint)
149-
self.hidden_size = self.bert.args.encoder_embed_dim
150-
self.pooler = LinearPooler(self.hidden_size)
140+
self.pooler = pooler
151141

152142
# Dump other features if value is set to true
153143
if config.dump_feature:
@@ -233,7 +223,8 @@ def forward(
233223
logits = self.scoring_list[task_id](pooled_output)
234224
return logits
235225

236-
def generate_scoring_options(self):
226+
# TODO - Move to training step
227+
def generate_tasks_scoring_options(self):
237228
""" Enumerate over tasks and setup of decoding and scoring list for training """
238229
assert len(self.tasks_nclass_list) > 0, "Number of classes to train for cannot be 0"
239230
for idx, task_num_labels in enumerate(self.tasks_nclass_list):

utils_nlp/models/mtdnn/configuration_mtdnn.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
"""MTDNN model configuration"""
1515

16+
encoder_checkpoint_map = {1: "bert", 2: "roberta"}
17+
1618

1719
class MTDNNConfig(PretrainedConfig):
1820
r"""
@@ -49,7 +51,8 @@ class MTDNNConfig(PretrainedConfig):
4951

5052
def __init__(
5153
self,
52-
encoder_type=EncoderModelType.BERT,
54+
use_pretrained_model=False,
55+
encoder_type=EncoderModelType.ROBERTA,
5356
vocab_size=30522,
5457
hidden_size=768,
5558
num_hidden_layers=12,
@@ -70,7 +73,7 @@ def __init__(
7073
tasks_dropout_p=[],
7174
enable_variational_dropout=True,
7275
init_ratio=1.0,
73-
init_checkpoint="bert-base-uncased",
76+
init_checkpoint="roberta.base",
7477
# Training config
7578
cuda=torch.cuda.is_available(),
7679
multi_gpu_on=False,
@@ -89,6 +92,7 @@ def __init__(
8992
warmup=0.1,
9093
warmup_schedule="warmup_linear",
9194
adam_eps=1e-6,
95+
pooler=None,
9296
# Scheduler config
9397
have_lr_scheduler=True,
9498
multi_step_lr="10,20,30",
@@ -109,7 +113,16 @@ def __init__(
109113
weighted_on=False,
110114
**kwargs,
111115
):
116+
# basic Configuration validation
117+
# assert inital checkpoint and encoder type are same
118+
assert init_checkpoint.startswith(
119+
encoder_checkpoint_map[encoder_type]
120+
), """Encoder type and initial checkpoint mismatch.
121+
1 - Bert models
122+
2 - Roberta models
123+
"""
112124
super(MTDNNConfig, self).__init__(**kwargs)
125+
self.use_pretrained_model = use_pretrained_model
113126
self.encoder_type = encoder_type
114127
self.vocab_size = vocab_size
115128
self.hidden_size = hidden_size
@@ -148,6 +161,7 @@ def __init__(
148161
self.momentum = momentum
149162
self.warmup = warmup
150163
self.warmup_schedule = warmup_schedule
164+
self.pooler = pooler
151165
self.adam_eps = adam_eps
152166
self.have_lr_scheduler = have_lr_scheduler
153167
self.multi_step_lr = multi_step_lr

utils_nlp/models/mtdnn/modeling_mtdnn.py

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,14 @@
4141
logger = logging.getLogger(__name__)
4242

4343

44-
class MTDNNPretrainedModel(BertPreTrainedModel):
44+
class MTDNNPretrainedModel(nn.Module):
4545
config_class = MTDNNConfig
4646
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
4747
load_tf_weights = lambda model, config, path: None
4848
base_model_prefix = "mtdnn"
4949

5050
def __init__(self, config):
51-
super(MTDNNPretrainedModel, self).__init__(config)
51+
super(MTDNNPretrainedModel, self).__init__()
5252
if not isinstance(config, PretrainedConfig):
5353
raise ValueError(
5454
"Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
@@ -61,43 +61,69 @@ def __init__(self, config):
6161
self.config = config
6262

6363

64-
class MTDNNModel(MTDNNPretrainedModel, BertModel):
64+
class MTDNNModel(MTDNNPretrainedModel):
6565
def __init__(
6666
self,
6767
config: MTDNNConfig,
6868
pretrained_model_name: str = "mtdnn-base-uncased",
6969
num_train_step: int = -1,
7070
):
71+
assert (
72+
config.init_checkpoint in self.supported_init_checkpoints()
73+
), f"Initial checkpoint must be in {self.supported_init_checkpoints()}"
7174
super(MTDNNModel, self).__init__(config)
7275
self.config = config
7376

74-
# Set the config base on encoder type set for initial checkpoint
75-
76-
# Download pretrained model
77+
# Setup the baseline network
78+
# - Define the encoder based on config options
79+
# - Set state dictionary based on configuration setting
80+
# - Download pretrained model if flag is set
7781
# TODO - Use Model.pretrained_model() after configuration file is hosted.
78-
with download_path() as file_path:
79-
path = pathlib.Path(file_path)
80-
self.local_model_path = maybe_download(
81-
url=self.pretrained_model_archive_map[pretrained_model_name]
82-
)
83-
self.mtdnn_model = MTDNNCommonUtils.load_pytorch_model(self.local_model_path)
82+
if self.config.use_pretrained_model:
83+
with download_path() as file_path:
84+
path = pathlib.Path(file_path)
85+
self.local_model_path = maybe_download(
86+
url=self.pretrained_model_archive_map[pretrained_model_name]
87+
)
88+
self.mtdnn_model = MTDNNCommonUtils.load_pytorch_model(self.local_model_path)
89+
self.state_dict = self.mtdnn_model["state"]
90+
else:
91+
# Set the config base on encoder type set for initial checkpoint
92+
if config.encoder_type == EncoderModelType.BERT:
93+
self.bert_config = BertConfig.from_dict(self.config.to_dict())
94+
self.bert_model = BertModel(self.bert_config)
95+
self.state_dict = self.bert_model.state_dict()
96+
if config.encoder_type == EncoderModelType.ROBERTA:
97+
# Download and extract from PyTorch hub if not downloaded before
98+
self.bert_model = torch.hub.load("pytorch/fairseq", config.init_checkpoint)
99+
self.config.hidden_size = self.bert_model.args.encoder_embed_dim
100+
self.pooler = LinearPooler(self.config.hidden_size)
101+
new_state_dict = {}
102+
for key, val in self.bert_model.state_dict().items():
103+
if key.startswith("model.decoder.sentence_encoder") or key.startswith(
104+
"model.classification_heads"
105+
):
106+
key = f"bert.{key}"
107+
new_state_dict[key] = val
108+
# backward compatibility PyTorch <= 1.0.0
109+
if key.startswith("classification_heads"):
110+
key = f"bert.model.{key}"
111+
new_state_dict[key] = val
112+
self.state_dict = new_state_dict
84113

85-
self.state_dict = self.mtdnn_model["state"]
86114
self.updates = (
87115
self.state_dict["updates"] if self.state_dict and "updates" in self.state_dict else 0
88116
)
89117
self.local_updates = 0
90118
self.train_loss = AverageMeter()
91-
self.network = SANNetwork(self.config)
119+
self.network = SANNetwork(self.config, self.pooler)
92120
if self.state_dict:
93121
self.network.load_state_dict(self.state_dict, strict=False)
94122
self.mnetwork = nn.DataParallel(self.network) if self.config.multi_gpu_on else self.network
95123
self.total_param = sum([p.nelement() for p in self.network.parameters() if p.requires_grad])
96124

97125
# Move network to GPU if device available and flag set
98-
print(f" =======> Can move to cuda {self.config.cuda} and {torch.cuda.is_available()}")
99126
if self.config.cuda:
100-
print(" =======> Moving to cuda")
101127
self.network.cuda()
102128
self.optimizer_parameters = self._get_param_groups()
103129
self._setup_optim(self.optimizer_parameters, self.state_dict, num_train_step)
@@ -383,3 +409,16 @@ def load(self, checkpoint):
383409

384410
def cuda(self):
385411
self.network.cuda()
412+
413+
def supported_init_checkpoints(self):
414+
"""List of allowed check points
415+
"""
416+
return [
417+
"bert-base-uncased",
418+
"bert-base-cased",
419+
"bert-large-uncased",
420+
"mtdnn-base-uncased",
421+
"mtdnn-large-uncased",
422+
"roberta.base",
423+
"roberta.large",
424+
]

0 commit comments

Comments
 (0)