Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 10 additions & 19 deletions llm/alignment/dpo/dpo_argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
from paddlenlp.trainer import TrainingArguments
from paddlenlp.trainer.trainer_utils import IntervalStrategy
from paddlenlp.trainer.utils.doc import add_start_docstrings
from paddlenlp.transformers.configuration_utils import llmmetaclass


@dataclass
@llmmetaclass
@add_start_docstrings(TrainingArguments.__doc__)
class DPOTrainingArguments(TrainingArguments):
"""DPOTrainingArguments"""
Expand Down Expand Up @@ -122,30 +124,19 @@ class DPOModelArgument:
tokenizer_name_or_path: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
use_flash_attention: bool = field(default=False, metadata={"help": "Whether to use flash attention"})
recompute_granularity: str = field(
default="full",
metadata={
"help": "The granularity of recompute training can be selected as `full` or `full_attn` or `core_attn`."
},
)
flash_mask: bool = field(default=False, metadata={"help": "Whether to use flash mask in flash attention."})
virtual_pp_degree: int = field(
default=1,
metadata={"help": "virtual_pp_degree"},
)
sequence_parallel: bool = field(
default=False,
metadata={"help": "whether to use sequence parallel"},
)
tensor_parallel_output: bool = field(
default=True,
metadata={"help": "whether to use tensor_parallel_output"},
)
weight_quantize_algo: str = field(
default=None,
metadata={"help": "Model weight quantization algorithm including 'nf4'(qlora), 'weight_only_int8'."},
)
fuse_attention_qkv: bool = field(
default=None,
metadata={"help": "whether to fuse attention qkv"},
)
fuse_attention_ffn: bool = field(
default=None,
metadata={"help": "whether to fuse first up and gate proj in mlp block"},
)
# LoRA
lora_rank: int = field(default=8, metadata={"help": "Lora rank."})
lora_path: str = field(default=None, metadata={"help": "Initialize lora state dict."})
Expand Down
12 changes: 8 additions & 4 deletions llm/alignment/dpo/run_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,13 @@ def main():
dtype=dtype,
tensor_parallel_degree=training_args.tensor_parallel_degree,
tensor_parallel_rank=training_args.tensor_parallel_rank,
recompute_granularity=model_args.recompute_granularity,
use_flash_attention=model_args.use_flash_attention,
tensor_parallel_output=model_args.tensor_parallel_output,
recompute_granularity=training_args.recompute_granularity,
use_flash_attention=training_args.use_flash_attention,
tensor_parallel_output=training_args.tensor_parallel_output,
use_fused_rms_norm=training_args.use_fused_rms_norm,
use_fused_rope=training_args.use_fused_rope,
use_fused_linear=training_args.use_fused_linear,
use_fused_dropout_add=training_args.use_fused_dropout_add,
Comment on lines +114 to +117
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以用 llm meta config, set_config 避免复杂的赋值操作。也可以不用

)
if training_args.pipeline_parallel_degree > 1:
raise ValueError("DPO does not support pipeline parallelism yet.")
Expand Down Expand Up @@ -143,7 +147,7 @@ def main():
if model_args.flash_mask and not any(isinstance(model, cls) for cls in flash_mask_support_list):
raise NotImplementedError(f"{model.__class__} not support flash mask.")

if model_args.sequence_parallel:
if training_args.sequence_parallel:
register_sequence_parallel_allreduce_hooks(
model, training_args.gradient_accumulation_steps, model_args.fuse_sequence_parallel_allreduce
)
Expand Down