From b489ee3ee43b633dc9b226aa3076868cce2db272 Mon Sep 17 00:00:00 2001 From: Srijan Patel Date: Fri, 14 Mar 2025 19:53:01 +0000 Subject: [PATCH 01/21] refactor: improve docstring formatting and add validation check in NodeFactory --- backend/pyspur/nodes/factory.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/backend/pyspur/nodes/factory.py b/backend/pyspur/nodes/factory.py index aaf7520c8..3eec83d20 100644 --- a/backend/pyspur/nodes/factory.py +++ b/backend/pyspur/nodes/factory.py @@ -12,8 +12,8 @@ class NodeFactory: - """ - Factory for creating node instances from a configuration. + """Create node instances from a configuration. + Supports both decorator-based registration and legacy configured registration. Conventions: @@ -30,8 +30,8 @@ class NodeFactory: @staticmethod def get_all_node_types() -> Dict[str, List[NodeTypeSchema]]: - """ - Returns a dictionary of all available node types grouped by category. + """Return a dictionary of all available node types grouped by category. + Combines both decorator-registered and configured nodes. """ # Get nodes from both sources @@ -44,6 +44,12 @@ def get_all_node_types() -> Dict[str, List[NodeTypeSchema]]: if category not in converted_nodes: converted_nodes[category] = [] for node in nodes: + if ( + node["node_type_name"] is None + or node["module"] is None + or node["class_name"] is None + ): + continue schema = NodeTypeSchema( node_type_name=node["node_type_name"], module=node["module"], @@ -65,8 +71,8 @@ def get_all_node_types() -> Dict[str, List[NodeTypeSchema]]: @staticmethod def create_node(node_name: str, node_type_name: str, config: Any) -> BaseNode: - """ - Creates a node instance from a configuration. + """Create a node instance from a configuration. + Checks both registration methods for the node type. """ if not is_valid_node_type(node_type_name): From 8cf6aba1b559e88639995bacda9f105af9d8ca08 Mon Sep 17 00:00:00 2001 From: Srijan Patel Date: Sat, 15 Mar 2025 01:42:27 +0000 Subject: [PATCH 02/21] docs: refine docstring in get_node_types function for clarity --- backend/pyspur/api/node_management.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/backend/pyspur/api/node_management.py b/backend/pyspur/api/node_management.py index 286d32605..6b116d5bb 100644 --- a/backend/pyspur/api/node_management.py +++ b/backend/pyspur/api/node_management.py @@ -13,9 +13,7 @@ description="Get the schemas for all available node types", ) async def get_node_types() -> Dict[str, List[Dict[str, Any]]]: - """ - Returns the schemas for all available node types. - """ + """Return the schemas for all available node types.""" # get the schemas for each node class node_groups = NodeFactory.get_all_node_types() From c5bb754d65e4c03b9ab791eea8d3a04b4c66d58c Mon Sep 17 00:00:00 2001 From: Srijan Patel Date: Sat, 15 Mar 2025 01:42:50 +0000 Subject: [PATCH 03/21] feat: enhance NodeRegistry to discover and register tool function nodes - Added functionality to discover tool function nodes in the tools directory. - Introduced a new method for registering tool function nodes, ensuring separation from regular nodes. - Updated docstrings for clarity on the new features and their handling. --- backend/pyspur/nodes/registry.py | 62 +++++++++++++++++++------------- 1 file changed, 38 insertions(+), 24 deletions(-) diff --git a/backend/pyspur/nodes/registry.py b/backend/pyspur/nodes/registry.py index cc1506055..56a3745c3 100644 --- a/backend/pyspur/nodes/registry.py +++ b/backend/pyspur/nodes/registry.py @@ -159,6 +159,9 @@ def discover_nodes(cls, package_path: str = "pyspur.nodes") -> None: # Start recursive discovery cls._discover_in_directory(base_path, package_path) + # Also discover tool function nodes + cls.discover_tool_functions() + logger.info( "Node discovery complete." f" Found {len(cls._decorator_registered_classes)} decorated nodes." @@ -174,6 +177,12 @@ def discover_tool_functions(cls) -> None: This method searches recursively through Python files in the PROJECT_ROOT/tools directory for functions decorated with @tool_function and registers their node classes. Works with both package (with __init__.py) and non-package Python files. + + The method handles the following: + 1. Discovers tool function decorated nodes + 2. Ensures proper registration in the NodeRegistry + 3. Handles both package and non-package Python files + 4. Maintains separation between tool function nodes and regular nodes """ # Get PROJECT_ROOT from environment variable project_root = os.getenv("PROJECT_ROOT") @@ -230,9 +239,35 @@ def _get_module_path(file_path: Path, base_path: Path) -> str: # Fallback to absolute path if anything goes wrong return str(file_path) - def _discover_tools_in_directory(path: Path) -> None: - nonlocal registered_tools + def _register_tool_function_node(node_class: Type[BaseNode], category: str) -> None: + """Register a tool function node in the NodeRegistry. + + This function handles the specific registration logic for tool function nodes, + ensuring they are properly integrated into the registry while maintaining + separation from regular nodes. + """ + if category not in cls._nodes: + cls._nodes[category] = [] + # Create node registration info + node_info = { + "node_type_name": node_class.__name__, + "module": node_class.__module__, + "class_name": node_class.__name__, + "subcategory": getattr(node_class, "subcategory", None), + } + + # Add to registry if not already present + if not any(n["node_type_name"] == node_class.__name__ for n in cls._nodes[category]): + cls._nodes[category].append(node_info) + nonlocal registered_tools + registered_tools += 1 + logger.debug( + f"Registered tool function node {node_class.__name__} in category {category}" + ) + + def _discover_tools_in_directory(path: Path) -> None: + """Recursively discover tool functions in a directory.""" for item in path.iterdir(): if item.is_file() and item.suffix == ".py" and not item.name.startswith("_"): try: @@ -263,28 +298,7 @@ def _discover_tools_in_directory(path: Path) -> None: node_class = attr.node_class # Register the node class if it has a category category = getattr(node_class, "category", "Uncategorized") - if category not in cls._nodes: - cls._nodes[category] = [] - - # Create node registration info - node_info = { - "node_type_name": node_class.__name__, - "module": module_path, # Use the full module path - "class_name": node_class.__name__, - "subcategory": getattr(node_class, "subcategory", None), - } - - # Add to registry if not already present - if not any( - n["node_type_name"] == node_class.__name__ - for n in cls._nodes[category] - ): - cls._nodes[category].append(node_info) - registered_tools += 1 - logger.debug( - f"Registered tool function node {node_class.__name__}" - f" from {module_path} in category {category}" - ) + _register_tool_function_node(node_class, category) except Exception as e: logger.error(f"Failed to load module {item}: {e}") From 897c04c70d9367936dca0fa13737b132f1fd5b4a Mon Sep 17 00:00:00 2001 From: Srijan Patel Date: Sat, 15 Mar 2025 02:06:21 +0000 Subject: [PATCH 04/21] feat: update tool_function decorator to dynamically set class name and attributes --- backend/pyspur/nodes/decorator.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/backend/pyspur/nodes/decorator.py b/backend/pyspur/nodes/decorator.py index 0b669cec2..01fd12daf 100644 --- a/backend/pyspur/nodes/decorator.py +++ b/backend/pyspur/nodes/decorator.py @@ -245,8 +245,14 @@ async def run(self, input: BaseModel) -> BaseModel: # Update the class with the function's docstring FunctionToolNode.__doc__ = func_doc + # Set the display name and logo + FunctionToolNode.display_name = func_display_name + + # Change the name of the class to the function name + new_class_name = type(f"{func_name}", (FunctionToolNode,), {}) + # Set NodeClass attribute to the function - func.node_class = FunctionToolNode # type: ignore + func.node_class = new_class_name # type: ignore # Set the config model to the config_model func.config_model = _config_model # type: ignore From 6d49ac7595478a4996e94b3c5594848ff55fce29 Mon Sep 17 00:00:00 2001 From: Srijan Patel Date: Sat, 15 Mar 2025 02:09:38 +0000 Subject: [PATCH 05/21] fix: update module path retrieval in NodeRegistry to use project_root --- backend/pyspur/nodes/registry.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/backend/pyspur/nodes/registry.py b/backend/pyspur/nodes/registry.py index 56a3745c3..140800aff 100644 --- a/backend/pyspur/nodes/registry.py +++ b/backend/pyspur/nodes/registry.py @@ -263,7 +263,7 @@ def _register_tool_function_node(node_class: Type[BaseNode], category: str) -> N nonlocal registered_tools registered_tools += 1 logger.debug( - f"Registered tool function node {node_class.__name__} in category {category}" + f"Registered tool function {node_class.__name__} in category {category}" ) def _discover_tools_in_directory(path: Path) -> None: @@ -272,7 +272,8 @@ def _discover_tools_in_directory(path: Path) -> None: if item.is_file() and item.suffix == ".py" and not item.name.startswith("_"): try: # Get the appropriate module path/name - module_path = _get_module_path(item, tools_dir) + module_path = _get_module_path(item, Path(project_root)) + print(f"Module path: {module_path}") # Create a spec for the module if module_path.endswith(".py"): From 2f6975cc1489da81caeeced1edb61ac82054b5f1 Mon Sep 17 00:00:00 2001 From: Srijan Patel Date: Sat, 15 Mar 2025 02:31:21 +0000 Subject: [PATCH 06/21] feat: bind tool function class to module namespace and set module attribute --- backend/pyspur/nodes/decorator.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/backend/pyspur/nodes/decorator.py b/backend/pyspur/nodes/decorator.py index 01fd12daf..c59e2a1f9 100644 --- a/backend/pyspur/nodes/decorator.py +++ b/backend/pyspur/nodes/decorator.py @@ -1,5 +1,6 @@ import inspect import json +import sys from typing import ( Any, Callable, @@ -248,8 +249,18 @@ async def run(self, input: BaseModel) -> BaseModel: # Set the display name and logo FunctionToolNode.display_name = func_display_name - # Change the name of the class to the function name - new_class_name = type(f"{func_name}", (FunctionToolNode,), {}) + # Change the name of the class to the function name and bind it to the module + new_class_name = type( + f"{func_name}", + (FunctionToolNode,), + { + "__module__": func.__module__ # Set the module to match the decorated func's module + }, + ) + + # Bind the class to the module's namespace + module = sys.modules[func.__module__] + setattr(module, func_name, new_class_name) # Set NodeClass attribute to the function func.node_class = new_class_name # type: ignore From b9d310b8494b72f17be555cdba1f5d53cc28c6c0 Mon Sep 17 00:00:00 2001 From: Srijan Patel Date: Sat, 15 Mar 2025 02:39:21 +0000 Subject: [PATCH 07/21] fix: update docstrings in node_types.py for consistency in tense --- backend/pyspur/nodes/node_types.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/pyspur/nodes/node_types.py b/backend/pyspur/nodes/node_types.py index da64e847e..ded151b1e 100644 --- a/backend/pyspur/nodes/node_types.py +++ b/backend/pyspur/nodes/node_types.py @@ -236,7 +236,7 @@ def get_all_node_types() -> Dict[str, List[NodeTypeSchema]]: - """Returns a dictionary of all available node types grouped by category.""" + """Return a dictionary of all available node types grouped by category.""" node_type_groups: Dict[str, List[NodeTypeSchema]] = {} for group_name, node_types in SUPPORTED_NODE_TYPES.items(): node_type_groups[group_name] = [] @@ -247,7 +247,7 @@ def get_all_node_types() -> Dict[str, List[NodeTypeSchema]]: def is_valid_node_type(node_type_name: str) -> bool: - """Checks if a node type is valid (supported, deprecated, or registered via decorator).""" + """Check if a node type is valid (supported, deprecated, or registered via decorator).""" # Check configured nodes first for node_types in SUPPORTED_NODE_TYPES.values(): for node_type in node_types: From cd85554e367c88b0ad05e6f540a96a5baa71299e Mon Sep 17 00:00:00 2001 From: Srijan Patel Date: Sat, 15 Mar 2025 05:33:52 +0000 Subject: [PATCH 08/21] refactor: simplify tool function discovery to support only Python packages and improve error logging --- backend/pyspur/nodes/registry.py | 101 +++++++------------------------ 1 file changed, 22 insertions(+), 79 deletions(-) diff --git a/backend/pyspur/nodes/registry.py b/backend/pyspur/nodes/registry.py index 140800aff..b81581396 100644 --- a/backend/pyspur/nodes/registry.py +++ b/backend/pyspur/nodes/registry.py @@ -2,6 +2,7 @@ import importlib import importlib.util import os +import traceback from pathlib import Path from typing import Dict, List, Optional, Set, Type, Union @@ -53,7 +54,7 @@ def decorator(node_class: Type[BaseNode]) -> Type[BaseNode]: # Store subcategory as class attribute without type checking if subcategory: - node_class.subcategory = subcategory # type: ignore + node_class.subcategory = subcategory # Initialize category if not exists if category not in cls._nodes: @@ -176,13 +177,7 @@ def discover_tool_functions(cls) -> None: This method searches recursively through Python files in the PROJECT_ROOT/tools directory for functions decorated with @tool_function and registers their node classes. - Works with both package (with __init__.py) and non-package Python files. - - The method handles the following: - 1. Discovers tool function decorated nodes - 2. Ensures proper registration in the NodeRegistry - 3. Handles both package and non-package Python files - 4. Maintains separation between tool function nodes and regular nodes + Only works with proper Python packages (directories with __init__.py). """ # Get PROJECT_ROOT from environment variable project_root = os.getenv("PROJECT_ROOT") @@ -203,53 +198,11 @@ def _is_package_dir(path: Path) -> bool: """Check if a directory is a Python package (has __init__.py).""" return (path / "__init__.py").exists() - def _get_module_path(file_path: Path, base_path: Path) -> str: - """Get the appropriate module path for importing. - - For files in a package (directory with __init__.py), returns the full package path. - For standalone files, returns the absolute file path. - """ - try: - rel_path = file_path.relative_to(base_path) - parts = list(rel_path.parts) - - # Build the module path by checking each parent directory - module_parts: List[str] = [] - current_path = base_path - - # Handle the directory parts - for part in parts[:-1]: # Exclude the file name - current_path = current_path / part - if _is_package_dir(current_path): - module_parts.append(part) - else: - # If we hit a non-package directory, we'll use absolute path - return str(file_path) - - # Add the file name without .py - module_parts.append(parts[-1][:-3]) # Remove .py extension - - # If we have a valid package path, return it with dots - if module_parts: - return ".".join(module_parts) - - # Fallback to absolute path - return str(file_path) - except Exception: - # Fallback to absolute path if anything goes wrong - return str(file_path) - def _register_tool_function_node(node_class: Type[BaseNode], category: str) -> None: - """Register a tool function node in the NodeRegistry. - - This function handles the specific registration logic for tool function nodes, - ensuring they are properly integrated into the registry while maintaining - separation from regular nodes. - """ + """Register a tool function node in the NodeRegistry.""" if category not in cls._nodes: cls._nodes[category] = [] - # Create node registration info node_info = { "node_type_name": node_class.__name__, "module": node_class.__module__, @@ -257,7 +210,6 @@ def _register_tool_function_node(node_class: Type[BaseNode], category: str) -> N "subcategory": getattr(node_class, "subcategory", None), } - # Add to registry if not already present if not any(n["node_type_name"] == node_class.__name__ for n in cls._nodes[category]): cls._nodes[category].append(node_info) nonlocal registered_tools @@ -266,48 +218,39 @@ def _register_tool_function_node(node_class: Type[BaseNode], category: str) -> N f"Registered tool function {node_class.__name__} in category {category}" ) - def _discover_tools_in_directory(path: Path) -> None: - """Recursively discover tool functions in a directory.""" + def _discover_tools_in_directory(path: Path, base_package: str = "tools") -> None: + """Recursively discover tool functions in package directories.""" + # Skip if not a package directory + if not _is_package_dir(path): + return + for item in path.iterdir(): if item.is_file() and item.suffix == ".py" and not item.name.startswith("_"): try: - # Get the appropriate module path/name - module_path = _get_module_path(item, Path(project_root)) - print(f"Module path: {module_path}") - - # Create a spec for the module - if module_path.endswith(".py"): - # For non-package files, use spec_from_file_location - spec = importlib.util.spec_from_file_location(item.stem, str(item)) - else: - # For package files, use find_spec - spec = importlib.util.find_spec(module_path) - - if spec is None or spec.loader is None: - logger.warning(f"Could not create module spec for {item}") - continue - - # Create and execute the module - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - - # Look for tool functions in module attributes + # Get the module path relative to project root + module_path = f"{base_package}.{item.stem}" + + # Import the module using standard import_module + module = importlib.import_module(module_path) + + # Register any tool functions found in the module for attr_name in dir(module): attr = getattr(module, attr_name) - # Check if attribute has node_class (indicating it's a tool function) if hasattr(attr, "node_class"): node_class = attr.node_class - # Register the node class if it has a category category = getattr(node_class, "category", "Uncategorized") _register_tool_function_node(node_class, category) except Exception as e: logger.error(f"Failed to load module {item}: {e}") + logger.error(traceback.format_exc()) # Recursively process subdirectories elif item.is_dir() and not item.name.startswith("_"): - _discover_tools_in_directory(item) + # Update the base package for the subdirectory + subpackage = f"{base_package}.{item.name}" + _discover_tools_in_directory(item, subpackage) - # Start recursive discovery + # Start recursive discovery from tools directory _discover_tools_in_directory(tools_dir) logger.info(f"Tool function discovery complete. Found {registered_tools} tool functions.") From 15ea741c98d918472fb29be9f3bc40589c8575c7 Mon Sep 17 00:00:00 2001 From: Srijan Patel Date: Sat, 15 Mar 2025 05:33:58 +0000 Subject: [PATCH 09/21] feat: add default values for name, display_name, and category in BaseNode class --- backend/pyspur/nodes/base.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/backend/pyspur/nodes/base.py b/backend/pyspur/nodes/base.py index 6e6a8e5ae..b21eaa8be 100644 --- a/backend/pyspur/nodes/base.py +++ b/backend/pyspur/nodes/base.py @@ -71,10 +71,11 @@ class BaseNode(ABC): - Field types are the corresponding NodeOutputModels """ - name: str - display_name: str = "" # Will be used for config title, defaults to class name if not set + name: str = "" + display_name: str = "" + category: str = "Uncategorized" + subcategory: Optional[str] = None logo: Optional[str] = None - category: Optional[str] = None config_model: Type[BaseModel] output_model: Type[BaseNodeOutput] input_model: Type[BaseNodeInput] From a762c028e34838c2324eb4d6471a55ab3d66bb56 Mon Sep 17 00:00:00 2001 From: Srijan Patel Date: Sat, 15 Mar 2025 05:41:52 +0000 Subject: [PATCH 10/21] feat: create __init__.py file in project directory if it doesn't exist --- backend/pyspur/cli/main.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/backend/pyspur/cli/main.py b/backend/pyspur/cli/main.py index f3a0d00c8..a3e2f25eb 100644 --- a/backend/pyspur/cli/main.py +++ b/backend/pyspur/cli/main.py @@ -66,6 +66,13 @@ def init( f.write("\n# ================================") f.write("\nPROJECT_ROOT=" + str(target_dir) + "\n") + # add __init__.py to the project directory + init_file_path = target_dir / "__init__.py" + if not init_file_path.exists(): + with open(init_file_path, "w") as f: + f.write("# This is an empty __init__.py file") + print("[green]✓[/green] Created __init__.py") + custom_dirs = { "data": target_dir / "data", "tools": target_dir / "tools", From 8844250658e85224d663b2058fd679cddb55d696 Mon Sep 17 00:00:00 2001 From: Srijan Patel Date: Sun, 16 Mar 2025 03:51:58 +0000 Subject: [PATCH 11/21] refactor: move FunctionToolNode class outside --- backend/pyspur/nodes/decorator.py | 130 ++++++++++++++++++------------ 1 file changed, 80 insertions(+), 50 deletions(-) diff --git a/backend/pyspur/nodes/decorator.py b/backend/pyspur/nodes/decorator.py index c59e2a1f9..9a3c80a6d 100644 --- a/backend/pyspur/nodes/decorator.py +++ b/backend/pyspur/nodes/decorator.py @@ -22,6 +22,73 @@ from .base import BaseNode, BaseNodeConfig, BaseNodeInput, BaseNodeOutput, VisualTag +class FunctionToolNode(BaseNode): + """Node class for function-based tools. + + This class is used to wrap Python functions as PySpur nodes. It handles parameter extraction, + template rendering, and function execution. + """ + + name: str + display_name: str + config_model: Type[BaseNodeConfig] + output_model: Type[BaseNodeOutput] + input_model: Type[BaseNodeInput] + function_param_names: Set[str] + is_output_model_defined: bool + _func: Callable[..., Any] + _visual_tag: Optional[Dict[str, str]] + + def __init__( + self, + name: str, + config: Optional[BaseNodeConfig] = None, + context: Optional[WorkflowExecutionContext] = None, + func: Optional[Callable[..., Any]] = None, + visual_tag: Optional[Dict[str, str]] = None, + ): + # Create default config if none provided + if config is None: + config = self.config_model() + + # Call parent init first + super().__init__(name=name, config=config, context=context) + + # Store the function and visual tag + if func is not None: + self._func = func + if visual_tag: + self.visual_tag = VisualTag(**visual_tag) + self._visual_tag = visual_tag + + async def run(self, input: BaseModel) -> BaseModel: + # Extract parameters from config directly using the stored parameter names + # This is more efficient than checking sig.parameters each time + kwargs: Dict[str, Any] = {} + + for param_name in self.function_param_names: + if hasattr(self.config, param_name): + kwargs[param_name] = getattr(self.config, param_name) + + # config values can be jinja2 templates so we need to render them + for param_name, param_value in kwargs.items(): + if isinstance(param_value, str): + template = Template(param_value) + kwargs[param_name] = template.render(input=input) + + # Call the original function + result = self._func(**kwargs) + + # Handle async functions + if hasattr(result, "__await__"): + result = await result + + if self.is_output_model_defined: + return self.output_model.model_validate(result) + else: + return self.output_model.model_validate({"output": result}) + + @runtime_checkable class ToolFunction(Protocol): """Protocol for functions decorated with @tool.""" @@ -187,17 +254,16 @@ def decorator(func: Callable[..., Any]) -> ToolFunction: _is_output_model_defined = is_output_model_defined # Create a Node class for this function - class FunctionToolNode(BaseNode): - # Class attributes + class CustomFunctionToolNode(FunctionToolNode): name = func_name display_name = func_display_name - category = _category + category = _category or "FunctionTools" config_model = _config_model - # ignore the type check for output_model, we know it is a BaseNodeOutput output_model = _output_model # type: ignore input_model = _input_model function_param_names = _function_param_names is_output_model_defined = _is_output_model_defined + __doc__ = func_doc def __init__( self, @@ -205,54 +271,18 @@ def __init__( config: Optional[BaseNodeConfig] = None, context: Optional[WorkflowExecutionContext] = None, ): - # Create default config if none provided - if config is None: - config = _config_model() - - # Call parent init first - super().__init__(name=name, config=config, context=context) - - # Initialize with custom visual tag if provided after parent init - if visual_tag: - self.visual_tag = VisualTag(**visual_tag) - - async def run(self, input: BaseModel) -> BaseModel: - # Extract parameters from config directly using the stored parameter names - # This is more efficient than checking sig.parameters each time - kwargs: Dict[str, Any] = {} - - for param_name in self.function_param_names: - if hasattr(self.config, param_name): - kwargs[param_name] = getattr(self.config, param_name) - - # config values can be jinja2 templates so we need to render them - for param_name, param_value in kwargs.items(): - if isinstance(param_value, str): - template = Template(param_value) - kwargs[param_name] = template.render(input=input) - - # Call the original function - result = func(**kwargs) - - # Handle async functions - if hasattr(result, "__await__"): - result = await result - - if self.is_output_model_defined: - return self.output_model.model_validate(result) - else: - return self.output_model.model_validate({"output": result}) - - # Update the class with the function's docstring - FunctionToolNode.__doc__ = func_doc - - # Set the display name and logo - FunctionToolNode.display_name = func_display_name + super().__init__( + name=name, + config=config, + context=context, + func=func, + visual_tag=visual_tag, + ) # Change the name of the class to the function name and bind it to the module new_class_name = type( f"{func_name}", - (FunctionToolNode,), + (CustomFunctionToolNode,), { "__module__": func.__module__ # Set the module to match the decorated func's module }, @@ -266,7 +296,7 @@ async def run(self, input: BaseModel) -> BaseModel: func.node_class = new_class_name # type: ignore # Set the config model to the config_model - func.config_model = _config_model # type: ignore + func.config_model = config_model # type: ignore # Set the output model to the output_model func.output_model = _output_model # type: ignore @@ -277,7 +307,7 @@ def create_node( config: Optional[BaseNodeConfig] = None, context: Optional[WorkflowExecutionContext] = None, ) -> FunctionToolNode: - return FunctionToolNode(name=name, config=config, context=context) + return new_class_name(name=name, config=config, context=context) func.create_node = create_node # type: ignore From 799c9323aa28d26f1fee811c078b0533d26d317f Mon Sep 17 00:00:00 2001 From: Srijan Patel Date: Sun, 16 Mar 2025 04:11:12 +0000 Subject: [PATCH 12/21] fix: cast return type of tool_function to ToolFunction for type safety --- backend/pyspur/nodes/decorator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/backend/pyspur/nodes/decorator.py b/backend/pyspur/nodes/decorator.py index 9a3c80a6d..5c381d4bc 100644 --- a/backend/pyspur/nodes/decorator.py +++ b/backend/pyspur/nodes/decorator.py @@ -9,6 +9,7 @@ Protocol, Set, Type, + cast, get_type_hints, runtime_checkable, ) @@ -311,7 +312,7 @@ def create_node( func.create_node = create_node # type: ignore - return func # type: ignore + return cast(ToolFunction, func) return decorator From d0bf5e9a659843ec150fbd88264a258fbb8d50f6 Mon Sep 17 00:00:00 2001 From: Srijan Patel Date: Sun, 16 Mar 2025 05:56:20 +0000 Subject: [PATCH 13/21] feat: introduce NodeInfo schema for node registration and enhance tool function validation --- backend/pyspur/nodes/registry.py | 82 +++++++++++++++++++++++--------- 1 file changed, 59 insertions(+), 23 deletions(-) diff --git a/backend/pyspur/nodes/registry.py b/backend/pyspur/nodes/registry.py index b81581396..a4ceca342 100644 --- a/backend/pyspur/nodes/registry.py +++ b/backend/pyspur/nodes/registry.py @@ -4,15 +4,38 @@ import os import traceback from pathlib import Path -from typing import Dict, List, Optional, Set, Type, Union +from typing import Any, Dict, List, Optional, Set, Type, Union from loguru import logger +from pydantic import BaseModel from .base import BaseNode +from .decorator import FunctionToolNode, ToolFunction + + +class NodeInfo(BaseModel): + """Schema for node registration information. + + This schema is used to store metadata about nodes in the NodeRegistry. + + Attributes: + node_type_name: The name of the node type + module: The module path (e.g. "tools.foo") + class_name: The attribute path within the module, + supporting dot notation for nested attributes + (e.g. "MyClass" or "some_var.some_attr.target_class") + subcategory: Optional subcategory for organization + + """ + + node_type_name: str + module: str + class_name: str # Now supports dot notation for nested attributes + subcategory: Optional[str] = None class NodeRegistry: - _nodes: Dict[str, List[Dict[str, Union[str, Optional[str]]]]] = {} + _nodes: Dict[str, List[NodeInfo]] = {} _decorator_registered_classes: Set[Type[BaseNode]] = ( set() ) # Track classes registered via decorator @@ -66,12 +89,12 @@ def decorator(node_class: Type[BaseNode]) -> Type[BaseNode]: if module_path.startswith("pyspur."): module_path = module_path.replace("pyspur.", "", 1) - node_info: Dict[str, Union[str, Optional[str]]] = { - "node_type_name": node_class.__name__, - "module": f".{module_path}", - "class_name": node_class.__name__, - "subcategory": subcategory, - } + node_info = NodeInfo( + node_type_name=node_class.__name__, + module=f".{module_path}", + class_name=node_class.__name__, + subcategory=subcategory, + ) # Handle positioning nodes_list = cls._nodes[category] @@ -83,7 +106,7 @@ def decorator(node_class: Type[BaseNode]) -> Type[BaseNode]: elif position.startswith("after:"): target_node = position[6:] for i, n in enumerate(nodes_list): - if n["node_type_name"] == target_node: + if n.node_type_name == target_node: nodes_list.insert(i + 1, node_info) break else: @@ -91,7 +114,7 @@ def decorator(node_class: Type[BaseNode]) -> Type[BaseNode]: elif position.startswith("before:"): target_node = position[7:] for i, n in enumerate(nodes_list): - if n["node_type_name"] == target_node: + if n.node_type_name == target_node: nodes_list.insert(i, node_info) break else: @@ -100,7 +123,7 @@ def decorator(node_class: Type[BaseNode]) -> Type[BaseNode]: nodes_list.append(node_info) else: # Add to end if no position specified - if not any(n["node_type_name"] == node_class.__name__ for n in nodes_list): + if not any(n.node_type_name == node_class.__name__ for n in nodes_list): nodes_list.append(node_info) logger.debug(f"Registered node {node_class.__name__} in category {category}") cls._decorator_registered_classes.add(node_class) @@ -112,8 +135,9 @@ def decorator(node_class: Type[BaseNode]) -> Type[BaseNode]: @classmethod def get_registered_nodes( cls, - ) -> Dict[str, List[Dict[str, Union[str, Optional[str]]]]]: + ) -> Dict[str, List[NodeInfo]]: """Get all registered nodes.""" + cls.discover_nodes() return cls._nodes @classmethod @@ -198,19 +222,21 @@ def _is_package_dir(path: Path) -> bool: """Check if a directory is a Python package (has __init__.py).""" return (path / "__init__.py").exists() - def _register_tool_function_node(node_class: Type[BaseNode], category: str) -> None: + def _register_tool_function_node(func: ToolFunction, category: str) -> None: """Register a tool function node in the NodeRegistry.""" + node_class = func.node_class if category not in cls._nodes: cls._nodes[category] = [] - node_info = { - "node_type_name": node_class.__name__, - "module": node_class.__module__, - "class_name": node_class.__name__, - "subcategory": getattr(node_class, "subcategory", None), - } + node_info = NodeInfo( + node_type_name=node_class.__name__, + module=node_class.__module__, + # Using dot notation for nested attribute + class_name=f"{func.func_name}.node_class", + subcategory=getattr(node_class, "subcategory", None), + ) - if not any(n["node_type_name"] == node_class.__name__ for n in cls._nodes[category]): + if not any(n.node_type_name == node_class.__name__ for n in cls._nodes[category]): cls._nodes[category].append(node_info) nonlocal registered_tools registered_tools += 1 @@ -218,6 +244,16 @@ def _register_tool_function_node(node_class: Type[BaseNode], category: str) -> N f"Registered tool function {node_class.__name__} in category {category}" ) + def _is_valid_tool_function(attr: Any) -> bool: + """Check if an attribute is a properly decorated tool function.""" + if not isinstance(attr, ToolFunction): + return False + if not issubclass(attr.node_class, FunctionToolNode): + return False # Skip regular functions + # Must have all required node attributes + required_attrs = {"display_name", "config_model", "input_model", "output_model"} + return all(hasattr(attr.node_class, attr_name) for attr_name in required_attrs) + def _discover_tools_in_directory(path: Path, base_package: str = "tools") -> None: """Recursively discover tool functions in package directories.""" # Skip if not a package directory @@ -233,13 +269,13 @@ def _discover_tools_in_directory(path: Path, base_package: str = "tools") -> Non # Import the module using standard import_module module = importlib.import_module(module_path) - # Register any tool functions found in the module + # Register any valid tool functions found in the module for attr_name in dir(module): attr = getattr(module, attr_name) - if hasattr(attr, "node_class"): + if _is_valid_tool_function(attr): node_class = attr.node_class category = getattr(node_class, "category", "Uncategorized") - _register_tool_function_node(node_class, category) + _register_tool_function_node(attr, category) except Exception as e: logger.error(f"Failed to load module {item}: {e}") From 88625e94fdc66b000dbb9302b5cf82c6151d55e7 Mon Sep 17 00:00:00 2001 From: Srijan Patel Date: Sun, 16 Mar 2025 05:56:31 +0000 Subject: [PATCH 14/21] feat: add func_name attribute to ToolFunction for improved node identification --- backend/pyspur/nodes/decorator.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/backend/pyspur/nodes/decorator.py b/backend/pyspur/nodes/decorator.py index 5c381d4bc..029b73513 100644 --- a/backend/pyspur/nodes/decorator.py +++ b/backend/pyspur/nodes/decorator.py @@ -1,6 +1,5 @@ import inspect import json -import sys from typing import ( Any, Callable, @@ -97,6 +96,7 @@ class ToolFunction(Protocol): node_class: Type[BaseNode] config_model: Type[BaseNodeConfig] output_model: Type[BaseNodeOutput] + func_name: str def create_node( self, @@ -289,10 +289,6 @@ def __init__( }, ) - # Bind the class to the module's namespace - module = sys.modules[func.__module__] - setattr(module, func_name, new_class_name) - # Set NodeClass attribute to the function func.node_class = new_class_name # type: ignore @@ -302,6 +298,9 @@ def __init__( # Set the output model to the output_model func.output_model = _output_model # type: ignore + # Set the func_name attribute to the function name + func.func_name = func.__name__ # type: ignore + # Set the create_node function to the func def create_node( name: str = func_name, From d7f4f48412e851804da162e1a72405e391120018 Mon Sep 17 00:00:00 2001 From: Srijan Patel Date: Sun, 16 Mar 2025 05:56:59 +0000 Subject: [PATCH 15/21] feat: enhance node_class property to support nested class attribute traversal --- backend/pyspur/schemas/node_type_schemas.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/backend/pyspur/schemas/node_type_schemas.py b/backend/pyspur/schemas/node_type_schemas.py index 61ec87f15..f159c04a7 100644 --- a/backend/pyspur/schemas/node_type_schemas.py +++ b/backend/pyspur/schemas/node_type_schemas.py @@ -10,8 +10,20 @@ class NodeTypeSchema(BaseModel): @property def node_class(self): + # Import the module module = importlib.import_module(name=f"{self.module}", package="pyspur") - return getattr(module, self.class_name) + + # Split the class name into parts for attribute traversal + parts = self.class_name.split(".") + + # Start with the module + obj = module + + # Traverse the attribute chain + for part in parts: + obj = getattr(obj, part) + + return obj @property def input_model(self): From 346e4feee988c182884c31cd9e0df126e8343ded Mon Sep 17 00:00:00 2001 From: Srijan Patel Date: Sun, 16 Mar 2025 05:57:15 +0000 Subject: [PATCH 16/21] refactor: update NodeFactory to use dot notation for accessing node attributes --- backend/pyspur/nodes/factory.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/backend/pyspur/nodes/factory.py b/backend/pyspur/nodes/factory.py index 3eec83d20..186f212ef 100644 --- a/backend/pyspur/nodes/factory.py +++ b/backend/pyspur/nodes/factory.py @@ -44,16 +44,10 @@ def get_all_node_types() -> Dict[str, List[NodeTypeSchema]]: if category not in converted_nodes: converted_nodes[category] = [] for node in nodes: - if ( - node["node_type_name"] is None - or node["module"] is None - or node["class_name"] is None - ): - continue schema = NodeTypeSchema( - node_type_name=node["node_type_name"], - module=node["module"], - class_name=node["class_name"], + node_type_name=node.node_type_name, + module=node.module, + class_name=node.class_name, ) converted_nodes[category].append(schema) @@ -96,9 +90,9 @@ def create_node(node_name: str, node_type_name: str, config: Any) -> BaseNode: registered_nodes = NodeRegistry.get_registered_nodes() for nodes in registered_nodes.values(): for node in nodes: - if node["node_type_name"] == node_type_name: - module_name = node["module"] - class_name = node["class_name"] + if node.node_type_name == node_type_name: + module_name = node.module + class_name = node.class_name break if module_name and class_name: break From 2ad77ccc0eaf0cd99e7771fa603bd2264f13aa21 Mon Sep 17 00:00:00 2001 From: Srijan Patel Date: Sun, 16 Mar 2025 06:01:48 +0000 Subject: [PATCH 17/21] fix: update config_model type to BaseNodeConfig for improved type consistency --- backend/pyspur/nodes/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/pyspur/nodes/base.py b/backend/pyspur/nodes/base.py index b21eaa8be..1f9f8c931 100644 --- a/backend/pyspur/nodes/base.py +++ b/backend/pyspur/nodes/base.py @@ -76,7 +76,7 @@ class BaseNode(ABC): category: str = "Uncategorized" subcategory: Optional[str] = None logo: Optional[str] = None - config_model: Type[BaseModel] + config_model: Type[BaseNodeConfig] output_model: Type[BaseNodeOutput] input_model: Type[BaseNodeInput] _config: BaseNodeConfig From f5828dde415ee874c79e4c7c971254d46b115c4a Mon Sep 17 00:00:00 2001 From: Srijan Patel Date: Sun, 16 Mar 2025 06:13:33 +0000 Subject: [PATCH 18/21] fix: change category type to Optional[str] for better flexibility --- backend/pyspur/nodes/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/pyspur/nodes/base.py b/backend/pyspur/nodes/base.py index 1f9f8c931..842eaf3ef 100644 --- a/backend/pyspur/nodes/base.py +++ b/backend/pyspur/nodes/base.py @@ -73,7 +73,7 @@ class BaseNode(ABC): name: str = "" display_name: str = "" - category: str = "Uncategorized" + category: Optional[str] = None subcategory: Optional[str] = None logo: Optional[str] = None config_model: Type[BaseNodeConfig] From 05c0a7b6fd1904adbb66eb7552ae2cf2a06d24e8 Mon Sep 17 00:00:00 2001 From: Srijan Patel Date: Sun, 16 Mar 2025 06:19:34 +0000 Subject: [PATCH 19/21] fix: set default category to "Custom Tools" when registering tool function nodes --- backend/pyspur/nodes/registry.py | 1 + 1 file changed, 1 insertion(+) diff --git a/backend/pyspur/nodes/registry.py b/backend/pyspur/nodes/registry.py index a4ceca342..2d5a05c89 100644 --- a/backend/pyspur/nodes/registry.py +++ b/backend/pyspur/nodes/registry.py @@ -225,6 +225,7 @@ def _is_package_dir(path: Path) -> bool: def _register_tool_function_node(func: ToolFunction, category: str) -> None: """Register a tool function node in the NodeRegistry.""" node_class = func.node_class + category = "Custom Tools" if category not in cls._nodes: cls._nodes[category] = [] From 931dcf145c35da31400fe664d6417828a925efd9 Mon Sep 17 00:00:00 2001 From: Srijan Patel Date: Sun, 16 Mar 2025 06:34:24 +0000 Subject: [PATCH 20/21] feat: enhance MessageCard to support Markdown in messages --- frontend/src/components/chat/MessageCard.tsx | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/frontend/src/components/chat/MessageCard.tsx b/frontend/src/components/chat/MessageCard.tsx index fb199c5a8..681c858c8 100644 --- a/frontend/src/components/chat/MessageCard.tsx +++ b/frontend/src/components/chat/MessageCard.tsx @@ -4,11 +4,12 @@ import { Avatar, Badge, Button, Chip, cn, Link, Tooltip } from '@heroui/react' import { useClipboard } from '@heroui/use-clipboard' import { Icon } from '@iconify/react' import React from 'react' +import ReactMarkdown from 'react-markdown' export type MessageCardProps = React.HTMLAttributes & { avatar?: string showFeedback?: boolean - message?: React.ReactNode + message?: string | React.ReactNode currentAttempt?: number status?: 'success' | 'failed' | 'loading' attempts?: number @@ -54,8 +55,8 @@ const MessageCard = React.forwardRef( const failedMessage = (

