Skip to content

Commit 8398e3a

Browse files
authored
Add special tokens for keys in prompt (databrickslabs#39)
This configures the tokenizer so that strings `### Instruction:`, `### Response:`, and `### End` are all represented by a single token ID each. This simplifies the logic for finding the response. `generate` is also now configured to stop generation at `### End`, making generation faster. The default `max_new_tokens` has been double to 256. The notebook now has widgets `local_training_root` and `dbfs_output_root` for configuring where data is stored locally and in DBFS. By default if `local_training_root` is not provided it now uses `/local_disk0` if it exists and otherwise defaults to a subdir of the home directory, as before.
1 parent e8c5175 commit 8398e3a

File tree

4 files changed

+120
-49
lines changed

4 files changed

+120
-49
lines changed

train_dolly.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@
5959
from training.trainer import load_training_dataset, load_tokenizer
6060

6161
dbutils.widgets.text("num_gpus", "", "num_gpus")
62+
dbutils.widgets.text("local_training_root", "", "local_training_root")
63+
dbutils.widgets.text("dbfs_output_root", "", "dbfs_output_root")
6264

6365
# COMMAND ----------
6466

@@ -75,12 +77,27 @@
7577
root_path = os.getcwd()
7678
deepspeed_config = os.path.join(root_path, "config/ds_z3_bf16_config.json")
7779

78-
local_training_root = os.path.join(os.path.expanduser('~'), "dolly_training")
80+
dolly_training_dir_name = "dolly_training"
81+
82+
# Use the local training root path if it was provided. Otherwise try to find a sensible default.
83+
local_training_root = dbutils.widgets.get("local_training_root")
84+
if not local_training_root:
85+
# Use preferred path when working in a Databricks cluster if it exists.
86+
if os.path.exists("/local_disk0"):
87+
local_training_root = os.path.join("/local_disk0", dolly_training_dir_name)
88+
# Otherwise use the home directory.
89+
else:
90+
local_training_root = os.path.join(os.path.expanduser('~'), dolly_training_dir_name)
91+
92+
dbfs_output_root = dbutils.widgets.get("dbfs_output_root")
93+
if not dbfs_output_root:
94+
dbfs_output_root = f"/dbfs/{dolly_training_dir_name}"
7995

8096
os.makedirs(local_training_root, exist_ok=True)
97+
os.makedirs(dbfs_output_root, exist_ok=True)
8198

8299
local_output_dir = os.path.join(local_training_root, checkpoint_dir_name)
83-
dbfs_output_dir = os.path.join("/dbfs/dolly_training", checkpoint_dir_name)
100+
dbfs_output_dir = os.path.join(dbfs_output_root, checkpoint_dir_name)
84101

85102
num_gpus_flag = ""
86103
num_gpus = dbutils.widgets.get("num_gpus")

training/consts.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
DEFAULT_TRAINING_DATASET = "tatsu-lab/alpaca"
2+
DEFAULT_INPUT_MODEL = "EleutherAI/gpt-j-6B"
3+
RESPONSE_KEY = "### Response:"
4+
END_KEY = "### End"
5+
INSTRUCTION_KEY = "### Instruction:"
6+
RESPONSE_KEY_NL = f"{RESPONSE_KEY}\n"
7+
DEFAULT_SEED = 42
8+
9+
# The format of the instruction the model has been trained on.
10+
PROMPT_FORMAT = """%s
11+
12+
%s
13+
{instruction}
14+
15+
%s""" % (
16+
"Below is an instruction that describes a task. Write a response that appropriately completes the request.",
17+
INSTRUCTION_KEY,
18+
RESPONSE_KEY_NL,
19+
)

training/generate.py

Lines changed: 56 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,17 @@
11
import logging
2-
import re
32
from typing import Tuple
43

4+
import numpy as np
55
from transformers import (
66
AutoModelForCausalLM,
77
AutoTokenizer,
88
PreTrainedModel,
99
PreTrainedTokenizer,
1010
)
1111

12-
logger = logging.getLogger(__name__)
13-
14-
# The format of the instruction the model has been trained on.
15-
INTRO = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
16-
INSTRUCTION_FORMAT = """{intro}
17-
18-
### Instruction:
19-
{instruction}
12+
from .consts import END_KEY, PROMPT_FORMAT, RESPONSE_KEY
2013

21-
### Response:
22-
"""
14+
logger = logging.getLogger(__name__)
2315

2416

2517
def load_model_tokenizer_for_generate(
@@ -40,13 +32,35 @@ def load_model_tokenizer_for_generate(
4032
return model, tokenizer
4133

4234

35+
def get_special_token_id(tokenizer: PreTrainedTokenizer, key: str) -> int:
36+
"""Gets the token ID for a given string that has been added to the tokenizer as a special token.
37+
38+
When training, we configure the tokenizer so that the sequences like "### Instruction:" and "### End" are
39+
treated specially and converted to a single, new token. This retrieves the token ID each of these keys map to.
40+
41+
Args:
42+
tokenizer (PreTrainedTokenizer): the tokenizer
43+
key (str): the key to convert to a single token
44+
45+
Raises:
46+
RuntimeError: if more than one ID was generated
47+
48+
Returns:
49+
int: the token ID for the given key
50+
"""
51+
token_ids = tokenizer.encode(key)
52+
if len(token_ids) > 1:
53+
raise RuntimeError(f"Expected only a single token for '{key}' but found {token_ids}")
54+
return token_ids[0]
55+
56+
4357
def generate_response(
4458
instruction: str,
4559
*,
4660
model: PreTrainedModel,
4761
tokenizer: PreTrainedTokenizer,
4862
do_sample: bool = True,
49-
max_new_tokens: int = 128,
63+
max_new_tokens: int = 256,
5064
top_p: float = 0.92,
5165
top_k: int = 0,
5266
**kwargs,
@@ -68,34 +82,45 @@ def generate_response(
6882
Returns:
6983
str: the generated response
7084
"""
71-
input_ids = tokenizer(
72-
INSTRUCTION_FORMAT.format(intro=INTRO, instruction=instruction), return_tensors="pt"
73-
).input_ids.to("cuda")
85+
input_ids = tokenizer(PROMPT_FORMAT.format(instruction=instruction), return_tensors="pt").input_ids.to("cuda")
86+
87+
response_key_token_id = get_special_token_id(tokenizer, RESPONSE_KEY)
88+
end_key_token_id = get_special_token_id(tokenizer, END_KEY)
7489

7590
gen_tokens = model.generate(
7691
input_ids,
7792
pad_token_id=tokenizer.pad_token_id,
93+
# Ensure generation stops once it generates "### End"
94+
eos_token_id=end_key_token_id,
7895
do_sample=do_sample,
7996
max_new_tokens=max_new_tokens,
8097
top_p=top_p,
8198
top_k=top_k,
8299
**kwargs,
83-
)
84-
decoded = tokenizer.batch_decode(gen_tokens)[0]
100+
)[0].cpu()
85101

86-
# The response appears after "### Response:". The model has been trained to append "### End" at the end.
87-
m = re.search(r"#+\s*Response:\s*(.+?)#+\s*End", decoded, flags=re.DOTALL)
102+
# The response will be set to this variable if we can identify it.
103+
decoded = None
88104

89-
response = None
90-
if m:
91-
response = m.group(1).strip()
105+
# Find where "### Response:" is first found in the generated tokens. Considering this is part of the prompt,
106+
# we should definitely find it. We will return the tokens found after this token.
107+
response_pos = None
108+
response_positions = np.where(gen_tokens == response_key_token_id)[0]
109+
if len(response_positions) == 0:
110+
logger.warn(f"Could not find response key {response_key_token_id} in: {gen_tokens}")
92111
else:
93-
# The model might not generate the "### End" sequence before reaching the max tokens. In this case, return
94-
# everything after "### Response:".
95-
m = re.search(r"#+\s*Response:\s*(.+)", decoded, flags=re.DOTALL)
96-
if m:
97-
response = m.group(1).strip()
98-
else:
99-
logger.warn(f"Failed to find response in:\n{decoded}")
100-
101-
return response
112+
response_pos = response_positions[0]
113+
114+
if response_pos:
115+
# Next find where "### End" is located. The model has been trained to end its responses with this sequence
116+
# (or actually, the token ID it maps to, since it is a special token). We may not find this token, as the
117+
# response could be truncated. If we don't find it then just return everything to the end. Note that
118+
# even though we set eos_token_id, we still see the this token at the end.
119+
end_pos = None
120+
end_positions = np.where(gen_tokens == end_key_token_id)[0]
121+
if len(end_positions) > 0:
122+
end_pos = end_positions[0]
123+
124+
decoded = tokenizer.decode(gen_tokens[response_pos + 1 : end_pos]).strip()
125+
126+
return decoded

training/trainer.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,20 +29,26 @@
2929
set_seed,
3030
)
3131

32-
logger = logging.getLogger(__name__)
32+
from .consts import (
33+
DEFAULT_INPUT_MODEL,
34+
DEFAULT_SEED,
35+
DEFAULT_TRAINING_DATASET,
36+
END_KEY,
37+
INSTRUCTION_KEY,
38+
RESPONSE_KEY,
39+
RESPONSE_KEY_NL,
40+
)
3341

34-
DEFAULT_TRAINING_DATASET = "tatsu-lab/alpaca"
35-
DEFAULT_INPUT_MODEL = "EleutherAI/gpt-j-6B"
36-
RESPONSE_KEY = "### Response:\n"
37-
DEFAULT_SEED = 42
38-
MAX_LENGTH = 1024
42+
logger = logging.getLogger(__name__)
3943

4044

4145
class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
4246
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
4347
batch = super().torch_call(examples)
4448

45-
response_token_ids = self.tokenizer.encode(RESPONSE_KEY)
49+
# The prompt ends with the response key plus a newline. We encode this and then try to find it in the
50+
# sequence of tokens.
51+
response_token_ids = self.tokenizer.encode(RESPONSE_KEY_NL)
4652

4753
labels = batch["labels"].clone()
4854

@@ -67,7 +73,7 @@ def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> D
6773
return batch
6874

6975

70-
def preprocess_batch(batch: Dict[str, List], tokenizer: AutoTokenizer, max_length: int = MAX_LENGTH) -> dict:
76+
def preprocess_batch(batch: Dict[str, List], tokenizer: AutoTokenizer, max_length: int) -> dict:
7177
return tokenizer(
7278
batch["text"],
7379
max_length=max_length,
@@ -81,10 +87,10 @@ def load_training_dataset(training_data_id: str = DEFAULT_TRAINING_DATASET, spli
8187
logger.info("Found %d rows", dataset.num_rows)
8288

8389
# Remove empty responses
84-
dataset = dataset.filter(lambda rec: not rec["text"].strip().endswith("### Response:"))
90+
dataset = dataset.filter(lambda rec: not rec["text"].strip().endswith(RESPONSE_KEY))
8591

8692
def _func(rec):
87-
rec["text"] += "\n\n### End"
93+
rec["text"] += f"\n\n{END_KEY}"
8894
return rec
8995

9096
dataset = dataset.map(_func)
@@ -114,15 +120,18 @@ def get_model_tokenizer(
114120
) -> Tuple[AutoModelForCausalLM, PreTrainedTokenizer]:
115121
tokenizer = load_tokenizer(pretrained_model_name_or_path)
116122
model = load_model(pretrained_model_name_or_path, gradient_checkpointing=gradient_checkpointing)
123+
tokenizer.add_special_tokens({"additional_special_tokens": [END_KEY, INSTRUCTION_KEY, RESPONSE_KEY]})
124+
model.resize_token_embeddings(len(tokenizer))
125+
117126
return model, tokenizer
118127

119128

120-
def preprocess_dataset(tokenizer: AutoTokenizer, max_length: int = MAX_LENGTH, seed=DEFAULT_SEED) -> Dataset:
129+
def preprocess_dataset(tokenizer: AutoTokenizer, max_length: int, seed=DEFAULT_SEED) -> Dataset:
121130
"""Loads the training dataset and tokenizes it so it is ready for training.
122131
123132
Args:
124133
tokenizer (AutoTokenizer): Tokenizer tied to the model.
125-
max_length (int, optional): Maximum number of tokens to emit from tokenizer. Defaults to MAX_INPUT_LENGTH.
134+
max_length (int): Maximum number of tokens to emit from tokenizer.
126135
127136
Returns:
128137
Dataset: HuggingFace dataset
@@ -164,7 +173,10 @@ def train(
164173

165174
model, tokenizer = get_model_tokenizer(gradient_checkpointing=gradient_checkpointing)
166175

167-
processed_dataset = preprocess_dataset(tokenizer=tokenizer, seed=seed)
176+
# Use the same max length that the model supports
177+
max_length: int = model.config.n_positions
178+
179+
processed_dataset = preprocess_dataset(tokenizer=tokenizer, max_length=max_length, seed=seed)
168180

169181
split_dataset = processed_dataset.train_test_split(test_size=test_size, seed=seed)
170182

@@ -225,9 +237,7 @@ def train(
225237

226238

227239
@click.command()
228-
@click.option(
229-
"--local-output-dir", type=str, help="Write directly to this local path", required=True
230-
)
240+
@click.option("--local-output-dir", type=str, help="Write directly to this local path", required=True)
231241
@click.option("--dbfs-output-dir", type=str, help="Sync data to this path on DBFS")
232242
@click.option("--epochs", type=int, default=3, help="Number of epochs to train for.")
233243
@click.option("--per-device-train-batch-size", type=int, default=8, help="Batch size to use for training.")

0 commit comments

Comments
 (0)