Skip to content

Commit e950ab3

Browse files
authored
Improve tokenization to work with other tokenizers (databrickslabs#40)
This addresses databrickslabs#4. The tokenizer used by bloom appears to combine the newline after `## Response:` with the following character, which does not happen with GPT-J 6b. This results in the tokens for `### Response:\n` being different when appearing in the text compared to when it is tokenized in isolation. My solution here is to change the key to `### Response:\n" so that this becomes a single token. The other fix is to try getting a different config setting for the max length, or fallback to 1024 if none can be found. I've tested this on [bloomz-7b1-mt](https://huggingface.co/bigscience/bloomz-7b1-mt) and it produces similar generation quality. It also still trains successfully using GPT-J 6B as the base model.
1 parent 8398e3a commit e950ab3

File tree

3 files changed

+18
-15
lines changed

3 files changed

+18
-15
lines changed

training/consts.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
DEFAULT_TRAINING_DATASET = "tatsu-lab/alpaca"
22
DEFAULT_INPUT_MODEL = "EleutherAI/gpt-j-6B"
3-
RESPONSE_KEY = "### Response:"
43
END_KEY = "### End"
54
INSTRUCTION_KEY = "### Instruction:"
6-
RESPONSE_KEY_NL = f"{RESPONSE_KEY}\n"
5+
RESPONSE_KEY_NL = f"### Response:\n"
76
DEFAULT_SEED = 42
87

98
# The format of the instruction the model has been trained on.

training/generate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
PreTrainedTokenizer,
1010
)
1111

12-
from .consts import END_KEY, PROMPT_FORMAT, RESPONSE_KEY
12+
from .consts import END_KEY, PROMPT_FORMAT, RESPONSE_KEY_NL
1313

1414
logger = logging.getLogger(__name__)
1515

@@ -84,7 +84,7 @@ def generate_response(
8484
"""
8585
input_ids = tokenizer(PROMPT_FORMAT.format(instruction=instruction), return_tensors="pt").input_ids.to("cuda")
8686

87-
response_key_token_id = get_special_token_id(tokenizer, RESPONSE_KEY)
87+
response_key_token_id = get_special_token_id(tokenizer, RESPONSE_KEY_NL)
8888
end_key_token_id = get_special_token_id(tokenizer, END_KEY)
8989

9090
gen_tokens = model.generate(

training/trainer.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
DEFAULT_TRAINING_DATASET,
3636
END_KEY,
3737
INSTRUCTION_KEY,
38-
RESPONSE_KEY,
3938
RESPONSE_KEY_NL,
4039
)
4140

@@ -47,7 +46,7 @@ def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> D
4746
batch = super().torch_call(examples)
4847

4948
# The prompt ends with the response key plus a newline. We encode this and then try to find it in the
50-
# sequence of tokens.
49+
# sequence of tokens. This should just be a single token.
5150
response_token_ids = self.tokenizer.encode(RESPONSE_KEY_NL)
5251

5352
labels = batch["labels"].clone()
@@ -56,14 +55,15 @@ def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> D
5655

5756
response_token_ids_start_idx = None
5857
for idx in np.where(batch["labels"][i] == response_token_ids[0])[0]:
59-
if np.array_equal(response_token_ids, batch["labels"][i, idx : idx + len(response_token_ids)]):
60-
response_token_ids_start_idx = idx
61-
break
58+
response_token_ids_start_idx = idx
59+
break
6260

6361
if response_token_ids_start_idx is None:
64-
raise RuntimeError("Could not find response key token IDs")
62+
raise RuntimeError(
63+
f'Could not find response key {response_token_ids} in token IDs {batch["labels"][i]}'
64+
)
6565

66-
response_token_ids_end_idx = response_token_ids_start_idx + len(response_token_ids)
66+
response_token_ids_end_idx = response_token_ids_start_idx + 1
6767

6868
# Make pytorch loss function ignore all tokens up through the end of the response key
6969
labels[i, :response_token_ids_end_idx] = -100
@@ -87,7 +87,8 @@ def load_training_dataset(training_data_id: str = DEFAULT_TRAINING_DATASET, spli
8787
logger.info("Found %d rows", dataset.num_rows)
8888

8989
# Remove empty responses
90-
dataset = dataset.filter(lambda rec: not rec["text"].strip().endswith(RESPONSE_KEY))
90+
response_key_stripped = RESPONSE_KEY_NL.strip()
91+
dataset = dataset.filter(lambda rec: not rec["text"].strip().endswith(response_key_stripped))
9192

9293
def _func(rec):
9394
rec["text"] += f"\n\n{END_KEY}"
@@ -102,6 +103,7 @@ def load_tokenizer(pretrained_model_name_or_path: str = DEFAULT_INPUT_MODEL) ->
102103
logger.info(f"Loading tokenizer for {pretrained_model_name_or_path}")
103104
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
104105
tokenizer.pad_token = tokenizer.eos_token
106+
tokenizer.add_special_tokens({"additional_special_tokens": [END_KEY, INSTRUCTION_KEY, RESPONSE_KEY_NL]})
105107
return tokenizer
106108

107109

@@ -120,7 +122,6 @@ def get_model_tokenizer(
120122
) -> Tuple[AutoModelForCausalLM, PreTrainedTokenizer]:
121123
tokenizer = load_tokenizer(pretrained_model_name_or_path)
122124
model = load_model(pretrained_model_name_or_path, gradient_checkpointing=gradient_checkpointing)
123-
tokenizer.add_special_tokens({"additional_special_tokens": [END_KEY, INSTRUCTION_KEY, RESPONSE_KEY]})
124125
model.resize_token_embeddings(len(tokenizer))
125126

126127
return model, tokenizer
@@ -173,8 +174,11 @@ def train(
173174

174175
model, tokenizer = get_model_tokenizer(gradient_checkpointing=gradient_checkpointing)
175176

176-
# Use the same max length that the model supports
177-
max_length: int = model.config.n_positions
177+
# Use the same max length that the model supports. Try a couple different keys in case a different
178+
# model is used. The default model uses n_positions. If no config settings can be found just default
179+
# to 1024 as this is probably supported by most models.
180+
conf = model.config
181+
max_length: int = getattr(conf, "n_positions", getattr(conf, "seq_lenth", 1024))
178182

179183
processed_dataset = preprocess_dataset(tokenizer=tokenizer, max_length=max_length, seed=seed)
180184

0 commit comments

Comments
 (0)