Skip to content

Commit c0849af

Browse files
authored
feat(context): decrease token in web_search AIMessage (bytedance#827)
This PR addresses token limit issues when web_search is enabled with include_raw_content by implementing a two-pronged approach: changing the default behavior to exclude raw content and adding compression logic for when raw content is included.
1 parent 65cdc18 commit c0849af

File tree

5 files changed

+124
-89
lines changed

5 files changed

+124
-89
lines changed

src/tools/crawl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,13 @@ def crawl_tool(
3737
"error": "PDF files cannot be crawled directly. Please download and view the PDF manually.",
3838
"crawled_content": None,
3939
"is_pdf": True
40-
})
40+
}, ensure_ascii=False)
4141
return pdf_message
4242

4343
try:
4444
crawler = Crawler()
4545
article = crawler.crawl(url)
46-
return json.dumps({"url": url, "crawled_content": article.to_markdown()[:1000]})
46+
return json.dumps({"url": url, "crawled_content": article.to_markdown()[:1000]}, ensure_ascii=False)
4747
except BaseException as e:
4848
error_msg = f"Failed to crawl. Error: {repr(e)}"
4949
logger.error(error_msg)

src/tools/search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def get_web_search_tool(max_search_results: int):
5757
exclude_domains: Optional[List[str]] = search_config.get("exclude_domains", [])
5858
include_answer: bool = search_config.get("include_answer", False)
5959
search_depth: str = search_config.get("search_depth", "advanced")
60-
include_raw_content: bool = search_config.get("include_raw_content", True)
60+
include_raw_content: bool = search_config.get("include_raw_content", False)
6161
include_images: bool = search_config.get("include_images", True)
6262
include_image_descriptions: bool = include_images and search_config.get(
6363
"include_image_descriptions", True

src/utils/context_manager.py

Lines changed: 76 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -188,77 +188,86 @@ def compress_messages(self, state: dict, runtime: Runtime | None = None) -> dict
188188

189189
def _compress_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]:
190190
"""
191-
Compress compressible messages
192-
191+
Compress messages to fit within token limit through two strategies:
192+
1. First, compress web_search ToolMessage raw_content by truncating to 1024 chars
193+
2. If still over limit, drop oldest messages while preserving prefix messages and system messages
194+
193195
Args:
194196
messages: List of messages to compress
195-
196-
Returns:
197-
Compressed message list
198-
"""
199-
200-
available_token = self.token_limit
201-
prefix_messages = []
202-
203-
# 1. Preserve head messages of specified length to retain system prompts and user input
204-
for i in range(min(self.preserve_prefix_message_count, len(messages))):
205-
cur_token_cnt = self._count_message_tokens(messages[i])
206-
if available_token > 0 and available_token >= cur_token_cnt:
207-
prefix_messages.append(messages[i])
208-
available_token -= cur_token_cnt
209-
elif available_token > 0:
210-
# Truncate content to fit available tokens
211-
truncated_message = self._truncate_message_content(
212-
messages[i], available_token
213-
)
214-
prefix_messages.append(truncated_message)
215-
return prefix_messages
216-
else:
217-
break
218-
219-
# 2. Compress subsequent messages from the tail, some messages may be discarded
220-
messages = messages[len(prefix_messages) :]
221-
suffix_messages = []
222-
for i in range(len(messages) - 1, -1, -1):
223-
cur_token_cnt = self._count_message_tokens(messages[i])
224-
225-
if cur_token_cnt > 0 and available_token >= cur_token_cnt:
226-
suffix_messages = [messages[i]] + suffix_messages
227-
available_token -= cur_token_cnt
228-
elif available_token > 0:
229-
# Truncate content to fit available tokens
230-
truncated_message = self._truncate_message_content(
231-
messages[i], available_token
232-
)
233-
suffix_messages = [truncated_message] + suffix_messages
234-
return prefix_messages + suffix_messages
235-
else:
236-
break
237-
238-
return prefix_messages + suffix_messages
239-
240-
def _truncate_message_content(
241-
self, message: BaseMessage, max_tokens: int
242-
) -> BaseMessage:
243-
"""
244-
Truncate message content while preserving all other attributes by copying the original message
245-
and only modifying its content attribute.
246-
247-
Args:
248-
message: The message to truncate
249-
max_tokens: Maximum number of tokens to keep
250-
251197
Returns:
252-
New message instance with truncated content
198+
List of messages with compressed content and/or dropped messages
253199
"""
254-
255-
# Create a deep copy of the original message to preserve all attributes
256-
truncated_message = copy.deepcopy(message)
257-
258-
# Truncate only the content attribute
259-
truncated_message.content = message.content[:max_tokens]
260-
261-
return truncated_message
200+
# Create a deep copy to avoid mutating original messages
201+
compressed = copy.deepcopy(messages)
202+
203+
# Step 1: Compress raw_content in web_search ToolMessages
204+
for msg in compressed:
205+
# Only compress ToolMessage with name 'web_search'
206+
if isinstance(msg, ToolMessage) and getattr(msg, "name", None) == "web_search":
207+
try:
208+
# Determine content type and check if compression is needed
209+
if isinstance(msg.content, str):
210+
# Early exit if content is small enough (avoid JSON parsing overhead)
211+
# A heuristic: if string is less than 2KB, raw_content likely doesn't need truncation
212+
if len(msg.content) < 2048:
213+
continue
214+
215+
try:
216+
content_data = json.loads(msg.content)
217+
except json.JSONDecodeError as e:
218+
logger.error(f"Failed to parse JSON content in web_search ToolMessage: {e}. Content: {msg.content[:200]}")
219+
continue
220+
elif isinstance(msg.content, list):
221+
content_data = copy.deepcopy(msg.content)
222+
else:
223+
continue
224+
225+
# Compress raw_content in the content (item by item processing)
226+
# Track if any modifications were made
227+
modified = False
228+
if isinstance(content_data, list):
229+
for item in content_data:
230+
if isinstance(item, dict) and "raw_content" in item:
231+
raw_content = item.get("raw_content")
232+
if raw_content and isinstance(raw_content, str) and len(raw_content) > 1024:
233+
item["raw_content"] = raw_content[:1024]
234+
modified = True
235+
236+
# Update message content with modified data only if changes were made
237+
if modified:
238+
msg.content = json.dumps(content_data, ensure_ascii=False)
239+
except Exception as e:
240+
logger.error(f"Unexpected error during message compression: {e}")
241+
continue
242+
243+
# Step 2: If still over limit after raw_content compression, drop oldest messages
244+
# while preserving prefix messages (e.g., system message) and recent messages
245+
if self.is_over_limit(compressed):
246+
# Identify messages to preserve at the beginning
247+
preserved_count = self.preserve_prefix_message_count
248+
preserved_messages = compressed[:preserved_count]
249+
remaining_messages = compressed[preserved_count:]
250+
251+
# Drop messages from the middle, keeping the most recent ones
252+
result_messages = preserved_messages
253+
for msg in reversed(remaining_messages):
254+
result_messages.insert(len(preserved_messages), msg)
255+
if not self.is_over_limit(result_messages):
256+
break
257+
258+
compressed = result_messages
259+
260+
# Step 3: Verify that compression was successful and log warning if needed
261+
if self.is_over_limit(compressed):
262+
current_tokens = self.count_tokens(compressed)
263+
logger.warning(
264+
f"Message compression failed to bring tokens below limit: "
265+
f"{current_tokens} > {self.token_limit} tokens. "
266+
f"Total messages: {len(compressed)}. "
267+
f"Consider increasing token_limit or preserve_prefix_message_count."
268+
)
269+
270+
return compressed
262271

