Skip to content

Commit 20be84b

Browse files
committed
poolish
1 parent 29b05b0 commit 20be84b

File tree

4 files changed

+36
-38
lines changed

4 files changed

+36
-38
lines changed

llm/auto_parallel/deepseek-v2/run_pretrain_auto.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -39,14 +39,14 @@
3939
AutoTokenizer,
4040
CosineAnnealingWithWarmupDecay,
4141
DeepseekV2Config,
42-
DeepseekV2ForCausalLMNet,
42+
DeepseekV2ForCausalLMAuto,
4343
DeepseekV2PretrainingCriterion,
4444
LinearAnnealingWithWarmupDecay,
4545
)
4646
from paddlenlp.utils.log import logger
4747

4848
MODEL_CLASSES = {
49-
"deepseekv2_network": (DeepseekV2Config, DeepseekV2ForCausalLMNet, DeepseekV2PretrainingCriterion),
49+
"deepseekv2_auto": (DeepseekV2Config, DeepseekV2ForCausalLMAuto, DeepseekV2PretrainingCriterion),
5050
}
5151

5252

@@ -90,7 +90,7 @@ class PreTrainingArguments(AutoTrainingArguments):
9090
)
9191
sr: Optional[int] = field(default=0, metadata={"help": "The count of chunks without recompute."})
9292
virtual_pipeline_seg_method: str = field(
93-
default="DeepseekV2DecoderLayerNet",
93+
default="DeepseekV2DecoderLayerAuto",
9494
metadata={"help": "The seg method of spliting pp layer for virtual pipeline."},
9595
)
9696
# NOTE(gongenlei): new add autotuner_benchmark

llm/auto_parallel/deepseek-v2/run_pretrain_auto.sh

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -41,12 +41,12 @@ unset CUDA_VISIBLE_DEVICES
4141
task_name="deepseekv2"
4242
rm -rf output/$task_name/
4343
rm -rf "output/$task_name""_log"
44-
rm -rf /root/paddlejob/workspace/env_run/xuxinyi/PaddleNLP/llm/auto_parallel/deepseek-v2/deepseek_single_card
44+
rm -rf /root/paddlejob/workspace/env_run/xuxinyi/PaddleNLP/llm/auto_parallel/deepseek-v2/log
4545

4646
export SOT_LOG_LEVEL=4
4747
export PYTHONPATH=/root/paddlejob/workspace/env_run/xuxinyi/PaddleNLP:$PYTHONPATH
4848
#ulimit -c unlimited
49-
export GLOG_v=7
49+
# export GLOG_v=3
5050

5151
# export FLAGS_call_stack_level=3
5252
# export FLAGS_use_cuda_managed_memory=true
@@ -59,9 +59,9 @@ to_static=0 # 是否开启动转静训练
5959

6060
python -u -m paddle.distributed.launch \
6161
--gpus "0,1,2,3" \
62-
--log_dir "deepseek_single_card" \
62+
--log_dir "log" \
6363
run_pretrain_auto.py \
64-
--model_type "deepseekv2_network" \
64+
--model_type "deepseekv2_auto" \
6565
--model_name_or_path "deepseek-ai/DeepSeek-V2-Lite" \
6666
--tokenizer_name_or_path "deepseek-ai/DeepSeek-V2-Lite" \
6767
--input_dir "./data" \

paddlenlp/transformers/deepseek_v2/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,6 @@
1414

1515
from .configuration import *
1616
from .modeling import *
17-
from .modeling_network import *
17+
from .modeling_auto import *
1818
from .modeling_pp import *
1919
from .tokenizer_fast import *

paddlenlp/transformers/deepseek_v2/modeling_network.py renamed to paddlenlp/transformers/deepseek_v2/modeling_auto.py

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
22
# Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
33
#
44
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
@@ -69,10 +69,10 @@
6969
)
7070

