@@ -155,36 +155,62 @@ def merge_mix_model(self, file_type_list):
155155 )
156156 for key in local_keys :
157157 # Tensor preprocess
158- is_bf16 = False
159- tensor_list = []
160- for i in range (model_num ):
161- if self .merge_config .tensor_type == "np" and str (state_dict_list [i ][key ].dtype ) == "uint16" :
162- is_bf16 = True
163- state_dict_list [i ][key ] = (
164- paddle .Tensor (state_dict_list [i ][key ], zero_copy = True ).astype ("float32" ).numpy ()
165- )
166- elif self .merge_config .tensor_type == "pd" :
167- state_dict_list [i ][key ] = paddle .Tensor (state_dict_list [i ][key ], zero_copy = True )
168- if i == 0 :
169- tensor_dtype = state_dict_list [i ][key ].dtype
170- # Using float32 to reduce precision loss
171- state_dict_list [i ][key ] = state_dict_list [i ][key ].astype ("float32" )
172- tensor_list .append (state_dict_list [i ].pop (key ))
173-
174- # Tensor merge
175- if self .merge_config .base_model_path is not None :
176- base_tensor = tensor_list .pop ()
177- tensor_list = [tensor - base_tensor for tensor in tensor_list ]
178- merge_state_dict [key ] = self .merge_method .merge (tensor_list )
179- if self .merge_config .base_model_path is not None :
180- merge_state_dict [key ] += base_tensor
158+ is_bf16 = str (state_dict_list [0 ][key ].dtype ) == "uint16"
159+ tensor_list = [state_dict_list [i ].pop (key ) for i in range (model_num )]
160+ tensor_mem = int (np .prod (tensor_list [0 ].shape ) * self .numpy_dtype_map [str (tensor_list [0 ].dtype )]) / (
161+ 1024 ** 3
162+ )
163+ if self .merge_config .tensor_type == "pd" and tensor_mem > self .merge_config .max_tensor_mem :
164+ tensor_split_list = [
165+ np .array_split (tensor , self .merge_config .split_pieces , axis = 0 ) for tensor in tensor_list
166+ ]
167+ merge_split = []
168+ for sp in range (self .merge_config .split_pieces ):
169+ tensor_list = [tensor_split [sp ] for tensor_split in tensor_split_list ]
170+ if is_bf16 :
171+ tensor_list = [
172+ paddle .Tensor (tensor , zero_copy = True ).astype ("float32" ) for tensor in tensor_list
173+ ]
174+ else :
175+ tensor_list = [paddle .Tensor (tensor , zero_copy = True ) for tensor in tensor_list ]
176+ if self .merge_config .base_model_path is not None :
177+ base_tensor = tensor_list .pop ()
178+ tensor_list = [tensor - base_tensor for tensor in tensor_list ]
179+ merge_tensor = self .merge_method .merge (tensor_list )
180+ if self .merge_config .base_model_path is not None :
181+ merge_tensor += base_tensor
182+ if is_bf16 :
183+ merge_split .append (merge_tensor .astype ("bfloat16" ).numpy ())
184+ else :
185+ merge_split .append (merge_tensor .numpy ())
186+ merge_state_dict [key ] = np .concatenate (merge_split , axis = 0 )
187+ else :
188+ if self .merge_config .tensor_type == "pd" :
189+ if is_bf16 :
190+ tensor_list = [
191+ paddle .Tensor (tensor , zero_copy = True ).astype ("float32" ) for tensor in tensor_list
192+ ]
193+ else :
194+ tensor_list = [paddle .Tensor (tensor , zero_copy = True ) for tensor in tensor_list ]
195+ elif self .merge_config .tensor_type == "np" and is_bf16 :
196+ tensor_list = [
197+ paddle .Tensor (tensor , zero_copy = True ).astype ("float32" ).numpy () for tensor in tensor_list
198+ ]
181199
182- # Tensor postprocess
183- # dtype==bfloat16: numpy(float32) -> paddle(float32) -> paddle(bfloat16) -> numpy(uint16)
184- if self .merge_config .tensor_type == "np" and is_bf16 :
185- merge_state_dict [key ] = paddle .Tensor (merge_state_dict [key ], zero_copy = True ).astype ("bfloat16" ).numpy ()
186- elif self .merge_config .tensor_type == "pd" :
187- merge_state_dict [key ] = merge_state_dict [key ].astype (tensor_dtype ).numpy ()
200+ if self .merge_config .base_model_path is not None :
201+ base_tensor = tensor_list .pop ()
202+ tensor_list = [tensor - base_tensor for tensor in tensor_list ]
203+ merge_tensor = self .merge_method .merge (tensor_list )
204+ if self .merge_config .base_model_path is not None :
205+ merge_tensor += base_tensor
206+ if self .merge_config .tensor_type == "pd" :
207+ if is_bf16 :
208+ merge_state_dict [key ] = merge_tensor .astype ("bfloat16" ).numpy ()
209+ else :
210+ merge_state_dict [key ] = merge_tensor .numpy ()
211+ elif self .merge_config .tensor_type == "np" and is_bf16 :
212+ # dtype==bfloat16: numpy(float32) -> paddle(float32) -> paddle(bfloat16) -> numpy(uint16)
213+ merge_state_dict [key ] = paddle .Tensor (merge_tensor , zero_copy = True ).astype ("bfloat16" ).numpy ()
188214
189215 # Save safetensor file
190216 save_file (
0 commit comments