Skip to content

DPO训练GLM4失败 #4090

@Orion-zhen

Description

@Orion-zhen

Reminder

  • I have read the README and searched the existing issues.

System Info

  • transformers version: 4.41.2
  • Platform: Linux-6.9.3-arch1-1-x86_64-with-glibc2.39
  • Python version: 3.10.14
  • Huggingface_hub version: 0.23.3
  • Safetensors version: 0.4.3
  • Accelerate version: 0.30.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.3.0+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: No
  • GPU: RTX 2080TI 22G

Reproduction

训练配置:

### model
model_name_or_path: /home/orion/ai/Models/glm4-9b
quantization_bit: 8

### method
stage: dpo
do_train: true
finetuning_type: lora
lora_target: query_key_value
pref_beta: 0.1
pref_loss: orpo

### dataset
dataset: dpo_zh_emoji, dpo_toxic, dpo_code, dpo_physical_reasoning
template: glm4
cutoff_len: 1024
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16

### output
output_dir: saves/glm4-9b/lora/dpo
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true

### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-5
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
fp16: true

### eval
val_size: 0.1
per_device_eval_batch_size: 1
evaluation_strategy: steps
eval_steps: 500

终端输出:

/home/orion/repo/llama-factory/.venv/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
/home/orion/repo/llama-factory/.venv/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
Traceback (most recent call last):
  File "/home/orion/repo/llama-factory/.venv/bin/llamafactory-cli", line 8, in <module>
    sys.exit(main())
  File "/home/orion/repo/llama-factory/src/llamafactory/cli.py", line 93, in main
    run_exp()
  File "/home/orion/repo/llama-factory/src/llamafactory/train/tuner.py", line 39, in run_exp
    run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
  File "/home/orion/repo/llama-factory/src/llamafactory/train/dpo/workflow.py", line 64, in run_dpo
    train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
  File "/home/orion/repo/llama-factory/.venv/lib/python3.10/site-packages/transformers/trainer.py", line 1885, in train
    return inner_training_loop(
  File "/home/orion/repo/llama-factory/.venv/lib/python3.10/site-packages/transformers/trainer.py", line 2216, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/home/orion/repo/llama-factory/.venv/lib/python3.10/site-packages/transformers/trainer.py", line 3238, in training_step
    loss = self.compute_loss(model, inputs)
  File "/home/orion/repo/llama-factory/.venv/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 1257, in compute_loss
    loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
  File "/home/orion/repo/llama-factory/src/llamafactory/train/dpo/trainer.py", line 223, in get_batch_loss_metrics
    ) = self.concatenated_forward(model, batch)
  File "/home/orion/repo/llama-factory/src/llamafactory/train/dpo/trainer.py", line 170, in concatenated_forward
    all_logps = self.get_batch_logps(
TypeError: DPOTrainer.get_batch_logps() got an unexpected keyword argument 'average_log_prob'

Expected behavior

No response

Others

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    solvedThis problem has been already solved

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions