Skip to content

Commit 2231feb

Browse files
[Embedding] Add embedding training (#9508)
* add Qwen2SentenceEmbedding * add embedding trainer --------- Co-authored-by: DrownFish19 <DrownFish19@gmail.com>
1 parent da7a7d2 commit 2231feb

File tree

7 files changed

+835
-1
lines changed

7 files changed

+835
-1
lines changed

llm/config/qwen/emb_argument.json

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
{
2+
"model_name_or_path": "Qwen/Qwen2-0.5B",
3+
"dataset_name_or_path": "./dureader_data",
4+
"output_dir": "./checkpoints/sft_ckpts",
5+
"per_device_train_batch_size": 1,
6+
"gradient_accumulation_steps": 4,
7+
"per_device_eval_batch_size": 1,
8+
"eval_accumulation_steps": 1,
9+
"max_steps": 2000,
10+
"learning_rate": 3e-5,
11+
"warmup_steps": 30,
12+
"logging_steps": 1,
13+
"evaluation_strategy": "no",
14+
"save_strategy": "epoch",
15+
"max_query_len": 512,
16+
"max_passage_len": 512,
17+
"group_size": 4,
18+
"bf16": true,
19+
"fp16_opt_level": "O2",
20+
"do_train": true,
21+
"do_eval": false,
22+
"disable_tqdm": true,
23+
"load_best_model_at_end": false,
24+
"eval_with_do_generation": false,
25+
"metric_for_best_model": "accuracy",
26+
"recompute": true,
27+
"save_total_limit": 1,
28+
"tensor_parallel_degree": 1,
29+
"pipeline_parallel_degree": 1,
30+
"sharding": "stage1",
31+
"zero_padding": false,
32+
"unified_checkpoint": true,
33+
"use_flash_attention": true,
34+
"amp_custom_black_list": "elementwise_div",
35+
"release_grads": true
36+
}

llm/run_embedding.py

Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# import inspect
15+
import os
16+
import sys
17+
18+
import paddle
19+
from utils.argument import EmbeddingArgument
20+
21+
from paddlenlp.data import DataCollatorForEmbedding
22+
from paddlenlp.datasets import EmbeddingIterableDataset, load_dataset
23+
from paddlenlp.trainer import PdArgumentParser, get_last_checkpoint, set_seed
24+
from paddlenlp.trainer.trainer_callback import TrainerState
25+
from paddlenlp.transformers import (
26+
AutoConfig,
27+
AutoTokenizer,
28+
Qwen2Config,
29+
Qwen2SentenceEmbedding,
30+
)
31+
from paddlenlp.transformers.configuration_utils import LlmMetaConfig
32+
from paddlenlp.transformers.refined_recompute import update_refined_recompute
33+
from paddlenlp.trl import DataConfig, EmbeddingTrainer, ModelConfig, SFTConfig
34+
from paddlenlp.trl.llm_utils import compute_metrics, init_chat_template
35+
from paddlenlp.utils.log import logger
36+
37+
# Fine-tune Environment Variables to support sharding stage1 overlap optimization.
38+
os.environ["USE_CASUAL_MASK"] = "False"
39+
40+
41+
def main():
42+
parser = PdArgumentParser((ModelConfig, DataConfig, SFTConfig, EmbeddingArgument))
43+
if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"):
44+
model_args, data_args, training_args, embedding_args = parser.parse_json_file_and_cmd_lines()
45+
else:
46+
model_args, data_args, training_args, embedding_args = parser.parse_args_into_dataclasses()
47+
48+
training_args.print_config(model_args, "Model")
49+
training_args.print_config(data_args, "Data")
50+
51+
# Setup GPU & distributed training
52+
paddle.set_device(training_args.device)
53+
set_seed(seed=training_args.seed)
54+
logger.warning(
55+
f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, "
56+
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16 or training_args.bf16}"
57+
)
58+
59+
if training_args.pipeline_parallel_degree > 1:
60+
raise NotImplementedError("Cannot support pipeline parallel for Embedding training now.")
61+
62+
# Detecting last checkpoint.
63+
last_checkpoint = None
64+
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
65+
last_checkpoint = get_last_checkpoint(training_args.output_dir)
66+
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
67+
logger.info(
68+
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
69+
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
70+
)
71+
72+
# Load model
73+
if training_args.fp16_opt_level == "O2":
74+
if training_args.fp16:
75+
dtype = "float16"
76+
elif training_args.bf16:
77+
dtype = "bfloat16"
78+
else:
79+
raise ValueError("Please specific dtype: --fp16 or --bf16")
80+
else:
81+
dtype = "float32"
82+
83+
model_config = AutoConfig.from_pretrained(
84+
model_args.model_name_or_path,
85+
dtype=dtype,
86+
from_aistudio=model_args.from_aistudio,
87+
)
88+
assert isinstance(model_config, Qwen2Config), "Now only qwen2 supported"
89+
90+
LlmMetaConfig.set_llm_config(model_config, training_args)
91+
model_config.refined_recompute = update_refined_recompute(training_args.refined_recompute)
92+
model_config.use_fast_layer_norm = model_args.use_fast_layer_norm
93+
94+
# Config for model using dropout, such as GPT.
95+
if hasattr(model_config, "hidden_dropout_prob"):
96+
model_config.hidden_dropout_prob = model_args.hidden_dropout_prob
97+
if hasattr(model_config, "attention_probs_dropout_prob"):
98+
model_config.attention_probs_dropout_prob = model_args.attention_probs_dropout_prob
99+
if hasattr(model_config, "ignore_index"):
100+
model_config.ignore_index = -100
101+
102+
if model_args.fuse_attention_qkv is not None:
103+
model_config.fuse_attention_qkv = model_args.fuse_attention_qkv
104+
if model_args.fuse_attention_ffn is not None:
105+
model_config.fuse_attention_ffn = model_args.fuse_attention_ffn
106+
107+
model_config.seq_length = data_args.max_length
108+
model_config.embedding_negatives_cross_device = embedding_args.embedding_negatives_cross_device
109+
logger.info(f"Final model config: {model_config}")
110+
111+
model_class = Qwen2SentenceEmbedding
112+
113+
if model_args.continue_training and not training_args.autotuner_benchmark:
114+
model = model_class.from_pretrained(
115+
model_args.model_name_or_path,
116+
config=model_config,
117+
from_aistudio=model_args.from_aistudio,
118+
)
119+
else:
120+
model = model_class.from_config(model_config, dtype=dtype)
121+
122+
if model_args.flash_mask and (not data_args.zero_padding or not model.config.use_flash_attention):
123+
logger.warning("`flash_mask` must use with zero padding and flash attention.")
124+
data_args.zero_padding = True
125+
model.config.use_flash_attention = True
126+
127+
# Load tokenizer & dataset
128+
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, from_aistudio=model_args.from_aistudio)
129+
130+
# init chat_template for tokenizer
131+
init_chat_template(tokenizer, model_args.model_name_or_path, data_args.chat_template)
132+
133+
# if using chat_template, data_args.eval_with_do_generation must be false
134+
if tokenizer.chat_template is not None:
135+
data_args.eval_with_do_generation = False
136+
137+
if training_args.do_eval:
138+
logger.warning("Warning: 'do_eval' is set to True, but will be set to False for Embedding training currently.")
139+
training_args.do_eval = False
140+
training_args.evaluation_strategy = "no"
141+
142+
if data_args.dataset_name_or_path is None:
143+
raise ValueError(f"Please specific dataset name or path (got {data_args.dataset_name_or_path})")
144+
elif os.path.exists(os.path.join(data_args.dataset_name_or_path, "train.json")) or os.path.exists(
145+
os.path.join(data_args.dataset_name_or_path, "dev.json")
146+
):
147+
if training_args.do_train:
148+
train_ds = load_dataset(
149+
"json",
150+
data_files=os.path.join(data_args.dataset_name_or_path, "train.json"),
151+
lazy=data_args.lazy,
152+
)[0]
153+
else:
154+
train_ds = None
155+
if training_args.do_eval:
156+
dev_ds = load_dataset(
157+
"json",
158+
data_files=os.path.join(data_args.dataset_name_or_path, "dev.json"),
159+
lazy=data_args.lazy,
160+
)[0]
161+
else:
162+
dev_ds = None
163+
164+
elif os.path.exists(os.path.join(data_args.dataset_name_or_path, "train")) or os.path.exists(
165+
os.path.join(data_args.dataset_name_or_path, "dev")
166+
):
167+
import glob
168+
169+
if training_args.do_train:
170+
train_ds = load_dataset(
171+
"json",
172+
data_files=glob.glob(os.path.join(data_args.dataset_name_or_path, "train", "*.json")),
173+
lazy=data_args.lazy,
174+
)[0]
175+
else:
176+
train_ds = None
177+
if training_args.do_eval:
178+
dev_ds = load_dataset(
179+
"json",
180+
data_files=glob.glob(os.path.join(data_args.dataset_name_or_path, "dev", "*.json")),
181+
lazy=data_args.lazy,
182+
)[0]
183+
else:
184+
dev_ds = None
185+
186+
else:
187+
if training_args.do_train:
188+
train_ds = load_dataset(data_args.dataset_name_or_path, splits=["train"])[0]
189+
else:
190+
train_ds = None
191+
if training_args.do_eval:
192+
dev_ds = load_dataset(data_args.dataset_name_or_path, splits=["dev"])[0]
193+
else:
194+
dev_ds = None
195+
196+
# TODO(ZHUI & sijunhe): Temporary implementation. Generalize this logic and move to Trainer later.
197+
if training_args.resume_from_checkpoint is not None and data_args.lazy:
198+
logger.info(
199+
f"Loading from '{training_args.resume_from_checkpoint}' with `lazy=True`, manually skipping dataset and setting `ignore_data_skip` to True."
200+
)
201+
training_args.ignore_data_skip = True
202+
state = TrainerState.load_from_json(os.path.join(training_args.resume_from_checkpoint, "trainer_state.json"))
203+
if state.trial_params is not None and "zero_padding_global_step" in state.trial_params:
204+
consumed_samples = state.trial_params["zero_padding_global_step"]
205+
else:
206+
consumed_samples = (
207+
state.global_step
208+
* training_args.per_device_train_batch_size
209+
* training_args.gradient_accumulation_steps
210+
* training_args.dataset_world_size
211+
)
212+
logger.info(
213+
f"Skipping the first {consumed_samples} samples to warmup the dataset from checkpoint '{training_args.resume_from_checkpoint}'."
214+
)
215+
train_ds = train_ds.skip(consumed_samples)
216+
217+
if train_ds is not None:
218+
train_ds = EmbeddingIterableDataset(
219+
train_ds,
220+
tokenizer,
221+
max_query_len=embedding_args.max_query_len,
222+
max_passage_len=embedding_args.max_passage_len,
223+
group_size=embedding_args.group_size,
224+
query_template=embedding_args.query_template,
225+
passage_template=embedding_args.passage_template,
226+
)
227+
228+
if dev_ds is not None:
229+
dev_ds = EmbeddingIterableDataset(
230+
dev_ds,
231+
tokenizer,
232+
max_query_len=embedding_args.max_query_len,
233+
max_passage_len=embedding_args.max_passage_len,
234+
group_size=embedding_args.group_size,
235+
query_template=embedding_args.query_template,
236+
passage_template=embedding_args.passage_template,
237+
)
238+
239+
# Create trainer
240+
if data_args.pad_to_max_length:
241+
padding = "max_length"
242+
else:
243+
padding = True
244+
245+
data_collator_fn = DataCollatorForEmbedding(
246+
tokenizer=tokenizer,
247+
max_query_len=embedding_args.max_query_len,
248+
padding=padding,
249+
max_passage_len=embedding_args.max_passage_len,
250+
return_tensors="np",
251+
return_attention_mask=not model_args.flash_mask,
252+
pad_to_multiple_of=data_args.pad_to_multiple_of,
253+
)
254+
trainer = EmbeddingTrainer(
255+
model=model,
256+
model_args=embedding_args,
257+
args=training_args,
258+
train_dataset=train_ds,
259+
eval_dataset=dev_ds,
260+
tokenizer=tokenizer,
261+
compute_metrics=compute_metrics,
262+
data_collator=data_collator_fn,
263+
)
264+
trainable_parameters = [p for p in model.parameters() if not p.stop_gradient]
265+
trainer.set_optimizer_grouped_parameters(trainable_parameters)
266+
267+
# Train
268+
if training_args.do_train:
269+
checkpoint = None
270+
if training_args.resume_from_checkpoint is not None:
271+
checkpoint = training_args.resume_from_checkpoint
272+
elif last_checkpoint is not None:
273+
checkpoint = last_checkpoint
274+
train_result = trainer.train(resume_from_checkpoint=checkpoint)
275+
trainer.save_model(merge_tensor_parallel=training_args.tensor_parallel_degree > 1)
276+
trainer.log_metrics("train", train_result.metrics)
277+
trainer.save_metrics("train", train_result.metrics)
278+
trainer.save_state()
279+
280+
# Evaluation dev set
281+
if training_args.do_eval:
282+
logger.info("*** Evaluate result after train ***")
283+
eval_result = trainer.evaluate(dev_ds)
284+
trainer.log_metrics("eval", eval_result)
285+
286+
287+
if __name__ == "__main__":
288+
main()

