Skip to content

Commit 1511509

Browse files
committed
sdk/python: Bug fixes for lifecycle hooks
1 parent 71c89a4 commit 1511509

File tree

7 files changed

+9
-144
lines changed

7 files changed

+9
-144
lines changed

sdk/python/polos/agents/stream.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,6 @@ async def _agent_stream_function(ctx: AgentContext, payload: dict[str, Any]) ->
157157
if agent and agent.on_agent_step_start:
158158
hook_context = HookContext(
159159
workflow_id=ctx.agent_id,
160-
agent_workflow_id=ctx.agent_id,
161-
agent_run_id=agent_run_id,
162160
session_id=ctx.session_id,
163161
user_id=ctx.user_id,
164162
agent_config=agent_config,
@@ -173,8 +171,6 @@ async def _agent_stream_function(ctx: AgentContext, payload: dict[str, Any]) ->
173171
)
174172

175173
# Apply modifications
176-
if hook_result.modified_agent_config:
177-
agent_config.update(hook_result.modified_agent_config)
178174
if hook_result.modified_payload and "messages" in hook_result.modified_payload:
179175
conversation_messages = hook_result.modified_payload["messages"]
180176

@@ -284,8 +280,6 @@ async def _agent_stream_function(ctx: AgentContext, payload: dict[str, Any]) ->
284280
if agent and agent.on_tool_start:
285281
hook_context = HookContext(
286282
workflow_id=ctx.agent_id,
287-
agent_workflow_id=ctx.agent_id,
288-
agent_run_id=agent_run_id,
289283
session_id=ctx.session_id,
290284
user_id=ctx.user_id,
291285
agent_config=agent_config,
@@ -298,8 +292,6 @@ async def _agent_stream_function(ctx: AgentContext, payload: dict[str, Any]) ->
298292
)
299293

300294
# Apply modifications
301-
if hook_result.modified_agent_config:
302-
agent_config.update(hook_result.modified_agent_config)
303295
if hook_result.modified_payload:
304296
tool_args.update(hook_result.modified_payload)
305297

@@ -353,8 +345,6 @@ async def _agent_stream_function(ctx: AgentContext, payload: dict[str, Any]) ->
353345
if agent and agent.on_tool_end:
354346
hook_context = HookContext(
355347
workflow_id=ctx.agent_id,
356-
agent_workflow_id=ctx.agent_id,
357-
agent_run_id=agent_run_id,
358348
session_id=ctx.session_id,
359349
user_id=ctx.user_id,
360350
agent_config=agent_config,
@@ -368,8 +358,6 @@ async def _agent_stream_function(ctx: AgentContext, payload: dict[str, Any]) ->
368358
)
369359

370360
# Apply modifications
371-
if hook_result.modified_agent_config:
372-
agent_config.update(hook_result.modified_agent_config)
373361
if hook_result.modified_output is not None:
374362
tool_result = hook_result.modified_output
375363

@@ -442,8 +430,6 @@ async def _agent_stream_function(ctx: AgentContext, payload: dict[str, Any]) ->
442430
if agent and agent.on_agent_step_end:
443431
hook_context = HookContext(
444432
workflow_id=ctx.agent_id,
445-
agent_workflow_id=ctx.agent_id,
446-
agent_run_id=agent_run_id,
447433
session_id=ctx.session_id,
448434
user_id=ctx.user_id,
449435
agent_config=agent_config,
@@ -456,11 +442,9 @@ async def _agent_stream_function(ctx: AgentContext, payload: dict[str, Any]) ->
456442
)
457443

458444
# Apply modifications
459-
if hook_result.modified_agent_config:
460-
agent_config.update(hook_result.modified_agent_config)
461445
if hook_result.modified_output:
462446
new_result = hook_result.modified_output
463-
steps[-1].update(new_result)
447+
steps[-1] = new_result
464448

465449
# Check hook action
466450
if hook_result.action == HookAction.FAIL:

