Skip to content

Commit aa8134f

Browse files
add timeout to onyx guardrail (BerriAI#19731)
* add timeout to onyx guardrail * add tests
1 parent 87acdef commit aa8134f

File tree

4 files changed

+313
-4
lines changed
  • docs/my-website/docs/proxy/guardrails
  • litellm
    • proxy/guardrails/guardrail_hooks/onyx
    • types/proxy/guardrails/guardrail_hooks
  • tests/test_litellm/proxy/guardrails/guardrail_hooks

4 files changed

+313
-4
lines changed

docs/my-website/docs/proxy/guardrails/onyx_security.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ guardrails:
128128
mode: ["pre_call", "post_call", "during_call"] # Run at multiple stages
129129
api_key: os.environ/ONYX_API_KEY
130130
api_base: os.environ/ONYX_API_BASE
131+
timeout: 10.0 # Optional, defaults to 10 seconds
131132
```
132133

133134
### Required Parameters
@@ -137,6 +138,7 @@ guardrails:
137138
### Optional Parameters
138139

139140
- **`api_base`**: Onyx API base URL (defaults to `https://ai-guard.onyx.security`)
141+
- **`timeout`**: Request timeout in seconds (defaults to `10.0`)
140142

141143
## Environment Variables
142144

@@ -145,4 +147,5 @@ You can set these environment variables instead of hardcoding values in your con
145147
```shell
146148
export ONYX_API_KEY="your-api-key-here"
147149
export ONYX_API_BASE="https://ai-guard.onyx.security" # Optional
150+
export ONYX_TIMEOUT=10 # Optional, timeout in seconds
148151
```

litellm/proxy/guardrails/guardrail_hooks/onyx/onyx.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import uuid
99
from typing import TYPE_CHECKING, Any, Literal, Optional, Type
1010

11+
import httpx
1112
from fastapi import HTTPException
1213

1314
from litellm._logging import verbose_proxy_logger
@@ -25,10 +26,12 @@
2526

2627
class OnyxGuardrail(CustomGuardrail):
2728
def __init__(
28-
self, api_base: Optional[str] = None, api_key: Optional[str] = None, **kwargs
29+
self, api_base: Optional[str] = None, api_key: Optional[str] = None, timeout: Optional[float] = 10.0, **kwargs
2930
):
31+
timeout = timeout or int(os.getenv("ONYX_TIMEOUT", 10.0))
3032
self.async_handler = get_async_httpx_client(
31-
llm_provider=httpxSpecialProvider.GuardrailCallback
33+
llm_provider=httpxSpecialProvider.GuardrailCallback,
34+
params={"timeout": httpx.Timeout(timeout=timeout, connect=5.0)},
3235
)
3336
self.api_base = api_base or os.getenv(
3437
"ONYX_API_BASE",

litellm/types/proxy/guardrails/guardrail_hooks/onyx.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ class OnyxGuardrailConfigModel(GuardrailConfigModel):
1616
description="The API key for the Onyx Guard server. If not provided, the `ONYX_API_KEY` environment variable is checked.",
1717
)
1818

19+
timeout: Optional[float] = Field(
20+
default=None,
21+
description="The timeout for the Onyx Guard server in seconds. If not provided, the `ONYX_TIMEOUT` environment variable is checked.",
22+
)
23+
1924
@staticmethod
2025
def ui_friendly_name() -> str:
2126
return "Onyx Guardrail"

tests/test_litellm/proxy/guardrails/guardrail_hooks/test_onyx.py

Lines changed: 300 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import uuid
44
from unittest.mock import AsyncMock, MagicMock, patch
55

6+
import httpx
67
import pytest
78
from fastapi import HTTPException
89
from httpx import Request, Response
@@ -47,20 +48,129 @@ def test_onyx_guard_config():
4748
del os.environ["ONYX_API_KEY"]
4849

4950

51+
def test_onyx_guard_with_custom_timeout_from_kwargs():
52+
"""Test Onyx guard instantiation with custom timeout passed via kwargs."""
53+
# Set environment variables for testing
54+
os.environ["ONYX_API_BASE"] = "https://test.onyx.security"
55+
os.environ["ONYX_API_KEY"] = "test-api-key"
56+
57+
with patch(
58+
"litellm.proxy.guardrails.guardrail_hooks.onyx.onyx.get_async_httpx_client"
59+
) as mock_get_client:
60+
mock_get_client.return_value = MagicMock()
61+
62+
# Simulate how guardrail is instantiated from config with timeout
63+
guardrail = OnyxGuardrail(
64+
guardrail_name="onyx-guard-custom-timeout",
65+
event_hook="pre_call",
66+
default_on=True,
67+
timeout=45.0,
68+
)
69+
70+
# Verify the client was initialized with custom timeout
71+
mock_get_client.assert_called()
72+
call_kwargs = mock_get_client.call_args.kwargs
73+
timeout_param = call_kwargs["params"]["timeout"]
74+
assert timeout_param.read == 45.0
75+
assert timeout_param.connect == 5.0
76+
77+
# Clean up
78+
if "ONYX_API_BASE" in os.environ:
79+
del os.environ["ONYX_API_BASE"]
80+
if "ONYX_API_KEY" in os.environ:
81+
del os.environ["ONYX_API_KEY"]
82+
83+
84+
def test_onyx_guard_with_timeout_none_uses_env_var():
85+
"""Test Onyx guard with timeout=None uses ONYX_TIMEOUT env var.
86+
87+
When timeout=None is passed (as it would be from config model with default None),
88+
the ONYX_TIMEOUT environment variable should be used.
89+
"""
90+
# Set environment variables for testing
91+
os.environ["ONYX_API_BASE"] = "https://test.onyx.security"
92+
os.environ["ONYX_API_KEY"] = "test-api-key"
93+
os.environ["ONYX_TIMEOUT"] = "60"
94+
95+
with patch(
96+
"litellm.proxy.guardrails.guardrail_hooks.onyx.onyx.get_async_httpx_client"
97+
) as mock_get_client:
98+
mock_get_client.return_value = MagicMock()
99+
100+
# Pass timeout=None to simulate config model behavior
101+
guardrail = OnyxGuardrail(
102+
guardrail_name="onyx-guard-env-timeout",
103+
event_hook="pre_call",
104+
default_on=True,
105+
timeout=None, # This triggers env var lookup
106+
)
107+
108+
# Verify the client was initialized with timeout from env var
109+
mock_get_client.assert_called()
110+
call_kwargs = mock_get_client.call_args.kwargs
111+
timeout_param = call_kwargs["params"]["timeout"]
112+
assert timeout_param.read == 60.0
113+
assert timeout_param.connect == 5.0
114+
115+
# Clean up
116+
if "ONYX_API_BASE" in os.environ:
117+
del os.environ["ONYX_API_BASE"]
118+
if "ONYX_API_KEY" in os.environ:
119+
del os.environ["ONYX_API_KEY"]
120+
if "ONYX_TIMEOUT" in os.environ:
121+
del os.environ["ONYX_TIMEOUT"]
122+
123+
124+
def test_onyx_guard_with_timeout_none_defaults_to_10():
125+
"""Test Onyx guard with timeout=None and no env var defaults to 10 seconds."""
126+
# Set environment variables for testing
127+
os.environ["ONYX_API_BASE"] = "https://test.onyx.security"
128+
os.environ["ONYX_API_KEY"] = "test-api-key"
129+
# Ensure ONYX_TIMEOUT is not set
130+
if "ONYX_TIMEOUT" in os.environ:
131+
del os.environ["ONYX_TIMEOUT"]
132+
133+
with patch(
134+
"litellm.proxy.guardrails.guardrail_hooks.onyx.onyx.get_async_httpx_client"
135+
) as mock_get_client:
136+
mock_get_client.return_value = MagicMock()
137+
138+
# Pass timeout=None with no env var - should default to 10.0
139+
guardrail = OnyxGuardrail(
140+
guardrail_name="onyx-guard-default-timeout",
141+
event_hook="pre_call",
142+
default_on=True,
143+
timeout=None,
144+
)
145+
146+
# Verify the client was initialized with default timeout of 10.0
147+
mock_get_client.assert_called()
148+
call_kwargs = mock_get_client.call_args.kwargs
149+
timeout_param = call_kwargs["params"]["timeout"]
150+
assert timeout_param.read == 10.0
151+
assert timeout_param.connect == 5.0
152+
153+
# Clean up
154+
if "ONYX_API_BASE" in os.environ:
155+
del os.environ["ONYX_API_BASE"]
156+
if "ONYX_API_KEY" in os.environ:
157+
del os.environ["ONYX_API_KEY"]
158+
159+
50160
class TestOnyxGuardrail:
51161
"""Test suite for Onyx Security Guardrail integration."""
52162

53163
def setup_method(self):
54164
"""Setup test environment."""
55165
# Clean up any existing environment variables
56-
for key in ["ONYX_API_BASE", "ONYX_API_KEY"]:
166+
for key in ["ONYX_API_BASE", "ONYX_API_KEY", "ONYX_TIMEOUT"]:
57167
if key in os.environ:
58168
del os.environ[key]
59169

60170
def teardown_method(self):
61171
"""Clean up test environment."""
62172
# Clean up any environment variables set during tests
63-
for key in ["ONYX_API_BASE", "ONYX_API_KEY"]:
173+
for key in ["ONYX_API_BASE", "ONYX_API_KEY", "ONYX_TIMEOUT"]:
64174
if key in os.environ:
65175
del os.environ[key]
66176

@@ -103,6 +213,95 @@ def test_initialization_fails_when_api_key_missing(self):
103213
):
104214
OnyxGuardrail(guardrail_name="test-guard", event_hook="pre_call")
105215

