|
| 1 | +# Copyright 2023-2025 SGLang Team |
| 2 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | +# you may not use this file except in compliance with the License. |
| 4 | +# You may obtain a copy of the License at |
| 5 | +# |
| 6 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 7 | +# |
| 8 | +# Unless required by applicable law or agreed to in writing, software |
| 9 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 10 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 11 | +# See the License for the specific language governing permissions and |
| 12 | +# limitations under the License. |
| 13 | +# ============================================================================== |
| 14 | + |
| 15 | +""" |
| 16 | +Test LoRA on models with tied lm_head (tie_word_embeddings=True). |
| 17 | +
|
| 18 | +When tie_word_embeddings=True, lm_head shares the same weight tensor as |
| 19 | +embed_tokens. PyTorch's named_modules() deduplicates by object identity, |
| 20 | +so lm_head won't appear as a separate module. This test validates that |
| 21 | +SGLang correctly handles this case by untying lm_head before LoRA wrapping. |
| 22 | +
|
| 23 | +The test: |
| 24 | +1. Programmatically creates a LoRA adapter with lm_head in target_modules |
| 25 | + using PEFT on a model with tie_word_embeddings=True (Qwen/Qwen2.5-0.5B). |
| 26 | +2. Compares logprobs between HuggingFace+PEFT and SGLang to ensure numerical |
| 27 | + consistency. This implicitly verifies no NaN values are produced and that |
| 28 | + LoRA is actually being applied (since HF+PEFT is the trusted reference). |
| 29 | +""" |
| 30 | + |
| 31 | +import multiprocessing as mp |
| 32 | +import os |
| 33 | +import shutil |
| 34 | +import tempfile |
| 35 | +import unittest |
| 36 | + |
| 37 | +import torch |
| 38 | + |
| 39 | +try: |
| 40 | + from peft import LoraConfig, get_peft_model |
| 41 | +except ImportError: |
| 42 | + import subprocess |
| 43 | + |
| 44 | + subprocess.check_call(["pip", "install", "peft", "--no-deps"]) |
| 45 | + from peft import LoraConfig, get_peft_model |
| 46 | + |
| 47 | +from transformers import AutoModelForCausalLM |
| 48 | + |
| 49 | +from sglang.test.ci.ci_register import register_cuda_ci |
| 50 | +from sglang.test.runners import HFRunner, SRTRunner |
| 51 | +from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, CustomTestCase |
| 52 | + |
| 53 | +register_cuda_ci(est_time=120, suite="nightly-1-gpu", nightly=True) |
| 54 | + |
| 55 | +# Use a small model with tie_word_embeddings=True |
| 56 | +BASE_MODEL = "Qwen/Qwen2.5-0.5B" |
| 57 | + |
| 58 | +TEST_PROMPTS = [ |
| 59 | + "AI is a field of computer science focused on", |
| 60 | + "The capital of France is", |
| 61 | +] |
| 62 | + |
| 63 | +MAX_NEW_TOKENS = 16 |
| 64 | +LOGPROB_THRESHOLD = 2e-1 |
| 65 | + |
| 66 | + |
| 67 | +def create_lora_adapter_with_lm_head(base_model_name: str, output_dir: str): |
| 68 | + """ |
| 69 | + Programmatically create a LoRA adapter that targets lm_head, |
| 70 | + using a model with tie_word_embeddings=True. |
| 71 | +
|
| 72 | + The adapter uses randomly initialized LoRA weights (no training). |
| 73 | + This is sufficient to test that: |
| 74 | + - SGLang can load the adapter without errors |
| 75 | + - lm_head LoRA is applied (output differs from base model) |
| 76 | + - Logprobs match between HF and SGLang |
| 77 | + """ |
| 78 | + model = AutoModelForCausalLM.from_pretrained( |
| 79 | + base_model_name, |
| 80 | + torch_dtype=torch.float16, |
| 81 | + device_map="cpu", |
| 82 | + ) |
| 83 | + |
| 84 | + # Verify the model actually has tied embeddings |
| 85 | + assert ( |
| 86 | + model.config.tie_word_embeddings |
| 87 | + ), f"Expected tie_word_embeddings=True for {base_model_name}" |
| 88 | + |
| 89 | + # Only target lm_head to isolate the test to the tied-embedding scenario. |
| 90 | + lora_config = LoraConfig( |
| 91 | + r=8, |
| 92 | + lora_alpha=16, |
| 93 | + target_modules=["lm_head"], |
| 94 | + lora_dropout=0, |
| 95 | + bias="none", |
| 96 | + task_type="CAUSAL_LM", |
| 97 | + ) |
| 98 | + |
| 99 | + peft_model = get_peft_model(model, lora_config) |
| 100 | + |
| 101 | + # PEFT initializes lora_B to zeros by default, which makes the adapter |
| 102 | + # produce identical output to the base model. Initialize lora_B with |
| 103 | + # non-zero random weights so the adapter has a visible effect. |
| 104 | + with torch.no_grad(): |
| 105 | + for name, param in peft_model.named_parameters(): |
| 106 | + if "lora_B" in name: |
| 107 | + torch.nn.init.normal_(param, mean=0.0, std=0.02) |
| 108 | + |
| 109 | + peft_model.save_pretrained(output_dir) |
| 110 | + |
| 111 | + # Verify the saved adapter contains lm_head keys |
| 112 | + from safetensors import safe_open |
| 113 | + |
| 114 | + safetensors_path = os.path.join(output_dir, "adapter_model.safetensors") |
| 115 | + f = safe_open(safetensors_path, framework="pt") |
| 116 | + lm_head_keys = [k for k in f.keys() if "lm_head" in k] |
| 117 | + assert ( |
| 118 | + len(lm_head_keys) > 0 |
| 119 | + ), f"Expected lm_head LoRA weights in adapter, got keys: {sorted(f.keys())}" |
| 120 | + |
| 121 | + print(f"Created LoRA adapter at {output_dir}") |
| 122 | + print(f" lm_head keys: {lm_head_keys}") |
| 123 | + |
| 124 | + # Clean up the model to free memory |
| 125 | + del peft_model, model |
| 126 | + torch.cuda.empty_cache() |
| 127 | + |
| 128 | + |
| 129 | +class TestLoRATiedLMHead(CustomTestCase): |
| 130 | + """ |
| 131 | + Test that LoRA works correctly on models with tied lm_head. |
| 132 | + """ |
| 133 | + |
| 134 | + _adapter_dir = None |
| 135 | + |
| 136 | + @classmethod |
| 137 | + def setUpClass(cls): |
| 138 | + """Create a temporary LoRA adapter with lm_head targeting.""" |
| 139 | + super().setUpClass() |
| 140 | + cls._adapter_dir = tempfile.mkdtemp(prefix="sglang_test_lora_tied_lm_head_") |
| 141 | + create_lora_adapter_with_lm_head(BASE_MODEL, cls._adapter_dir) |
| 142 | + |
| 143 | + @classmethod |
| 144 | + def tearDownClass(cls): |
| 145 | + """Clean up the temporary adapter directory.""" |
| 146 | + if cls._adapter_dir and os.path.exists(cls._adapter_dir): |
| 147 | + shutil.rmtree(cls._adapter_dir) |
| 148 | + super().tearDownClass() |
| 149 | + |
| 150 | + def test_tied_lm_head_lora_hf_sgl_logprob_match(self): |
| 151 | + """ |
| 152 | + Compare logprobs between HuggingFace+PEFT and SGLang+LoRA |
| 153 | + for a tied lm_head adapter, ensuring numerical consistency. |
| 154 | + """ |
| 155 | + prompts = TEST_PROMPTS[:2] |
| 156 | + |
| 157 | + # Run SGLang with LoRA |
| 158 | + with SRTRunner( |
| 159 | + BASE_MODEL, |
| 160 | + torch_dtype=torch.float16, |
| 161 | + model_type="generation", |
| 162 | + lora_paths=[self._adapter_dir], |
| 163 | + max_loras_per_batch=1, |
| 164 | + lora_backend="triton", |
| 165 | + lora_target_modules=["lm_head"], |
| 166 | + disable_cuda_graph=True, |
| 167 | + disable_radix_cache=True, |
| 168 | + mem_fraction_static=0.80, |
| 169 | + port=DEFAULT_PORT_FOR_SRT_TEST_RUNNER, |
| 170 | + ) as srt_runner: |
| 171 | + srt_outputs = srt_runner.forward( |
| 172 | + prompts, |
| 173 | + max_new_tokens=MAX_NEW_TOKENS, |
| 174 | + lora_paths=[self._adapter_dir] * len(prompts), |
| 175 | + ) |
| 176 | + |
| 177 | + torch.cuda.empty_cache() |
| 178 | + |
| 179 | + # Run HuggingFace with LoRA (via PEFT) |
| 180 | + with HFRunner( |
| 181 | + BASE_MODEL, |
| 182 | + torch_dtype=torch.float16, |
| 183 | + model_type="generation", |
| 184 | + ) as hf_runner: |
| 185 | + hf_outputs = hf_runner.forward( |
| 186 | + prompts, |
| 187 | + max_new_tokens=MAX_NEW_TOKENS, |
| 188 | + lora_paths=[self._adapter_dir] * len(prompts), |
| 189 | + ) |
| 190 | + |
| 191 | + # Compare prefill logprobs |
| 192 | + for i in range(len(prompts)): |
| 193 | + srt_logprobs = torch.tensor(srt_outputs.top_input_logprobs[i]) |
| 194 | + hf_logprobs = torch.tensor(hf_outputs.top_input_logprobs[i]) |
| 195 | + max_diff = torch.max(torch.abs(srt_logprobs - hf_logprobs)).item() |
| 196 | + print(f"Prompt {i} prefill logprob max_diff (SGLang vs HF): {max_diff:.6e}") |
| 197 | + self.assertLess( |
| 198 | + max_diff, |
| 199 | + LOGPROB_THRESHOLD, |
| 200 | + f"Prompt {i}: prefill logprob diff {max_diff:.6e} " |
| 201 | + f"exceeds threshold {LOGPROB_THRESHOLD:.0e}", |
| 202 | + ) |
| 203 | + |
| 204 | + # Compare decode logprobs |
| 205 | + for i in range(len(prompts)): |
| 206 | + srt_logprobs = torch.tensor(srt_outputs.top_output_logprobs[i]) |
| 207 | + hf_logprobs = torch.tensor(hf_outputs.top_output_logprobs[i]) |
| 208 | + max_diff = torch.max(torch.abs(srt_logprobs - hf_logprobs)).item() |
| 209 | + print(f"Prompt {i} decode logprob max_diff (SGLang vs HF): {max_diff:.6e}") |
| 210 | + self.assertLess( |
| 211 | + max_diff, |
| 212 | + LOGPROB_THRESHOLD, |
| 213 | + f"Prompt {i}: decode logprob diff {max_diff:.6e} " |
| 214 | + f"exceeds threshold {LOGPROB_THRESHOLD:.0e}", |
| 215 | + ) |
| 216 | + |
| 217 | + |
| 218 | +if __name__ == "__main__": |
| 219 | + try: |
| 220 | + mp.set_start_method("spawn") |
| 221 | + except RuntimeError: |
| 222 | + pass |
| 223 | + |
| 224 | + unittest.main(warnings="ignore") |
0 commit comments