forked from ZhuLinsen/daily_stock_analysis
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_config_validate_structured.py
More file actions
341 lines (287 loc) · 13.6 KB
/
test_config_validate_structured.py
File metadata and controls
341 lines (287 loc) · 13.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
# -*- coding: utf-8 -*-
"""Tests for Config.validate_structured() and backward-compatible validate().
Covers:
- ConfigIssue dataclass basics
- validate_structured() severity classifications
- LLM availability check honours all three config tiers (YAML / channels /
legacy keys) via llm_model_list
- validate() backward-compat: still returns List[str] with the same messages
"""
import pytest
from unittest.mock import patch
from src.config import Config, ConfigIssue
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_config(**kwargs) -> Config:
"""Build a minimal Config object with sensible defaults for testing.
Any keyword argument overrides the corresponding dataclass field so tests
only have to specify the fields that matter for their scenario.
"""
defaults = dict(
stock_list=["600519"],
tushare_token=None,
# Populate llm_model_list as the three-tier signal
llm_model_list=[{"model_name": "gemini/gemini-2.0-flash", "litellm_params": {"api_key": "sk-test"}}],
litellm_model="gemini/gemini-2.0-flash",
gemini_api_keys=[],
anthropic_api_keys=[],
openai_api_keys=[],
deepseek_api_keys=[],
bocha_api_keys=[],
tavily_api_keys=[],
brave_api_keys=[],
serpapi_keys=[],
wechat_webhook_url="https://example.com/webhook",
feishu_webhook_url=None,
telegram_bot_token=None,
telegram_chat_id=None,
email_sender=None,
email_password=None,
pushover_user_key=None,
pushover_api_token=None,
pushplus_token=None,
serverchan3_sendkey=None,
custom_webhook_urls=[],
discord_bot_token=None,
discord_main_channel_id=None,
discord_webhook_url=None,
llm_channels=[],
litellm_config_path=None,
gemini_api_key=None,
anthropic_api_key=None,
openai_api_key=None,
openai_base_url=None,
openai_vision_model=None,
)
defaults.update(kwargs)
return Config(**defaults)
def _severities(issues):
return [i.severity for i in issues]
def _fields(issues):
return [i.field for i in issues]
# ---------------------------------------------------------------------------
# ConfigIssue basics
# ---------------------------------------------------------------------------
class TestConfigIssue:
def test_str_equals_message(self):
issue = ConfigIssue(severity="error", message="something went wrong", field="FOO")
assert str(issue) == "something went wrong"
def test_severity_values(self):
for sev in ("error", "warning", "info"):
issue = ConfigIssue(severity=sev, message="test", field="F")
assert issue.severity == sev
def test_default_field(self):
issue = ConfigIssue(severity="info", message="hello")
assert issue.field == ""
# ---------------------------------------------------------------------------
# validate_structured() — happy path (all good)
# ---------------------------------------------------------------------------
class TestValidateStructuredHappyPath:
def test_no_issues_when_fully_configured(self):
cfg = _make_config()
issues = cfg.validate_structured()
# No errors or warnings; only possible info about tushare / search
errors = [i for i in issues if i.severity == "error"]
warnings = [i for i in issues if i.severity == "warning"]
assert errors == []
assert warnings == []
# ---------------------------------------------------------------------------
# validate_structured() — stock list
# ---------------------------------------------------------------------------
class TestValidateStructuredStockList:
def test_empty_stock_list_is_error(self):
cfg = _make_config(stock_list=[])
issues = cfg.validate_structured()
errors = [i for i in issues if i.severity == "error"]
assert any("STOCK_LIST" in i.field for i in errors)
def test_configured_stock_list_no_stock_error(self):
cfg = _make_config(stock_list=["600519", "000001"])
issues = cfg.validate_structured()
assert not any(i.field == "STOCK_LIST" for i in issues if i.severity == "error")
# ---------------------------------------------------------------------------
# validate_structured() — LLM availability (three-tier check)
# ---------------------------------------------------------------------------
class TestValidateStructuredLLM:
def test_no_llm_is_error(self):
"""Empty llm_model_list must produce an error regardless of legacy keys."""
cfg = _make_config(llm_model_list=[])
issues = cfg.validate_structured()
assert any(i.severity == "error" and "LLM" in i.message for i in issues)
def test_llm_channels_only_no_error(self):
"""LLM_CHANNELS populated via llm_model_list must NOT trigger an error.
This is the primary regression guard: a user who only configures
LLM_CHANNELS (no legacy *_API_KEY) should not see 'AI 功能不可用'.
"""
channel_model_list = [
{"model_name": "openai/gpt-4o-mini", "litellm_params": {"api_key": "sk-chan", "api_base": "https://aihubmix.com/v1"}},
]
cfg = _make_config(
llm_model_list=channel_model_list,
gemini_api_keys=[],
anthropic_api_keys=[],
openai_api_keys=[],
deepseek_api_keys=[],
)
issues = cfg.validate_structured()
assert not any(i.severity == "error" and "LLM" in i.message for i in issues)
def test_yaml_config_only_no_error(self):
"""LITELLM_CONFIG (YAML) path: populated llm_model_list = no error."""
yaml_model_list = [
{"model_name": "gemini/gemini-2.5-flash", "litellm_params": {"api_key": "sk-yaml"}},
]
cfg = _make_config(
llm_model_list=yaml_model_list,
litellm_config_path="/tmp/litellm.yaml",
gemini_api_keys=[],
anthropic_api_keys=[],
openai_api_keys=[],
)
issues = cfg.validate_structured()
assert not any(i.severity == "error" and "LLM" in i.message for i in issues)
def test_legacy_gemini_key_no_error(self):
"""Legacy GEMINI_API_KEY path: llm_model_list populated = no error."""
model_list = [
{"model_name": "__legacy_gemini__", "litellm_params": {"model": "__legacy_gemini__", "api_key": "sk-gem"}},
]
cfg = _make_config(llm_model_list=model_list, gemini_api_keys=["sk-gem"])
issues = cfg.validate_structured()
assert not any(i.severity == "error" and "LLM" in i.message for i in issues)
def test_deepseek_only_no_error(self):
"""DEEPSEEK_API_KEY path (was missing in old validate()): no error."""
model_list = [
{"model_name": "__legacy_deepseek__", "litellm_params": {"model": "__legacy_deepseek__", "api_key": "sk-ds"}},
]
cfg = _make_config(
llm_model_list=model_list,
deepseek_api_keys=["sk-ds"],
gemini_api_keys=[],
anthropic_api_keys=[],
openai_api_keys=[],
)
issues = cfg.validate_structured()
assert not any(i.severity == "error" and "LLM" in i.message for i in issues)
def test_missing_litellm_model_is_info_not_error(self):
"""llm_model_list present but litellm_model unset = info, not error."""
cfg = _make_config(litellm_model="")
issues = cfg.validate_structured()
llm_issues = [i for i in issues if "LITELLM_MODEL" in i.field]
assert llm_issues, "Expected an info issue about LITELLM_MODEL"
assert all(i.severity == "info" for i in llm_issues)
# ---------------------------------------------------------------------------
# validate_structured() — notification & search
# ---------------------------------------------------------------------------
class TestValidateStructuredNotification:
def test_no_notification_is_warning(self):
cfg = _make_config(wechat_webhook_url=None)
issues = cfg.validate_structured()
warn = [i for i in issues if i.severity == "warning"]
assert any("通知渠道" in i.message for i in warn)
def test_notification_configured_no_warning(self):
cfg = _make_config(wechat_webhook_url="https://example.com/wh")
issues = cfg.validate_structured()
assert not any(i.severity == "warning" and "通知渠道" in i.message for i in issues)
def test_no_search_engine_is_info(self):
cfg = _make_config()
issues = cfg.validate_structured()
info = [i for i in issues if i.severity == "info"]
assert any("搜索引擎" in i.message for i in info)
# ---------------------------------------------------------------------------
# Deprecated field migration hints
# ---------------------------------------------------------------------------
class TestDeprecatedFieldHints:
def test_openai_vision_model_deprecation_when_env_set(self):
"""When OPENAI_VISION_MODEL is in env, validate_structured reports deprecation hint."""
cfg = _make_config()
with patch.dict("os.environ", {"OPENAI_VISION_MODEL": "openai/gpt-4o"}, clear=False):
issues = cfg.validate_structured()
deprec = [i for i in issues if i.field == "OPENAI_VISION_MODEL"]
assert deprec, "Expected deprecation hint when OPENAI_VISION_MODEL is set"
assert deprec[0].severity == "info"
assert "VISION_MODEL" in deprec[0].message
def test_no_deprecation_when_openai_vision_model_not_in_env(self):
"""When OPENAI_VISION_MODEL is not in env, no deprecation hint."""
import os
cfg = _make_config()
real_getenv = os.getenv
def mock_getenv(key, default=None):
if key == "OPENAI_VISION_MODEL":
return None
return real_getenv(key, default)
with patch("src.config.os.getenv", side_effect=mock_getenv):
issues = cfg.validate_structured()
deprec = [i for i in issues if i.field == "OPENAI_VISION_MODEL"]
assert not deprec, "Should not report deprecation when OPENAI_VISION_MODEL is unset"
# ---------------------------------------------------------------------------
# Vision key validation
# ---------------------------------------------------------------------------
class TestVisionKeyValidation:
def test_vision_model_set_no_key_is_warning(self):
cfg = _make_config(
vision_model="gemini/gemini-2.0-flash",
gemini_api_keys=[],
anthropic_api_keys=[],
openai_api_keys=[],
deepseek_api_keys=[],
)
issues = cfg.validate_structured()
warn = [i for i in issues if i.field == "VISION_MODEL"]
assert warn and warn[0].severity == "warning"
def test_vision_model_set_with_key_no_warning(self):
cfg = _make_config(
vision_model="gemini/gemini-2.0-flash",
gemini_api_keys=["sk-gemini-testkey-1234"],
)
issues = cfg.validate_structured()
assert not any(
i.field == "VISION_MODEL" and i.severity == "warning" for i in issues
)
def test_vision_model_set_with_short_key_still_warns(self):
"""Short keys (len < 8) are filtered at runtime; validation should warn."""
cfg = _make_config(
vision_model="gemini/gemini-2.0-flash",
gemini_api_keys=["x"],
anthropic_api_keys=[],
openai_api_keys=[],
deepseek_api_keys=[],
)
issues = cfg.validate_structured()
warn = [i for i in issues if i.field == "VISION_MODEL"]
assert warn and warn[0].severity == "warning"
def test_primary_provider_key_sufficient_even_if_not_in_priority(self):
"""Primary model's provider key is checked even when absent from VISION_PROVIDER_PRIORITY."""
cfg = _make_config(
vision_model="openai/gpt-4o",
vision_provider_priority="gemini,anthropic", # openai excluded from priority
openai_api_keys=["sk-openai-validkey-xyz"],
gemini_api_keys=[],
anthropic_api_keys=[],
deepseek_api_keys=[],
)
issues = cfg.validate_structured()
# Should NOT warn: primary model (openai) has a valid key
assert not any(i.field == "VISION_MODEL" and i.severity == "warning" for i in issues)
def test_no_vision_model_no_warning(self):
"""When VISION_MODEL is not set, no Vision key warning is raised."""
cfg = _make_config(vision_model="", gemini_api_keys=[])
issues = cfg.validate_structured()
assert not any(i.field == "VISION_MODEL" for i in issues)
# ---------------------------------------------------------------------------
# validate() backward compatibility
# ---------------------------------------------------------------------------
class TestValidateBackwardCompat:
def test_returns_list_of_str(self):
cfg = _make_config()
result = cfg.validate()
assert isinstance(result, list)
assert all(isinstance(s, str) for s in result)
def test_empty_llm_model_list_message_in_validate(self):
cfg = _make_config(llm_model_list=[])
messages = cfg.validate()
assert any("LLM" in m for m in messages)
def test_messages_match_validate_structured(self):
"""validate() strings must be the message field of each ConfigIssue."""
cfg = _make_config(llm_model_list=[], stock_list=[])
structured = cfg.validate_structured()
plain = cfg.validate()
assert plain == [i.message for i in structured]