216+
def test_initialization_with_default_timeout(self):
217+
"""Test that default timeout is 10.0 seconds."""
218+
os.environ["ONYX_API_KEY"] = "test-api-key"
219+
220+
with patch(
221+
"litellm.proxy.guardrails.guardrail_hooks.onyx.onyx.get_async_httpx_client"
222+
) as mock_get_client:
223+
mock_get_client.return_value = MagicMock()
224+
guardrail = OnyxGuardrail(
225+
guardrail_name="test-guard", event_hook="pre_call", default_on=True
226+
)
227+
228+
# Verify the client was initialized with correct timeout
229+
mock_get_client.assert_called_once()
230+
call_kwargs = mock_get_client.call_args.kwargs
231+
timeout_param = call_kwargs["params"]["timeout"]
232+
assert timeout_param.read == 10.0
233+
assert timeout_param.connect == 5.0
234+
235+
def test_initialization_with_custom_timeout_parameter(self):
236+
"""Test initialization with custom timeout parameter."""
237+
os.environ["ONYX_API_KEY"] = "test-api-key"
238+
239+
with patch(
240+
"litellm.proxy.guardrails.guardrail_hooks.onyx.onyx.get_async_httpx_client"
241+
) as mock_get_client:
242+
mock_get_client.return_value = MagicMock()
243+
guardrail = OnyxGuardrail(
244+
guardrail_name="test-guard",
245+
event_hook="pre_call",
246+
default_on=True,
247+
timeout=30.0,
248+
)
249+
250+
# Verify the client was initialized with custom timeout
251+
mock_get_client.assert_called_once()
252+
call_kwargs = mock_get_client.call_args.kwargs
253+
timeout_param = call_kwargs["params"]["timeout"]
254+
assert timeout_param.read == 30.0
255+
assert timeout_param.connect == 5.0
256+
257+
def test_initialization_with_timeout_from_env_var(self):
258+
"""Test initialization with timeout from ONYX_TIMEOUT environment variable.
259+
260+
Note: The env var is only used when timeout=None is explicitly passed,
261+
since the default parameter value is 10.0 (not None).
262+
"""
263+
os.environ["ONYX_API_KEY"] = "test-api-key"
264+
os.environ["ONYX_TIMEOUT"] = "25"
265+
266+
with patch(
267+
"litellm.proxy.guardrails.guardrail_hooks.onyx.onyx.get_async_httpx_client"
268+
) as mock_get_client:
269+
mock_get_client.return_value = MagicMock()
270+
# Must pass timeout=None explicitly to trigger env var lookup
271+
guardrail = OnyxGuardrail(
272+
guardrail_name="test-guard", event_hook="pre_call", default_on=True, timeout=None
273+
)
274+
275+
# Verify the client was initialized with timeout from env var
276+
mock_get_client.assert_called_once()
277+
call_kwargs = mock_get_client.call_args.kwargs
278+
timeout_param = call_kwargs["params"]["timeout"]
279+
assert timeout_param.read == 25.0
280+
assert timeout_param.connect == 5.0
281+
282+
def test_initialization_timeout_parameter_overrides_env_var(self):
283+
"""Test that timeout parameter overrides ONYX_TIMEOUT environment variable."""
284+
os.environ["ONYX_API_KEY"] = "test-api-key"
285+
os.environ["ONYX_TIMEOUT"] = "25"
286+
287+
with patch(
288+
"litellm.proxy.guardrails.guardrail_hooks.onyx.onyx.get_async_httpx_client"
289+
) as mock_get_client:
290+
mock_get_client.return_value = MagicMock()
291+
guardrail = OnyxGuardrail(
292+
guardrail_name="test-guard",
293+
event_hook="pre_call",
294+
default_on=True,
295+
timeout=15.0,
296+
)
297+
298+
# Verify the client was initialized with parameter timeout (not env var)
299+
mock_get_client.assert_called_once()
300+
call_kwargs = mock_get_client.call_args.kwargs
301+
timeout_param = call_kwargs["params"]["timeout"]
302+
assert timeout_param.read == 15.0
303+
assert timeout_param.connect == 5.0
304+
106305
@pytest.mark.asyncio
107306
async def test_apply_guardrail_request_no_violations(self):
108307
"""Test apply_guardrail for request with no violations detected."""
@@ -388,6 +587,105 @@ async def test_apply_guardrail_api_error_handling(self):
388587

