Skip to content

Commit 98b5013

Browse files
vedantjh2Vedant Jhaveri
andauthored
add support to enable lora with embedding models (sgl-project#17780)
Co-authored-by: Vedant Jhaveri <vjhaveri@linkedin.com>
1 parent 947927b commit 98b5013

File tree

7 files changed

+260
-0
lines changed

7 files changed

+260
-0
lines changed

python/sglang/srt/entrypoints/engine.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,7 @@ def encode(
379379
audio_data: Optional[MultimodalDataInputFormat] = None,
380380
video_data: Optional[MultimodalDataInputFormat] = None,
381381
dimensions: Optional[int] = None,
382+
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None,
382383
external_trace_header: Optional[Dict] = None,
383384
rid: Optional[Union[List[str], str]] = None,
384385
) -> Dict:
@@ -392,6 +393,7 @@ def encode(
392393
audio_data=audio_data,
393394
video_data=video_data,
394395
dimensions=dimensions,
396+
lora_path=lora_path,
395397
external_trace_header=external_trace_header,
396398
rid=rid,
397399
)
@@ -406,6 +408,7 @@ async def async_encode(
406408
audio_data: Optional[MultimodalDataInputFormat] = None,
407409
video_data: Optional[MultimodalDataInputFormat] = None,
408410
dimensions: Optional[int] = None,
411+
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None,
409412
external_trace_header: Optional[Dict] = None,
410413
rid: Optional[Union[List[str], str]] = None,
411414
) -> Dict:
@@ -421,6 +424,7 @@ async def async_encode(
421424
audio_data=audio_data,
422425
video_data=video_data,
423426
dimensions=dimensions,
427+
lora_path=lora_path,
424428
external_trace_header=external_trace_header,
425429
rid=rid,
426430
)

python/sglang/srt/entrypoints/openai/protocol.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -897,6 +897,8 @@ class EmbeddingRequest(BaseModel):
897897
rid: Optional[Union[List[str], str]] = None
898898
# Priority for the request
899899
priority: Optional[int] = None
900+
# LoRA adapter path(s)
901+
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
900902

901903

902904
class EmbeddingObject(BaseModel):

python/sglang/srt/entrypoints/openai/serving_embedding.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,24 @@ def _convert_to_internal_request(
126126
# Other types (should not happen but handle gracefully)
127127
prompt_kwargs = {"input_ids": prompt}
128128

129+
# Resolve LoRA adapter from model parameter or explicit lora_path
130+
lora_path = self._resolve_lora_path(request.model, request.lora_path)
131+
if lora_path:
132+
first_adapter = (
133+
lora_path
134+
if isinstance(lora_path, str)
135+
else next((a for a in lora_path if a), None)
136+
)
137+
if first_adapter:
138+
self._validate_lora_enabled(first_adapter)
139+
129140
adapted_request = EmbeddingReqInput(
130141
**prompt_kwargs,
131142
rid=request.rid,
132143
priority=request.priority,
133144
routing_key=self.extract_routing_key(raw_request),
134145
dimensions=request.dimensions,
146+
lora_path=lora_path,
135147
)
136148

137149
return adapted_request, request

python/sglang/srt/managers/io_struct.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -824,6 +824,11 @@ class EmbeddingReqInput(BaseReq, APIServingTimingMixin):
824824
# The number of dimensions the resulting output embeddings should have. It is applicable for Matryoshka Embeddings.
825825
dimensions: Optional[int] = None
826826

827+
# The path to the LoRA adaptors
828+
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
829+
# The uid of LoRA adaptors, should be initialized by tokenizer manager
830+
lora_id: Optional[Union[List[Optional[str]], Optional[str]]] = None
831+
827832
def normalize_batch_and_arguments(self):
828833
# at least one of text, input_ids, or image should be provided
829834
if self.text is None and self.input_ids is None and self.image_data is None:
@@ -875,6 +880,21 @@ def normalize_batch_and_arguments(self):
875880
for i in range(self.batch_size):
876881
self.sampling_params[i]["max_new_tokens"] = 0
877882

883+
self._normalize_lora_paths(self.batch_size)
884+
885+
def _normalize_lora_paths(self, num):
886+
"""Normalize LoRA paths for batch processing."""
887+
if self.lora_path is not None:
888+
if isinstance(self.lora_path, str):
889+
self.lora_path = [self.lora_path] * num
890+
elif isinstance(self.lora_path, list):
891+
if len(self.lora_path) != num:
892+
raise ValueError(
893+
f"lora_path list length ({len(self.lora_path)}) must match batch size ({num})"
894+
)
895+
else:
896+
raise ValueError("lora_path should be a list or a string.")
897+
878898
def contains_mm_input(self) -> bool:
879899
return (
880900
has_valid_data(self.image_data)
@@ -888,6 +908,8 @@ def __getitem__(self, i):
888908
text=[self.text[i]] if self.text is not None else None,
889909
sampling_params=self.sampling_params[i],
890910
rid=self.rid[i],
911+
lora_path=self.lora_path[i] if self.lora_path is not None else None,
912+
lora_id=self.lora_id[i] if self.lora_id is not None else None,
891913
is_cross_encoder_request=True,
892914
http_worker_ipc=self.http_worker_ipc,
893915
)
@@ -900,6 +922,8 @@ def __getitem__(self, i):
900922
video_data=self.video_data[i] if self.video_data is not None else None,
901923
sampling_params=self.sampling_params[i],
902924
rid=self.rid[i],
925+
lora_path=self.lora_path[i] if self.lora_path is not None else None,
926+
lora_id=self.lora_id[i] if self.lora_id is not None else None,
903927
external_trace_header=self.external_trace_header,
904928
dimensions=self.dimensions,
905929
http_worker_ipc=self.http_worker_ipc,
@@ -928,6 +952,8 @@ class TokenizedEmbeddingReqInput(BaseReq):
928952
priority: Optional[int] = None
929953
# The number of dimensions the resulting output embeddings should have. It is applicable for Matryoshka Embeddings.
930954
dimensions: Optional[int] = None
955+
# LoRA related
956+
lora_id: Optional[str] = None # None means just use the base model
931957

932958

933959
@dataclass

python/sglang/srt/managers/scheduler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1762,6 +1762,7 @@ def handle_embedding_request(
17621762
token_type_ids=recv_req.token_type_ids,
17631763
priority=recv_req.priority,
17641764
dimensions=recv_req.dimensions,
1765+
lora_id=recv_req.lora_id,
17651766
http_worker_ipc=recv_req.http_worker_ipc,
17661767
)
17671768
req.tokenizer = self.tokenizer

python/sglang/srt/managers/tokenizer_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -958,6 +958,7 @@ def _create_tokenized_object(
958958
rid=obj.rid,
959959
priority=obj.priority,
960960
dimensions=obj.dimensions,
961+
lora_id=obj.lora_id,
961962
http_worker_ipc=obj.http_worker_ipc,
962963
)
963964

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
# Copyright 2023-2024 SGLang Team
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# ==============================================================================
14+
"""
15+
Unit tests for LoRA support in embedding models.
16+
17+
Validates that EmbeddingReqInput correctly handles LoRA fields through
18+
normalization, batching, and request splitting.
19+
"""
20+
21+
import multiprocessing as mp
22+
import unittest
23+
24+
import numpy as np
25+
import torch
26+
27+
from sglang.srt.entrypoints.openai.protocol import EmbeddingRequest
28+
from sglang.srt.managers.io_struct import EmbeddingReqInput, TokenizedEmbeddingReqInput
29+
from sglang.srt.sampling.sampling_params import SamplingParams
30+
from sglang.test.ci.ci_register import register_cuda_ci
31+
from sglang.test.runners import SRTRunner
32+
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, CustomTestCase
33+
34+
# Test configuration (same model/LoRA as test_lora_hf_sgl_logprob_diff.py)
35+
MODEL_PATH = "meta-llama/Llama-2-7b-hf"
36+
LORA_PATH = "yushengsu/sglang_lora_logprob_diff_without_tuning"
37+
LORA_BACKEND = "triton"
38+
SIMILARITY_THRESHOLD = 0.9999
39+
40+
register_cuda_ci(
41+
est_time=150,
42+
suite="nightly-1-gpu",
43+
)
44+
45+
46+
class TestEmbeddingLoraSupport(unittest.TestCase):
47+
"""Test LoRA support in embedding request structures."""
48+
49+
def test_embedding_lora_fields(self):
50+
"""Test LoRA fields exist and work correctly across all embedding structures."""
51+
# EmbeddingReqInput: fields exist, normalization expands single to batch, indexing works
52+
req = EmbeddingReqInput(
53+
text=["Hello", "World"], lora_path="my-adapter", lora_id=["id1", "id2"]
54+
)
55+
self.assertIsNotNone(req.lora_path)
56+
req.normalize_batch_and_arguments()
57+
self.assertEqual(req.lora_path, ["my-adapter", "my-adapter"])
58+
self.assertEqual(req[0].lora_path, "my-adapter")
59+
self.assertEqual(req[1].lora_id, "id2")
60+
61+
# EmbeddingReqInput: mismatched list length raises error
62+
req = EmbeddingReqInput(text=["Hello", "World", "Test"], lora_path=["adapter1"])
63+
with self.assertRaises(ValueError):
64+
req.normalize_batch_and_arguments()
65+
66+
# TokenizedEmbeddingReqInput and EmbeddingRequest have lora fields
67+
tokenized = TokenizedEmbeddingReqInput(
68+
input_text="Hello",
69+
input_ids=[1, 2, 3],
70+
image_inputs={},
71+
token_type_ids=[],
72+
sampling_params=SamplingParams(),
73+
lora_id="my-lora-id",
74+
)
75+
self.assertEqual(tokenized.lora_id, "my-lora-id")
76+
self.assertEqual(
77+
EmbeddingRequest(
78+
input="Hello", model="test", lora_path="adapter"
79+
).lora_path,
80+
"adapter",
81+
)
82+
83+
84+
class TestEmbeddingLoraHFComparison(CustomTestCase):
85+
"""Compare HF+LoRA vs SGLang+LoRA embedding outputs."""
86+
87+
@classmethod
88+
def get_hf_embedding_with_lora(cls, model_path, lora_path, texts, torch_dtype):
89+
"""Get embeddings from HuggingFace model with LoRA adapter."""
90+
from peft import PeftModel
91+
from transformers import AutoModelForCausalLM, AutoTokenizer
92+
93+
# Load base model as CausalLM to match adapter's expected structure
94+
base_model = AutoModelForCausalLM.from_pretrained(
95+
model_path,
96+
torch_dtype=torch_dtype,
97+
trust_remote_code=True,
98+
).cuda()
99+
100+
# Load LoRA adapter
101+
model = PeftModel.from_pretrained(base_model, lora_path)
102+
model.eval()
103+
104+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
105+
if tokenizer.pad_token is None:
106+
tokenizer.pad_token = tokenizer.eos_token
107+
108+
with torch.no_grad():
109+
inputs = tokenizer(
110+
texts, padding=True, truncation=True, return_tensors="pt"
111+
).to("cuda")
112+
113+
# Access the inner model (CausalLM wraps the base model)
114+
outputs = model.model(**inputs, output_hidden_states=True)
115+
hidden_states = outputs.hidden_states[-1]
116+
117+
# Last token pooling with L2 normalization (matching SGLang)
118+
attention_mask = inputs["attention_mask"]
119+
last_token_indices = attention_mask.sum(dim=1) - 1
120+
batch_size = hidden_states.shape[0]
121+
embeddings = hidden_states[
122+
torch.arange(batch_size, device="cuda"), last_token_indices
123+
]
124+
embeddings = embeddings / embeddings.norm(dim=1, keepdim=True)
125+
126+
# Cleanup
127+
del model, base_model
128+
torch.cuda.empty_cache()
129+
130+
return embeddings.cpu().numpy()
131+
132+
@classmethod
133+
def get_sglang_embedding_with_lora(cls, model_path, lora_path, texts, torch_dtype):
134+
"""Get embeddings from SGLang with LoRA adapter."""
135+
with SRTRunner(
136+
model_path,
137+
torch_dtype=torch_dtype,
138+
model_type="embedding",
139+
lora_paths=[lora_path],
140+
lora_backend=LORA_BACKEND,
141+
port=DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
142+
trust_remote_code=True,
143+
mem_fraction_static=0.88,
144+
) as runner:
145+
# Call engine.encode directly with lora_path
146+
response = runner.engine.encode(prompt=texts, lora_path=lora_path)
147+
if isinstance(response, list):
148+
embeddings = [r["embedding"] for r in response]
149+
else:
150+
embeddings = [response["embedding"]]
151+
152+
return np.array(embeddings)
153+
154+
@staticmethod
155+
def cosine_similarity(a, b):
156+
"""Compute cosine similarity between vectors."""
157+
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
158+
159+
def test_embedding_lora_hf_sglang_similarity(self):
160+
"""Test that HF+LoRA and SGLang+LoRA produce similar embeddings."""
161+
test_texts = [
162+
"Hello world",
163+
"This is a test sentence for embedding comparison",
164+
]
165+
166+
print(f"\nModel: {MODEL_PATH}")
167+
print(f"LoRA: {LORA_PATH}")
168+
169+
# Get SGLang embeddings first (before HF loads model into GPU)
170+
# This order matches test_lora_hf_sgl_logprob_diff.py and avoids OOM
171+
print("\nGetting SGLang embeddings...")
172+
sglang_embeddings = self.get_sglang_embedding_with_lora(
173+
MODEL_PATH, LORA_PATH, test_texts, torch.float16
174+
)
175+
176+
# Clear GPU memory
177+
torch.cuda.empty_cache()
178+
179+
# Get HF embeddings
180+
print("Getting HF embeddings...")
181+
hf_embeddings = self.get_hf_embedding_with_lora(
182+
MODEL_PATH, LORA_PATH, test_texts, torch.float16
183+
)
184+
185+
# Compare embeddings
186+
print("\nHF vs SGLang LoRA Embedding Comparison:")
187+
similarities = []
188+
for i, (hf_emb, sgl_emb) in enumerate(zip(hf_embeddings, sglang_embeddings)):
189+
sim = self.cosine_similarity(hf_emb, sgl_emb)
190+
similarities.append(sim)
191+
print(f" Text {i}: cosine similarity = {sim:.6f}")
192+
self.assertGreater(
193+
sim,
194+
SIMILARITY_THRESHOLD,
195+
f"Text {i} similarity {sim:.6f} below threshold {SIMILARITY_THRESHOLD}",
196+
)
197+
198+
avg_similarity = np.mean(similarities)
199+
print(f" Average similarity: {avg_similarity:.6f}")
200+
print(f" Threshold: {SIMILARITY_THRESHOLD}")
201+
202+
self.assertGreater(
203+
avg_similarity,
204+
SIMILARITY_THRESHOLD,
205+
f"Average similarity {avg_similarity:.4f} below threshold {SIMILARITY_THRESHOLD}",
206+
)
207+
208+
209+
if __name__ == "__main__":
210+
try:
211+
mp.set_start_method("spawn")
212+
except RuntimeError:
213+
pass
214+
unittest.main()

0 commit comments

Comments
 (0)