|
| 1 | +import json |
| 2 | +import logging |
| 3 | +import re |
| 4 | +from typing import List |
| 5 | + |
| 6 | +from sglang.srt.entrypoints.openai.protocol import Tool |
| 7 | +from sglang.srt.function_call.base_format_detector import BaseFormatDetector |
| 8 | +from sglang.srt.function_call.core_types import ( |
| 9 | + StreamingParseResult, |
| 10 | + ToolCallItem, |
| 11 | + _GetInfoFunc, |
| 12 | +) |
| 13 | + |
| 14 | +logger = logging.getLogger(__name__) |
| 15 | + |
| 16 | +REGEX_FUNCTION_CALL = re.compile( |
| 17 | + r"function call<\|role_sep\|>\n(.*)", |
| 18 | + re.DOTALL, |
| 19 | +) |
| 20 | + |
| 21 | +REGEX_CONTENT_PATTERN = re.compile( |
| 22 | + r"^(.*?)<\|message_sep\|>", |
| 23 | + re.DOTALL, |
| 24 | +) |
| 25 | + |
| 26 | +NAME_REGEX = re.compile( |
| 27 | + r'"name"\s*:\s*"([^"]*)"', |
| 28 | + re.DOTALL, |
| 29 | +) |
| 30 | + |
| 31 | +ARGS_REGEX = re.compile( |
| 32 | + r'"arguments"\s*:\s*(.*)', |
| 33 | + re.DOTALL, |
| 34 | +) |
| 35 | + |
| 36 | + |
| 37 | +class GigaChat3Detector(BaseFormatDetector): |
| 38 | + def __init__(self) -> None: |
| 39 | + super().__init__() |
| 40 | + self.tool_started: bool = False |
| 41 | + self.tool_name_sent: bool = False |
| 42 | + self.end_content: bool = False |
| 43 | + self._buffer: str = "" |
| 44 | + self.prev_tool_call_arr: list[dict] = [] |
| 45 | + |
| 46 | + def has_tool_call(self, text: str) -> bool: |
| 47 | + """Check if text contains a tool call marker""" |
| 48 | + return "function call<|role_sep|>\n" in text |
| 49 | + |
| 50 | + def detect_and_parse( |
| 51 | + self, |
| 52 | + text: str, |
| 53 | + tools: List[Tool], |
| 54 | + ) -> StreamingParseResult: |
| 55 | + """ |
| 56 | + Non-streaming parsing of complete model output. |
| 57 | + Extracts tool calls and content from the full text. |
| 58 | + """ |
| 59 | + logger.debug(f"[GigaChat3] detect_and_parse: {text}") |
| 60 | + model_output = text |
| 61 | + function_call = None |
| 62 | + content = None |
| 63 | + if model_output.rstrip().endswith("</s>"): |
| 64 | + model_output = model_output[: model_output.rfind("</s>")] |
| 65 | + m_func = REGEX_FUNCTION_CALL.search(model_output) |
| 66 | + if m_func: |
| 67 | + try: |
| 68 | + function_call = json.loads(m_func.group(1), strict=False) |
| 69 | + if not ( |
| 70 | + isinstance(function_call, dict) |
| 71 | + and "name" in function_call |
| 72 | + and "arguments" in function_call |
| 73 | + ): |
| 74 | + function_call = None |
| 75 | + elif not isinstance(function_call["arguments"], dict): |
| 76 | + function_call = None |
| 77 | + except json.JSONDecodeError as e: |
| 78 | + logger.warning(f"[GigaChat3] JSON decode error: {e}") |
| 79 | + return StreamingParseResult( |
| 80 | + normal_text=model_output, |
| 81 | + calls=[], |
| 82 | + ) |
| 83 | + m_content = REGEX_CONTENT_PATTERN.search(model_output) |
| 84 | + if m_content: |
| 85 | + content = m_content.group(1) |
| 86 | + else: |
| 87 | + if "<|message_sep|>" in model_output: |
| 88 | + content = model_output.split("<|message_sep|>")[0] |
| 89 | + else: |
| 90 | + content = model_output |
| 91 | + if not function_call: |
| 92 | + return StreamingParseResult(normal_text=content, calls=[]) |
| 93 | + name = function_call["name"] |
| 94 | + args = function_call["arguments"] |
| 95 | + match_result = {"name": name, "arguments": args} |
| 96 | + calls = self.parse_base_json(match_result, tools) |
| 97 | + return StreamingParseResult(normal_text=content, calls=calls) |
| 98 | + |
| 99 | + def parse_streaming_increment( |
| 100 | + self, |
| 101 | + new_text: str, |
| 102 | + tools: List[Tool], |
| 103 | + ) -> StreamingParseResult: |
| 104 | + """ |
| 105 | + Streaming parser for incremental text chunks. |
| 106 | + Maintains state across calls to build complete tool calls. |
| 107 | + """ |
| 108 | + if not new_text: |
| 109 | + return StreamingParseResult() |
| 110 | + logger.debug(f"[GigaChat3] parse_streaming_increment: '{new_text}'") |
| 111 | + self._buffer += new_text |
| 112 | + current_text = self._buffer |
| 113 | + delta_text = new_text |
| 114 | + content = None |
| 115 | + func_name = None |
| 116 | + cur_args = None |
| 117 | + m_func = REGEX_FUNCTION_CALL.search(current_text) |
| 118 | + if not self.tool_started: |
| 119 | + m_content = REGEX_CONTENT_PATTERN.search(delta_text) |
| 120 | + if m_content: |
| 121 | + content = m_content.group(1) |
| 122 | + self.end_content = True |
| 123 | + else: |
| 124 | + if "<|message_sep|>" in delta_text: |
| 125 | + content = delta_text.split("<|message_sep|>")[0] |
| 126 | + self.end_content = True |
| 127 | + else: |
| 128 | + if not self.end_content: |
| 129 | + content = delta_text |
| 130 | + if m_func: |
| 131 | + self.tool_started = True |
| 132 | + logger.debug("[GigaChat3] Tool call started") |
| 133 | + if content: |
| 134 | + return StreamingParseResult(normal_text=content) |
| 135 | + if not m_func: |
| 136 | + return StreamingParseResult() |
| 137 | + json_tail = m_func.group(1).strip() |
| 138 | + name_match = NAME_REGEX.search(json_tail) |
| 139 | + if name_match: |
| 140 | + func_name = name_match.group(1) |
| 141 | + args_match = ARGS_REGEX.search(json_tail) |
| 142 | + if args_match: |
| 143 | + cur_args = args_match.group(1).strip() |
| 144 | + if cur_args.endswith("</s>"): |
| 145 | + cur_args = cur_args[: -len("</s>")] |
| 146 | + if cur_args.endswith("}"): |
| 147 | + try: |
| 148 | + candidate = cur_args[:-1].strip() |
| 149 | + json.loads(candidate, strict=False) |
| 150 | + cur_args = candidate |
| 151 | + except json.JSONDecodeError: |
| 152 | + pass |
| 153 | + calls: List[ToolCallItem] = [] |
| 154 | + if not self.prev_tool_call_arr: |
| 155 | + self.prev_tool_call_arr.append({}) |
| 156 | + if not self.tool_name_sent: |
| 157 | + if not func_name: |
| 158 | + return StreamingParseResult() |
| 159 | + self.tool_name_sent = True |
| 160 | + self.prev_tool_call_arr[0]["name"] = func_name |
| 161 | + logger.debug(f"[GigaChat3] Sending tool name: {func_name}") |
| 162 | + calls.append( |
| 163 | + ToolCallItem( |
| 164 | + tool_index=0, |
| 165 | + name=func_name, |
| 166 | + parameters="", |
| 167 | + ) |
| 168 | + ) |
| 169 | + return StreamingParseResult(calls=calls) |
| 170 | + if cur_args is None: |
| 171 | + return StreamingParseResult() |
| 172 | + prev_args = self.prev_tool_call_arr[0].get("arguments_str", "") |
| 173 | + if not prev_args: |
| 174 | + delta_args = cur_args |
| 175 | + elif cur_args.startswith(prev_args): |
| 176 | + delta_args = cur_args[len(prev_args) :] |
| 177 | + else: |
| 178 | + logger.warning( |
| 179 | + f"[GigaChat3] Arguments overlap mismatch. " |
| 180 | + f"prev='{prev_args[:50]}...' cur='{cur_args[:50]}...'" |
| 181 | + ) |
| 182 | + return StreamingParseResult() |
| 183 | + if not delta_args: |
| 184 | + return StreamingParseResult() |
| 185 | + self.prev_tool_call_arr[0]["arguments_str"] = cur_args |
| 186 | + try: |
| 187 | + args_dict = json.loads(cur_args, strict=False) |
| 188 | + self.prev_tool_call_arr[0]["arguments"] = args_dict |
| 189 | + except json.JSONDecodeError: |
| 190 | + self.prev_tool_call_arr[0]["arguments"] = {} |
| 191 | + logger.debug(f"[GigaChat3] Sending args delta: '{delta_args[:100]}...'") |
| 192 | + calls.append( |
| 193 | + ToolCallItem( |
| 194 | + tool_index=0, |
| 195 | + name=None, |
| 196 | + parameters=delta_args, |
| 197 | + ) |
| 198 | + ) |
| 199 | + return StreamingParseResult(calls=calls) |
| 200 | + |
| 201 | + def supports_structural_tag(self) -> bool: |
| 202 | + """GigaChat3 does not use structural tags""" |
| 203 | + return False |
| 204 | + |
| 205 | + def structure_info(self) -> _GetInfoFunc: |
| 206 | + """Not applicable for GigaChat3""" |
| 207 | + raise NotImplementedError( |
| 208 | + "GigaChat3Detector does not support structural_tag format." |
| 209 | + ) |
0 commit comments