-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Expand file tree
/
Copy pathtest_skills_middleware.py
More file actions
153 lines (113 loc) · 4.14 KB
/
test_skills_middleware.py
File metadata and controls
153 lines (113 loc) · 4.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""Test script for SkillsMiddlewareV2.
This script demonstrates how to use the new middleware system
to load and use skills in DB-GPT agents.
"""
import asyncio
import os
import sys
sys.path.insert(
0, os.path.join(os.path.dirname(__file__), "../../packages/dbgpt-core/src")
)
from dbgpt.agent.core.agent import AgentContext
from dbgpt.agent.core.profile.base import ProfileConfig
from dbgpt.agent.middleware.agent import AgentConfig, MiddlewareAgent
from dbgpt.agent.skill.middleware_v2 import SkillsMiddlewareV2
async def test_skills_middleware():
"""Test SkillsMiddlewareV2 functionality."""
skills_path = os.path.join(os.path.dirname(__file__), "skills/user")
if not os.path.exists(skills_path):
print(f"Skills directory not found: {skills_path}")
print("Creating test skills...")
return
config = AgentConfig(
enable_middleware=True,
enable_skills=True,
skill_sources=[skills_path],
skill_auto_load=True,
skill_auto_match=True,
skill_inject_to_prompt=True,
)
profile = ProfileConfig(
name="assistant",
role="AI Assistant",
goal="Help users with their tasks using available skills.",
)
agent = MiddlewareAgent(
profile=profile,
agent_config=config,
)
agent_context = AgentContext(
conv_id="test_conv_001",
language="zh-CN",
)
await agent.bind(agent_context).build()
print("\n=== Skills Summary ===")
skills = agent.middleware_manager._middlewares[0] # Get SkillsMiddlewareV2
print(skills.get_skills_summary())
print("\n=== Testing Skill Matching ===")
test_inputs = [
"research quantum computing",
"review my python code",
"analyze this data",
]
for test_input in test_inputs:
print(f"\nInput: {test_input}")
matched = skills.match_skills(test_input)
if matched:
print(f"Matched skills: {[s.metadata.name for s in matched]}")
else:
print("No skills matched")
print("\n=== Test Complete ===")
async def test_custom_middleware():
"""Test custom middleware."""
from dbgpt.agent.middleware.base import AgentMiddleware
class LoggingMiddleware(AgentMiddleware):
"""Custom middleware for logging."""
async def before_generate_reply(self, agent, context, **kwargs):
print(f"[LoggingMiddleware] Before generate reply")
if context and hasattr(context, "message"):
print(f" Message: {context.message.content[:50]}...")
async def after_generate_reply(self, agent, context, reply_message, **kwargs):
print(f"[LoggingMiddleware] After generate reply")
if reply_message:
print(f" Reply: {reply_message.content[:50]}...")
async def modify_system_prompt(self, agent, original_prompt, context=None):
modified = (
f"\n[LoggingMiddleware] Custom prompt section\n\n{original_prompt}"
)
return modified
profile = ProfileConfig(
name="assistant",
role="AI Assistant",
)
config = AgentConfig(
enable_middleware=True,
enable_skills=False, # Disable skills for this test
)
agent = MiddlewareAgent(
profile=profile,
agent_config=config,
)
logging_middleware = LoggingMiddleware()
agent.register_middleware(logging_middleware)
agent_context = AgentContext(
conv_id="test_logging_conv",
language="en",
)
await agent.bind(agent_context).build()
print("\n=== Custom Middleware Test ===")
print("LoggingMiddleware has been registered")
print(f"Total middleware: {len(agent.middleware_manager._middlewares)}")
print("\n=== Test Complete ===")
async def main():
"""Run all tests."""
print("=" * 80)
print("DB-GPT Skills Middleware Test")
print("=" * 80)
print("\n\n### Test 1: SkillsMiddlewareV2 ###")
await test_skills_middleware()
print("\n\n### Test 2: Custom Middleware ###")
await test_custom_middleware()
print("\n\n### All Tests Complete ###")
if __name__ == "__main__":
asyncio.run(main())