Skip to content

Commit 3407cad

Browse files
committed
add
1 parent abd525c commit 3407cad

File tree

2 files changed

+56
-30
lines changed

2 files changed

+56
-30
lines changed

paddlenlp/mergekit/merge_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class MergeConfig:
3636
merge_type: str = field(default="linear", metadata={"help": "The type of merge process."})
3737
sparsify_type: str = field(default=None, metadata={"help": "The type of sparsify process."})
3838
split_pieces: int = field(default=8, metadata={"help": "Split large tensor to multi-piece"})
39-
max_tensor_mem: float = field(default=1, metadata={"help": "Split tensor if exceed setting max_tensor_mem."})
39+
max_tensor_mem: float = field(default=0.5, metadata={"help": "Split tensor if exceed setting max_tensor_mem."})
4040

4141
# Model parameters
4242
model_path_list: Optional[List[str]] = field(default=None, metadata={"help": "Merge model name or path list"})

paddlenlp/mergekit/merge_model.py

Lines changed: 55 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -155,36 +155,62 @@ def merge_mix_model(self, file_type_list):
155155
)
156156
for key in local_keys:
157157
# Tensor preprocess
158-
is_bf16 = False
159-
tensor_list = []
160-
for i in range(model_num):
161-
if self.merge_config.tensor_type == "np" and str(state_dict_list[i][key].dtype) == "uint16":
162-
is_bf16 = True
163-
state_dict_list[i][key] = (
164-
paddle.Tensor(state_dict_list[i][key], zero_copy=True).astype("float32").numpy()
165-
)
166-
elif self.merge_config.tensor_type == "pd":
167-
state_dict_list[i][key] = paddle.Tensor(state_dict_list[i][key], zero_copy=True)
168-
if i == 0:
169-
tensor_dtype = state_dict_list[i][key].dtype
170-
# Using float32 to reduce precision loss
171-
state_dict_list[i][key] = state_dict_list[i][key].astype("float32")
172-
tensor_list.append(state_dict_list[i].pop(key))
173-
174-
# Tensor merge
175-
if self.merge_config.base_model_path is not None:
176-
base_tensor = tensor_list.pop()
177-
tensor_list = [tensor - base_tensor for tensor in tensor_list]
178-
merge_state_dict[key] = self.merge_method.merge(tensor_list)
179-
if self.merge_config.base_model_path is not None:
180-
merge_state_dict[key] += base_tensor
158+
is_bf16 = str(state_dict_list[0][key].dtype) == "uint16"
159+
tensor_list = [state_dict_list[i].pop(key) for i in range(model_num)]
160+
tensor_mem = int(np.prod(tensor_list[0].shape) * self.numpy_dtype_map[str(tensor_list[0].dtype)]) / (
161+
1024**3
162+
)
163+
if self.merge_config.tensor_type == "pd" and tensor_mem > self.merge_config.max_tensor_mem:
164+
tensor_split_list = [
165+
np.array_split(tensor, self.merge_config.split_pieces, axis=0) for tensor in tensor_list
166+
]
167+
merge_split = []
168+
for sp in range(self.merge_config.split_pieces):
169+
tensor_list = [tensor_split[sp] for tensor_split in tensor_split_list]
170+
if is_bf16:
171+
tensor_list = [
172+
paddle.Tensor(tensor, zero_copy=True).astype("float32") for tensor in tensor_list
173+
]
174+
else:
175+
tensor_list = [paddle.Tensor(tensor, zero_copy=True) for tensor in tensor_list]
176+
if self.merge_config.base_model_path is not None:
177+
base_tensor = tensor_list.pop()
178+
tensor_list = [tensor - base_tensor for tensor in tensor_list]
179+
merge_tensor = self.merge_method.merge(tensor_list)
180+
if self.merge_config.base_model_path is not None:
181+
merge_tensor += base_tensor
182+
if is_bf16:
183+
merge_split.append(merge_tensor.astype("bfloat16").numpy())
184+
else:
185+
merge_split.append(merge_tensor.numpy())
186+
merge_state_dict[key] = np.concatenate(merge_split, axis=0)
187+
else:
188+
if self.merge_config.tensor_type == "pd":
189+
if is_bf16:
190+
tensor_list = [
191+
paddle.Tensor(tensor, zero_copy=True).astype("float32") for tensor in tensor_list
192+
]
193+
else:
194+
tensor_list = [paddle.Tensor(tensor, zero_copy=True) for tensor in tensor_list]
195+
elif self.merge_config.tensor_type == "np" and is_bf16:
196+
tensor_list = [
197+
paddle.Tensor(tensor, zero_copy=True).astype("float32").numpy() for tensor in tensor_list
198+
]
181199

182-
# Tensor postprocess
183-
# dtype==bfloat16: numpy(float32) -> paddle(float32) -> paddle(bfloat16) -> numpy(uint16)
184-
if self.merge_config.tensor_type == "np" and is_bf16:
185-
merge_state_dict[key] = paddle.Tensor(merge_state_dict[key], zero_copy=True).astype("bfloat16").numpy()
186-
elif self.merge_config.tensor_type == "pd":
187-
merge_state_dict[key] = merge_state_dict[key].astype(tensor_dtype).numpy()
200+
if self.merge_config.base_model_path is not None:
201+
base_tensor = tensor_list.pop()
202+
tensor_list = [tensor - base_tensor for tensor in tensor_list]
203+
merge_tensor = self.merge_method.merge(tensor_list)
204+
if self.merge_config.base_model_path is not None:
205+
merge_tensor += base_tensor
206+
if self.merge_config.tensor_type == "pd":
207+
if is_bf16:
208+
merge_state_dict[key] = merge_tensor.astype("bfloat16").numpy()
209+
else:
210+
merge_state_dict[key] = merge_tensor.numpy()
211+
elif self.merge_config.tensor_type == "np" and is_bf16:
212+
# dtype==bfloat16: numpy(float32) -> paddle(float32) -> paddle(bfloat16) -> numpy(uint16)
213+
merge_state_dict[key] = paddle.Tensor(merge_tensor, zero_copy=True).astype("bfloat16").numpy()
188214

189215
# Save safetensor file
190216
save_file(

0 commit comments

Comments
 (0)