@@ -25,14 +25,14 @@ async def handler(message):
2525
2626from __future__ import annotations
2727
28+ import asyncio
2829import uuid
2930from dataclasses import dataclass , field
30- from typing import Any , Awaitable , Callable , Dict , Optional
3131from functools import wraps
32+ from typing import Any , Awaitable , Callable , Dict , Optional
3233
3334from praisonaiui .server import MessageContext
3435
35-
3636# Global registry for action callbacks
3737# Key: action name, Value: async callback function
3838_ACTION_REGISTRY : Dict [str , Callable [["Action" ], Awaitable [None ]]] = {}
@@ -41,50 +41,50 @@ async def handler(message):
4141@dataclass
4242class Action :
4343 """An interactive action button attached to a message.
44-
44+
4545 Attributes:
4646 name: Unique identifier for the action (used for callback registration)
4747 label: Display text for the button
4848 payload: Optional data passed to the callback (default: None)
49- icon: Optional icon name (default: None)
49+ icon: Optional icon name (default: None)
5050 variant: Button style variant (default: "secondary")
5151 id: Auto-generated unique ID for this action instance
5252 message_id: ID of the message this action belongs to (set automatically)
53-
53+
5454 Example:
5555 action = Action(
5656 name="approve_pr",
57- label="Approve",
57+ label="Approve",
5858 payload={"pr_number": 42},
5959 icon="check",
6060 variant="primary"
6161 )
6262 """
63-
63+
6464 name : str
6565 label : str
6666 payload : Optional [Dict [str , Any ]] = None
6767 icon : Optional [str ] = None
6868 variant : str = "secondary"
69-
69+
7070 # Auto-generated fields
7171 id : str = field (default_factory = lambda : str (uuid .uuid4 ()))
7272 message_id : Optional [str ] = None
73-
73+
7474 # Internal state
7575 _context : Optional [MessageContext ] = field (default = None , repr = False )
76-
76+
7777 def __post_init__ (self ):
7878 """Initialize action with context from current callback."""
7979 from praisonaiui .callbacks import _get_context
8080 self ._context = _get_context ()
81-
81+
8282 def to_dict (self ) -> Dict [str , Any ]:
8383 """Serialize action to deterministic dict format.
84-
84+
8585 Returns dict with keys sorted alphabetically for stable serialization.
8686 This ensures the action survives persistence and page reloads.
87-
87+
8888 Returns:
8989 Dict representation suitable for JSON serialization
9090 """
@@ -94,26 +94,26 @@ def to_dict(self) -> Dict[str, Any]:
9494 "label" : self .label ,
9595 "variant" : self .variant ,
9696 }
97-
97+
9898 # Add optional fields only if they have values (deterministic output)
9999 if self .payload is not None :
100100 result ["payload" ] = self .payload
101101 if self .icon is not None :
102102 result ["icon" ] = self .icon
103103 if self .message_id is not None :
104104 result ["message_id" ] = self .message_id
105-
105+
106106 return result
107-
107+
108108 async def remove (self ) -> None :
109109 """Remove this action button from the rendered message.
110-
110+
111111 Emits a server-side event that removes the button from the UI.
112112 The action becomes non-functional after removal.
113113 """
114114 if not self ._context or not self ._context ._stream_queue :
115115 return
116-
116+
117117 await self ._context ._stream_queue .put ({
118118 "type" : "action_remove" ,
119119 "action_id" : self .id ,
@@ -123,112 +123,151 @@ async def remove(self) -> None:
123123
124124def action_callback (name : str ) -> Callable [[Callable ], Callable ]:
125125 """Decorator to register an async action callback handler.
126-
126+
127127 The decorated function will be called when an action with the given name
128128 is clicked by the user. The function receives an Action instance with
129129 the original payload and context.
130-
130+
131131 Args:
132132 name: The action name to register for (must match Action.name)
133-
133+
134134 Returns:
135135 Decorator function that registers the callback
136-
136+
137137 Example:
138138 @action_callback("approve_pr")
139139 async def on_approve(action: Action):
140140 pr_number = action.payload["pr_number"]
141141 await action.remove() # Hide button after click
142142 await Message(content=f"✅ PR #{pr_number} approved").send()
143-
143+
144144 Raises:
145145 ValueError: If name is empty or callback is not async
146146 """
147147 if not name :
148148 raise ValueError ("Action name cannot be empty" )
149-
150- def decorator (func : Callable [["Action" ], Awaitable [None ]]) -> Callable [["Action" ], Awaitable [None ]]:
149+
150+ def decorator (
151+ func : Callable [["Action" ], Awaitable [None ]]
152+ ) -> Callable [["Action" ], Awaitable [None ]]:
151153 if not callable (func ):
152154 raise ValueError ("Action callback must be callable" )
153-
155+ import inspect
156+ if not inspect .iscoroutinefunction (func ):
157+ raise ValueError ("Action callback must be an async function (coroutine)" )
158+
159+ # Check for duplicate registration and warn
160+ if name in _ACTION_REGISTRY :
161+ import warnings
162+ warnings .warn (
163+ f"Action callback '{ name } ' is already registered and will be overwritten" ,
164+ UserWarning ,
165+ stacklevel = 3 ,
166+ )
167+
154168 # Register the callback in global registry
155169 _ACTION_REGISTRY [name ] = func
156-
170+
157171 @wraps (func )
158172 async def wrapper (action : "Action" ) -> None :
159173 return await func (action )
160-
174+
161175 return wrapper
162-
176+
163177 return decorator
164178
165179
166- def register_action_callback (name : str , callback : Callable [["Action" ], Awaitable [None ]]) -> None :
180+ def register_action_callback (
181+ name : str , callback : Callable [["Action" ], Awaitable [None ]]
182+ ) -> None :
167183 """Programmatically register an action callback (alternative to decorator).
168-
184+
169185 Args:
170186 name: The action name to register for
171187 callback: Async function that handles the action
172-
188+
173189 Raises:
174190 ValueError: If name is empty or callback is not callable
175-
191+
176192 Example:
177193 async def my_handler(action: Action):
178194 print(f"Action {action.name} clicked with payload: {action.payload}")
179-
195+
180196 register_action_callback("my_action", my_handler)
181197 """
182198 if not name :
183199 raise ValueError ("Action name cannot be empty" )
184200 if not callable (callback ):
185201 raise ValueError ("Callback must be callable" )
186-
202+
187203 _ACTION_REGISTRY [name ] = callback
188204
189205
190206async def dispatch_action_callback (
191- action_name : str ,
207+ action_name : str ,
192208 action_id : str ,
193209 payload : Optional [Dict [str , Any ]] = None ,
194210 message_id : Optional [str ] = None ,
195- session_id : Optional [str ] = None
211+ session_id : Optional [str ] = None ,
212+ stream_queue : Optional [asyncio .Queue ] = None
196213) -> None :
197214 """Dispatch an action callback by name.
198-
215+
199216 Called by the server endpoint when an action button is clicked.
200217 Creates an Action instance with the provided data and calls the registered callback.
201-
218+
202219 Args:
203220 action_name: Name of the action (must be registered)
204221 action_id: Unique ID of the clicked action instance
205222 payload: Optional data from the original action
206223 message_id: ID of the message containing the action
207- session_id: Session ID for context (currently unused)
208-
224+ session_id: Session ID for context
225+ stream_queue: Stream queue for server-side events (for action.remove())
226+
209227 Raises:
210228 ValueError: If no callback is registered for the action name (HTTP 404 equivalent)
211229 """
212230 if action_name not in _ACTION_REGISTRY :
213231 raise ValueError (f"No callback registered for action '{ action_name } '" )
214-
232+
215233 callback = _ACTION_REGISTRY [action_name ]
216-
217- # Reconstruct the action for the callback
218- action = Action (
219- name = action_name ,
220- label = "" , # Label not needed for callback dispatch
221- payload = payload ,
222- id = action_id ,
223- message_id = message_id ,
234+
235+ # Create a MessageContext for server-side side effects (like action.remove())
236+ if stream_queue is None :
237+ stream_queue = asyncio .Queue ()
238+
239+ msg_context = MessageContext (
240+ text = "" , # Not needed for action dispatch
241+ session_id = session_id or "" ,
242+ agent_name = "action_callback" ,
224243 )
225-
226- await callback (action )
244+ msg_context ._stream_queue = stream_queue
245+
246+ # Set the context for the duration of the callback
247+ from praisonaiui .callbacks import _set_context
248+ _set_context (msg_context )
249+
250+ try :
251+ # Reconstruct the action for the callback
252+ action = Action (
253+ name = action_name ,
254+ label = "" , # Label not needed for callback dispatch
255+ payload = payload ,
256+ id = action_id ,
257+ message_id = message_id ,
258+ )
259+ # Set the context directly on the action
260+ action ._context = msg_context
261+
262+ await callback (action )
263+ finally :
264+ # Clear the context after callback execution
265+ _set_context (None )
227266
228267
229268def get_registered_actions () -> Dict [str , Callable [["Action" ], Awaitable [None ]]]:
230269 """Get all registered action callbacks (for testing/debugging).
231-
270+
232271 Returns:
233272 Copy of the action registry dict
234273 """
@@ -237,8 +276,8 @@ def get_registered_actions() -> Dict[str, Callable[["Action"], Awaitable[None]]]
237276
238277def clear_action_registry () -> None :
239278 """Clear all registered action callbacks (for testing).
240-
279+
241280 Warning: This will remove all action handlers. Use only in tests.
242281 """
243282 global _ACTION_REGISTRY
244- _ACTION_REGISTRY .clear ()
283+ _ACTION_REGISTRY .clear ()
0 commit comments