|
| 1 | +import os |
| 2 | +import unittest |
| 3 | +from types import SimpleNamespace |
| 4 | + |
| 5 | +import requests |
| 6 | + |
| 7 | +from sglang.srt.utils import kill_process_tree |
| 8 | +from sglang.test.ci.ci_register import register_amd_ci |
| 9 | +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k |
| 10 | +from sglang.test.send_one import BenchArgs, send_one_prompt |
| 11 | +from sglang.test.test_utils import ( |
| 12 | + DEFAULT_URL_FOR_TEST, |
| 13 | + CustomTestCase, |
| 14 | + is_in_ci, |
| 15 | + popen_launch_server, |
| 16 | + write_github_step_summary, |
| 17 | +) |
| 18 | + |
| 19 | +register_amd_ci(est_time=3600, suite="stage-c-test-large-8-gpu-amd-mi35x") |
| 20 | + |
| 21 | +KIMI_K2_MODEL_PATH = "moonshotai/Kimi-K2-Instruct-0905" |
| 22 | +SERVER_LAUNCH_TIMEOUT = 3600 |
| 23 | + |
| 24 | + |
| 25 | +class TestKimiK2Instruct0905(CustomTestCase): |
| 26 | + @classmethod |
| 27 | + def setUpClass(cls): |
| 28 | + cls.model = KIMI_K2_MODEL_PATH |
| 29 | + cls.base_url = DEFAULT_URL_FOR_TEST |
| 30 | + other_args = [ |
| 31 | + "--tp", |
| 32 | + "8", |
| 33 | + "--decode-attention-backend", |
| 34 | + "triton", |
| 35 | + "--prefill-attention-backend", |
| 36 | + "aiter", |
| 37 | + "--trust-remote-code", |
| 38 | + "--model-loader-extra-config", |
| 39 | + '{"enable_multithread_load": true}', |
| 40 | + ] |
| 41 | + env = os.environ.copy() |
| 42 | + env["SGLANG_USE_AITER"] = "1" |
| 43 | + env["SGLANG_ROCM_FUSED_DECODE_MLA"] = "0" |
| 44 | + cls.process = popen_launch_server( |
| 45 | + cls.model, |
| 46 | + cls.base_url, |
| 47 | + timeout=SERVER_LAUNCH_TIMEOUT, |
| 48 | + other_args=other_args, |
| 49 | + env=env, |
| 50 | + ) |
| 51 | + |
| 52 | + @classmethod |
| 53 | + def tearDownClass(cls): |
| 54 | + kill_process_tree(cls.process.pid) |
| 55 | + |
| 56 | + def test_a_gsm8k( |
| 57 | + self, |
| 58 | + ): # Append an "a" to make this test run first (alphabetically) to warm up the server |
| 59 | + requests.get(self.base_url + "/flush_cache") |
| 60 | + |
| 61 | + args = SimpleNamespace( |
| 62 | + num_shots=8, |
| 63 | + data_path=None, |
| 64 | + num_questions=1319, |
| 65 | + parallel=1319, |
| 66 | + max_new_tokens=512, |
| 67 | + host="http://127.0.0.1", |
| 68 | + port=int(self.base_url.split(":")[-1]), |
| 69 | + ) |
| 70 | + metrics = run_eval_few_shot_gsm8k(args) |
| 71 | + print(f"{metrics=}") |
| 72 | + |
| 73 | + if is_in_ci(): |
| 74 | + write_github_step_summary( |
| 75 | + f"### test_gsm8k (Kimi-K2-Instruct-0905)\n" |
| 76 | + f'{metrics["accuracy"]=:.3f}\n' |
| 77 | + ) |
| 78 | + self.assertGreater(metrics["accuracy"], 0.94) |
| 79 | + |
| 80 | + def test_bs_1_speed(self): |
| 81 | + args = BenchArgs(port=int(self.base_url.split(":")[-1]), max_new_tokens=2048) |
| 82 | + _, speed = send_one_prompt(args) |
| 83 | + |
| 84 | + print(f"{speed=:.2f}") |
| 85 | + |
| 86 | + if is_in_ci(): |
| 87 | + write_github_step_summary( |
| 88 | + f"### test_bs_1_speed (Kimi-K2-Instruct-0905)\n" |
| 89 | + f"{speed=:.2f} token/s\n" |
| 90 | + ) |
| 91 | + self.assertGreater(speed, 45) |
| 92 | + |
| 93 | + |
| 94 | +if __name__ == "__main__": |
| 95 | + unittest.main() |
0 commit comments