sdk/python/polos/core/step.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ async def _publish_step_event(
188188
data={
189189
"step_key": step_key,
190190
"step_type": step_type,
191-
"input_params": safe_serialize(input_params) if input_params else {},
191+
"data": safe_serialize(input_params) if input_params else {},
192192
"_metadata": {
193193
"execution_id": self.ctx.execution_id,
194194
"workflow_id": self.ctx.workflow_id,

sdk/python/polos/middleware/guardrail.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ class GuardrailContext(BaseModel):
2626
tool_calls: list[ToolCall] | None = None # LLM tool calls
2727

2828
# Execution context (for guardrail to make decisions)
29-
agent_workflow_id: str = ""
30-
agent_run_id: str = ""
29+
agent_workflow_id: str | None = ""
30+
agent_run_id: str | None = ""
3131
session_id: str | None = None
3232
user_id: str | None = None
3333
llm_config: AgentConfig = AgentConfig(name="", provider="", model="")

sdk/python/polos/middleware/hook.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,19 +32,17 @@ class HookContext(BaseModel):
3232

3333
# Immutable identifiers
3434
workflow_id: str
35-
agent_workflow_id: str | None = None # Available for agents
36-
agent_run_id: str | None = None # Available for agents
3735
session_id: str | None = None
3836
user_id: str | None = None
39-
agent_config: AgentConfig | None = None
37+
agent_config: AgentConfig | None = None # Not available to agent's on_start and on_end hooks
4038

4139
# Current state
4240
steps: list[Step] = [] # All previous steps
4341

4442
# For workflow/tool hooks
4543
current_tool: str | None = None
46-
current_payload: dict[str, Any] | None = None
47-
current_output: dict[str, Any] | None = None
44+
current_payload: dict[str, Any] | BaseModel | None = None
45+
current_output: dict[str, Any] | BaseModel | None = None
4846

4947
def to_dict(self) -> dict[str, Any]:
5048
"""Convert hook context to dictionary for serialization."""
@@ -72,7 +70,6 @@ class HookResult(BaseModel):
7270
action: HookAction = HookAction.CONTINUE
7371

7472
# Optional modifications
75-
modified_agent_config: AgentConfig | None = None
7673
modified_payload: dict[str, Any] | None = None
7774
modified_output: Any | None = None
7875

@@ -84,7 +81,7 @@ def continue_with(cls, **modifications) -> "HookResult":
8481
"""Continue with optional modifications.
8582
8683
Args:
87-
**modifications: Can include modified_agent_config, modified_payload, modified_output
84+
**modifications: Can include modified_payload, modified_output
8885
8986
Returns:
9087
HookResult with CONTINUE action and modifications

sdk/python/polos/middleware/hooks.py

Lines changed: 0 additions & 109 deletions
This file was deleted.

sdk/python/tests/unit/test_core/test_step.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ async def test_publish_step_event(self, mock_workflow_context):
243243
assert event.event_type == event_type
244244
assert event.data["step_key"] == step_key
245245
assert event.data["step_type"] == event_name # Note: it's "step_type" not "event_name"
246-
assert "input_params" in event.data
246+
assert "data" in event.data
247247

248248
@pytest.mark.asyncio
249249
async def test_publish_step_event_topic(self, mock_workflow_context):

sdk/python/tests/unit/test_middleware/test_hook.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@ def test_hook_context_initialization(self):
2828
"""Test HookContext initialization."""
2929
ctx = HookContext(workflow_id="test-workflow")
3030
assert ctx.workflow_id == "test-workflow"
31-
assert ctx.agent_workflow_id is None
32-
assert ctx.agent_run_id is None
3331
assert ctx.session_id is None
3432
assert ctx.user_id is None
3533
assert ctx.agent_config is None
@@ -46,8 +44,6 @@ def test_hook_context_full_initialization(self):
4644
steps = [Step(step=1, content="test")]
4745
ctx = HookContext(
4846
workflow_id="test-workflow",
49-
agent_workflow_id="test-agent",
50-
agent_run_id="test-run",
5147
session_id="test-session",
5248
user_id="test-user",
5349
agent_config=agent_config,
@@ -57,8 +53,6 @@ def test_hook_context_full_initialization(self):
5753
current_output={"result": "output"},
5854
)
5955
assert ctx.workflow_id == "test-workflow"
60-
assert ctx.agent_workflow_id == "test-agent"
61-
assert ctx.agent_run_id == "test-run"
6256
assert ctx.session_id == "test-session"
6357
assert ctx.user_id == "test-user"
6458
assert ctx.agent_config == agent_config
@@ -106,7 +100,6 @@ def test_hook_result_default(self):
106100
"""Test HookResult with default values."""
107101
result = HookResult()
108102
assert result.action == HookAction.CONTINUE
109-
assert result.modified_agent_config is None
110103
assert result.modified_payload is None
111104
assert result.modified_output is None
112105
assert result.error_message is None

0 commit comments

Comments
 (0)