Skip to content

Commit 3efdfe0

Browse files
committed
Added a base recipe for Llama3.1-405b experiments. Added a recipe for the dataset for Llama3.1-405b.
1 parent 72fd5a8 commit 3efdfe0

File tree

4 files changed

+159
-1
lines changed

4 files changed

+159
-1
lines changed
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import json
2+
3+
from transformers import AutoTokenizer
4+
5+
def get_accuracy_dict(accuracy_dict_full):
6+
accuracy_dict = {}
7+
for k in accuracy_dict_full.keys():
8+
if k in ["rougeL", "exact_match", "tokens_per_sample"]:
9+
accuracy_dict[k] = accuracy_dict_full[k]
10+
return accuracy_dict
11+
12+
def parse_tokens(
13+
tokenised_accuracy_log_path: str, output_log_path: str
14+
):
15+
with open(tokenised_accuracy_log_path) as f:
16+
log = json.load(f)
17+
18+
output_log = []
19+
for item in log:
20+
hex_str = item["data"]
21+
hex_tokens = [hex_str[i : i + 8] for i in range(0, len(hex_str), 8)]
22+
tokens = [
23+
int.from_bytes(bytes.fromhex(tok), byteorder="little") for tok in hex_tokens
24+
]
25+
output_log.append(tokens)
26+
27+
with open(output_log_path, "w") as f:
28+
json.dump(output_log, f, indent=2)
29+
return output_log_path
30+
31+
def detokenise(
32+
checkpoint_path: str, tokenised_accuracy_log_path: str, output_log_path: str
33+
):
34+
tokeniser = AutoTokenizer.from_pretrained(checkpoint_path)
35+
36+
with open(tokenised_accuracy_log_path) as f:
37+
log = json.load(f)
38+
39+
output_log = []
40+
for item in log:
41+
hex_str = item["data"]
42+
hex_tokens = [hex_str[i : i + 8] for i in range(0, len(hex_str), 8)]
43+
tokens = [
44+
int.from_bytes(bytes.fromhex(tok), byteorder="little") for tok in hex_tokens
45+
]
46+
output_log.append({
47+
"seq_id" : item["seq_id"],
48+
"qsl_idx" : item["qsl_idx"],
49+
"data": tokeniser.decode(tokens),
50+
"token_count" : item["token_count"]
51+
})
52+
53+
with open(output_log_path, "w") as f:
54+
json.dump(output_log, f, indent=2)
55+
return output_log_path
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
{
2+
"_parent_entries": [ [ "^", "byname", "base_loadgen_experiment" ] ],
3+
4+
"transformers_query": [ "python_package", "package_name=transformers", ["desired_python_version", ["^", "kernel_python_major_dot_minor"]] ],
5+
6+
"_BEFORE_CODE_LOADING": [ "^^", "execute", [[
7+
[ "get_kernel" ],
8+
[ "byquery", [[ "^^", "get", "transformers_query" ]] ],
9+
[ "use" ]
10+
]] ],
11+
12+
"desired_python_version": "3.8",
13+
14+
"mlperf_inference_git_entry": [ "^", "byquery", "git_repo,repo_name=mlperf_inference_git" ],
15+
16+
"abs_script_path": [ "^^", "execute", [[
17+
[ "get", "mlperf_inference_git_entry" ],
18+
[ "get_path_of", "llama3_1_accuracy_script" ]
19+
]] ],
20+
21+
"accuracy_log_path": ["^^", "get_path", "mlperf_log_accuracy.json"],
22+
23+
"dataset_name": "llrg",
24+
"dataset_query": [ "downloaded", [ "^^", "substitute", "dataset_name=#{dataset_name}#" ]],
25+
"dataset_entry": [ "^", "byquery", [[ "^^", "get", "dataset_query" ]], {}, ["dataset_query"] ],
26+
27+
"dataset_path": [ "^^", "execute", [[
28+
[ "get", "dataset_entry" ],
29+
[ "get_path" ],
30+
[ "__add__", "/mlperf_llama3.1_405b_dataset_8313_processed_fp16_eval.pkl" ]
31+
]] ],
32+
33+
"model_family": "llama3_1",
34+
"model_variant": "405b",
35+
"variant": [ "^^", "get", "model_variant" ],
36+
"checkpoint_path_query": [ "^^", "substitute", "downloaded,hf_tokeniser,model_family=#{model_family}#,variant=#{variant}#" ],
37+
"checkpoint_path": [ "^^", "execute", [[
38+
[ "get_kernel" ],
39+
[ "byquery", [[ "^^", "get", "checkpoint_path_query" ]] ],
40+
[ "get_path" ]
41+
]] ],
42+
43+
"accuracy_log_dtype": "int32",
44+
45+
"extract_accuracy_report": [ "^^", "execute", [[
46+
[ "get_kernel" ],
47+
[ "byname", "python_script" ],
48+
[ "run", [], {
49+
"python_deps": [
50+
[ "AS^IS", "^^", "python_sync_pip_package", "python_package,package_name=protobuf" ],
51+
[ "AS^IS", "^^", "python_sync_pip_package", "python_package,package_name=torch" ],
52+
[ "AS^IS", "^^", "python_sync_pip_package", "python_package,package_name=transformers" ],
53+
[ "AS^IS", "^^", "python_sync_pip_package", "python_package,package_name=nltk" ],
54+
[ "AS^IS", "^^", "python_sync_pip_package", "python_package,package_name=rouge_score" ],
55+
[ "AS^IS", "^^", "python_sync_pip_package", "python_package,package_name=sentencepiece" ],
56+
[ "AS^IS", "^^", "python_sync_pip_package", "python_package,package_name=pillow" ],
57+
[ "AS^IS", "^^", "python_sync_pip_package", "python_package,package_name=evaluate" ]
58+
],
59+
"abs_script_path": ["^^", "get", "abs_script_path"],
60+
"script_extra_params": [ "^^", "substitute", "--mlperf-accuracy-file #{accuracy_log_path}# --dataset-file #{dataset_path}# --dtype #{accuracy_log_dtype}# --checkpoint-path #{checkpoint_path}#" ],
61+
"desired_python_version": ["^", "kernel_python_major_dot_minor"],
62+
"capture_output": true
63+
} ],
64+
0,
65+
[ "func", [ "ufun.rematch", "(\\{.*\\})" ] ],
66+
0,
67+
[ "denumpify_dict" ],
68+
0,
69+
[ "func", "str" ]
70+
]], {} ],
71+
72+
"accuracy_dict_full": [ "^^", "execute", [[
73+
["get", "accuracy_report" ],
74+
0,
75+
[ "func", "eval" ]
76+
]], {} ],
77+
"accuracy_dict": [ "^^", "get_accuracy_dict" ],
78+
"rougeL": [ "^^" , "dig","accuracy_dict.rougeL" ],
79+
"exact_match": [ "^^" , "dig","accuracy_dict.exact_match" ],
80+
"tokens_per_sample": [ "^^" , "dig","accuracy_dict.tokens_per_sample" ],
81+
82+
"accuracy_range_dict": { "rougeL": [ 21.449934, null ], "exact_match": [ 89.232165, null ], "tokens_per_sample": [ 616.212, null ] },
83+
84+
"tokenised_accuracy_log_path": [ "^^", "get_path", "mlperf_log_accuracy.json" ],
85+
"output_log_path": [ "^^", "get_path", "detokenised_mlperf_log.json" ],
86+
87+
"detokenised_log": [ "^^", "detokenise" ]
88+
}

