Skip to content

Commit 3bcf5ec

Browse files
authored
support regex in xgrammar backend (sgl-project#2983)
1 parent 2c05f81 commit 3bcf5ec

File tree

7 files changed

+200
-8
lines changed

7 files changed

+200
-8
lines changed

docs/backend/openai_api_completions.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@
219219
"SGLang supports two grammar backends:\n",
220220
"\n",
221221
"- [Outlines](https://github.com/dottxt-ai/outlines) (default): Supports JSON schema and regular expression constraints.\n",
222-
"- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema and EBNF constraints.\n",
222+
"- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema, regular expression, and EBNF constraints.\n",
223223
" - XGrammar currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md)\n",
224224
"\n",
225225
"Initialize the XGrammar backend using `--grammar-backend xgrammar` flag\n",

docs/backend/structured_outputs.ipynb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
"SGLang supports two grammar backends:\n",
1717
"\n",
1818
"- [Outlines](https://github.com/dottxt-ai/outlines) (default): Supports JSON schema and regular expression constraints.\n",
19-
"- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema and EBNF constraints and currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md).\n",
19+
"- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema, regular expression, and EBNF constraints.\n",
20+
" - XGrammar currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md)\n",
2021
"\n",
2122
"We suggest using XGrammar whenever possible for its better performance. For more details, see [XGrammar technical overview](https://blog.mlc.ai/2024/11/22/achieving-efficient-flexible-portable-structured-generation-with-xgrammar).\n",
2223
"\n",

docs/references/sampling_params.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ You can specify a JSON schema, regular expression or [EBNF](https://en.wikipedia
189189
SGLang supports two grammar backends:
190190

191191
- [Outlines](https://github.com/dottxt-ai/outlines) (default): Supports JSON schema and regular expression constraints.
192-
- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema and EBNF constraints.
192+
- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema, regular expression, and EBNF constraints.
193193
- XGrammar currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md)
194194

195195
Initialize the XGrammar backend using `--grammar-backend xgrammar` flag

python/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ runtime_common = [
2323
"packaging", "pillow", "prometheus-client>=0.20.0",
2424
"psutil", "pydantic", "python-multipart",
2525
"pyzmq>=25.1.2", "torchao>=0.7.0", "uvicorn", "uvloop",
26-
"xgrammar>=0.1.6"
26+
"xgrammar>=0.1.10"
2727
]
2828
srt = [
2929
"sglang[runtime_common]", "cuda-python",

python/sglang/srt/constrained/xgrammar_backend.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch
2020
from xgrammar import (
2121
CompiledGrammar,
22+
Grammar,
2223
GrammarCompiler,
2324
GrammarMatcher,
2425
TokenizerInfo,
@@ -133,10 +134,13 @@ def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
133134
logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}")
134135
return None
135136
elif key_type == "regex":
136-
logger.warning(
137-
"regex hasn't been supported by xgrammar yet. This is skipped."
138-
)
139-
return None
137+
try:
138+
ctx = self.grammar_compiler.compile_grammar(
139+
Grammar.from_regex(key_string)
140+
)
141+
except RuntimeError as e:
142+
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
143+
return None
140144
else:
141145
raise ValueError(f"Invalid key_type: {key_type}")
142146

test/srt/run_suite.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
"test_openai_server.py",
3232
"test_pytorch_sampling_backend.py",
3333
"test_radix_attention.py",
34+
"test_regex_constrained.py",
3435
"test_release_memory_occupation.py",
3536
"test_request_length_validation.py",
3637
"test_retract_decode.py",

test/srt/test_regex_constrained.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
"""
2+
python3 -m unittest test_regex_constrained.TestRegexConstrained.test_regex_generate_email
3+
python3 -m unittest test_regex_constrained.TestRegexConstrained.test_regex_generate_greeting
4+
"""
5+
6+
import json
7+
import unittest
8+
9+
import requests
10+
11+
from sglang.srt.utils import kill_process_tree
12+
from sglang.test.test_utils import (
13+
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
14+
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
15+
DEFAULT_URL_FOR_TEST,
16+
popen_launch_server,
17+
)
18+
19+
20+
def setup_class(cls, disable_overlap: bool):
21+
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
22+
cls.base_url = DEFAULT_URL_FOR_TEST
23+
24+
other_args = [
25+
"--max-running-requests",
26+
"10",
27+
"--grammar-backend",
28+
"xgrammar",
29+
]
30+
31+
if disable_overlap:
32+
other_args += ["--disable-overlap-schedule"]
33+
34+
cls.process = popen_launch_server(
35+
cls.model,
36+
cls.base_url,
37+
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
38+
other_args=other_args,
39+
)
40+
41+
42+
class TestRegexConstrained(unittest.TestCase):
43+
@classmethod
44+
def setUpClass(cls):
45+
setup_class(cls, disable_overlap=False)
46+
cls.check_jump_forward = False
47+
48+
@classmethod
49+
def tearDownClass(cls):
50+
kill_process_tree(cls.process.pid)
51+
52+
def run_decode(
53+
self,
54+
regex,
55+
prompt,
56+
return_logprob=False,
57+
top_logprobs_num=0,
58+
n=1,
59+
):
60+
response = requests.post(
61+
self.base_url + "/generate",
62+
json={
63+
"text": prompt,
64+
"sampling_params": {
65+
"temperature": 0 if n == 1 else 0.5,
66+
"max_new_tokens": 128,
67+
"n": n,
68+
"regex": regex,
69+
},
70+
"stream": False,
71+
"return_logprob": return_logprob,
72+
"top_logprobs_num": top_logprobs_num,
73+
"logprob_start_len": 0,
74+
},
75+
)
76+
77+
ret = response.json()
78+
print(json.dumps(ret, indent=2))
79+
print("=" * 100)
80+
81+
if not isinstance(ret, list):
82+
self.fail(f"Expected response to be a list, but got {type(ret)}")
83+
84+
for item in ret:
85+
text = item.get("text", "").strip()
86+
if not text:
87+
self.fail("Generated text is empty.")
88+
89+
if not self.regex_match(text, regex):
90+
self.fail(f"Text '{text}' does not match regex pattern.")
91+
92+
def regex_match(self, text, pattern):
93+
import re
94+
95+
return re.match(pattern, text) is not None
96+
97+
def test_regex_generate_email(self):
98+
pattern = r"^user@example\.com$"
99+
prompt = "Generate an email address:"
100+
101+
self.run_decode(
102+
regex=pattern,
103+
prompt=prompt,
104+
n=3,
105+
)
106+
107+
def test_regex_generate_greeting(self):
108+
pattern = r"^(Hello|Hi|Hey)$"
109+
prompt = "Generate a greeting:"
110+
111+
self.run_decode(
112+
regex=pattern,
113+
prompt=prompt,
114+
n=3,
115+
)
116+
117+
def test_regex_generate_number(self):
118+
pattern = r"^\d{3}$"
119+
prompt = "Generate a three-digit number:"
120+
121+
self.run_decode(
122+
regex=pattern,
123+
prompt=prompt,
124+
n=3,
125+
)
126+
127+
def test_regex_generate_phone(self):
128+
pattern = r"^\(\d{3}\) \d{3}-\d{4}$"
129+
prompt = "Generate a phone number:"
130+
131+
self.run_decode(
132+
regex=pattern,
133+
prompt=prompt,
134+
n=3,
135+
)
136+
137+
def test_regex_generate_date(self):
138+
pattern = r"^2024-(0[1-9]|1[0-2])-(0[1-9]|[12]\d|3[01])$"
139+
prompt = "Generate a date in YYYY-MM-DD format:"
140+
141+
self.run_decode(
142+
regex=pattern,
143+
prompt=prompt,
144+
n=3,
145+
)
146+
147+
def test_regex_generate_hex_color(self):
148+
pattern = r"^#[0-9A-F]{6}$"
149+
prompt = "Generate a hex color code:"
150+
151+
self.run_decode(
152+
regex=pattern,
153+
prompt=prompt,
154+
n=3,
155+
)
156+
157+
def test_regex_generate_complex_json(self):
158+
pattern = r'^\{\s*"name"\s*:\s*"[a-zA-Z0-9 ]+"\s*,\s*"age"\s*:\s*[1-9][0-9]*\s*,\s*"city"\s*:\s*"[a-zA-Z0-9 ]+"\s*\}$'
159+
prompt = "Generate a simple JSON with name, age, and city:"
160+
161+
self.run_decode(
162+
regex=pattern,
163+
prompt=prompt,
164+
n=3,
165+
)
166+
167+
def test_regex_generate_custom_log_format(self):
168+
pattern = r"^\[2024-01-01T12:00:00Z\] INFO: System\.process - Operation [a-z]+ successfully$"
169+
prompt = "Generate a log entry:"
170+
171+
self.run_decode(
172+
regex=pattern,
173+
prompt=prompt,
174+
n=3,
175+
)
176+
177+
178+
class TestJumpForward(TestRegexConstrained):
179+
@classmethod
180+
def setUpClass(cls):
181+
setup_class(cls, disable_overlap=True)
182+
cls.check_jump_forward = True
183+
184+
185+
if __name__ == "__main__":
186+
unittest.main()

0 commit comments

Comments
 (0)