Something went wrong, if the issue persists please contact us through our help center at  - - support@acmeai.com + + support@pyspur.dev

) @@ -135,7 +136,15 @@ const MessageCard = React.forwardRef( )} >
- {hasFailed ? failedMessage : isLoading ? loadingMessage : message} + {hasFailed ? ( + failedMessage + ) : isLoading ? ( + loadingMessage + ) : typeof message === 'string' ? ( + {message} + ) : ( + message + )}
{showFeedback && !hasFailed && !isLoading && (
From 0b02eed3c56bc0f174aeb9a808f6bca7b007cc3f Mon Sep 17 00:00:00 2001 From: Srijan Patel Date: Sun, 16 Mar 2025 06:35:10 +0000 Subject: [PATCH 21/21] refactor: replace NodeInfo schema with NodeTypeSchema for improved consistency --- backend/pyspur/nodes/registry.py | 21 ++------------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/backend/pyspur/nodes/registry.py b/backend/pyspur/nodes/registry.py index 2d5a05c89..c78c0b9ec 100644 --- a/backend/pyspur/nodes/registry.py +++ b/backend/pyspur/nodes/registry.py @@ -7,30 +7,13 @@ from typing import Any, Dict, List, Optional, Set, Type, Union from loguru import logger -from pydantic import BaseModel +from ..schemas.node_type_schemas import NodeTypeSchema from .base import BaseNode from .decorator import FunctionToolNode, ToolFunction -class NodeInfo(BaseModel): - """Schema for node registration information. - - This schema is used to store metadata about nodes in the NodeRegistry. - - Attributes: - node_type_name: The name of the node type - module: The module path (e.g. "tools.foo") - class_name: The attribute path within the module, - supporting dot notation for nested attributes - (e.g. "MyClass" or "some_var.some_attr.target_class") - subcategory: Optional subcategory for organization - - """ - - node_type_name: str - module: str - class_name: str # Now supports dot notation for nested attributes +class NodeInfo(NodeTypeSchema): subcategory: Optional[str] = None