263272
def _create_summary_message(self, messages: List[BaseMessage]) -> BaseMessage:
264273
"""

tests/unit/tools/test_search.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def test_get_web_search_tool_tavily(self):
1717
tool = get_web_search_tool(max_search_results=5)
1818
assert tool.name == "web_search"
1919
assert tool.max_results == 5
20-
assert tool.include_raw_content is True
20+
assert tool.include_raw_content is False
2121
assert tool.include_images is True
2222
assert tool.include_image_descriptions is True
2323
assert tool.include_answer is False
@@ -79,7 +79,7 @@ def test_get_web_search_tool_tavily_with_custom_config(self, mock_config):
7979
"SEARCH_ENGINE": {
8080
"include_answer": True,
8181
"search_depth": "basic",
82-
"include_raw_content": False,
82+
"include_raw_content": True,
8383
"include_images": False,
8484
"include_image_descriptions": True,
8585
"include_domains": ["example.com"],
@@ -91,7 +91,7 @@ def test_get_web_search_tool_tavily_with_custom_config(self, mock_config):
9191
assert tool.max_results == 5
9292
assert tool.include_answer is True
9393
assert tool.search_depth == "basic"
94-
assert tool.include_raw_content is False
94+
assert tool.include_raw_content is True
9595
assert tool.include_images is False
9696
# include_image_descriptions should be False because include_images is False
9797
assert tool.include_image_descriptions is False
@@ -108,7 +108,7 @@ def test_get_web_search_tool_tavily_with_empty_config(self, mock_config):
108108
assert tool.max_results == 10
109109
assert tool.include_answer is False
110110
assert tool.search_depth == "advanced"
111-
assert tool.include_raw_content is True
111+
assert tool.include_raw_content is False
112112
assert tool.include_images is True
113113
assert tool.include_image_descriptions is True
114114
assert tool.include_domains == []
@@ -143,7 +143,7 @@ def test_get_web_search_tool_tavily_partial_config(self, mock_config):
143143
tool = get_web_search_tool(max_search_results=3)
144144
assert tool.include_answer is True
145145
assert tool.search_depth == "advanced" # default
146-
assert tool.include_raw_content is True # default
146+
assert tool.include_raw_content is False # default
147147
assert tool.include_domains == ["trusted.com"]
148148
assert tool.exclude_domains == [] # default
149149

@@ -157,7 +157,7 @@ def test_get_web_search_tool_tavily_with_no_config_file(self, mock_config):
157157
assert tool.max_results == 5
158158
assert tool.include_answer is False
159159
assert tool.search_depth == "advanced"
160-
assert tool.include_raw_content is True
160+
assert tool.include_raw_content is False
161161
assert tool.include_images is True
162162

163163
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
@@ -184,7 +184,7 @@ def test_tavily_with_no_search_engine_section(self, mock_config):
184184
assert tool.max_results == 5
185185
assert tool.include_answer is False
186186
assert tool.search_depth == "advanced"
187-
assert tool.include_raw_content is True
187+
assert tool.include_raw_content is False
188188
assert tool.include_images is True
189189
assert tool.include_domains == []
190190
assert tool.exclude_domains == []
@@ -199,7 +199,7 @@ def test_tavily_with_completely_empty_config(self, mock_config):
199199
assert tool.max_results == 5
200200
assert tool.include_answer is False
201201
assert tool.search_depth == "advanced"
202-
assert tool.include_raw_content is True
202+
assert tool.include_raw_content is False
203203
assert tool.include_images is True
204204

205205
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
@@ -210,7 +210,7 @@ def test_tavily_with_only_include_answer_param(self, mock_config):
210210
tool = get_web_search_tool(max_search_results=5)
211211
assert tool.include_answer is True
212212
assert tool.search_depth == "advanced"
213-
assert tool.include_raw_content is True
213+
assert tool.include_raw_content is False
214214
assert tool.include_images is True
215215

216216
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
@@ -221,7 +221,7 @@ def test_tavily_with_only_search_depth_param(self, mock_config):
221221
tool = get_web_search_tool(max_search_results=5)
222222
assert tool.search_depth == "basic"
223223
assert tool.include_answer is False
224-
assert tool.include_raw_content is True
224+
assert tool.include_raw_content is False
225225
assert tool.include_images is True
226226

227227
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
@@ -286,6 +286,6 @@ def test_tavily_all_parameters_optional_mix(self, mock_config):
286286
tool.include_image_descriptions is False
287287
) # should be False since include_images is False
288288
assert tool.search_depth == "advanced" # default
289-
assert tool.include_raw_content is True # default
289+
assert tool.include_raw_content is False # default
290290
assert tool.include_domains == [] # default
291291
assert tool.exclude_domains == [] # default

tests/unit/utils/test_context_manager.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,24 +85,35 @@ def test_compress_messages_when_not_over_limit(self):
8585
# Should return the same messages when not over limit
8686
assert len(compressed["messages"]) == len(messages)
8787

88-
def test_compress_messages_with_system_message(self):
89-
"""Test compress_messages preserves system message"""
88+
def test_compress_messages_with_tool_message(self):
89+
"""Test compress_messages preserves system message and compresses raw_content"""
9090
# Create a context manager with limited token capacity
9191
limited_cm = ContextManager(token_limit=200)
9292

9393
messages = [
9494
SystemMessage(content="You are a helpful assistant."),
9595
HumanMessage(content="Hello"),
9696
AIMessage(content="Hi there!"),
97-
HumanMessage(
98-
content="Can you tell me a very long story that would exceed token limits? "
99-
* 100
100-
),
97+
ToolMessage(
98+
name="web_search",
99+
content='[{"title": "Test Result", "url": "https://example.com", "raw_content": "' + ("This is a test content that should be compressed if it exceeds 1024 characters. " * 2000) + '"}]',
100+
tool_call_id="test_search",
101+
)
101102
]
102103

103104
compressed = limited_cm.compress_messages({"messages": messages})
104105
# Should preserve system message and some recent messages
105-
assert len(compressed["messages"]) == 1
106+
assert len(compressed["messages"]) == 4
107+
108+
# Verify raw_content was compressed to 1024 characters
109+
import json
110+
for msg in compressed["messages"]:
111+
if isinstance(msg, ToolMessage) and getattr(msg, "name", None) == "web_search":
112+
content_data = json.loads(msg.content)
113+
if isinstance(content_data, list):
114+
for item in content_data:
115+
if isinstance(item, dict) and "raw_content" in item:
116+
assert len(item["raw_content"]) == 1024
106117

107118
def test_compress_messages_with_preserve_prefix_message(self):
108119
"""Test compress_messages when no system message is present"""
@@ -201,9 +212,24 @@ def test_compress_messages_with_runtime_when_over_limit(self):
201212
HumanMessage(
202213
content="Can you tell me a very long story that would exceed token limits? " * 100
203214
),
215+
ToolMessage(
216+
name="web_search",
217+
content='[{"title": "Test Result", "url": "https://example.com", "raw_content": "' + ("This is a test content that should be compressed if it exceeds 1024 characters. " * 2000) + '"}]',
218+
tool_call_id="test_search",
219+
)
204220
]
205221
compressed = limited_cm.compress_messages({"messages": messages}, runtime=object())
206222
assert isinstance(compressed, dict)
207223
assert "messages" in compressed
208224
# Should preserve only what fits; with this setup we expect heavy compression
209-
assert len(compressed["messages"]) == 1
225+
assert len(compressed["messages"]) == 5
226+
227+
# Verify raw_content was compressed to 1024 characters
228+
import json
229+
for msg in compressed["messages"]:
230+
if isinstance(msg, ToolMessage) and getattr(msg, "name", None) == "web_search":
231+
content_data = json.loads(msg.content)
232+
if isinstance(content_data, list):
233+
for item in content_data:
234+
if isinstance(item, dict) and "raw_content" in item:
235+
assert len(item["raw_content"]) == 1024

0 commit comments

Comments
 (0)