Skip to content

Commit c6e3c14

Browse files
authored
Merge pull request hiyouga#6395 from hiyouga/hiyouga/fix_genkwargs
[generate] fix generate kwargs
2 parents ffbb4db + d4c1fda commit c6e3c14

File tree

6 files changed

+22
-16
lines changed

6 files changed

+22
-16
lines changed

.gitignore

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,3 @@ saves/
172172
output/
173173
wandb/
174174
generated_predictions.jsonl
175-
176-
# unittest
177-
dummy_dir/

src/llamafactory/hparams/generating_args.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from dataclasses import asdict, dataclass, field
1616
from typing import Any, Dict, Optional
1717

18+
from transformers import GenerationConfig
19+
1820

1921
@dataclass
2022
class GeneratingArguments:
@@ -69,10 +71,17 @@ class GeneratingArguments:
6971
metadata={"help": "Whether or not to remove special tokens in the decoding."},
7072
)
7173

72-
def to_dict(self) -> Dict[str, Any]:
74+
def to_dict(self, obey_generation_config: bool = False) -> Dict[str, Any]:
7375
args = asdict(self)
7476
if args.get("max_new_tokens", -1) > 0:
7577
args.pop("max_length", None)
7678
else:
7779
args.pop("max_new_tokens", None)
80+
81+
if obey_generation_config:
82+
generation_config = GenerationConfig()
83+
for key in list(args.keys()):
84+
if not hasattr(generation_config, key):
85+
args.pop(key)
86+
7887
return args

src/llamafactory/train/sft/trainer.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def _pad_tensors_to_target_len(self, src_tensor: "torch.Tensor", tgt_tensor: "to
151151
return padded_tensor.contiguous() # in contiguous memory
152152

153153
def save_predictions(
154-
self, dataset: "Dataset", predict_results: "PredictionOutput", gen_kwargs: Dict[str, Any]
154+
self, dataset: "Dataset", predict_results: "PredictionOutput", skip_special_tokens: bool = True
155155
) -> None:
156156
r"""
157157
Saves model predictions to `output_dir`.
@@ -179,12 +179,8 @@ def save_predictions(
179179
preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)
180180

181181
decoded_inputs = self.processing_class.batch_decode(dataset["input_ids"], skip_special_tokens=False)
182-
decoded_preds = self.processing_class.batch_decode(
183-
preds, skip_special_tokens=gen_kwargs["skip_special_tokens"]
184-
)
185-
decoded_labels = self.processing_class.batch_decode(
186-
labels, skip_special_tokens=gen_kwargs["skip_special_tokens"]
187-
)
182+
decoded_preds = self.processing_class.batch_decode(preds, skip_special_tokens=skip_special_tokens)
183+
decoded_labels = self.processing_class.batch_decode(labels, skip_special_tokens=skip_special_tokens)
188184

189185
with open(output_prediction_file, "w", encoding="utf-8") as f:
190186
for text, pred, label in zip(decoded_inputs, decoded_preds, decoded_labels):

src/llamafactory/train/sft/workflow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def run_sft(
9191
)
9292

9393
# Keyword arguments for `model.generate`
94-
gen_kwargs = generating_args.to_dict()
94+
gen_kwargs = generating_args.to_dict(obey_generation_config=True)
9595
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
9696
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
9797
gen_kwargs["logits_processor"] = get_logits_processor()
@@ -130,7 +130,7 @@ def run_sft(
130130
predict_results.metrics.pop("predict_loss", None)
131131
trainer.log_metrics("predict", predict_results.metrics)
132132
trainer.save_metrics("predict", predict_results.metrics)
133-
trainer.save_predictions(dataset_module["eval_dataset"], predict_results, gen_kwargs)
133+
trainer.save_predictions(dataset_module["eval_dataset"], predict_results, generating_args.skip_special_tokens)
134134

135135
# Create model card
136136
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)

tests/e2e/test_train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,12 @@
6060
],
6161
)
6262
def test_run_exp(stage: str, dataset: str):
63-
output_dir = os.path.join("output", f"dummy_dir/train_{stage}")
63+
output_dir = os.path.join("output", f"train_{stage}")
6464
run_exp({"stage": stage, "dataset": dataset, "output_dir": output_dir, **TRAIN_ARGS})
6565
assert os.path.exists(output_dir)
6666

6767

6868
def test_export():
69-
export_dir = os.path.join("output", "dummy_dir/llama3_export")
69+
export_dir = os.path.join("output", "llama3_export")
7070
export_model({"export_dir": export_dir, **INFER_ARGS})
7171
assert os.path.exists(export_dir)

tests/train/test_sft_trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,11 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
5858
@pytest.mark.parametrize("disable_shuffling", [False, True])
5959
def test_shuffle(disable_shuffling: bool):
6060
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
61-
{"output_dir": f"dummy_dir/{disable_shuffling}", "disable_shuffling": disable_shuffling, **TRAIN_ARGS}
61+
{
62+
"output_dir": os.path.join("output", f"shuffle{str(disable_shuffling).lower()}"),
63+
"disable_shuffling": disable_shuffling,
64+
**TRAIN_ARGS,
65+
}
6266
)
6367
tokenizer_module = load_tokenizer(model_args)
6468
tokenizer = tokenizer_module["tokenizer"]

0 commit comments

Comments
 (0)