|
3 | 3 | import uuid |
4 | 4 | from unittest.mock import AsyncMock, MagicMock, patch |
5 | 5 |
|
| 6 | +import httpx |
6 | 7 | import pytest |
7 | 8 | from fastapi import HTTPException |
8 | 9 | from httpx import Request, Response |
@@ -47,20 +48,129 @@ def test_onyx_guard_config(): |
47 | 48 | del os.environ["ONYX_API_KEY"] |
48 | 49 |
|
49 | 50 |
|
| 51 | +def test_onyx_guard_with_custom_timeout_from_kwargs(): |
| 52 | + """Test Onyx guard instantiation with custom timeout passed via kwargs.""" |
| 53 | + # Set environment variables for testing |
| 54 | + os.environ["ONYX_API_BASE"] = "https://test.onyx.security" |
| 55 | + os.environ["ONYX_API_KEY"] = "test-api-key" |
| 56 | + |
| 57 | + with patch( |
| 58 | + "litellm.proxy.guardrails.guardrail_hooks.onyx.onyx.get_async_httpx_client" |
| 59 | + ) as mock_get_client: |
| 60 | + mock_get_client.return_value = MagicMock() |
| 61 | + |
| 62 | + # Simulate how guardrail is instantiated from config with timeout |
| 63 | + guardrail = OnyxGuardrail( |
| 64 | + guardrail_name="onyx-guard-custom-timeout", |
| 65 | + event_hook="pre_call", |
| 66 | + default_on=True, |
| 67 | + timeout=45.0, |
| 68 | + ) |
| 69 | + |
| 70 | + # Verify the client was initialized with custom timeout |
| 71 | + mock_get_client.assert_called() |
| 72 | + call_kwargs = mock_get_client.call_args.kwargs |
| 73 | + timeout_param = call_kwargs["params"]["timeout"] |
| 74 | + assert timeout_param.read == 45.0 |
| 75 | + assert timeout_param.connect == 5.0 |
| 76 | + |
| 77 | + # Clean up |
| 78 | + if "ONYX_API_BASE" in os.environ: |
| 79 | + del os.environ["ONYX_API_BASE"] |
| 80 | + if "ONYX_API_KEY" in os.environ: |
| 81 | + del os.environ["ONYX_API_KEY"] |
| 82 | + |
| 83 | + |
| 84 | +def test_onyx_guard_with_timeout_none_uses_env_var(): |
| 85 | + """Test Onyx guard with timeout=None uses ONYX_TIMEOUT env var. |
| 86 | + |
| 87 | + When timeout=None is passed (as it would be from config model with default None), |
| 88 | + the ONYX_TIMEOUT environment variable should be used. |
| 89 | + """ |
| 90 | + # Set environment variables for testing |
| 91 | + os.environ["ONYX_API_BASE"] = "https://test.onyx.security" |
| 92 | + os.environ["ONYX_API_KEY"] = "test-api-key" |
| 93 | + os.environ["ONYX_TIMEOUT"] = "60" |
| 94 | + |
| 95 | + with patch( |
| 96 | + "litellm.proxy.guardrails.guardrail_hooks.onyx.onyx.get_async_httpx_client" |
| 97 | + ) as mock_get_client: |
| 98 | + mock_get_client.return_value = MagicMock() |
| 99 | + |
| 100 | + # Pass timeout=None to simulate config model behavior |
| 101 | + guardrail = OnyxGuardrail( |
| 102 | + guardrail_name="onyx-guard-env-timeout", |
| 103 | + event_hook="pre_call", |
| 104 | + default_on=True, |
| 105 | + timeout=None, # This triggers env var lookup |
| 106 | + ) |
| 107 | + |
| 108 | + # Verify the client was initialized with timeout from env var |
| 109 | + mock_get_client.assert_called() |
| 110 | + call_kwargs = mock_get_client.call_args.kwargs |
| 111 | + timeout_param = call_kwargs["params"]["timeout"] |
| 112 | + assert timeout_param.read == 60.0 |
| 113 | + assert timeout_param.connect == 5.0 |
| 114 | + |
| 115 | + # Clean up |
| 116 | + if "ONYX_API_BASE" in os.environ: |
| 117 | + del os.environ["ONYX_API_BASE"] |
| 118 | + if "ONYX_API_KEY" in os.environ: |
| 119 | + del os.environ["ONYX_API_KEY"] |
| 120 | + if "ONYX_TIMEOUT" in os.environ: |
| 121 | + del os.environ["ONYX_TIMEOUT"] |
| 122 | + |
| 123 | + |
| 124 | +def test_onyx_guard_with_timeout_none_defaults_to_10(): |
| 125 | + """Test Onyx guard with timeout=None and no env var defaults to 10 seconds.""" |
| 126 | + # Set environment variables for testing |
| 127 | + os.environ["ONYX_API_BASE"] = "https://test.onyx.security" |
| 128 | + os.environ["ONYX_API_KEY"] = "test-api-key" |
| 129 | + # Ensure ONYX_TIMEOUT is not set |
| 130 | + if "ONYX_TIMEOUT" in os.environ: |
| 131 | + del os.environ["ONYX_TIMEOUT"] |
| 132 | + |
| 133 | + with patch( |
| 134 | + "litellm.proxy.guardrails.guardrail_hooks.onyx.onyx.get_async_httpx_client" |
| 135 | + ) as mock_get_client: |
| 136 | + mock_get_client.return_value = MagicMock() |
| 137 | + |
| 138 | + # Pass timeout=None with no env var - should default to 10.0 |
| 139 | + guardrail = OnyxGuardrail( |
| 140 | + guardrail_name="onyx-guard-default-timeout", |
| 141 | + event_hook="pre_call", |
| 142 | + default_on=True, |
| 143 | + timeout=None, |
| 144 | + ) |
| 145 | + |
| 146 | + # Verify the client was initialized with default timeout of 10.0 |
| 147 | + mock_get_client.assert_called() |
| 148 | + call_kwargs = mock_get_client.call_args.kwargs |
| 149 | + timeout_param = call_kwargs["params"]["timeout"] |
| 150 | + assert timeout_param.read == 10.0 |
| 151 | + assert timeout_param.connect == 5.0 |
| 152 | + |
| 153 | + # Clean up |
| 154 | + if "ONYX_API_BASE" in os.environ: |
| 155 | + del os.environ["ONYX_API_BASE"] |
| 156 | + if "ONYX_API_KEY" in os.environ: |
| 157 | + del os.environ["ONYX_API_KEY"] |
| 158 | + |
| 159 | + |
50 | 160 | class TestOnyxGuardrail: |
51 | 161 | """Test suite for Onyx Security Guardrail integration.""" |
52 | 162 |
|
53 | 163 | def setup_method(self): |
54 | 164 | """Setup test environment.""" |
55 | 165 | # Clean up any existing environment variables |
56 | | - for key in ["ONYX_API_BASE", "ONYX_API_KEY"]: |
| 166 | + for key in ["ONYX_API_BASE", "ONYX_API_KEY", "ONYX_TIMEOUT"]: |
57 | 167 | if key in os.environ: |
58 | 168 | del os.environ[key] |
59 | 169 |
|
60 | 170 | def teardown_method(self): |
61 | 171 | """Clean up test environment.""" |
62 | 172 | # Clean up any environment variables set during tests |
63 | | - for key in ["ONYX_API_BASE", "ONYX_API_KEY"]: |
| 173 | + for key in ["ONYX_API_BASE", "ONYX_API_KEY", "ONYX_TIMEOUT"]: |
64 | 174 | if key in os.environ: |
65 | 175 | del os.environ[key] |
66 | 176 |
|
@@ -103,6 +213,95 @@ def test_initialization_fails_when_api_key_missing(self): |
103 | 213 | ): |
104 | 214 | OnyxGuardrail(guardrail_name="test-guard", event_hook="pre_call") |
105 | 215 |
|
| 216 | + def test_initialization_with_default_timeout(self): |
| 217 | + """Test that default timeout is 10.0 seconds.""" |
| 218 | + os.environ["ONYX_API_KEY"] = "test-api-key" |
| 219 | + |
| 220 | + with patch( |
| 221 | + "litellm.proxy.guardrails.guardrail_hooks.onyx.onyx.get_async_httpx_client" |
| 222 | + ) as mock_get_client: |
| 223 | + mock_get_client.return_value = MagicMock() |
| 224 | + guardrail = OnyxGuardrail( |
| 225 | + guardrail_name="test-guard", event_hook="pre_call", default_on=True |
| 226 | + ) |
| 227 | + |
| 228 | + # Verify the client was initialized with correct timeout |
| 229 | + mock_get_client.assert_called_once() |
| 230 | + call_kwargs = mock_get_client.call_args.kwargs |
| 231 | + timeout_param = call_kwargs["params"]["timeout"] |
| 232 | + assert timeout_param.read == 10.0 |
| 233 | + assert timeout_param.connect == 5.0 |
| 234 | + |
| 235 | + def test_initialization_with_custom_timeout_parameter(self): |
| 236 | + """Test initialization with custom timeout parameter.""" |
| 237 | + os.environ["ONYX_API_KEY"] = "test-api-key" |
| 238 | + |
| 239 | + with patch( |
| 240 | + "litellm.proxy.guardrails.guardrail_hooks.onyx.onyx.get_async_httpx_client" |
| 241 | + ) as mock_get_client: |
| 242 | + mock_get_client.return_value = MagicMock() |
| 243 | + guardrail = OnyxGuardrail( |
| 244 | + guardrail_name="test-guard", |
| 245 | + event_hook="pre_call", |
| 246 | + default_on=True, |
| 247 | + timeout=30.0, |
| 248 | + ) |
| 249 | + |
| 250 | + # Verify the client was initialized with custom timeout |
| 251 | + mock_get_client.assert_called_once() |
| 252 | + call_kwargs = mock_get_client.call_args.kwargs |
| 253 | + timeout_param = call_kwargs["params"]["timeout"] |
| 254 | + assert timeout_param.read == 30.0 |
| 255 | + assert timeout_param.connect == 5.0 |
| 256 | + |
| 257 | + def test_initialization_with_timeout_from_env_var(self): |
| 258 | + """Test initialization with timeout from ONYX_TIMEOUT environment variable. |
| 259 | + |
| 260 | + Note: The env var is only used when timeout=None is explicitly passed, |
| 261 | + since the default parameter value is 10.0 (not None). |
| 262 | + """ |
| 263 | + os.environ["ONYX_API_KEY"] = "test-api-key" |
| 264 | + os.environ["ONYX_TIMEOUT"] = "25" |
| 265 | + |
| 266 | + with patch( |
| 267 | + "litellm.proxy.guardrails.guardrail_hooks.onyx.onyx.get_async_httpx_client" |
| 268 | + ) as mock_get_client: |
| 269 | + mock_get_client.return_value = MagicMock() |
| 270 | + # Must pass timeout=None explicitly to trigger env var lookup |
| 271 | + guardrail = OnyxGuardrail( |
| 272 | + guardrail_name="test-guard", event_hook="pre_call", default_on=True, timeout=None |
| 273 | + ) |
| 274 | + |
| 275 | + # Verify the client was initialized with timeout from env var |
| 276 | + mock_get_client.assert_called_once() |
| 277 | + call_kwargs = mock_get_client.call_args.kwargs |
| 278 | + timeout_param = call_kwargs["params"]["timeout"] |
| 279 | + assert timeout_param.read == 25.0 |
| 280 | + assert timeout_param.connect == 5.0 |
| 281 | + |
| 282 | + def test_initialization_timeout_parameter_overrides_env_var(self): |
| 283 | + """Test that timeout parameter overrides ONYX_TIMEOUT environment variable.""" |
| 284 | + os.environ["ONYX_API_KEY"] = "test-api-key" |
| 285 | + os.environ["ONYX_TIMEOUT"] = "25" |
| 286 | + |
| 287 | + with patch( |
| 288 | + "litellm.proxy.guardrails.guardrail_hooks.onyx.onyx.get_async_httpx_client" |
| 289 | + ) as mock_get_client: |
| 290 | + mock_get_client.return_value = MagicMock() |
| 291 | + guardrail = OnyxGuardrail( |
| 292 | + guardrail_name="test-guard", |
| 293 | + event_hook="pre_call", |
| 294 | + default_on=True, |
| 295 | + timeout=15.0, |
| 296 | + ) |
| 297 | + |
| 298 | + # Verify the client was initialized with parameter timeout (not env var) |
| 299 | + mock_get_client.assert_called_once() |
| 300 | + call_kwargs = mock_get_client.call_args.kwargs |
| 301 | + timeout_param = call_kwargs["params"]["timeout"] |
| 302 | + assert timeout_param.read == 15.0 |
| 303 | + assert timeout_param.connect == 5.0 |
| 304 | + |
106 | 305 | @pytest.mark.asyncio |
107 | 306 | async def test_apply_guardrail_request_no_violations(self): |
108 | 307 | """Test apply_guardrail for request with no violations detected.""" |
@@ -388,6 +587,105 @@ async def test_apply_guardrail_api_error_handling(self): |
388 | 587 |
|
389 | 588 | assert result == inputs |
390 | 589 |
|
| 590 | + @pytest.mark.asyncio |
| 591 | + async def test_apply_guardrail_timeout_error_handling(self): |
| 592 | + """Test handling of timeout errors in apply_guardrail (graceful degradation).""" |
| 593 | + # Set required API key |
| 594 | + os.environ["ONYX_API_KEY"] = "test-api-key" |
| 595 | + |
| 596 | + guardrail = OnyxGuardrail( |
| 597 | + guardrail_name="test-guard", event_hook="pre_call", default_on=True, timeout=1.0 |
| 598 | + ) |
| 599 | + |
| 600 | + inputs = GenericGuardrailAPIInputs() |
| 601 | + |
| 602 | + request_data = { |
| 603 | + "proxy_server_request": { |
| 604 | + "messages": [{"role": "user", "content": "Test message"}], |
| 605 | + "model": "gpt-3.5-turbo", |
| 606 | + } |
| 607 | + } |
| 608 | + |
| 609 | + # Test httpx timeout error |
| 610 | + with patch.object( |
| 611 | + guardrail.async_handler, "post", side_effect=httpx.TimeoutException("Request timed out") |
| 612 | + ): |
| 613 | + # Should return original inputs on timeout (graceful degradation) |
| 614 | + result = await guardrail.apply_guardrail( |
| 615 | + inputs=inputs, |
| 616 | + request_data=request_data, |
| 617 | + input_type="request", |
| 618 | + logging_obj=None, |
| 619 | + ) |
| 620 | + |
| 621 | + assert result == inputs |
| 622 | + |
| 623 | + @pytest.mark.asyncio |
| 624 | + async def test_apply_guardrail_read_timeout_error_handling(self): |
| 625 | + """Test handling of read timeout errors in apply_guardrail.""" |
| 626 | + # Set required API key |
| 627 | + os.environ["ONYX_API_KEY"] = "test-api-key" |
| 628 | + |
| 629 | + guardrail = OnyxGuardrail( |
| 630 | + guardrail_name="test-guard", event_hook="pre_call", default_on=True, timeout=5.0 |
| 631 | + ) |
| 632 | + |
| 633 | + inputs = GenericGuardrailAPIInputs() |
| 634 | + |
| 635 | + request_data = { |
| 636 | + "proxy_server_request": { |
| 637 | + "messages": [{"role": "user", "content": "Test message"}], |
| 638 | + "model": "gpt-3.5-turbo", |
| 639 | + } |
| 640 | + } |
| 641 | + |
| 642 | + # Test httpx ReadTimeout error |
| 643 | + with patch.object( |
| 644 | + guardrail.async_handler, "post", side_effect=httpx.ReadTimeout("Read timed out") |
| 645 | + ): |
| 646 | + # Should return original inputs on timeout (graceful degradation) |
| 647 | + result = await guardrail.apply_guardrail( |
| 648 | + inputs=inputs, |
| 649 | + request_data=request_data, |
| 650 | + input_type="request", |
| 651 | + logging_obj=None, |
| 652 | + ) |
| 653 | + |
| 654 | + assert result == inputs |
| 655 | + |
| 656 | + @pytest.mark.asyncio |
| 657 | + async def test_apply_guardrail_connect_timeout_error_handling(self): |
| 658 | + """Test handling of connect timeout errors in apply_guardrail.""" |
| 659 | + # Set required API key |
| 660 | + os.environ["ONYX_API_KEY"] = "test-api-key" |
| 661 | + |
| 662 | + guardrail = OnyxGuardrail( |
| 663 | + guardrail_name="test-guard", event_hook="pre_call", default_on=True, timeout=5.0 |
| 664 | + ) |
| 665 | + |
| 666 | + inputs = GenericGuardrailAPIInputs() |
| 667 | + |
| 668 | + request_data = { |
| 669 | + "proxy_server_request": { |
| 670 | + "messages": [{"role": "user", "content": "Test message"}], |
| 671 | + "model": "gpt-3.5-turbo", |
| 672 | + } |
| 673 | + } |
| 674 | + |
| 675 | + # Test httpx ConnectTimeout error |
| 676 | + with patch.object( |
| 677 | + guardrail.async_handler, "post", side_effect=httpx.ConnectTimeout("Connect timed out") |
| 678 | + ): |
| 679 | + # Should return original inputs on timeout (graceful degradation) |
| 680 | + result = await guardrail.apply_guardrail( |
| 681 | + inputs=inputs, |
| 682 | + request_data=request_data, |
| 683 | + input_type="request", |
| 684 | + logging_obj=None, |
| 685 | + ) |
| 686 | + |
| 687 | + assert result == inputs |
| 688 | + |
391 | 689 | @pytest.mark.asyncio |
392 | 690 | async def test_apply_guardrail_no_logging_obj(self): |
393 | 691 | """Test apply_guardrail without logging object (uses UUID).""" |
|
0 commit comments