7171
__all__ = [
72-
"DeepseekV2LMHeadNet",
73-
"DeepseekV2ForCausalLMNet",
74-
"DeepseekV2ModelNet",
75-
"DeepseekV2PretrainedModelNet",
72+
"DeepseekV2LMHeadAuto",
73+
"DeepseekV2ForCausalLMAuto",
74+
"DeepseekV2ModelAuto",
75+
"DeepseekV2PretrainedModelAuto",
7676
]
7777

7878

@@ -169,7 +169,7 @@ def scaled_dot_product_attention(
169169
return (attn_output, attn_weights) if output_attentions else attn_output
170170

171171

172-
class DeepseekV2MLPNet(nn.Layer):
172+
class DeepseekV2MLPAuto(nn.Layer):
173173
def __init__(self, config: DeepseekV2Config, hidden_size=None, intermediate_size=None, is_moe=False):
174174
super().__init__()
175175
self.config = config
@@ -187,7 +187,7 @@ def forward(self, x):
187187
return down_proj
188188

189189

190-
class DeepseekV2MoENet(MoELayer):
190+
class DeepseekV2MoEAuto(MoELayer):
191191
"""
192192
A mixed expert module containing shared experts.
193193
"""
@@ -209,15 +209,15 @@ def __init__(self, config: DeepseekV2Config):
209209
super().__init__(
210210
config=config,
211211
moe_num_experts=config.n_routed_experts,
212-
expert_class=DeepseekV2MLPNet,
212+
expert_class=DeepseekV2MLPAuto,
213213
expert_kwargs={"config": config, "intermediate_size": config.moe_intermediate_size},
214214
gate=gate,
215215
capacity=2.0,
216216
)
217217
self.alpha = config.aux_loss_alpha
218218
if config.n_shared_experts is not None:
219219
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
220-
self.shared_experts = DeepseekV2MLPNet(config=config, intermediate_size=intermediate_size, is_moe=True)
220+
self.shared_experts = DeepseekV2MLPAuto(config=config, intermediate_size=intermediate_size, is_moe=True)
221221

222222
def forward(self, hidden_states):
223223
final_hidden_states, l_aux, l_zloss = super().forward(hidden_states)
@@ -231,7 +231,7 @@ def forward(self, hidden_states):
231231

232232

233233
# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV2
234-
class DeepseekV2AttentionNet(nn.Layer):
234+
class DeepseekV2AttentionAuto(nn.Layer):
235235
"""Multi-headed attention from 'Attention Is All You Need' paper"""
236236

237237
def __init__(self, config: DeepseekV2Config, layerwise_recompute: bool = False):
@@ -393,9 +393,7 @@ def forward(
393393
# query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
394394
# query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
395395

396-
print("k_nope.shape", k_nope.shape, "k_pe.shape", k_pe.shape)
397396
key_states = paddle.empty([bsz, q_len, self.num_heads, self.q_head_dim], dtype=self.config.dtype)
398-
print("key_states.shape:", key_states.shape)
399397
# input[0]'s shape = [1, 2048, 16, 128], input[1]'s shape = [1, 2048, 1, 64].
400398
key_states = paddle.concat([k_nope, k_pe.expand([bsz, q_len, self.num_heads, k_pe.shape[-1]])], axis=3)
401399

@@ -456,7 +454,7 @@ def forward(
456454
return attn_output, attn_weights, past_key_value
457455

458456

459-
class DeepseekV2DecoderLayerNet(nn.Layer):
457+
class DeepseekV2DecoderLayerAuto(nn.Layer):
460458
def __init__(self, config: DeepseekV2Config, layer_idx: int, layerwise_recompute: bool = False):
461459
super().__init__()
462460
self.config = config
@@ -467,16 +465,16 @@ def __init__(self, config: DeepseekV2Config, layer_idx: int, layerwise_recompute
467465

468466
self.hidden_size = config.hidden_size
469467

470-
self.self_attn = DeepseekV2AttentionNet(config=config, layerwise_recompute=layerwise_recompute)
468+
self.self_attn = DeepseekV2AttentionAuto(config=config, layerwise_recompute=layerwise_recompute)
471469

472470
self.mlp = (
473-
DeepseekV2MoENet(config)
471+
DeepseekV2MoEAuto(config)
474472
if (
475473
config.n_routed_experts is not None
476474
and layer_idx >= config.first_k_dense_replace
477475
and layer_idx % config.moe_layer_freq == 0
478476
)
479-
else DeepseekV2MLPNet(config)
477+
else DeepseekV2MLPAuto(config)
480478
)
481479
self.input_layernorm = DeepseekV2RMSNorm(config)
482480
self.post_attention_layernorm = DeepseekV2RMSNorm(config)
@@ -566,16 +564,16 @@ def forward(
566564
return outputs
567565

568566

569-
class DeepseekV2PretrainedModelNet(PretrainedModel):
567+
class DeepseekV2PretrainedModelAuto(PretrainedModel):
570568
config_class = DeepseekV2Config
571569
base_model_prefix = "deepseek_v2"
572-
_no_split_modules = ["DeepseekV2DecoderLayerNet"]
570+
_no_split_modules = ["DeepseekV2DecoderLayerAuto"]
573571

574572

575573
@register_base_model
576-
class DeepseekV2ModelNet(DeepseekV2PretrainedModelNet):
574+
class DeepseekV2ModelAuto(DeepseekV2PretrainedModelAuto):
577575
"""
578-
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV2DecoderLayerNet`]
576+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV2DecoderLayerAuto`]
579577
580578
Args:
581579
config: DeepseekV2Config
@@ -597,7 +595,7 @@ def __init__(self, config: DeepseekV2Config):
597595

598596
self.layers = nn.LayerList(
599597
[
600-
DeepseekV2DecoderLayerNet(config, layer_idx, layer_idx not in self.no_recompute_layers)
598+
DeepseekV2DecoderLayerAuto(config, layer_idx, layer_idx not in self.no_recompute_layers)
601599
for layer_idx in range(config.num_hidden_layers)
602600
]
603601
)
@@ -777,9 +775,9 @@ def forward(
777775
)
778776

779777

780-
class DeepseekV2LMHeadNet(nn.Layer):
778+
class DeepseekV2LMHeadAuto(nn.Layer):
781779
def __init__(self, config: DeepseekV2Config):
782-
super(DeepseekV2LMHeadNet, self).__init__()
780+
super(DeepseekV2LMHeadAuto, self).__init__()
783781

784782
self.config = config
785783

@@ -796,15 +794,15 @@ def forward(self, hidden_states, tensor_parallel_output=None):
796794
return logits
797795

798796

799-
class DeepseekV2ForCausalLMNet(DeepseekV2PretrainedModelNet):
797+
class DeepseekV2ForCausalLMAuto(DeepseekV2PretrainedModelAuto):
800798
_tied_weights_keys = ["lm_head.weight"]
801799

802800
def __init__(self, config: DeepseekV2Config):
803801
super().__init__(config)
804802
self.config = config
805-
self.deepseek_v2 = DeepseekV2ModelNet(config)
803+
self.deepseek_v2 = DeepseekV2ModelAuto(config)
806804
self.vocab_size = config.vocab_size
807-
self.lm_head = DeepseekV2LMHeadNet(config)
805+
self.lm_head = DeepseekV2LMHeadAuto(config)
808806
self.criterion = DeepseekV2PretrainingCriterion(config)
809807

810808
def get_input_embeddings(self):
@@ -851,9 +849,9 @@ def forward(
851849
Example:
852850
853851
```python
854-
>>> from transformers import AutoTokenizer, DeepseekV2ForCausalLMNet
852+
>>> from transformers import AutoTokenizer, DeepseekV2ForCausalLMAuto
855853
856-
>>> model = DeepseekV2ForCausalLMNet.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
854+
>>> model = DeepseekV2ForCausalLMAuto.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
857855
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
858856
859857
>>> prompt = "Hey, are you conscious? Can you talk to me?"

0 commit comments

Comments
 (0)