4242
4343
4444def merge_splited_param (
45- state_dict , partial_tensor_list , param_shape_info , send_table , recv_table , is_master_weights = False
45+ state_dict ,
46+ partial_tensor_list ,
47+ param_shape_info ,
48+ send_table ,
49+ recv_table ,
50+ is_master_weights = False ,
51+ ckpt_quant_stage = "O0" ,
4652):
4753 """Merge the splited param in sharding group."""
4854 global_rank = dist .get_rank ()
4955 for key in list (state_dict .keys ()):
50- if state_dict [key ].numel (). item ( ) == 1 : # for example: beta1, beta2
56+ if int ( state_dict [key ].numel ()) == 1 : # for example: beta1, beta2
5157 continue
5258
5359 static_name = key if is_master_weights else generate_base_static_name (key )[0 ]
@@ -89,10 +95,21 @@ def merge_splited_param(
8995 )
9096 dist .stream .send (tensor , dst = recv_rank )
9197 state_dict .pop (key )
98+
99+ if ckpt_quant_stage != "O0" :
100+ for key in list (state_dict .keys ()):
101+ if int (state_dict [key ].numel ()) == 1 : # for example: beta1, beta2
102+ static_name = key if is_master_weights else generate_base_static_name (key )[0 ]
103+ if static_name in partial_tensor_list :
104+ recv_rank = recv_table [static_name ]
105+ send_info = send_table [static_name ]
106+ if global_rank != recv_rank :
107+ state_dict .pop (key )
108+
92109 return state_dict
93110
94111
95- def gather_splited_param_for_optimizer (optimizer ):
112+ def gather_splited_param_for_optimizer (optimizer , ckpt_quant_stage = "O0" ):
96113 hcg = fleet .get_hybrid_communicate_group ()
97114 sharding_group = hcg .get_sharding_parallel_group ()
98115 global_rank = dist .get_rank ()
@@ -127,7 +144,7 @@ def gather_splited_param_for_optimizer(optimizer):
127144 for key in list (optim_state_dict .keys ()):
128145 static_name , _ = generate_base_static_name (key )
129146 if static_name in param_slice_info .keys ():
130- if optim_state_dict [key ].numel (). item ( ) == 1 : # for example: beta1, beta2
147+ if int ( optim_state_dict [key ].numel ()) == 1 : # for example: beta1, beta2
131148 continue
132149 begin , end = param_slice_info [static_name ]
133150 shape , numel , _ , _ = param_shape_info [static_name ]
@@ -149,13 +166,17 @@ def gather_splited_param_for_optimizer(optimizer):
149166 recv_table [key ] = sharding_ranklist [0 ][0 ] # which sharding_rank to recv the splited tensor
150167 send_table [key ] = [(rank , begin , end ) for rank , begin , end in sharding_ranklist ]
151168
152- merge_splited_param (optim_state_dict , partial_tensor_list , param_shape_info , send_table , recv_table , False )
169+ merge_splited_param (
170+ optim_state_dict , partial_tensor_list , param_shape_info , send_table , recv_table , False , ckpt_quant_stage
171+ )
153172 if master_weights is not None :
154- merge_splited_param (master_weights , partial_tensor_list , param_shape_info , send_table , recv_table , True )
173+ merge_splited_param (
174+ master_weights , partial_tensor_list , param_shape_info , send_table , recv_table , True , ckpt_quant_stage
175+ )
155176 return optim_state_dict , master_weights
156177
157178
158- def load_unified_optimizer_split_param (args , model , optimizer , resume_from_checkpoint ):
179+ def load_unified_optimizer_split_param (args , model , optimizer , resume_from_checkpoint , ckpt_quant_stage = "O0" ):
159180 returned_optim_state_dict = nested_copy (optimizer .state_dict ())
160181
161182 index_filename , index_filename_master_weights = SAFE_OPTIMIZER_INDEX_NAME , SAFE_MASTER_WEIGHTS_INDEX_NAME
@@ -217,7 +238,9 @@ def load_unified_optimizer_split_param(args, model, optimizer, resume_from_check
217238 if len (resolved_archive_file_mw ) > 1 :
218239 resolved_archive_file_mw = tqdm (resolved_archive_file_mw , desc = "Loading master weights shards" )
219240
220- def load_resolved_archive_file (resolved_archive_file , sharded_metadata , expected_keys , is_master_weights = False ):
241+ def load_resolved_archive_file (
242+ resolved_archive_file , sharded_metadata , expected_keys , is_master_weights = False , ckpt_quant_stage = "O0"
243+ ):
221244 returned_state_dict = {}
222245
223246 if model .config .tensor_parallel_degree > 1 :
@@ -232,24 +255,38 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
232255 if expected_keys .isdisjoint (sharded_metadata ["file_map" ][os .path .split (shard_file )[- 1 ]]):
233256 continue
234257 if model .config .tensor_parallel_degree > 1 :
235- state_dict = load_state_dict (shard_file , tp_actions , expected_keys , device = "cpu" )
258+ state_dict = load_state_dict (
259+ shard_file ,
260+ tp_actions ,
261+ expected_keys ,
262+ device = "cpu" ,
263+ ckpt_quant_stage = ckpt_quant_stage ,
264+ )
236265 else :
237- state_dict = load_state_dict (shard_file , None , expected_keys , device = "cpu" )
266+ state_dict = load_state_dict (
267+ shard_file ,
268+ None ,
269+ expected_keys ,
270+ device = "cpu" ,
271+ ckpt_quant_stage = ckpt_quant_stage ,
272+ )
238273 returned_state_dict .update (state_dict )
239274 del state_dict
240275 gc .collect ()
241276
242277 return returned_state_dict
243278
244279 # get tp params
245- state_dict_optim = load_resolved_archive_file (resolved_archive_file , sharded_metadata , expected_keys_optim )
280+ state_dict_optim = load_resolved_archive_file (
281+ resolved_archive_file , sharded_metadata , expected_keys_optim , ckpt_quant_stage = ckpt_quant_stage
282+ )
246283
247284 # need to split param for different sharding rank, maybe need to deal with oom issue.
248285 for key in list (state_dict_optim .keys ()):
249286 key_name = key .split ("/" )
250287 static_name = struct2static_name_mappings .get (key_name [0 ], None )
251288
252- if state_dict_optim [key ].numel (). item ( ) > 1 :
289+ if int ( state_dict_optim [key ].numel ()) > 1 :
253290 begin , end = param_slice_info [static_name ]
254291 shape , numel , index , padded_size = param_shape_info [static_name ]
255292 state_dict_optim [key ] = state_dict_optim [key ].reshape ([- 1 ])
@@ -284,7 +321,7 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
284321
285322 for key in list (state_dict_master_weight .keys ()):
286323 static_name = struct2static_name_mappings .get (key , None )
287- if state_dict_master_weight [key ].numel (). item ( ) > 1 :
324+ if int ( state_dict_master_weight [key ].numel ()) > 1 :
288325 begin , end = param_slice_info [static_name ]
289326 shape , numel , index , padded_size = param_shape_info [static_name ]
290327 state_dict_master_weight [key ] = state_dict_master_weight [key ].reshape ([- 1 ])
0 commit comments