-
Notifications
You must be signed in to change notification settings - Fork 427
Feat/agents #241
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Feat/agents #241
Changes from 1 commit
Commits
Show all changes
42 commits
Select commit
Hold shift + click to select a range
a2d98c8
feat: add function_schema property to BaseNode
srijanpatel 8f2d277
refactor: improve docstrings and comments for clarity in LLM utility …
srijanpatel 5ae66a7
refactor: enhance docstrings and improve formatting in SingleLLMCallNode
srijanpatel 9196847
refactor: rename functions as tools as functions API is deprecated
srijanpatel 90a9027
refactor: update completion_with_backoff to return Message type and a…
srijanpatel b56c3fe
refactor: update generate_text to return Message type and adjust resp…
srijanpatel 056e381
feat: add call_as_tool method to enable node invocation with arguments
srijanpatel a4775ea
feat: implement AgentNode for LLM-based agent execution with tool-cal…
srijanpatel 359494a
refactor: update is_valid_node_type to use attribute access for node_…
srijanpatel fa01f39
refactor: update function schema structure in BaseNode class
srijanpatel b379e76
fix: handle None tool_calls
srijanpatel adf6d78
refactor: normalize tool titles to lowercase in AgentNode and improve…
srijanpatel f44bc3b
feat: add AgentNode type and update environment example configuration
srijanpatel b7d27d4
feat: add visual tag to AgentNode for improved identification
srijanpatel 367c74e
feat: implement AgentNode component for tool management and interaction
srijanpatel 98859fe
feat: enhance CollapsibleNodePanel with custom node addition support …
srijanpatel 377c1a2
feat: add addToolToAgent action to support dynamic tool addition in f…
srijanpatel 596037a
feat: implement tool addition functionality in AgentNode with Collaps…
srijanpatel 4e694f3
feat: enhance CollapsibleNodePanel with controlled expansion and upda…
srijanpatel 9fe169b
Merge remote-tracking branch 'origin/main' into feat/agents
srijanpatel 653de45
fix: get tools from config.tools
srijanpatel c4931bf
feat: add hideHandles prop to BaseNode for conditional handle rendering
srijanpatel bd5264e
feat: enhance node creation with unique ID generation and isTool flag
srijanpatel b39a7d2
feat: integrate NodeResizer component and simplify tool management UI…
srijanpatel 9df6f51
feat: add CollapsibleNodePanel for tool addition in AgentNode
srijanpatel fb0c1be
feat: update ID generation to replace 'Node' with 'Tool' for better c…
srijanpatel ba7c963
feat: add isTool flag to BaseNode for conditional handle rendering
srijanpatel f034cd6
fix: normalize tool name lookup to be case-insensitive in AgentNode
srijanpatel d04c995
feat: implement processing of tool nodes for storage and response in …
srijanpatel 0964a7e
Merge remote-tracking branch 'origin/main' into feat/agents
srijanpatel 9a4f261
feat: move tools from conf to node prop
srijanpatel a34f0d2
fix: correct type hint for add_tools method in AgentNode
srijanpatel 8ada38f
add tools to agent in workflow execution
srijanpatel 7b9aeac
refactor: remove tool node processing from workflow management
srijanpatel bc1dc79
refactor: simplify tools initialization in AgentNode
srijanpatel 37fbb24
refactor: use subworkflows for tools
srijanpatel bf94a09
no need to add tools to config.tools
srijanpatel b2a8f48
feat: enable message history in Agent
srijanpatel 84cdb96
feat: add tool_calls to AgentNode output
srijanpatel 76b209a
feat: add NodeOutputModal to AgentNode for displaying output
srijanpatel 56523b9
chore: remove unused .env.example file
srijanpatel e124a8c
fix: correct formatting of system message in generate_message function
srijanpatel File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
feat: implement AgentNode for LLM-based agent execution with tool-cal…
…ling capabilities
- Loading branch information
commit a4775ea2d9315aa70718de99c1eef2b8a54aad86
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,357 @@ | ||
| import asyncio | ||
| import json | ||
| from typing import Any, Dict, List, Optional, cast | ||
|
|
||
| from jinja2 import Template | ||
| from litellm import ChatCompletionMessageToolCall, ChatCompletionToolMessage | ||
| from pydantic import BaseModel, Field | ||
|
|
||
| from ...schemas.workflow_schemas import WorkflowNodeSchema | ||
| from ...utils.pydantic_utils import get_nested_field | ||
| from ..base import BaseNode, BaseNodeInput, BaseNodeOutput | ||
| from ..factory import NodeFactory | ||
| from ._utils import create_messages, generate_text | ||
| from .single_llm_call import ( | ||
| LLMModels, | ||
| ModelInfo, | ||
| SingleLLMCallNode, | ||
| SingleLLMCallNodeConfig, | ||
| repair_json, | ||
| ) | ||
|
|
||
|
|
||
| class AgentNodeConfig(SingleLLMCallNodeConfig): | ||
| """Configuration for the AgentNode. | ||
|
|
||
| Extends SingleLLMCallNodeConfig with support for tools. | ||
| """ | ||
|
|
||
| tools: Optional[List[WorkflowNodeSchema]] = Field( | ||
| None, description="List of tool nodes that the agent can use" | ||
| ) | ||
| max_iterations: int = Field( | ||
| 10, description="Maximum number of tool calls the agent can make in a single run" | ||
| ) | ||
|
|
||
|
|
||
| class AgentNodeInput(BaseNodeInput): | ||
| pass | ||
|
|
||
| class Config: | ||
| extra = "allow" | ||
|
|
||
|
|
||
| class AgentNodeOutput(BaseNodeOutput): | ||
| pass | ||
|
|
||
|
|
||
| class AgentNode(SingleLLMCallNode): | ||
| """Node for executing an LLM-based agent with tool-calling capabilities. | ||
|
|
||
| Features: | ||
| - All features from SingleLLMCallNode | ||
| - Support for tool calling with other workflow nodes | ||
| - Control over the number of iterations and tool choice | ||
| - Tool results are fed back to the LLM for further reasoning | ||
| """ | ||
|
|
||
| name = "agent_node" | ||
| display_name = "Agent" | ||
| config_model = AgentNodeConfig | ||
| input_model = AgentNodeInput | ||
| output_model = AgentNodeOutput | ||
|
|
||
| def setup(self) -> None: | ||
| super().setup() | ||
| # Create a dictionary of tool nodes for easy access | ||
| self.tools_dict: Dict[str, WorkflowNodeSchema] = {} | ||
| tools: List[WorkflowNodeSchema] = self.config.tools or [] | ||
| for tool in tools: | ||
| self.tools_dict[tool.title] = tool | ||
|
|
||
| ## Create instances of the tools | ||
| self.tools_instances: Dict[str, BaseNode] = {} | ||
| for tool in tools: | ||
| # Create node instance | ||
| tool_node_instance = NodeFactory.create_node( | ||
| node_name=tool.id, | ||
| node_type_name=tool.node_type, | ||
| config=tool.config, | ||
| ) | ||
| self.tools_instances[tool.title] = tool_node_instance | ||
|
|
||
| ## Create list of tool schemas to pass to the LLM | ||
| self.tools_schemas: List[Dict[str, Any]] = [] | ||
| for tool in tools: | ||
| tool_node_instance = self.tools_instances[tool.title] | ||
| tool_schema = tool_node_instance.function_schema | ||
| self.tools_schemas.append(tool_schema) | ||
|
|
||
| def _render_template(self, template_str: str, data: Dict[str, Any]) -> str: | ||
| """Render a template with the given data.""" | ||
| try: | ||
| return Template(template_str).render(**data) | ||
| except Exception as e: | ||
| print(f"[ERROR] Failed to render template: {e}") | ||
| return template_str | ||
|
|
||
| async def _call_tool(self, tool_call: ChatCompletionMessageToolCall) -> Any: | ||
| """Call a tool with the provided parameters.""" | ||
| tool_name = tool_call.function.name | ||
| tool_args = tool_call.function.arguments | ||
| tool_call_id = tool_call.id | ||
|
|
||
| assert tool_name is not None, "Tool name cannot be None" | ||
| assert tool_call_id is not None, "Tool call ID cannot be None" | ||
| assert tool_args is not None, "Tool arguments cannot be None" | ||
|
|
||
| # Get the tool node from the dictionary | ||
| tool_node = self.tools_dict.get(tool_name) | ||
| if not tool_node: | ||
| raise ValueError(f"Tool {tool_name} not found in tools dictionary") | ||
|
|
||
| # Create node instance | ||
| tool_node_instance = NodeFactory.create_node( | ||
| node_name=tool_node.id, | ||
| node_type_name=tool_node.node_type, | ||
| config=tool_node.config, | ||
| ) | ||
| tool_args = json.loads(tool_args) | ||
| return await tool_node_instance.call_as_tool(arguments=tool_args) | ||
|
|
||
| async def execute_parallel_tool_calls( | ||
| self, tool_calls: List[ChatCompletionMessageToolCall] | ||
| ) -> List[ChatCompletionToolMessage]: | ||
| """Execute multiple tool calls in parallel.""" | ||
|
|
||
| # Create async tasks for all tool calls to execute them concurrently | ||
| async def process_tool_call( | ||
| tool_call: ChatCompletionMessageToolCall, | ||
| ) -> ChatCompletionToolMessage: | ||
| tool_response = await self._call_tool(tool_call) | ||
| return ChatCompletionToolMessage( | ||
| role="tool", | ||
| content=str(tool_response), | ||
| tool_call_id=tool_call.id, | ||
| ) | ||
|
|
||
| # Use asyncio.gather to run all tool calls concurrently | ||
| tool_messages: List[ChatCompletionToolMessage] = await asyncio.gather( | ||
| *[process_tool_call(tool_call) for tool_call in tool_calls] | ||
| ) | ||
| return tool_messages | ||
|
|
||
| async def run(self, input: BaseModel) -> BaseModel: | ||
| # Get the raw input dictionary | ||
| raw_input_dict = input.model_dump() | ||
|
|
||
| # Render the system message with the input data | ||
| system_message = self._render_template(self.config.system_message, raw_input_dict) | ||
| try: | ||
| # If user_message is empty, dump the entire raw dictionary | ||
| if not self.config.user_message.strip(): | ||
| user_message = json.dumps(raw_input_dict, indent=2) | ||
| else: | ||
| user_message = Template(self.config.user_message).render(**raw_input_dict) | ||
| except Exception as e: | ||
| print(f"[ERROR] Failed to render user_message {self.name}") | ||
srijanpatel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| print(f"[ERROR] user_message: {self.config.user_message} with input: {raw_input_dict}") | ||
| raise e | ||
|
|
||
| # Extract message history from input if enabled | ||
| history: Optional[List[Dict[str, str]]] = None | ||
|
|
||
| messages = create_messages( | ||
| system_message=system_message, | ||
| user_message=user_message, | ||
| few_shot_examples=self.config.few_shot_examples, | ||
| history=history, | ||
| ) | ||
|
|
||
| model_name = LLMModels(self.config.llm_info.model).value | ||
|
|
||
| url_vars: Optional[Dict[str, str]] = None | ||
| # Process URL variables if they exist and we're using a Gemini model | ||
| if self.config.url_variables: | ||
| url_vars = {} | ||
| if "file" in self.config.url_variables: | ||
| # Split the input variable reference (e.g. "input_node.video_url") | ||
| # Get the nested field value using the helper function | ||
| file_value = get_nested_field(self.config.url_variables["file"], input) | ||
| # Always use image_url format regardless of file type | ||
| url_vars["image"] = file_value | ||
|
|
||
| # Prepare thinking parameters if enabled | ||
| thinking_params = None | ||
| if self.config.enable_thinking: | ||
| model_info = LLMModels.get_model_info(model_name) | ||
| if model_info and model_info.constraints.supports_thinking: | ||
| thinking_params = { | ||
| "type": "enabled", | ||
| "budget_tokens": self.config.thinking_budget_tokens | ||
| or model_info.constraints.thinking_budget_tokens | ||
| or 1024, | ||
| } | ||
|
|
||
| try: | ||
| num_iterations = 0 | ||
| model_response = "" | ||
| # Loop until either the maximum number of iterations is reached | ||
| # or the model responds with an assistant message | ||
| while num_iterations < self.config.max_iterations: | ||
| message_response = await generate_text( | ||
| messages=messages, | ||
| model_name=model_name, | ||
| temperature=self.config.llm_info.temperature, | ||
| max_tokens=self.config.llm_info.max_tokens, | ||
| json_mode=True, | ||
| url_variables=url_vars, | ||
| output_json_schema=self.config.output_json_schema, | ||
| thinking=thinking_params, | ||
| tools=self.tools_schemas, | ||
| ) | ||
| num_iterations += 1 | ||
| # Check if the response is a tool call | ||
| if message_response.tool_calls and len(message_response.tool_calls) > 0: | ||
| tool_responses = await self.execute_parallel_tool_calls( | ||
| message_response.tool_calls | ||
| ) | ||
|
|
||
| messages.extend(cast(List[Dict[str, str]], tool_responses)) | ||
| # Add the tool responses to the messages and call the LLM for the next turn | ||
| continue | ||
|
|
||
| message = json.loads(str(message_response.content)) | ||
| if message.get("role") == "assistant": | ||
| model_response = str(message_response.content) | ||
| break | ||
| # Extract the assistant message | ||
| except Exception as e: | ||
| error_str = str(e) | ||
|
|
||
| # Handle all LiteLLM errors | ||
| if "litellm" in error_str.lower(): | ||
| error_message = "An error occurred with the LLM service" | ||
| error_type = "unknown" | ||
|
|
||
| # Extract provider from model name | ||
| provider = model_name.split("/")[0] if "/" in model_name else "unknown" | ||
|
|
||
| # Handle specific known error cases | ||
| if "VertexAIError" in error_str and "The model is overloaded" in error_str: | ||
| error_type = "overloaded" | ||
| error_message = "The model is currently overloaded. Please try again later." | ||
| elif "rate limit" in error_str.lower(): | ||
| error_type = "rate_limit" | ||
| error_message = "Rate limit exceeded. Please try again in a few minutes." | ||
| elif "context length" in error_str.lower() or "maximum token" in error_str.lower(): | ||
| error_type = "context_length" | ||
| error_message = ( | ||
| "Input is too long for the model's context window." | ||
| " Please reduce the input length." | ||
| ) | ||
| elif ( | ||
| "invalid api key" in error_str.lower() or "authentication" in error_str.lower() | ||
| ): | ||
| error_type = "auth" | ||
| error_message = ( | ||
| "Authentication error with the LLM service. Please check your API key." | ||
| ) | ||
| elif "bad gateway" in error_str.lower() or "503" in error_str: | ||
| error_type = "service_unavailable" | ||
| error_message = ( | ||
| "The LLM service is temporarily unavailable. Please try again later." | ||
| ) | ||
|
|
||
| raise Exception( | ||
| json.dumps( | ||
| { | ||
| "type": "model_provider_error", | ||
| "provider": provider, | ||
| "error_type": error_type, | ||
| "message": error_message, | ||
| "original_error": error_str, | ||
| } | ||
| ) | ||
| ) from e | ||
| raise e | ||
|
|
||
| try: | ||
| assistant_message_dict = json.loads(model_response) | ||
| except Exception: | ||
| try: | ||
| repaired_str = repair_json(model_response) | ||
| assistant_message_dict = json.loads(repaired_str) | ||
| except Exception as inner_e: | ||
| error_str = str(inner_e) | ||
| error_message = ( | ||
| "An error occurred while parsing and repairing the assistant message" | ||
| ) | ||
| error_type = "json_parse_error" | ||
| raise Exception( | ||
| json.dumps( | ||
| { | ||
| "type": "parsing_error", | ||
| "error_type": error_type, | ||
| "message": error_message, | ||
| "original_error": error_str, | ||
| "assistant_message_str": model_response, | ||
| } | ||
| ) | ||
| ) from inner_e | ||
|
|
||
| # Validate and return | ||
| assistant_message = self.output_model.model_validate(assistant_message_dict) | ||
| return assistant_message | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| import asyncio | ||
|
|
||
| async def test_agent_node(): | ||
| # Example tool node schema | ||
| tool_schema = WorkflowNodeSchema( | ||
| id="calculator", | ||
| title="Calculator", | ||
| node_type="CalculatorNode", | ||
| config={"operations": ["add", "subtract", "multiply", "divide"]}, | ||
| ) | ||
|
|
||
| # Create agent node | ||
| agent_node = AgentNode( | ||
| name="MathHelper", | ||
| config=AgentNodeConfig( | ||
| llm_info=ModelInfo(model=LLMModels.GPT_4O, temperature=0.7, max_tokens=1000), | ||
| system_message=( | ||
| "You are a helpful assistant that can use tools to solve math problems." | ||
| ), | ||
| user_message="I need help with this math problem: {{ problem }}", | ||
| tools=[tool_schema], | ||
| max_iterations=5, | ||
| url_variables=None, | ||
| enable_thinking=False, | ||
| thinking_budget_tokens=None, | ||
| enable_message_history=False, | ||
| message_history_variable=None, | ||
| output_json_schema=json.dumps( | ||
| { | ||
| "type": "object", | ||
| "properties": { | ||
| "answer": {"type": "string"}, | ||
| "explanation": {"type": "string"}, | ||
| }, | ||
| "required": ["answer", "explanation"], | ||
| } | ||
| ), | ||
| ), | ||
| ) | ||
|
|
||
| # Create input | ||
| test_input = AgentNodeInput.model_validate({"problem": "What is 25 × 13?"}) | ||
|
|
||
| # Run the agent | ||
| print("[DEBUG] Testing agent_node...") | ||
| output = await agent_node(test_input) | ||
| print("[DEBUG] Agent output:", output) | ||
|
|
||
| asyncio.run(test_agent_node()) | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.