Skip to content

Commit 7ec5028

Browse files
authored
feat(federated): full thread replies + direct URL fetch in Slack search (onyx-dot-app#9940)
1 parent 5b2ba5c commit 7ec5028

5 files changed

Lines changed: 350 additions & 50 deletions

File tree

backend/onyx/context/search/federated/models.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from dataclasses import dataclass
12
from datetime import datetime
23
from typing import TypedDict
34

@@ -6,6 +7,14 @@
67
from onyx.onyxbot.slack.models import ChannelType
78

89

10+
@dataclass(frozen=True)
11+
class DirectThreadFetch:
12+
"""Request to fetch a Slack thread directly by channel and timestamp."""
13+
14+
channel_id: str
15+
thread_ts: str
16+
17+
918
class ChannelMetadata(TypedDict):
1019
"""Type definition for cached channel metadata."""
1120

backend/onyx/context/search/federated/slack_search.py

Lines changed: 124 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from onyx.connectors.models import IndexingDocument
2020
from onyx.connectors.models import TextSection
2121
from onyx.context.search.federated.models import ChannelMetadata
22+
from onyx.context.search.federated.models import DirectThreadFetch
2223
from onyx.context.search.federated.models import SlackMessage
2324
from onyx.context.search.federated.slack_search_utils import ALL_CHANNEL_TYPES
2425
from onyx.context.search.federated.slack_search_utils import build_channel_query_filter
@@ -49,7 +50,6 @@
4950
from onyx.utils.logger import setup_logger
5051
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
5152
from onyx.utils.timing import log_function_time
52-
from shared_configs.configs import DOC_EMBEDDING_CONTEXT_SIZE
5353

5454
logger = setup_logger()
5555

@@ -58,7 +58,6 @@
5858

5959
CHANNEL_METADATA_CACHE_TTL = 60 * 60 * 24 # 24 hours
6060
USER_PROFILE_CACHE_TTL = 60 * 60 * 24 # 24 hours
61-
SLACK_THREAD_CONTEXT_WINDOW = 3 # Number of messages before matched message to include
6261
CHANNEL_METADATA_MAX_RETRIES = 3 # Maximum retry attempts for channel metadata fetching
6362
CHANNEL_METADATA_RETRY_DELAY = 1 # Initial retry delay in seconds (exponential backoff)
6463

@@ -421,6 +420,94 @@ class SlackQueryResult(BaseModel):
421420
filtered_channels: list[str] # Channels filtered out during this query
422421

423422

423+
def _fetch_thread_from_url(
424+
thread_fetch: DirectThreadFetch,
425+
access_token: str,
426+
channel_metadata_dict: dict[str, ChannelMetadata] | None = None,
427+
) -> SlackQueryResult:
428+
"""Fetch a thread directly from a Slack URL via conversations.replies."""
429+
channel_id = thread_fetch.channel_id
430+
thread_ts = thread_fetch.thread_ts
431+
432+
slack_client = WebClient(token=access_token)
433+
try:
434+
response = slack_client.conversations_replies(
435+
channel=channel_id,
436+
ts=thread_ts,
437+
)
438+
response.validate()
439+
messages: list[dict[str, Any]] = response.get("messages", [])
440+
except SlackApiError as e:
441+
logger.warning(
442+
f"Failed to fetch thread from URL (channel={channel_id}, ts={thread_ts}): {e}"
443+
)
444+
return SlackQueryResult(messages=[], filtered_channels=[])
445+
446+
if not messages:
447+
logger.warning(
448+
f"No messages found for URL override (channel={channel_id}, ts={thread_ts})"
449+
)
450+
return SlackQueryResult(messages=[], filtered_channels=[])
451+
452+
# Build thread text from all messages
453+
thread_text = _build_thread_text(messages, access_token, None, slack_client)
454+
455+
# Get channel name from metadata cache or API
456+
channel_name = "unknown"
457+
if channel_metadata_dict and channel_id in channel_metadata_dict:
458+
channel_name = channel_metadata_dict[channel_id].get("name", "unknown")
459+
else:
460+
try:
461+
ch_response = slack_client.conversations_info(channel=channel_id)
462+
ch_response.validate()
463+
channel_info: dict[str, Any] = ch_response.get("channel", {})
464+
channel_name = channel_info.get("name", "unknown")
465+
except SlackApiError:
466+
pass
467+
468+
# Build the SlackMessage
469+
parent_msg = messages[0]
470+
message_ts = parent_msg.get("ts", thread_ts)
471+
username = parent_msg.get("user", "unknown_user")
472+
parent_text = parent_msg.get("text", "")
473+
snippet = (
474+
parent_text[:50].rstrip() + "..." if len(parent_text) > 50 else parent_text
475+
).replace("\n", " ")
476+
477+
doc_time = datetime.fromtimestamp(float(message_ts))
478+
decay_factor = DOC_TIME_DECAY
479+
doc_age_years = (datetime.now() - doc_time).total_seconds() / (365 * 24 * 60 * 60)
480+
recency_bias = max(1 / (1 + decay_factor * doc_age_years), 0.75)
481+
482+
permalink = (
483+
f"https://slack.com/archives/{channel_id}/p{message_ts.replace('.', '')}"
484+
)
485+
486+
slack_message = SlackMessage(
487+
document_id=f"{channel_id}_{message_ts}",
488+
channel_id=channel_id,
489+
message_id=message_ts,
490+
thread_id=None, # Prevent double-enrichment in thread context fetch
491+
link=permalink,
492+
metadata={
493+
"channel": channel_name,
494+
"time": doc_time.isoformat(),
495+
},
496+
timestamp=doc_time,
497+
recency_bias=recency_bias,
498+
semantic_identifier=f"{username} in #{channel_name}: {snippet}",
499+
text=thread_text,
500+
highlighted_texts=set(),
501+
slack_score=100000.0, # High priority — user explicitly asked for this thread
502+
)
503+
504+
logger.info(
505+
f"URL override: fetched thread from channel={channel_id}, ts={thread_ts}, {len(messages)} messages"
506+
)
507+
508+
return SlackQueryResult(messages=[slack_message], filtered_channels=[])
509+
510+
424511
def query_slack(
425512
query_string: str,
426513
access_token: str,
@@ -432,7 +519,6 @@ def query_slack(
432519
available_channels: list[str] | None = None,
433520
channel_metadata_dict: dict[str, ChannelMetadata] | None = None,
434521
) -> SlackQueryResult:
435-
436522
# Check if query has channel override (user specified channels in query)
437523
has_channel_override = query_string.startswith("__CHANNEL_OVERRIDE__")
438524

@@ -662,7 +748,6 @@ def _fetch_thread_context(
662748
"""
663749
channel_id = message.channel_id
664750
thread_id = message.thread_id
665-
message_id = message.message_id
666751

667752
# If not a thread, return original text as success
668753
if thread_id is None:
@@ -695,62 +780,37 @@ def _fetch_thread_context(
695780
if len(messages) <= 1:
696781
return ThreadContextResult.success(message.text)
697782

698-
# Build thread text from thread starter + context window around matched message
699-
thread_text = _build_thread_text(
700-
messages, message_id, thread_id, access_token, team_id, slack_client
701-
)
783+
# Build thread text from thread starter + all replies
784+
thread_text = _build_thread_text(messages, access_token, team_id, slack_client)
702785
return ThreadContextResult.success(thread_text)
703786

704787

705788
def _build_thread_text(
706789
messages: list[dict[str, Any]],
707-
message_id: str,
708-
thread_id: str,
709790
access_token: str,
710791
team_id: str | None,
711792
slack_client: WebClient,
712793
) -> str:
713-
"""Build the thread text from messages."""
794+
"""Build thread text including all replies.
795+
796+
Includes the thread parent message followed by all replies in order.
797+
"""
714798
msg_text = messages[0].get("text", "")
715799
msg_sender = messages[0].get("user", "")
716800
thread_text = f"<@{msg_sender}>: {msg_text}"
717801

718-
thread_text += "\n\nReplies:"
719-
if thread_id == message_id:
720-
message_id_idx = 0
721-
else:
722-
message_id_idx = next(
723-
(i for i, msg in enumerate(messages) if msg.get("ts") == message_id), 0
724-
)
725-
if not message_id_idx:
726-
return thread_text
727-
728-
start_idx = max(1, message_id_idx - SLACK_THREAD_CONTEXT_WINDOW)
802+
# All messages after index 0 are replies
803+
replies = messages[1:]
804+
if not replies:
805+
return thread_text
729806

730-
if start_idx > 1:
731-
thread_text += "\n..."
732-
733-
for i in range(start_idx, message_id_idx):
734-
msg_text = messages[i].get("text", "")
735-
msg_sender = messages[i].get("user", "")
736-
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
737-
738-
msg_text = messages[message_id_idx].get("text", "")
739-
msg_sender = messages[message_id_idx].get("user", "")
740-
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
807+
logger.debug(f"Thread {messages[0].get('ts')}: {len(replies)} replies included")
808+
thread_text += "\n\nReplies:"
741809

742-
# Add following replies
743-
len_replies = 0
744-
for msg in messages[message_id_idx + 1 :]:
810+
for msg in replies:
745811
msg_text = msg.get("text", "")
746812
msg_sender = msg.get("user", "")
747-
reply = f"\n\n<@{msg_sender}>: {msg_text}"
748-
thread_text += reply
749-
750-
len_replies += len(reply)
751-
if len_replies >= DOC_EMBEDDING_CONTEXT_SIZE * 4:
752-
thread_text += "\n..."
753-
break
813+
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
754814

755815
# Replace user IDs with names using cached lookups
756816
userids: set[str] = set(re.findall(r"<@([A-Z0-9]+)>", thread_text))
@@ -976,7 +1036,16 @@ def slack_retrieval(
9761036

9771037
# Query slack with entity filtering
9781038
llm = get_default_llm()
979-
query_strings = build_slack_queries(query, llm, entities, available_channels)
1039+
query_items = build_slack_queries(query, llm, entities, available_channels)
1040+
1041+
# Partition into direct thread fetches and search query strings
1042+
direct_fetches: list[DirectThreadFetch] = []
1043+
query_strings: list[str] = []
1044+
for item in query_items:
1045+
if isinstance(item, DirectThreadFetch):
1046+
direct_fetches.append(item)
1047+
else:
1048+
query_strings.append(item)
9801049

9811050
# Determine filtering based on entities OR context (bot)
9821051
include_dm = False
@@ -993,8 +1062,16 @@ def slack_retrieval(
9931062
f"Private channel context: will only allow messages from {allowed_private_channel} + public channels"
9941063
)
9951064

996-
# Build search tasks
997-
search_tasks = [
1065+
# Build search tasks — direct thread fetches + keyword searches
1066+
search_tasks: list[tuple] = [
1067+
(
1068+
_fetch_thread_from_url,
1069+
(fetch, access_token, channel_metadata_dict),
1070+
)
1071+
for fetch in direct_fetches
1072+
]
1073+
1074+
search_tasks.extend(
9981075
(
9991076
query_slack,
10001077
(
@@ -1010,7 +1087,7 @@ def slack_retrieval(
10101087
),
10111088
)
10121089
for query_string in query_strings
1013-
]
1090+
)
10141091

10151092
# If include_dm is True AND we're not already searching all channels,
10161093
# add additional searches without channel filters.

backend/onyx/context/search/federated/slack_search_utils.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from onyx.configs.app_configs import MAX_SLACK_QUERY_EXPANSIONS
1212
from onyx.context.search.federated.models import ChannelMetadata
13+
from onyx.context.search.federated.models import DirectThreadFetch
1314
from onyx.context.search.models import ChunkIndexRequest
1415
from onyx.federated_connectors.slack.models import SlackEntities
1516
from onyx.llm.interfaces import LLM
@@ -638,12 +639,38 @@ def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
638639
return [query_text]
639640

640641

642+
SLACK_URL_PATTERN = re.compile(
643+
r"https?://[a-z0-9-]+\.slack\.com/archives/([A-Z0-9]+)/p(\d{16})"
644+
)
645+
646+
647+
def extract_slack_message_urls(
648+
query_text: str,
649+
) -> list[tuple[str, str]]:
650+
"""Extract Slack message URLs from query text.
651+
652+
Parses URLs like:
653+
https://onyx-company.slack.com/archives/C097NBWMY8Y/p1775491616524769
654+
655+
Returns list of (channel_id, thread_ts) tuples.
656+
The 16-digit timestamp is converted to Slack ts format (with dot).
657+
"""
658+
results = []
659+
for match in SLACK_URL_PATTERN.finditer(query_text):
660+
channel_id = match.group(1)
661+
raw_ts = match.group(2)
662+
# Convert p1775491616524769 -> 1775491616.524769
663+
thread_ts = f"{raw_ts[:10]}.{raw_ts[10:]}"
664+
results.append((channel_id, thread_ts))
665+
return results
666+
667+
641668
def build_slack_queries(
642669
query: ChunkIndexRequest,
643670
llm: LLM,
644671
entities: dict[str, Any] | None = None,
645672
available_channels: list[str] | None = None,
646-
) -> list[str]:
673+
) -> list[str | DirectThreadFetch]:
647674
"""Build Slack query strings with date filtering and query expansion."""
648675
default_search_days = 30
649676
if entities:
@@ -668,6 +695,15 @@ def build_slack_queries(
668695
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_back)
669696
time_filter = f" after:{cutoff_date.strftime('%Y-%m-%d')}"
670697

698+
# Check for Slack message URLs — if found, add direct fetch requests
699+
url_fetches: list[DirectThreadFetch] = []
700+
slack_urls = extract_slack_message_urls(query.query)
701+
for channel_id, thread_ts in slack_urls:
702+
url_fetches.append(
703+
DirectThreadFetch(channel_id=channel_id, thread_ts=thread_ts)
704+
)
705+
logger.info(f"Detected Slack URL: channel={channel_id}, ts={thread_ts}")
706+
671707
# ALWAYS extract channel references from the query (not just for recency queries)
672708
channel_references = extract_channel_references_from_query(query.query)
673709

@@ -684,7 +720,9 @@ def build_slack_queries(
684720

685721
# If valid channels detected, use ONLY those channels with NO keywords
686722
# Return query with ONLY time filter + channel filter (no keywords)
687-
return [build_channel_override_query(channel_references, time_filter)]
723+
return url_fetches + [
724+
build_channel_override_query(channel_references, time_filter)
725+
]
688726
except ValueError as e:
689727
# If validation fails, log the error and continue with normal flow
690728
logger.warning(f"Channel reference validation failed: {e}")
@@ -702,7 +740,8 @@ def build_slack_queries(
702740
rephrased_queries = expand_query_with_llm(query.query, llm)
703741

704742
# Build final query strings with time filters
705-
return [
743+
search_queries = [
706744
rephrased_query.strip() + time_filter
707745
for rephrased_query in rephrased_queries[:MAX_SLACK_QUERY_EXPANSIONS]
708746
]
747+
return url_fetches + search_queries

0 commit comments

Comments
 (0)