Skip to content

Commit b382890

Browse files
committed
[Fix] Add lora tied lm head support (for Qwen2.5, Gemma, etc model need) (sgl-project#18634)
1 parent e326729 commit b382890

File tree

5 files changed

+313
-4
lines changed

5 files changed

+313
-4
lines changed

python/sglang/srt/lora/lora.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,24 @@ def _process_weight(self, name: str, loaded_weight: torch.Tensor):
102102
self.config.target_modules
103103
)
104104

105+
# Remap PEFT "unembed_tokens" key to "lm_head" so the weight is
106+
# recognized and loaded into the correct buffer.
107+
if "unembed_tokens" in name:
108+
name = name.replace("unembed_tokens", "lm_head")
109+
105110
layer_id = get_layer_id(name)
106111
if layer_id is not None:
107112
self.layers[layer_id].weights[name] = loaded_weight.cpu()
108113
elif "embed_tokens" in name or "lm_head" in name:
109-
# Check if this module is declared in target_modules before loading
114+
# Check if this module is declared in target_modules before loading.
115+
# When normalized_target_modules is {"all"} (e.g. target_modules was
116+
# "all-linear"), we allow loading since the server-level
117+
# --lora-target-modules will govern which modules are active.
110118
module_name = "embed_tokens" if "embed_tokens" in name else "lm_head"
111-
if module_name in normalized_target_modules:
119+
if (
120+
"all" in normalized_target_modules
121+
or module_name in normalized_target_modules
122+
):
112123
self.embedding_layers[name] = loaded_weight.cpu()
113124
else:
114125
logger.debug(

python/sglang/srt/lora/lora_manager.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,33 @@ def init_lora_shapes(
387387
)
388388

389389
for lora_id, config in self.configs.items():
390+
# Handle PEFT shorthand strings like "all-linear" or "all".
391+
# These cannot be resolved to concrete module names without
392+
# inspecting the base model, so we require the user to specify
393+
# --lora-target-modules explicitly when such shorthands are used.
394+
if isinstance(config.target_modules, str):
395+
if config.target_modules in ("all-linear", "all"):
396+
if target_modules is not None:
397+
# CLI --lora-target-modules already provided; skip
398+
# per-adapter inference for this adapter.
399+
continue
400+
else:
401+
lora_name = self.lora_refs[lora_id].lora_name
402+
raise ValueError(
403+
f"LoRA adapter '{lora_name}' uses "
404+
f"target_modules='{config.target_modules}' which cannot "
405+
"be resolved automatically. Please explicitly specify "
406+
"--lora-target-modules during server startup. You can "
407+
"specify 'all' to enable all supported module types."
408+
)
409+
else:
410+
raise ValueError(
411+
f"SGLang does not recognize target_modules="
412+
f"'{config.target_modules}'. Please use a list of module "
413+
"name suffixes in the adapter's PEFT config, or explicitly "
414+
"specify --lora-target-modules during server startup."
415+
)
416+
390417
if not isinstance(config.target_modules, list):
391418
raise ValueError(
392419
f"SGLang currently only supports inferring LoRA target modules when a list of "
@@ -541,6 +568,40 @@ def init_lora_modules(self):
541568
self.embed_tokens_module: Optional[BaseLayerWithLoRA] = None
542569
self.lm_head_module: Optional[BaseLayerWithLoRA] = None
543570

571+
# When tie_word_embeddings=True, lm_head is the same Python object as
572+
# embed_tokens. PyTorch's named_modules() deduplicates by object identity,
573+
# so lm_head will not appear as a separate entry in the scan below,
574+
# preventing LoRA from wrapping it. To fix this, we create a new
575+
# ParallelLMHead that shares the same base weight tensor (no extra GPU
576+
# memory) so that named_modules() yields it as an independent module.
577+
if "lm_head" in self.target_modules:
578+
lm_head = getattr(self.base_model, "lm_head", None)
579+
embed_tokens = None
580+
for name, mod in self.base_model.named_modules():
581+
if name.endswith("embed_tokens"):
582+
embed_tokens = mod
583+
break
584+
if (
585+
lm_head is not None
586+
and embed_tokens is not None
587+
and lm_head is embed_tokens
588+
):
589+
logger.info(
590+
"lm_head is tied with embed_tokens. Creating a separate "
591+
"ParallelLMHead that shares the base weight for LoRA support."
592+
)
593+
untied_lm_head = ParallelLMHead(
594+
num_embeddings=embed_tokens.org_vocab_size,
595+
embedding_dim=embed_tokens.embedding_dim,
596+
params_dtype=embed_tokens.weight.dtype,
597+
org_num_embeddings=embed_tokens.org_vocab_size,
598+
)
599+
# Share the base weight tensor — no additional GPU memory.
600+
untied_lm_head.weight = embed_tokens.weight
601+
# Replace the model attribute so named_modules() sees it
602+
# independently.
603+
self.base_model.lm_head = untied_lm_head
604+
544605
for module_name, module in self.base_model.named_modules():
545606
# TODO (lifuhuang): in the future, we should consider generalizing the
546607
# should_apply_lora function to support mapping by full module name instead

python/sglang/srt/lora/mem_pool.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ def _can_support(config: LoRAConfig) -> bool:
115115
if config.lora_added_tokens_size > self.lora_added_tokens_size:
116116
return False
117117
target_module_names = get_normalized_target_modules(config.target_modules)
118+
if "all" in target_module_names:
119+
return True
118120
return target_module_names.issubset(self.target_modules)
119121

120122
if isinstance(config, LoRAConfig):

python/sglang/srt/lora/utils.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from dataclasses import dataclass
22
from enum import Enum
3-
from typing import Iterable, Optional, Set, Tuple
3+
from typing import Iterable, Optional, Set, Tuple, Union
44

55
import torch
66

@@ -98,12 +98,22 @@ def get_hidden_dim(
9898

9999

100100
def get_normalized_target_modules(
101-
target_modules: Iterable[str],
101+
target_modules: Union[str, Iterable[str]],
102102
) -> set[str]:
103103
"""
104104
Mapping a list of target module name to names of the normalized LoRA weights.
105105
Handles both base module names (e.g., "gate_proj") and prefixed module names (e.g., "feed_forward.gate_proj").
106+
107+
Also handles PEFT shorthand strings like "all-linear" or "all" by returning
108+
{"all"} as a sentinel value (the caller should check for "all" and fall
109+
back to the CLI --lora-target-modules to determine the concrete module set).
106110
"""
111+
# Handle PEFT shorthand strings — these cannot be resolved to concrete
112+
# module names without inspecting the base model, so we return {"all"}
113+
# and let the caller fall back to the CLI --lora-target-modules.
114+
if isinstance(target_modules, str):
115+
return {"all"}
116+
107117
params_mapping = {
108118
"q_proj": "qkv_proj",
109119
"k_proj": "qkv_proj",
@@ -116,6 +126,7 @@ def get_normalized_target_modules(
116126
"word_embeddings": "embed_tokens",
117127
"lm_head": "lm_head",
118128
"output": "lm_head",
129+
"unembed_tokens": "lm_head",
119130
}
120131

121132
result = set()
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
# Copyright 2023-2025 SGLang Team
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# ==============================================================================
14+
15+
"""
16+
Test LoRA on models with tied lm_head (tie_word_embeddings=True).
17+
18+
When tie_word_embeddings=True, lm_head shares the same weight tensor as
19+
embed_tokens. PyTorch's named_modules() deduplicates by object identity,
20+
so lm_head won't appear as a separate module. This test validates that
21+
SGLang correctly handles this case by untying lm_head before LoRA wrapping.
22+
23+
The test:
24+
1. Programmatically creates a LoRA adapter with lm_head in target_modules
25+
using PEFT on a model with tie_word_embeddings=True (Qwen/Qwen2.5-0.5B).
26+
2. Compares logprobs between HuggingFace+PEFT and SGLang to ensure numerical
27+
consistency. This implicitly verifies no NaN values are produced and that
28+
LoRA is actually being applied (since HF+PEFT is the trusted reference).
29+
"""
30+
31+
import multiprocessing as mp
32+
import os
33+
import shutil
34+
import tempfile
35+
import unittest
36+
37+
import torch
38+
39+
try:
40+
from peft import LoraConfig, get_peft_model
41+
except ImportError:
42+
import subprocess
43+
44+
subprocess.check_call(["pip", "install", "peft", "--no-deps"])
45+
from peft import LoraConfig, get_peft_model
46+
47+
from transformers import AutoModelForCausalLM
48+
49+
from sglang.test.ci.ci_register import register_cuda_ci
50+
from sglang.test.runners import HFRunner, SRTRunner
51+
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, CustomTestCase
52+
53+
register_cuda_ci(est_time=120, suite="nightly-1-gpu", nightly=True)
54+
55+
# Use a small model with tie_word_embeddings=True
56+
BASE_MODEL = "Qwen/Qwen2.5-0.5B"
57+
58+
TEST_PROMPTS = [
59+
"AI is a field of computer science focused on",
60+
"The capital of France is",
61+
]
62+
63+
MAX_NEW_TOKENS = 16
64+
LOGPROB_THRESHOLD = 2e-1
65+
66+
67+
def create_lora_adapter_with_lm_head(base_model_name: str, output_dir: str):
68+
"""
69+
Programmatically create a LoRA adapter that targets lm_head,
70+
using a model with tie_word_embeddings=True.
71+
72+
The adapter uses randomly initialized LoRA weights (no training).
73+
This is sufficient to test that:
74+
- SGLang can load the adapter without errors
75+
- lm_head LoRA is applied (output differs from base model)
76+
- Logprobs match between HF and SGLang
77+
"""
78+
model = AutoModelForCausalLM.from_pretrained(
79+
base_model_name,
80+
torch_dtype=torch.float16,
81+
device_map="cpu",
82+
)
83+
84+
# Verify the model actually has tied embeddings
85+
assert (
86+
model.config.tie_word_embeddings
87+
), f"Expected tie_word_embeddings=True for {base_model_name}"
88+
89+
# Only target lm_head to isolate the test to the tied-embedding scenario.
90+
lora_config = LoraConfig(
91+
r=8,
92+
lora_alpha=16,
93+
target_modules=["lm_head"],
94+
lora_dropout=0,
95+
bias="none",
96+
task_type="CAUSAL_LM",
97+
)
98+
99+
peft_model = get_peft_model(model, lora_config)
100+
101+
# PEFT initializes lora_B to zeros by default, which makes the adapter
102+
# produce identical output to the base model. Initialize lora_B with
103+
# non-zero random weights so the adapter has a visible effect.
104+
with torch.no_grad():
105+
for name, param in peft_model.named_parameters():
106+
if "lora_B" in name:
107+
torch.nn.init.normal_(param, mean=0.0, std=0.02)
108+
109+
peft_model.save_pretrained(output_dir)
110+
111+
# Verify the saved adapter contains lm_head keys
112+
from safetensors import safe_open
113+
114+
safetensors_path = os.path.join(output_dir, "adapter_model.safetensors")
115+
f = safe_open(safetensors_path, framework="pt")
116+
lm_head_keys = [k for k in f.keys() if "lm_head" in k]
117+
assert (
118+
len(lm_head_keys) > 0
119+
), f"Expected lm_head LoRA weights in adapter, got keys: {sorted(f.keys())}"
120+
121+
print(f"Created LoRA adapter at {output_dir}")
122+
print(f" lm_head keys: {lm_head_keys}")
123+
124+
# Clean up the model to free memory
125+
del peft_model, model
126+
torch.cuda.empty_cache()
127+
128+
129+
class TestLoRATiedLMHead(CustomTestCase):
130+
"""
131+
Test that LoRA works correctly on models with tied lm_head.
132+
"""
133+
134+
_adapter_dir = None
135+
136+
@classmethod
137+
def setUpClass(cls):
138+
"""Create a temporary LoRA adapter with lm_head targeting."""
139+
super().setUpClass()
140+
cls._adapter_dir = tempfile.mkdtemp(prefix="sglang_test_lora_tied_lm_head_")
141+
create_lora_adapter_with_lm_head(BASE_MODEL, cls._adapter_dir)
142+
143+
@classmethod
144+
def tearDownClass(cls):
145+
"""Clean up the temporary adapter directory."""
146+
if cls._adapter_dir and os.path.exists(cls._adapter_dir):
147+
shutil.rmtree(cls._adapter_dir)
148+
super().tearDownClass()
149+
150+
def test_tied_lm_head_lora_hf_sgl_logprob_match(self):
151+
"""
152+
Compare logprobs between HuggingFace+PEFT and SGLang+LoRA
153+
for a tied lm_head adapter, ensuring numerical consistency.
154+
"""
155+
prompts = TEST_PROMPTS[:2]
156+
157+
# Run SGLang with LoRA
158+
with SRTRunner(
159+
BASE_MODEL,
160+
torch_dtype=torch.float16,
161+
model_type="generation",
162+
lora_paths=[self._adapter_dir],
163+
max_loras_per_batch=1,
164+
lora_backend="triton",
165+
lora_target_modules=["lm_head"],
166+
disable_cuda_graph=True,
167+
disable_radix_cache=True,
168+
mem_fraction_static=0.80,
169+
port=DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
170+
) as srt_runner:
171+
srt_outputs = srt_runner.forward(
172+
prompts,
173+
max_new_tokens=MAX_NEW_TOKENS,
174+
lora_paths=[self._adapter_dir] * len(prompts),
175+
)
176+
177+
torch.cuda.empty_cache()
178+
179+
# Run HuggingFace with LoRA (via PEFT)
180+
with HFRunner(
181+
BASE_MODEL,
182+
torch_dtype=torch.float16,
183+
model_type="generation",
184+
) as hf_runner:
185+
hf_outputs = hf_runner.forward(
186+
prompts,
187+
max_new_tokens=MAX_NEW_TOKENS,
188+
lora_paths=[self._adapter_dir] * len(prompts),
189+
)
190+
191+
# Compare prefill logprobs
192+
for i in range(len(prompts)):
193+
srt_logprobs = torch.tensor(srt_outputs.top_input_logprobs[i])
194+
hf_logprobs = torch.tensor(hf_outputs.top_input_logprobs[i])
195+
max_diff = torch.max(torch.abs(srt_logprobs - hf_logprobs)).item()
196+
print(f"Prompt {i} prefill logprob max_diff (SGLang vs HF): {max_diff:.6e}")
197+
self.assertLess(
198+
max_diff,
199+
LOGPROB_THRESHOLD,
200+
f"Prompt {i}: prefill logprob diff {max_diff:.6e} "
201+
f"exceeds threshold {LOGPROB_THRESHOLD:.0e}",
202+
)
203+
204+
# Compare decode logprobs
205+
for i in range(len(prompts)):
206+
srt_logprobs = torch.tensor(srt_outputs.top_output_logprobs[i])
207+
hf_logprobs = torch.tensor(hf_outputs.top_output_logprobs[i])
208+
max_diff = torch.max(torch.abs(srt_logprobs - hf_logprobs)).item()
209+
print(f"Prompt {i} decode logprob max_diff (SGLang vs HF): {max_diff:.6e}")
210+
self.assertLess(
211+
max_diff,
212+
LOGPROB_THRESHOLD,
213+
f"Prompt {i}: decode logprob diff {max_diff:.6e} "
214+
f"exceeds threshold {LOGPROB_THRESHOLD:.0e}",
215+
)
216+
217+
218+
if __name__ == "__main__":
219+
try:
220+
mp.set_start_method("spawn")
221+
except RuntimeError:
222+
pass
223+
224+
unittest.main(warnings="ignore")

0 commit comments

Comments
 (0)