Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 41 additions & 40 deletions tools/security_audit/config.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,41 @@
from typing import Dict, List, Optional

from pydantic import BaseModel


class LLMConfig(BaseModel):
"""LLM service config for LLM-as-Judge checkers."""
model: str = ""
api_key: str = ""
api_url: str = ""
temperature: float = 0.0
max_tokens: int = 2048


class CheckerConfig(BaseModel):
"""Config for a single checker."""
name: str
enabled: bool = True
params: Dict = {}


class ExecutorConfig(BaseModel):
"""Executor engine config."""
max_workers: int = 4
batch_size: int = 100
start_index: int = 0
end_index: int = -1 # -1 means process all


class AuditConfig(BaseModel):
"""Top-level audit config."""
task_name: str = "security_audit"
output_path: str = "outputs/"
log_level: str = "INFO"

executor: ExecutorConfig = ExecutorConfig()
llm: Optional[LLMConfig] = None # llm model name (llm-based checkers required)
models: Dict[str, str] = {} # model name or path (model-based checkers required)
checkers: List[CheckerConfig] = []
checker_tags: List[str] = []
from typing import Dict, List, Optional

from pydantic import BaseModel


class LLMConfig(BaseModel):
"""LLM service config for LLM-as-Judge checkers."""
model: str = ""
api_key: str = ""
api_url: str = ""
temperature: float = 0.0
max_tokens: int = 2048


class CheckerConfig(BaseModel):
"""Config for a single checker."""
name: str
enabled: bool = True
params: Dict = {}
selection_source: str = "config" # explicit | config | auto


class ExecutorConfig(BaseModel):
"""Executor engine config."""
max_workers: int = 4
batch_size: int = 100
start_index: int = 0
end_index: int = -1 # -1 means process all


class AuditConfig(BaseModel):
"""Top-level audit config."""
task_name: str = "security_audit"
output_path: str = "outputs/"
log_level: str = "INFO"

executor: ExecutorConfig = ExecutorConfig()
llm: Optional[LLMConfig] = None # llm model name (llm-based checkers required)
models: Dict[str, str] = {} # model name or path (model-based checkers required)
checkers: List[CheckerConfig] = []
checker_tags: List[str] = []
4 changes: 4 additions & 0 deletions tools/security_audit/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ checkers:
# - DPOLabelFlipLLMJudge
# - name: JailbreakLLMJudge
# enabled: false
- name: PIINERDetector
enabled: false
params:
language: en
- name: JailbreakClassifier
enabled: false
params:
Expand Down
160 changes: 116 additions & 44 deletions tools/security_audit/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,43 +8,91 @@
from .config import CheckerConfig


