|
12 | 12 | from langgraph.types import Command |
13 | 13 |
|
14 | 14 | from src.config.report_style import ReportStyle |
15 | | -from src.server.app import _astream_workflow_generator, _make_event, app |
| 15 | +from src.server.app import ( |
| 16 | + _astream_workflow_generator, |
| 17 | + _create_interrupt_event, |
| 18 | + _make_event, |
| 19 | + app, |
| 20 | +) |
16 | 21 |
|
17 | 22 |
|
18 | 23 | @pytest.fixture |
@@ -657,9 +662,9 @@ async def mock_astream(*args, **kwargs): |
657 | 662 | @pytest.mark.asyncio |
658 | 663 | @patch("src.server.app.graph") |
659 | 664 | async def test_astream_workflow_generator_interrupt_event(self, mock_graph): |
660 | | - # Mock interrupt data |
| 665 | + # Mock interrupt data with the new 'id' attribute (LangGraph 1.0+) |
661 | 666 | mock_interrupt = MagicMock() |
662 | | - mock_interrupt.ns = ["interrupt_id"] |
| 667 | + mock_interrupt.id = "interrupt_id" |
663 | 668 | mock_interrupt.value = "Plan requires approval" |
664 | 669 |
|
665 | 670 | interrupt_data = {"__interrupt__": [mock_interrupt]} |
@@ -920,3 +925,94 @@ def test_generate_prose_error(self, mock_build_graph, client): |
920 | 925 | response = client.post("/api/prose/generate", json=request_data) |
921 | 926 | assert response.status_code == 500 |
922 | 927 | assert response.json()["detail"] == "Internal Server Error" |
| 928 | + |
| 929 | + |
| 930 | +class TestCreateInterruptEvent: |
| 931 | + """Tests for _create_interrupt_event function (Issue #730 fix).""" |
| 932 | + |
| 933 | + def test_create_interrupt_event_with_id_attribute(self): |
| 934 | + """Test that _create_interrupt_event works with LangGraph 1.0+ Interrupt objects that have 'id' attribute.""" |
| 935 | + # Create a mock Interrupt object with the new 'id' attribute (LangGraph 1.0+) |
| 936 | + mock_interrupt = MagicMock() |
| 937 | + mock_interrupt.id = "interrupt-123" |
| 938 | + mock_interrupt.value = "Please review the research plan" |
| 939 | + |
| 940 | + event_data = {"__interrupt__": [mock_interrupt]} |
| 941 | + thread_id = "thread-456" |
| 942 | + |
| 943 | + result = _create_interrupt_event(thread_id, event_data) |
| 944 | + |
| 945 | + # Verify the result is a properly formatted SSE event |
| 946 | + assert "event: interrupt\n" in result |
| 947 | + assert '"thread_id": "thread-456"' in result |
| 948 | + assert '"id": "interrupt-123"' in result |
| 949 | + assert '"content": "Please review the research plan"' in result |
| 950 | + assert '"finish_reason": "interrupt"' in result |
| 951 | + assert '"role": "assistant"' in result |
| 952 | + |
| 953 | + def test_create_interrupt_event_fallback_to_thread_id(self): |
| 954 | + """Test that _create_interrupt_event falls back to thread_id when 'id' attribute is None.""" |
| 955 | + # Create a mock Interrupt object where id is None |
| 956 | + mock_interrupt = MagicMock() |
| 957 | + mock_interrupt.id = None |
| 958 | + mock_interrupt.value = "Plan review needed" |
| 959 | + |
| 960 | + event_data = {"__interrupt__": [mock_interrupt]} |
| 961 | + thread_id = "thread-789" |
| 962 | + |
| 963 | + result = _create_interrupt_event(thread_id, event_data) |
| 964 | + |
| 965 | + # Verify it falls back to thread_id |
| 966 | + assert '"id": "thread-789"' in result |
| 967 | + assert '"thread_id": "thread-789"' in result |
| 968 | + assert '"content": "Plan review needed"' in result |
| 969 | + |
| 970 | + def test_create_interrupt_event_without_id_attribute(self): |
| 971 | + """Test that _create_interrupt_event handles objects without 'id' attribute (backward compatibility).""" |
| 972 | + # Create a mock object that doesn't have 'id' attribute at all |
| 973 | + class MockInterrupt: |
| 974 | + pass |
| 975 | + mock_interrupt = MockInterrupt() |
| 976 | + mock_interrupt.value = "Waiting for approval" |
| 977 | + |
| 978 | + event_data = {"__interrupt__": [mock_interrupt]} |
| 979 | + thread_id = "thread-abc" |
| 980 | + |
| 981 | + result = _create_interrupt_event(thread_id, event_data) |
| 982 | + |
| 983 | + # Verify it falls back to thread_id when id attribute doesn't exist |
| 984 | + assert '"id": "thread-abc"' in result |
| 985 | + assert '"content": "Waiting for approval"' in result |
| 986 | + |
| 987 | + def test_create_interrupt_event_options(self): |
| 988 | + """Test that _create_interrupt_event includes correct options.""" |
| 989 | + mock_interrupt = MagicMock() |
| 990 | + mock_interrupt.id = "int-001" |
| 991 | + mock_interrupt.value = "Review plan" |
| 992 | + |
| 993 | + event_data = {"__interrupt__": [mock_interrupt]} |
| 994 | + thread_id = "thread-xyz" |
| 995 | + |
| 996 | + result = _create_interrupt_event(thread_id, event_data) |
| 997 | + |
| 998 | + # Verify options are included |
| 999 | + assert '"options":' in result |
| 1000 | + assert '"text": "Edit plan"' in result |
| 1001 | + assert '"value": "edit_plan"' in result |
| 1002 | + assert '"text": "Start research"' in result |
| 1003 | + assert '"value": "accepted"' in result |
| 1004 | + |
| 1005 | + def test_create_interrupt_event_with_complex_value(self): |
| 1006 | + """Test that _create_interrupt_event handles complex content values.""" |
| 1007 | + mock_interrupt = MagicMock() |
| 1008 | + mock_interrupt.id = "int-complex" |
| 1009 | + mock_interrupt.value = {"plan": "Research AI", "steps": ["step1", "step2"]} |
| 1010 | + |
| 1011 | + event_data = {"__interrupt__": [mock_interrupt]} |
| 1012 | + thread_id = "thread-complex" |
| 1013 | + |
| 1014 | + result = _create_interrupt_event(thread_id, event_data) |
| 1015 | + |
| 1016 | + # Verify complex value is included (will be serialized as JSON) |
| 1017 | + assert '"id": "int-complex"' in result |
| 1018 | + assert "Research AI" in result or "plan" in result |
0 commit comments