Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions astrbot/builtin_stars/builtin_commands/commands/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from astrbot.api.event import AstrMessageEvent, MessageEventResult
from astrbot.core.platform.astr_message_event import MessageSession
from astrbot.core.platform.message_type import MessageType
from astrbot.core.utils.active_event_registry import active_event_registry

from .utils.rst_scene import RstScene

Expand Down Expand Up @@ -62,6 +63,7 @@ async def reset(self, message: AstrMessageEvent) -> None:

agent_runner_type = cfg["provider_settings"]["agent_runner_type"]
if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY:
active_event_registry.stop_all(umo, exclude=message)
await sp.remove_async(
scope="umo",
scope_id=umo,
Expand All @@ -86,6 +88,8 @@ async def reset(self, message: AstrMessageEvent) -> None:
)
return

active_event_registry.stop_all(umo, exclude=message)

await self.context.conversation_manager.update_conversation(
umo,
cid,
Expand Down Expand Up @@ -221,6 +225,7 @@ async def new_conv(self, message: AstrMessageEvent) -> None:
cfg = self.context.get_config(umo=message.unified_msg_origin)
agent_runner_type = cfg["provider_settings"]["agent_runner_type"]
if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY:
active_event_registry.stop_all(message.unified_msg_origin, exclude=message)
await sp.remove_async(
scope="umo",
scope_id=message.unified_msg_origin,
Expand All @@ -229,6 +234,7 @@ async def new_conv(self, message: AstrMessageEvent) -> None:
message.set_result(MessageEventResult().message("已创建新对话。"))
return

active_event_registry.stop_all(message.unified_msg_origin, exclude=message)
cpersona = await self._get_current_persona_id(message.unified_msg_origin)
cid = await self.context.conversation_manager.new_conversation(
message.unified_msg_origin,
Expand Down Expand Up @@ -321,7 +327,8 @@ async def rename_conv(self, message: AstrMessageEvent, new_name: str = "") -> No

async def del_conv(self, message: AstrMessageEvent) -> None:
"""删除当前对话"""
cfg = self.context.get_config(umo=message.unified_msg_origin)
umo = message.unified_msg_origin
cfg = self.context.get_config(umo=umo)
is_unique_session = cfg["platform_settings"]["unique_session"]
if message.get_group_id() and not is_unique_session and message.role != "admin":
# 群聊,没开独立会话,发送人不是管理员
Expand All @@ -334,18 +341,17 @@ async def del_conv(self, message: AstrMessageEvent) -> None:

agent_runner_type = cfg["provider_settings"]["agent_runner_type"]
if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY:
active_event_registry.stop_all(umo, exclude=message)
await sp.remove_async(
scope="umo",
scope_id=message.unified_msg_origin,
scope_id=umo,
key=THIRD_PARTY_AGENT_RUNNER_KEY[agent_runner_type],
)
message.set_result(MessageEventResult().message("重置对话成功。"))
return

session_curr_cid = (
await self.context.conversation_manager.get_curr_conversation_id(
message.unified_msg_origin,
)
await self.context.conversation_manager.get_curr_conversation_id(umo)
)

if not session_curr_cid:
Expand All @@ -356,8 +362,10 @@ async def del_conv(self, message: AstrMessageEvent) -> None:
)
return

active_event_registry.stop_all(umo, exclude=message)

await self.context.conversation_manager.delete_conversation(
message.unified_msg_origin,
umo,
session_curr_cid,
)

Expand Down
15 changes: 10 additions & 5 deletions astrbot/core/pipeline/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from astrbot.core.platform.sources.wecom_ai_bot.wecomai_event import (
WecomAIBotMessageEvent,
)
from astrbot.core.utils.active_event_registry import active_event_registry

from . import STAGES_ORDER
from .context import PipelineContext
Expand Down Expand Up @@ -79,10 +80,14 @@ async def execute(self, event: AstrMessageEvent) -> None:
event (AstrMessageEvent): 事件对象

"""
await self._process_stages(event)
active_event_registry.register(event)
try:
await self._process_stages(event)

# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
if isinstance(event, WebChatMessageEvent | WecomAIBotMessageEvent):
await event.send(None)
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
if isinstance(event, WebChatMessageEvent | WecomAIBotMessageEvent):
await event.send(None)

logger.debug("pipeline 执行完毕。")
logger.debug("pipeline 执行完毕。")
finally:
active_event_registry.unregister(event)
50 changes: 50 additions & 0 deletions astrbot/core/utils/active_event_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from __future__ import annotations

from collections import defaultdict
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from astrbot.core.platform import AstrMessageEvent


class ActiveEventRegistry:
"""维护 unified_msg_origin 到活跃事件的映射。

用于在 reset 等场景下终止该会话正在处理的事件。
"""

def __init__(self) -> None:
self._events: dict[str, set[AstrMessageEvent]] = defaultdict(set)

def register(self, event: AstrMessageEvent) -> None:
self._events[event.unified_msg_origin].add(event)

def unregister(self, event: AstrMessageEvent) -> None:
umo = event.unified_msg_origin
self._events[umo].discard(event)
if not self._events[umo]:
del self._events[umo]

def stop_all(
self,
umo: str,
exclude: AstrMessageEvent | None = None,
) -> int:
"""终止指定 UMO 的所有活跃事件。

Args:
umo: 统一消息来源标识符。
exclude: 需要排除的事件(通常是发起 reset 的事件本身)。

Returns:
被终止的事件数量。
"""
count = 0
for event in list(self._events.get(umo, [])):
if event is not exclude:
event.stop_event()
count += 1
return count


active_event_registry = ActiveEventRegistry()