389588
assert result == inputs
390589

590+
@pytest.mark.asyncio
591+
async def test_apply_guardrail_timeout_error_handling(self):
592+
"""Test handling of timeout errors in apply_guardrail (graceful degradation)."""
593+
# Set required API key
594+
os.environ["ONYX_API_KEY"] = "test-api-key"
595+
596+
guardrail = OnyxGuardrail(
597+
guardrail_name="test-guard", event_hook="pre_call", default_on=True, timeout=1.0
598+
)
599+
600+
inputs = GenericGuardrailAPIInputs()
601+
602+
request_data = {
603+
"proxy_server_request": {
604+
"messages": [{"role": "user", "content": "Test message"}],
605+
"model": "gpt-3.5-turbo",
606+
}
607+
}
608+
609+
# Test httpx timeout error
610+
with patch.object(
611+
guardrail.async_handler, "post", side_effect=httpx.TimeoutException("Request timed out")
612+
):
613+
# Should return original inputs on timeout (graceful degradation)
614+
result = await guardrail.apply_guardrail(
615+
inputs=inputs,
616+
request_data=request_data,
617+
input_type="request",
618+
logging_obj=None,
619+
)
620+
621+
assert result == inputs
622+
623+
@pytest.mark.asyncio
624+
async def test_apply_guardrail_read_timeout_error_handling(self):
625+
"""Test handling of read timeout errors in apply_guardrail."""
626+
# Set required API key
627+
os.environ["ONYX_API_KEY"] = "test-api-key"
628+
629+
guardrail = OnyxGuardrail(
630+
guardrail_name="test-guard", event_hook="pre_call", default_on=True, timeout=5.0
631+
)
632+
633+
inputs = GenericGuardrailAPIInputs()
634+
635+
request_data = {
636+
"proxy_server_request": {
637+
"messages": [{"role": "user", "content": "Test message"}],
638+
"model": "gpt-3.5-turbo",
639+
}
640+
}
641+
642+
# Test httpx ReadTimeout error
643+
with patch.object(
644+
guardrail.async_handler, "post", side_effect=httpx.ReadTimeout("Read timed out")
645+
):
646+
# Should return original inputs on timeout (graceful degradation)
647+
result = await guardrail.apply_guardrail(
648+
inputs=inputs,
649+
request_data=request_data,
650+
input_type="request",
651+
logging_obj=None,
652+
)
653+
654+
assert result == inputs
655+
656+
@pytest.mark.asyncio
657+
async def test_apply_guardrail_connect_timeout_error_handling(self):
658+
"""Test handling of connect timeout errors in apply_guardrail."""
659+
# Set required API key
660+
os.environ["ONYX_API_KEY"] = "test-api-key"
661+
662+
guardrail = OnyxGuardrail(
663+
guardrail_name="test-guard", event_hook="pre_call", default_on=True, timeout=5.0
664+
)
665+
666+
inputs = GenericGuardrailAPIInputs()
667+
668+
request_data = {
669+
"proxy_server_request": {
670+
"messages": [{"role": "user", "content": "Test message"}],
671+
"model": "gpt-3.5-turbo",
672+
}
673+
}
674+
675+
# Test httpx ConnectTimeout error
676+
with patch.object(
677+
guardrail.async_handler, "post", side_effect=httpx.ConnectTimeout("Connect timed out")
678+
):
679+
# Should return original inputs on timeout (graceful degradation)
680+
result = await guardrail.apply_guardrail(
681+
inputs=inputs,
682+
request_data=request_data,
683+
input_type="request",
684+
logging_obj=None,
685+
)
686+
687+
assert result == inputs
688+
391689
@pytest.mark.asyncio
392690
async def test_apply_guardrail_no_logging_obj(self):
393691
"""Test apply_guardrail without logging object (uses UUID)."""

0 commit comments

Comments
 (0)