Skip to content

Commit 63ba9d0

Browse files
committed
update
1 parent d815fce commit 63ba9d0

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

llm/config/qwen/emb_argument.json

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"dataset_name_or_path": "./data",
44
"output_dir": "./checkpoints/sft_ckpts",
55
"per_device_train_batch_size": 1,
6-
"gradient_accumulation_steps": 128,
6+
"gradient_accumulation_steps": 4,
77
"per_device_eval_batch_size": 1,
88
"eval_accumulation_steps": 1,
99
"max_steps": 2000,
@@ -15,7 +15,7 @@
1515
"max_query_len": 1024,
1616
"max_passage_len": 2048,
1717
"group_size": 4,
18-
"bp16": true,
18+
"bf16": true,
1919
"fp16_opt_level": "O2",
2020
"do_train": true,
2121
"do_eval": false,
@@ -30,5 +30,7 @@
3030
"sharding": "stage2",
3131
"zero_padding": false,
3232
"unified_checkpoint": false,
33-
"use_flash_attention": false
33+
"use_flash_attention": true,
34+
"amp_custom_black_list": "elementwise_div",
35+
"release_grads": true,
3436
}

llm/utils/argument.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from dataclasses import dataclass, field
15+
from typing import List, Optional
1516

1617

1718
@dataclass
@@ -83,3 +84,7 @@ class EmbeddingArgument:
8384
default=True,
8485
metadata={"help": "Whether to share the negatives across all GPUs."},
8586
)
87+
embedding_matryoshka_dims: Optional[List[int]] = field(
88+
default=None,
89+
metadata={"help": "The dims for matryoshka training."},
90+
)

0 commit comments

Comments
 (0)