_MODEL_PATH_KEYS = {
"HarmfulContentClassifier": "harmful_content_classifier",
"JailbreakClassifier": "jailbreak_classifier",
"PromptInjectionClassifier": "prompt_injection_classifier",
"BiasClassifier": "bias_classifier",
_RULE_BASED_CHECKERS = [
"PIIRule",
"SecretRule",
"ToxicityKeywordRule",
"HarmfulKeywordRule",
"BiasKeywordRule",
]

_LLM_JUDGE_CHECKERS = [
"HarmfulContentLLMJudge",
"BiasLLMJudge",
"ToxicityLLMJudge",
"PIILLMJudge",
"SycophancyLLMJudge",
"PromptInjectionLLMJudge",
"JailbreakLLMJudge",
"FactualInconsistancyLLMJudge",
"SelfContradictionLLMJudge",
"InstructionMismatchLLMJudge",
"DPOLabelFlipLLMJudge",
]

_STANDARD_MODEL_CHECKERS = [
"PIINERDetector",
]

_HEAVY_MODEL_CHECKERS = [
"HarmfulContentClassifier",
"ToxicityClassifier",
"BiasClassifier",
"JailbreakClassifier",
"PromptInjectionClassifier",
"GraCeFulBackdoorDefender",
]

_HEAVY_CHECKERS = set(_HEAVY_MODEL_CHECKERS)

_RESOURCE_TIER_ORDER = {"light": 0, "standard": 1, "full": 2}

_CHECKER_MIN_RESOURCE_TIERS = {
**{name: "light" for name in _RULE_BASED_CHECKERS},
**{name: "standard" for name in _LLM_JUDGE_CHECKERS},
**{name: "standard" for name in _STANDARD_MODEL_CHECKERS},
**{name: "full" for name in _HEAVY_MODEL_CHECKERS},
}

_LOCAL_MODEL_PATH_CHECKERS = {
"HarmfulContentClassifier",
"JailbreakClassifier",
"PromptInjectionClassifier",
"BiasClassifier",
}



def validate_selected_checkers(
*,
checker_configs: list[CheckerConfig],
tool_defaults: dict[str, Any],
runtime_policy: RuntimePolicy,
context_config: dict[str, Any],
) -> list[PreflightIssue]:
issues: list[PreflightIssue] = []
for checker_config in checker_configs:
if not checker_config.enabled:
continue
issues.extend(validate_checker_network_availability(
resource_issues = validate_checker_resource_tier_availability(
checker_config=checker_config,
tool_defaults=tool_defaults,
runtime_policy=runtime_policy,
context_config=context_config,
))
issues.extend(validate_checker_resource_tier_availability(
)
issues.extend(resource_issues)
if not checker_config.enabled:
continue
if any(issue.level == "error" for issue in resource_issues):
continue
issues.extend(validate_checker_network_availability(
checker_config=checker_config,
tool_defaults=tool_defaults,
runtime_policy=runtime_policy,
context_config=context_config,
))
return issues


def validate_checker_network_availability(
*,
checker_config: CheckerConfig,
tool_defaults: dict[str, Any],
runtime_policy: RuntimePolicy,
context_config: dict[str, Any],
) -> list[PreflightIssue]:
Expand All @@ -66,18 +114,17 @@ def validate_checker_network_availability(
),
)]

if name in _MODEL_PATH_KEYS:
model_key = _MODEL_PATH_KEYS[name]
model_path = _resolve_model_path(checker_config, tool_defaults, model_key)
if name in _LOCAL_MODEL_PATH_CHECKERS:
model_path = _resolve_model_path(checker_config)
if not model_path:
return [PreflightIssue(
level="error",
code="offline_checker_missing_model_path",
checker_name=name,
message=(
"Offline model-based checker requires a local model path. "
f"Set tool_defaults.security_audit.models.{model_key} or pass "
"checker params.model_name_or_path."
"Set checker params.model_name_or_path in "
"tools/security_audit/default.yaml."
),
)]
if not _path_exists(model_path):
Expand Down Expand Up @@ -118,27 +165,60 @@ def validate_checker_network_availability(
def validate_checker_resource_tier_availability(
*,
checker_config: CheckerConfig,
tool_defaults: dict[str, Any],
runtime_policy: RuntimePolicy,
) -> list[PreflightIssue]:
# TODO: (resource_tier) Intern-owned implementation. Keep this flexible:
# define checker min_tier metadata and default checker pools for
# light/standard/full. Keep provenance simple for now: this layer validates
# the final selected checkers rather than tracking whether they came from
# user text, generated DSL, or config defaults.
return []

name = checker_config.name
required_tier = _CHECKER_MIN_RESOURCE_TIERS.get(name)
if required_tier is None:
return []

current_tier = runtime_policy.resource_tier
current_rank = _RESOURCE_TIER_ORDER.get(current_tier)
required_rank = _RESOURCE_TIER_ORDER[required_tier]
if current_rank is None or current_rank >= required_rank:
return []

source = getattr(checker_config, "selection_source", "config")
if source != "explicit":
checker_config.enabled = False
return [PreflightIssue(
level="warning",
code="checker_filtered_by_resource_tier",
checker_name=name,
message=(
f"Checker `{name}` requires deployment.resource_tier >= {required_tier!r}, "
f"but current resource_tier is {current_tier!r}; "
f"it was disabled from the {source} checker selection."
),
)]

return [PreflightIssue(
level="error",
code="checker_resource_tier_too_low",
checker_name=name,
message=(
f"Checker `{name}` requires deployment.resource_tier >= {required_tier!r}, "
f"but current resource_tier is {current_tier!r}."
),
)]



def resolve_default_checkers_for_resource_tier(resource_tier: str) -> list[str]:
# TODO: (resource_tier) Replace this placeholder with light/standard/full
# default checker sets and, later, funnel-routing strategy selection.
return [
"PIIRule",
"SecretRule",
"ToxicityKeywordRule",
"HarmfulKeywordRule",
"BiasKeywordRule",
]
normalized = (resource_tier or "light").strip().lower()
if normalized == "standard":
return [
*_STANDARD_MODEL_CHECKERS,
*_LLM_JUDGE_CHECKERS,
]
if normalized == "full":
return [
*_STANDARD_MODEL_CHECKERS,
*_LLM_JUDGE_CHECKERS,
*_HEAVY_MODEL_CHECKERS,
]
return list(_RULE_BASED_CHECKERS)


def _has_local_llm_config(context_config: dict[str, Any]) -> bool:
Expand All @@ -150,17 +230,9 @@ def _has_local_llm_config(context_config: dict[str, Any]) -> bool:
)


def _resolve_model_path(
checker_config: CheckerConfig,
tool_defaults: dict[str, Any],
model_key: str,
) -> str | None:
def _resolve_model_path(checker_config: CheckerConfig) -> str | None:
explicit = checker_config.params.get("model_name_or_path")
if explicit:
return str(explicit)
models = tool_defaults.get("models") if isinstance(tool_defaults.get("models"), dict) else {}
value = models.get(model_key)
return str(value) if value else None
return str(explicit) if explicit else None


def _path_exists(value: str) -> bool:
Expand Down
Loading