Skip to content
Prev Previous commit
Next Next commit
run pre-commit
  • Loading branch information
yushengsu-thu committed Feb 11, 2026
commit 418e183aaa8fc7df87a179e0a06435616e464d4c
5 changes: 4 additions & 1 deletion python/sglang/srt/lora/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,10 @@ def _process_weight(self, name: str, loaded_weight: torch.Tensor):
# "all-linear"), we allow loading since the server-level
# --lora-target-modules will govern which modules are active.
module_name = "embed_tokens" if "embed_tokens" in name else "lm_head"
if not normalized_target_modules or module_name in normalized_target_modules:
if (
not normalized_target_modules
or module_name in normalized_target_modules
):
self.embedding_layers[name] = loaded_weight.cpu()
else:
logger.debug(
Expand Down
22 changes: 9 additions & 13 deletions test/registered/lora/test_lora_tied_lm_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@

from sglang.test.ci.ci_register import register_cuda_ci
from sglang.test.runners import HFRunner, SRTRunner
from sglang.test.test_utils import CustomTestCase, DEFAULT_PORT_FOR_SRT_TEST_RUNNER
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, CustomTestCase

register_cuda_ci(est_time=120, suite="nightly-1-gpu", nightly=True)

Expand Down Expand Up @@ -82,9 +82,9 @@ def create_lora_adapter_with_lm_head(base_model_name: str, output_dir: str):
)

# Verify the model actually has tied embeddings
assert model.config.tie_word_embeddings, (
f"Expected tie_word_embeddings=True for {base_model_name}"
)
assert (
model.config.tie_word_embeddings
), f"Expected tie_word_embeddings=True for {base_model_name}"

lora_config = LoraConfig(
r=8,
Expand Down Expand Up @@ -113,9 +113,9 @@ def create_lora_adapter_with_lm_head(base_model_name: str, output_dir: str):
safetensors_path = os.path.join(output_dir, "adapter_model.safetensors")
f = safe_open(safetensors_path, framework="pt")
lm_head_keys = [k for k in f.keys() if "lm_head" in k]
assert len(lm_head_keys) > 0, (
f"Expected lm_head LoRA weights in adapter, got keys: {sorted(f.keys())}"
)
assert (
len(lm_head_keys) > 0
), f"Expected lm_head LoRA weights in adapter, got keys: {sorted(f.keys())}"

print(f"Created LoRA adapter at {output_dir}")
print(f" lm_head keys: {lm_head_keys}")
Expand Down Expand Up @@ -293,9 +293,7 @@ def test_tied_lm_head_lora_hf_sgl_logprob_match(self):
srt_logprobs = torch.tensor(srt_outputs.top_input_logprobs[i])
hf_logprobs = torch.tensor(hf_outputs.top_input_logprobs[i])
max_diff = torch.max(torch.abs(srt_logprobs - hf_logprobs)).item()
print(
f"Prompt {i} prefill logprob max_diff (SGLang vs HF): {max_diff:.6e}"
)
print(f"Prompt {i} prefill logprob max_diff (SGLang vs HF): {max_diff:.6e}")
self.assertLess(
max_diff,
LOGPROB_THRESHOLD,
Expand All @@ -308,9 +306,7 @@ def test_tied_lm_head_lora_hf_sgl_logprob_match(self):
srt_logprobs = torch.tensor(srt_outputs.top_output_logprobs[i])
hf_logprobs = torch.tensor(hf_outputs.top_output_logprobs[i])
max_diff = torch.max(torch.abs(srt_logprobs - hf_logprobs)).item()
print(
f"Prompt {i} decode logprob max_diff (SGLang vs HF): {max_diff:.6e}"
)
print(f"Prompt {i} decode logprob max_diff (SGLang vs HF): {max_diff:.6e}")
self.assertLess(
max_diff,
LOGPROB_THRESHOLD,
Expand Down
Loading