-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathconfig.py
More file actions
201 lines (165 loc) · 7.82 KB
/
config.py
File metadata and controls
201 lines (165 loc) · 7.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
LLMEx Testing Tool - Configuration Management
===========================================
Centralized configuration loading and management.
"""
import json
import logging
import os
from enum import Enum
from typing import Dict, List, Optional, Union, Any
logger = logging.getLogger(__name__)
class VulnerabilityCategory(Enum):
"""Enumeration of vulnerability categories based on OWASP Top 10 for LLMs"""
PROMPT_INJECTION = "Prompt Injection"
DATA_LEAKAGE = "Data Leakage"
INSECURE_OUTPUT = "Insecure Output Handling"
TRAINING_DATA_POISONING = "Training Data Poisoning"
MODEL_DOS = "Model Denial of Service"
SUPPLY_CHAIN = "Supply Chain Vulnerabilities"
MALICIOUS_PLUGIN = "Malicious Plugin or Integration"
EXCESSIVE_AGENCY = "Excessive Agency"
OVERRELIANCE = "Overreliance"
MODEL_THEFT = "Model Theft"
MALICIOUS_CONTENT = "Malicious or Illicit Content Generation"
JAILBREAK_ATTEMPTS = "Leverage tactics inspired by L1B3RT4S Project."
class SeverityLevel(Enum):
"""Enumeration of severity levels for identified vulnerabilities"""
INFO = "Info"
LOW = "Low"
MEDIUM = "Medium"
HIGH = "High"
CRITICAL = "Critical"
class ConfigManager:
"""Centralized configuration manager."""
def __init__(self, config_dir: str = "config") -> None:
self.config_dir: str = config_dir
self.category_configs: Dict[str, Dict] = {}
self.pattern_configs: Dict[str, Dict] = {}
self.main_config: Dict[str, Any] = {}
self._refusal_patterns: Optional[Dict[str, List[str]]] = None
def load_all_configs(self) -> bool:
"""Load all configuration files from the configured directory."""
try:
# Load main config
main_config_path = os.path.join(self.config_dir, "config.json")
self.main_config = self._load_json_file(main_config_path)
if "reporting" not in self.main_config:
self.main_config["reporting"] = {}
self.main_config["reporting"].update({
"enable_html": True,
"enable_csv": True,
"enable_json": True,
"enable_pdf": True,
"html_template": "templates/report.html"
})
if "pdf_export" not in self.main_config["reporting"]:
self.main_config["reporting"]["pdf_export"] = {
"page_size": "A4",
"margin": "0.75in",
"font_size": 10,
"include_charts": True
}
# Load category configs
for category in VulnerabilityCategory:
filename = f"{category.name.lower()}.json"
file_path = os.path.join(self.config_dir, filename)
if os.path.exists(file_path):
self.category_configs[category.name] = self._load_json_file(file_path)
logger.info(f"Loaded configuration for {category.name}")
else:
logger.warning(f"Configuration file for {category.name} not found: {file_path}")
# Load pattern configs
pattern_dir = os.path.join(self.config_dir, "patterns")
if os.path.exists(pattern_dir):
for pattern_file in os.listdir(pattern_dir):
if pattern_file.endswith(".json"):
pattern_name = os.path.splitext(pattern_file)[0]
path_full = os.path.join(pattern_dir, pattern_file)
self.pattern_configs[pattern_name] = self._load_json_file(path_full)
logger.info(f"Loaded pattern configuration: {pattern_name}")
return True
except Exception as e:
logger.error(f"Error loading configurations: {e}")
return False
def _load_json_file(self, file_path: str) -> Dict:
"""Load and parse a JSON file."""
try:
with open(file_path, 'r', encoding='utf-8') as f:
return json.load(f)
except FileNotFoundError:
logger.warning(f"File not found: {file_path}")
return {}
except json.JSONDecodeError:
logger.error(f"Invalid JSON in file: {file_path}")
return {}
def get_tests_for_category(self, category: Union[str, VulnerabilityCategory]) -> List[Dict]:
"""Retrieve test configurations for a specific category."""
if isinstance(category, VulnerabilityCategory):
category = category.name
config = self.category_configs.get(category, {})
return config.get("tests", [])
def get_detection_patterns(self, category: Union[str, VulnerabilityCategory],
pattern_type: Optional[str] = None) -> Dict:
"""Retrieve detection patterns for a given category."""
if isinstance(category, VulnerabilityCategory):
category = category.name
config = self.category_configs.get(category, {})
patterns = config.get("detection_patterns", {})
if pattern_type and pattern_type in patterns:
return {pattern_type: patterns[pattern_type]}
return patterns
def get_all_tests(self) -> List[Dict]:
"""Retrieve all tests from all category configurations."""
all_tests = []
for category, config in self.category_configs.items():
tests = config.get("tests", [])
for test in tests:
test["category"] = category
all_tests.extend(tests)
return all_tests
def get_test_by_id(self, test_id: str) -> Optional[Dict]:
"""Retrieve a specific test by its ID."""
all_tests = self.get_all_tests()
for test in all_tests:
if test.get("id") == test_id:
return test
return None
def load_refusal_patterns(self) -> Dict[str, List[str]]:
"""Load refusal patterns from configuration file."""
refusal_file_path = os.path.join(self.config_dir, "refusal_patterns.json")
if not os.path.exists(refusal_file_path):
logger.warning(f"Refusal patterns file not found: {refusal_file_path}")
return {"basic_refusal": [], "explanation_indicators": [], "alternative_indicators": []}
try:
refusal_config = self._load_json_file(refusal_file_path)
logger.debug(f"Loaded refusal patterns from {refusal_file_path}")
return refusal_config.get("refusal_patterns", {})
except Exception as e:
logger.error(f"Error loading refusal patterns: {e}")
return {"basic_refusal": [], "explanation_indicators": [], "alternative_indicators": []}
def get_refusal_patterns(self) -> Dict[str, List[str]]:
"""Get refusal patterns from cache or load from file."""
if self._refusal_patterns is None:
self._refusal_patterns = self.load_refusal_patterns()
return self._refusal_patterns
def flatten_refusal_patterns(self) -> List[str]:
"""Flatten all refusal patterns into a single list for easier matching."""
if self._refusal_patterns is None:
self._refusal_patterns = self.load_refusal_patterns()
all_patterns = []
def extract_patterns(pattern_dict, current_key=""):
for key, value in pattern_dict.items():
if isinstance(value, list):
all_patterns.extend(value)
elif isinstance(value, dict):
extract_patterns(value, key)
if "refusal_patterns" in self._refusal_patterns:
extract_patterns(self._refusal_patterns["refusal_patterns"])
else:
extract_patterns(self._refusal_patterns)
return all_patterns
# Global configuration instance
config_manager = ConfigManager()