llm/utils/argument.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from dataclasses import dataclass, field
15+
from typing import List, Optional
1516

1617

1718
@dataclass
@@ -36,3 +37,54 @@ class GenerateArgument:
3637
top_p: float = field(
3738
default=1.0, metadata={"help": "The cumulative probability for top-p-filtering in the sampling strategy."}
3839
)
40+
41+
42+
@dataclass
43+
class EmbeddingArgument:
44+
max_query_len: int = field(
45+
default=1,
46+
metadata={
47+
"help": "The number of highest probability tokens to keep for top-k-filtering in the sampling strategy"
48+
},
49+
)
50+
max_passage_len: int = field(
51+
default=1.0, metadata={"help": "The cumulative probability for top-p-filtering in the sampling strategy."}
52+
)
53+
group_size: int = field(
54+
default=8,
55+
metadata={
56+
"help": (
57+
"Number of total positive and negative samples associated with " "each query for embedding training."
58+
)
59+
},
60+
)
61+
query_template: str = field(
62+
default="Query: {text}\nUse one word to summarize the query's relevant information. The word is: \"",
63+
metadata={
64+
"help": (
65+
"Query template. Ensure the template includes the placeholder "
66+
"'{text}' to insert the actual query text."
67+
)
68+
},
69+
)
70+
passage_template: str = field(
71+
default="Text: {text}\nUse one word to summarize the text's content. The word is: \"",
72+
metadata={
73+
"help": (
74+
"Passage template. Ensure the template includes the placeholder "
75+
"'{text}' to insert the actual passage text."
76+
)
77+
},
78+
)
79+
embedding_temperature: float = field(
80+
default=0.02,
81+
metadata={"help": "The temperature used in embedding learning."},
82+
)
83+
embedding_negatives_cross_device: bool = field(
84+
default=True,
85+
metadata={"help": "Whether to share the negatives across all GPUs."},
86+
)
87+
embedding_matryoshka_dims: Optional[List[int]] = field(
88+
default=None,
89+
metadata={"help": "The dims for matryoshka training."},
90+
)

0 commit comments

Comments
 (0)