Skip to content

Commit bd9dc9b

Browse files
fix: address review comments - security, context, error handling
- Fixed Action constructor TypeError by removing **kwargs - Added security verification for action callbacks in server endpoint - Fixed missing MessageContext during action dispatch for action.remove() - Added user-facing error notifications via toast in ActionButtons - Added async function validation for action callbacks - Added duplicate action name protection with warnings - Fixed all ruff linting issues Co-authored-by: Mervin Praison <MervinPraison@users.noreply.github.com>
1 parent 4a65a19 commit bd9dc9b

5 files changed

Lines changed: 248 additions & 160 deletions

File tree

src/frontend/src/chat/ActionButtons.tsx

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,24 @@
11
import { useState, useCallback } from 'react'
22
import type { ActionButton } from '../types'
33

4+
// Simple toast notification function
5+
function showToast(message: string, type: 'error' | 'success' = 'error') {
6+
const toast = document.createElement('div')
7+
toast.textContent = message
8+
toast.className = `fixed top-4 right-4 z-50 px-4 py-2 rounded-md text-white max-w-sm transition-opacity duration-300 ${
9+
type === 'error' ? 'bg-red-500' : 'bg-green-500'
10+
}`
11+
document.body.appendChild(toast)
12+
13+
// Auto-remove after 4 seconds
14+
setTimeout(() => {
15+
toast.style.opacity = '0'
16+
setTimeout(() => {
17+
document.body.removeChild(toast)
18+
}, 300)
19+
}, 4000)
20+
}
21+
422
interface ActionButtonsProps {
523
actions: ActionButton[]
624
messageId: string
@@ -54,7 +72,8 @@ export function ActionButtons({ actions, messageId, sessionId }: ActionButtonsPr
5472

5573
} catch (error) {
5674
console.error(`Failed to execute action '${action.name}':`, error)
57-
// TODO: Show user-friendly error notification
75+
const errorMessage = error instanceof Error ? error.message : 'Unknown error occurred'
76+
showToast(`Failed to execute action "${action.label}": ${errorMessage}`)
5877
} finally {
5978
// Remove pending state
6079
setPendingActions(prev => {

src/praisonaiui/actions.py

Lines changed: 94 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@ async def handler(message):
2525

2626
from __future__ import annotations
2727

28+
import asyncio
2829
import uuid
2930
from dataclasses import dataclass, field
30-
from typing import Any, Awaitable, Callable, Dict, Optional
3131
from functools import wraps
32+
from typing import Any, Awaitable, Callable, Dict, Optional
3233

3334
from praisonaiui.server import MessageContext
3435

35-
3636
# Global registry for action callbacks
3737
# Key: action name, Value: async callback function
3838
_ACTION_REGISTRY: Dict[str, Callable[["Action"], Awaitable[None]]] = {}
@@ -41,50 +41,50 @@ async def handler(message):
4141
@dataclass
4242
class Action:
4343
"""An interactive action button attached to a message.
44-
44+
4545
Attributes:
4646
name: Unique identifier for the action (used for callback registration)
4747
label: Display text for the button
4848
payload: Optional data passed to the callback (default: None)
49-
icon: Optional icon name (default: None)
49+
icon: Optional icon name (default: None)
5050
variant: Button style variant (default: "secondary")
5151
id: Auto-generated unique ID for this action instance
5252
message_id: ID of the message this action belongs to (set automatically)
53-
53+
5454
Example:
5555
action = Action(
5656
name="approve_pr",
57-
label="Approve",
57+
label="Approve",
5858
payload={"pr_number": 42},
5959
icon="check",
6060
variant="primary"
6161
)
6262
"""
63-
63+
6464
name: str
6565
label: str
6666
payload: Optional[Dict[str, Any]] = None
6767
icon: Optional[str] = None
6868
variant: str = "secondary"
69-
69+
7070
# Auto-generated fields
7171
id: str = field(default_factory=lambda: str(uuid.uuid4()))
7272
message_id: Optional[str] = None
73-
73+
7474
# Internal state
7575
_context: Optional[MessageContext] = field(default=None, repr=False)
76-
76+
7777
def __post_init__(self):
7878
"""Initialize action with context from current callback."""
7979
from praisonaiui.callbacks import _get_context
8080
self._context = _get_context()
81-
81+
8282
def to_dict(self) -> Dict[str, Any]:
8383
"""Serialize action to deterministic dict format.
84-
84+
8585
Returns dict with keys sorted alphabetically for stable serialization.
8686
This ensures the action survives persistence and page reloads.
87-
87+
8888
Returns:
8989
Dict representation suitable for JSON serialization
9090
"""
@@ -94,26 +94,26 @@ def to_dict(self) -> Dict[str, Any]:
9494
"label": self.label,
9595
"variant": self.variant,
9696
}
97-
97+
9898
# Add optional fields only if they have values (deterministic output)
9999
if self.payload is not None:
100100
result["payload"] = self.payload
101101
if self.icon is not None:
102102
result["icon"] = self.icon
103103
if self.message_id is not None:
104104
result["message_id"] = self.message_id
105-
105+
106106
return result
107-
107+
108108
async def remove(self) -> None:
109109
"""Remove this action button from the rendered message.
110-
110+
111111
Emits a server-side event that removes the button from the UI.
112112
The action becomes non-functional after removal.
113113
"""
114114
if not self._context or not self._context._stream_queue:
115115
return
116-
116+
117117
await self._context._stream_queue.put({
118118
"type": "action_remove",
119119
"action_id": self.id,
@@ -123,112 +123,151 @@ async def remove(self) -> None:
123123

124124
def action_callback(name: str) -> Callable[[Callable], Callable]:
125125
"""Decorator to register an async action callback handler.
126-
126+
127127
The decorated function will be called when an action with the given name
128128
is clicked by the user. The function receives an Action instance with
129129
the original payload and context.
130-
130+
131131
Args:
132132
name: The action name to register for (must match Action.name)
133-
133+
134134
Returns:
135135
Decorator function that registers the callback
136-
136+
137137
Example:
138138
@action_callback("approve_pr")
139139
async def on_approve(action: Action):
140140
pr_number = action.payload["pr_number"]
141141
await action.remove() # Hide button after click
142142
await Message(content=f"✅ PR #{pr_number} approved").send()
143-
143+
144144
Raises:
145145
ValueError: If name is empty or callback is not async
146146
"""
147147
if not name:
148148
raise ValueError("Action name cannot be empty")
149-
150-
def decorator(func: Callable[["Action"], Awaitable[None]]) -> Callable[["Action"], Awaitable[None]]:
149+
150+
def decorator(
151+
func: Callable[["Action"], Awaitable[None]]
152+
) -> Callable[["Action"], Awaitable[None]]:
151153
if not callable(func):
152154
raise ValueError("Action callback must be callable")
153-
155+
import inspect
156+
if not inspect.iscoroutinefunction(func):
157+
raise ValueError("Action callback must be an async function (coroutine)")
158+
159+
# Check for duplicate registration and warn
160+
if name in _ACTION_REGISTRY:
161+
import warnings
162+
warnings.warn(
163+
f"Action callback '{name}' is already registered and will be overwritten",
164+
UserWarning,
165+
stacklevel=3,
166+
)
167+
154168
# Register the callback in global registry
155169
_ACTION_REGISTRY[name] = func
156-
170+
157171
@wraps(func)
158172
async def wrapper(action: "Action") -> None:
159173
return await func(action)
160-
174+
161175
return wrapper
162-
176+
163177
return decorator
164178

165179

166-
def register_action_callback(name: str, callback: Callable[["Action"], Awaitable[None]]) -> None:
180+
def register_action_callback(
181+
name: str, callback: Callable[["Action"], Awaitable[None]]
182+
) -> None:
167183
"""Programmatically register an action callback (alternative to decorator).
168-
184+
169185
Args:
170186
name: The action name to register for
171187
callback: Async function that handles the action
172-
188+
173189
Raises:
174190
ValueError: If name is empty or callback is not callable
175-
191+
176192
Example:
177193
async def my_handler(action: Action):
178194
print(f"Action {action.name} clicked with payload: {action.payload}")
179-
195+
180196
register_action_callback("my_action", my_handler)
181197
"""
182198
if not name:
183199
raise ValueError("Action name cannot be empty")
184200
if not callable(callback):
185201
raise ValueError("Callback must be callable")
186-
202+
187203
_ACTION_REGISTRY[name] = callback
188204

189205

190206
async def dispatch_action_callback(
191-
action_name: str,
207+
action_name: str,
192208
action_id: str,
193209
payload: Optional[Dict[str, Any]] = None,
194210
message_id: Optional[str] = None,
195-
session_id: Optional[str] = None
211+
session_id: Optional[str] = None,
212+
stream_queue: Optional[asyncio.Queue] = None
196213
) -> None:
197214
"""Dispatch an action callback by name.
198-
215+
199216
Called by the server endpoint when an action button is clicked.
200217
Creates an Action instance with the provided data and calls the registered callback.
201-
218+
202219
Args:
203220
action_name: Name of the action (must be registered)
204221
action_id: Unique ID of the clicked action instance
205222
payload: Optional data from the original action
206223
message_id: ID of the message containing the action
207-
session_id: Session ID for context (currently unused)
208-
224+
session_id: Session ID for context
225+
stream_queue: Stream queue for server-side events (for action.remove())
226+
209227
Raises:
210228
ValueError: If no callback is registered for the action name (HTTP 404 equivalent)
211229
"""
212230
if action_name not in _ACTION_REGISTRY:
213231
raise ValueError(f"No callback registered for action '{action_name}'")
214-
232+
215233
callback = _ACTION_REGISTRY[action_name]
216-
217-
# Reconstruct the action for the callback
218-
action = Action(
219-
name=action_name,
220-
label="", # Label not needed for callback dispatch
221-
payload=payload,
222-
id=action_id,
223-
message_id=message_id,
234+
235+
# Create a MessageContext for server-side side effects (like action.remove())
236+
if stream_queue is None:
237+
stream_queue = asyncio.Queue()
238+
239+
msg_context = MessageContext(
240+
text="", # Not needed for action dispatch
241+
session_id=session_id or "",
242+
agent_name="action_callback",
224243
)
225-
226-
await callback(action)
244+
msg_context._stream_queue = stream_queue
245+
246+
# Set the context for the duration of the callback
247+
from praisonaiui.callbacks import _set_context
248+
_set_context(msg_context)
249+
250+
try:
251+
# Reconstruct the action for the callback
252+
action = Action(
253+
name=action_name,
254+
label="", # Label not needed for callback dispatch
255+
payload=payload,
256+
id=action_id,
257+
message_id=message_id,
258+
)
259+
# Set the context directly on the action
260+
action._context = msg_context
261+
262+
await callback(action)
263+
finally:
264+
# Clear the context after callback execution
265+
_set_context(None)
227266

228267

229268
def get_registered_actions() -> Dict[str, Callable[["Action"], Awaitable[None]]]:
230269
"""Get all registered action callbacks (for testing/debugging).
231-
270+
232271
Returns:
233272
Copy of the action registry dict
234273
"""
@@ -237,8 +276,8 @@ def get_registered_actions() -> Dict[str, Callable[["Action"], Awaitable[None]]]
237276

238277
def clear_action_registry() -> None:
239278
"""Clear all registered action callbacks (for testing).
240-
279+
241280
Warning: This will remove all action handlers. Use only in tests.
242281
"""
243282
global _ACTION_REGISTRY
244-
_ACTION_REGISTRY.clear()
283+
_ACTION_REGISTRY.clear()

src/praisonaiui/message.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -354,8 +354,7 @@ def add_action(
354354
label=label,
355355
icon=icon,
356356
payload=payload,
357-
variant=variant,
358-
**kwargs
357+
variant=variant
359358
)
360359
self.actions.append(action)
361360
except ImportError:

0 commit comments

Comments
 (0)