data_axs.json

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@
6464
"model_training_llama2_recipe": "model_training_llama2_recipe",
6565
"dataset_scrolls_gov_report_8k_recipe": "dataset_scrolls_gov_report_8k_recipe",
6666
"rclone_mlc_llama2_config": "rclone_mlc_llama2_config",
67-
"explore_recipe": "explore_recipe"
67+
"explore_recipe": "explore_recipe",
68+
"base_llama3_1_loadgen_experiment": "base_llama3_1_loadgen_experiment",
69+
"dataset_llrg_mlperf_recipe": "dataset_llrg_mlperf_recipe"
6870
},
6971
"repo_name": "axs2mlperf",
7072
"submodules": false
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
{
2+
"_producer_rules": [
3+
[ [ "downloaded", "dataset_name=llrg", "model_name=llama3_1", "variant=405b" ], [["get_kernel"],["byname","downloader"],["download"]], {
4+
"downloading_tool_query": "shell_tool,can_download_url_from_rclone",
5+
"url": "mlc_inference:mlcommons-inference-wg-public/llama3.1_405b",
6+
"downloading_tool_params": {
7+
"rclone_remote_name": "mlc_inference"
8+
},
9+
"newborn_entry_name": "downloaded_mlc_llrg",
10+
"file_path": "mlperf_llama3.1_405b_dataset_8313_processed_fp16_eval.pkl"
11+
}, [] ]
12+
]
13+
}

0 commit comments

Comments
 (0)