|
| 1 | +# Copyright 2023-2024 SGLang Team |
| 2 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | +# you may not use this file except in compliance with the License. |
| 4 | +# You may obtain a copy of the License at |
| 5 | +# |
| 6 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 7 | +# |
| 8 | +# Unless required by applicable law or agreed to in writing, software |
| 9 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 10 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 11 | +# See the License for the specific language governing permissions and |
| 12 | +# limitations under the License. |
| 13 | +# ============================================================================== |
| 14 | +""" |
| 15 | +Unit tests for LoRA support in embedding models. |
| 16 | +
|
| 17 | +Validates that EmbeddingReqInput correctly handles LoRA fields through |
| 18 | +normalization, batching, and request splitting. |
| 19 | +""" |
| 20 | + |
| 21 | +import multiprocessing as mp |
| 22 | +import unittest |
| 23 | + |
| 24 | +import numpy as np |
| 25 | +import torch |
| 26 | + |
| 27 | +from sglang.srt.entrypoints.openai.protocol import EmbeddingRequest |
| 28 | +from sglang.srt.managers.io_struct import EmbeddingReqInput, TokenizedEmbeddingReqInput |
| 29 | +from sglang.srt.sampling.sampling_params import SamplingParams |
| 30 | +from sglang.test.ci.ci_register import register_cuda_ci |
| 31 | +from sglang.test.runners import SRTRunner |
| 32 | +from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, CustomTestCase |
| 33 | + |
| 34 | +# Test configuration (same model/LoRA as test_lora_hf_sgl_logprob_diff.py) |
| 35 | +MODEL_PATH = "meta-llama/Llama-2-7b-hf" |
| 36 | +LORA_PATH = "yushengsu/sglang_lora_logprob_diff_without_tuning" |
| 37 | +LORA_BACKEND = "triton" |
| 38 | +SIMILARITY_THRESHOLD = 0.9999 |
| 39 | + |
| 40 | +register_cuda_ci( |
| 41 | + est_time=150, |
| 42 | + suite="nightly-1-gpu", |
| 43 | +) |
| 44 | + |
| 45 | + |
| 46 | +class TestEmbeddingLoraSupport(unittest.TestCase): |
| 47 | + """Test LoRA support in embedding request structures.""" |
| 48 | + |
| 49 | + def test_embedding_lora_fields(self): |
| 50 | + """Test LoRA fields exist and work correctly across all embedding structures.""" |
| 51 | + # EmbeddingReqInput: fields exist, normalization expands single to batch, indexing works |
| 52 | + req = EmbeddingReqInput( |
| 53 | + text=["Hello", "World"], lora_path="my-adapter", lora_id=["id1", "id2"] |
| 54 | + ) |
| 55 | + self.assertIsNotNone(req.lora_path) |
| 56 | + req.normalize_batch_and_arguments() |
| 57 | + self.assertEqual(req.lora_path, ["my-adapter", "my-adapter"]) |
| 58 | + self.assertEqual(req[0].lora_path, "my-adapter") |
| 59 | + self.assertEqual(req[1].lora_id, "id2") |
| 60 | + |
| 61 | + # EmbeddingReqInput: mismatched list length raises error |
| 62 | + req = EmbeddingReqInput(text=["Hello", "World", "Test"], lora_path=["adapter1"]) |
| 63 | + with self.assertRaises(ValueError): |
| 64 | + req.normalize_batch_and_arguments() |
| 65 | + |
| 66 | + # TokenizedEmbeddingReqInput and EmbeddingRequest have lora fields |
| 67 | + tokenized = TokenizedEmbeddingReqInput( |
| 68 | + input_text="Hello", |
| 69 | + input_ids=[1, 2, 3], |
| 70 | + image_inputs={}, |
| 71 | + token_type_ids=[], |
| 72 | + sampling_params=SamplingParams(), |
| 73 | + lora_id="my-lora-id", |
| 74 | + ) |
| 75 | + self.assertEqual(tokenized.lora_id, "my-lora-id") |
| 76 | + self.assertEqual( |
| 77 | + EmbeddingRequest( |
| 78 | + input="Hello", model="test", lora_path="adapter" |
| 79 | + ).lora_path, |
| 80 | + "adapter", |
| 81 | + ) |
| 82 | + |
| 83 | + |
| 84 | +class TestEmbeddingLoraHFComparison(CustomTestCase): |
| 85 | + """Compare HF+LoRA vs SGLang+LoRA embedding outputs.""" |
| 86 | + |
| 87 | + @classmethod |
| 88 | + def get_hf_embedding_with_lora(cls, model_path, lora_path, texts, torch_dtype): |
| 89 | + """Get embeddings from HuggingFace model with LoRA adapter.""" |
| 90 | + from peft import PeftModel |
| 91 | + from transformers import AutoModelForCausalLM, AutoTokenizer |
| 92 | + |
| 93 | + # Load base model as CausalLM to match adapter's expected structure |
| 94 | + base_model = AutoModelForCausalLM.from_pretrained( |
| 95 | + model_path, |
| 96 | + torch_dtype=torch_dtype, |
| 97 | + trust_remote_code=True, |
| 98 | + ).cuda() |
| 99 | + |
| 100 | + # Load LoRA adapter |
| 101 | + model = PeftModel.from_pretrained(base_model, lora_path) |
| 102 | + model.eval() |
| 103 | + |
| 104 | + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
| 105 | + if tokenizer.pad_token is None: |
| 106 | + tokenizer.pad_token = tokenizer.eos_token |
| 107 | + |
| 108 | + with torch.no_grad(): |
| 109 | + inputs = tokenizer( |
| 110 | + texts, padding=True, truncation=True, return_tensors="pt" |
| 111 | + ).to("cuda") |
| 112 | + |
| 113 | + # Access the inner model (CausalLM wraps the base model) |
| 114 | + outputs = model.model(**inputs, output_hidden_states=True) |
| 115 | + hidden_states = outputs.hidden_states[-1] |
| 116 | + |
| 117 | + # Last token pooling with L2 normalization (matching SGLang) |
| 118 | + attention_mask = inputs["attention_mask"] |
| 119 | + last_token_indices = attention_mask.sum(dim=1) - 1 |
| 120 | + batch_size = hidden_states.shape[0] |
| 121 | + embeddings = hidden_states[ |
| 122 | + torch.arange(batch_size, device="cuda"), last_token_indices |
| 123 | + ] |
| 124 | + embeddings = embeddings / embeddings.norm(dim=1, keepdim=True) |
| 125 | + |
| 126 | + # Cleanup |
| 127 | + del model, base_model |
| 128 | + torch.cuda.empty_cache() |
| 129 | + |
| 130 | + return embeddings.cpu().numpy() |
| 131 | + |
| 132 | + @classmethod |
| 133 | + def get_sglang_embedding_with_lora(cls, model_path, lora_path, texts, torch_dtype): |
| 134 | + """Get embeddings from SGLang with LoRA adapter.""" |
| 135 | + with SRTRunner( |
| 136 | + model_path, |
| 137 | + torch_dtype=torch_dtype, |
| 138 | + model_type="embedding", |
| 139 | + lora_paths=[lora_path], |
| 140 | + lora_backend=LORA_BACKEND, |
| 141 | + port=DEFAULT_PORT_FOR_SRT_TEST_RUNNER, |
| 142 | + trust_remote_code=True, |
| 143 | + mem_fraction_static=0.88, |
| 144 | + ) as runner: |
| 145 | + # Call engine.encode directly with lora_path |
| 146 | + response = runner.engine.encode(prompt=texts, lora_path=lora_path) |
| 147 | + if isinstance(response, list): |
| 148 | + embeddings = [r["embedding"] for r in response] |
| 149 | + else: |
| 150 | + embeddings = [response["embedding"]] |
| 151 | + |
| 152 | + return np.array(embeddings) |
| 153 | + |
| 154 | + @staticmethod |
| 155 | + def cosine_similarity(a, b): |
| 156 | + """Compute cosine similarity between vectors.""" |
| 157 | + return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) |
| 158 | + |
| 159 | + def test_embedding_lora_hf_sglang_similarity(self): |
| 160 | + """Test that HF+LoRA and SGLang+LoRA produce similar embeddings.""" |
| 161 | + test_texts = [ |
| 162 | + "Hello world", |
| 163 | + "This is a test sentence for embedding comparison", |
| 164 | + ] |
| 165 | + |
| 166 | + print(f"\nModel: {MODEL_PATH}") |
| 167 | + print(f"LoRA: {LORA_PATH}") |
| 168 | + |
| 169 | + # Get SGLang embeddings first (before HF loads model into GPU) |
| 170 | + # This order matches test_lora_hf_sgl_logprob_diff.py and avoids OOM |
| 171 | + print("\nGetting SGLang embeddings...") |
| 172 | + sglang_embeddings = self.get_sglang_embedding_with_lora( |
| 173 | + MODEL_PATH, LORA_PATH, test_texts, torch.float16 |
| 174 | + ) |
| 175 | + |
| 176 | + # Clear GPU memory |
| 177 | + torch.cuda.empty_cache() |
| 178 | + |
| 179 | + # Get HF embeddings |
| 180 | + print("Getting HF embeddings...") |
| 181 | + hf_embeddings = self.get_hf_embedding_with_lora( |
| 182 | + MODEL_PATH, LORA_PATH, test_texts, torch.float16 |
| 183 | + ) |
| 184 | + |
| 185 | + # Compare embeddings |
| 186 | + print("\nHF vs SGLang LoRA Embedding Comparison:") |
| 187 | + similarities = [] |
| 188 | + for i, (hf_emb, sgl_emb) in enumerate(zip(hf_embeddings, sglang_embeddings)): |
| 189 | + sim = self.cosine_similarity(hf_emb, sgl_emb) |
| 190 | + similarities.append(sim) |
| 191 | + print(f" Text {i}: cosine similarity = {sim:.6f}") |
| 192 | + self.assertGreater( |
| 193 | + sim, |
| 194 | + SIMILARITY_THRESHOLD, |
| 195 | + f"Text {i} similarity {sim:.6f} below threshold {SIMILARITY_THRESHOLD}", |
| 196 | + ) |
| 197 | + |
| 198 | + avg_similarity = np.mean(similarities) |
| 199 | + print(f" Average similarity: {avg_similarity:.6f}") |
| 200 | + print(f" Threshold: {SIMILARITY_THRESHOLD}") |
| 201 | + |
| 202 | + self.assertGreater( |
| 203 | + avg_similarity, |
| 204 | + SIMILARITY_THRESHOLD, |
| 205 | + f"Average similarity {avg_similarity:.4f} below threshold {SIMILARITY_THRESHOLD}", |
| 206 | + ) |
| 207 | + |
| 208 | + |
| 209 | +if __name__ == "__main__": |
| 210 | + try: |
| 211 | + mp.set_start_method("spawn") |
| 212 | + except RuntimeError: |
| 213 | + pass |
| 214 | + unittest.main() |
0 commit comments