4141logger = logging .getLogger (__name__ )
4242
4343
44- class MTDNNPretrainedModel (BertPreTrainedModel ):
44+ class MTDNNPretrainedModel (nn . Module ):
4545 config_class = MTDNNConfig
4646 pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
4747 load_tf_weights = lambda model , config , path : None
4848 base_model_prefix = "mtdnn"
4949
5050 def __init__ (self , config ):
51- super (MTDNNPretrainedModel , self ).__init__ (config )
51+ super (MTDNNPretrainedModel , self ).__init__ ()
5252 if not isinstance (config , PretrainedConfig ):
5353 raise ValueError (
5454 "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
@@ -61,43 +61,69 @@ def __init__(self, config):
6161 self .config = config
6262
6363
64- class MTDNNModel (MTDNNPretrainedModel , BertModel ):
64+ class MTDNNModel (MTDNNPretrainedModel ):
6565 def __init__ (
6666 self ,
6767 config : MTDNNConfig ,
6868 pretrained_model_name : str = "mtdnn-base-uncased" ,
6969 num_train_step : int = - 1 ,
7070 ):
71+ assert (
72+ config .init_checkpoint in self .supported_init_checkpoints ()
73+ ), f"Initial checkpoint must be in { self .supported_init_checkpoints ()} "
7174 super (MTDNNModel , self ).__init__ (config )
7275 self .config = config
7376
74- # Set the config base on encoder type set for initial checkpoint
75-
76- # Download pretrained model
77+ # Setup the baseline network
78+ # - Define the encoder based on config options
79+ # - Set state dictionary based on configuration setting
80+ # - Download pretrained model if flag is set
7781 # TODO - Use Model.pretrained_model() after configuration file is hosted.
78- with download_path () as file_path :
79- path = pathlib .Path (file_path )
80- self .local_model_path = maybe_download (
81- url = self .pretrained_model_archive_map [pretrained_model_name ]
82- )
83- self .mtdnn_model = MTDNNCommonUtils .load_pytorch_model (self .local_model_path )
82+ if self .config .use_pretrained_model :
83+ with download_path () as file_path :
84+ path = pathlib .Path (file_path )
85+ self .local_model_path = maybe_download (
86+ url = self .pretrained_model_archive_map [pretrained_model_name ]
87+ )
88+ self .mtdnn_model = MTDNNCommonUtils .load_pytorch_model (self .local_model_path )
89+ self .state_dict = self .mtdnn_model ["state" ]
90+ else :
91+ # Set the config base on encoder type set for initial checkpoint
92+ if config .encoder_type == EncoderModelType .BERT :
93+ self .bert_config = BertConfig .from_dict (self .config .to_dict ())
94+ self .bert_model = BertModel (self .bert_config )
95+ self .state_dict = self .bert_model .state_dict ()
96+ if config .encoder_type == EncoderModelType .ROBERTA :
97+ # Download and extract from PyTorch hub if not downloaded before
98+ self .bert_model = torch .hub .load ("pytorch/fairseq" , config .init_checkpoint )
99+ self .config .hidden_size = self .bert_model .args .encoder_embed_dim
100+ self .pooler = LinearPooler (self .config .hidden_size )
101+ new_state_dict = {}
102+ for key , val in self .bert_model .state_dict ().items ():
103+ if key .startswith ("model.decoder.sentence_encoder" ) or key .startswith (
104+ "model.classification_heads"
105+ ):
106+ key = f"bert.{ key } "
107+ new_state_dict [key ] = val
108+ # backward compatibility PyTorch <= 1.0.0
109+ if key .startswith ("classification_heads" ):
110+ key = f"bert.model.{ key } "
111+ new_state_dict [key ] = val
112+ self .state_dict = new_state_dict
84113
85- self .state_dict = self .mtdnn_model ["state" ]
86114 self .updates = (
87115 self .state_dict ["updates" ] if self .state_dict and "updates" in self .state_dict else 0
88116 )
89117 self .local_updates = 0
90118 self .train_loss = AverageMeter ()
91- self .network = SANNetwork (self .config )
119+ self .network = SANNetwork (self .config , self . pooler )
92120 if self .state_dict :
93121 self .network .load_state_dict (self .state_dict , strict = False )
94122 self .mnetwork = nn .DataParallel (self .network ) if self .config .multi_gpu_on else self .network
95123 self .total_param = sum ([p .nelement () for p in self .network .parameters () if p .requires_grad ])
96124
97125 # Move network to GPU if device available and flag set
98- print (f" =======> Can move to cuda { self .config .cuda } and { torch .cuda .is_available ()} " )
99126 if self .config .cuda :
100- print (" =======> Moving to cuda" )
101127 self .network .cuda ()
102128 self .optimizer_parameters = self ._get_param_groups ()
103129 self ._setup_optim (self .optimizer_parameters , self .state_dict , num_train_step )
@@ -383,3 +409,16 @@ def load(self, checkpoint):
383409
384410 def cuda (self ):
385411 self .network .cuda ()
412+
413+ def supported_init_checkpoints (self ):
414+ """List of allowed check points
415+ """
416+ return [
417+ "bert-base-uncased" ,
418+ "bert-base-cased" ,
419+ "bert-large-uncased" ,
420+ "mtdnn-base-uncased" ,
421+ "mtdnn-large-uncased" ,
422+ "roberta.base" ,
423+ "roberta.large" ,
424+ ]
0 commit comments