diff --git a/backend/pyspur/api/node_management.py b/backend/pyspur/api/node_management.py index 286d3260..6b116d5b 100644 --- a/backend/pyspur/api/node_management.py +++ b/backend/pyspur/api/node_management.py @@ -13,9 +13,7 @@ description="Get the schemas for all available node types", ) async def get_node_types() -> Dict[str, List[Dict[str, Any]]]: - """ - Returns the schemas for all available node types. - """ + """Return the schemas for all available node types.""" # get the schemas for each node class node_groups = NodeFactory.get_all_node_types() diff --git a/backend/pyspur/cli/main.py b/backend/pyspur/cli/main.py index f3a0d00c..a3e2f25e 100644 --- a/backend/pyspur/cli/main.py +++ b/backend/pyspur/cli/main.py @@ -66,6 +66,13 @@ def init( f.write("\n# ================================") f.write("\nPROJECT_ROOT=" + str(target_dir) + "\n") + # add __init__.py to the project directory + init_file_path = target_dir / "__init__.py" + if not init_file_path.exists(): + with open(init_file_path, "w") as f: + f.write("# This is an empty __init__.py file") + print("[green]✓[/green] Created __init__.py") + custom_dirs = { "data": target_dir / "data", "tools": target_dir / "tools", diff --git a/backend/pyspur/nodes/base.py b/backend/pyspur/nodes/base.py index 6e6a8e5a..842eaf3e 100644 --- a/backend/pyspur/nodes/base.py +++ b/backend/pyspur/nodes/base.py @@ -71,11 +71,12 @@ class BaseNode(ABC): - Field types are the corresponding NodeOutputModels """ - name: str - display_name: str = "" # Will be used for config title, defaults to class name if not set - logo: Optional[str] = None + name: str = "" + display_name: str = "" category: Optional[str] = None - config_model: Type[BaseModel] + subcategory: Optional[str] = None + logo: Optional[str] = None + config_model: Type[BaseNodeConfig] output_model: Type[BaseNodeOutput] input_model: Type[BaseNodeInput] _config: BaseNodeConfig diff --git a/backend/pyspur/nodes/decorator.py b/backend/pyspur/nodes/decorator.py index 0b669cec..029b7351 100644 --- a/backend/pyspur/nodes/decorator.py +++ b/backend/pyspur/nodes/decorator.py @@ -8,6 +8,7 @@ Protocol, Set, Type, + cast, get_type_hints, runtime_checkable, ) @@ -21,6 +22,73 @@ from .base import BaseNode, BaseNodeConfig, BaseNodeInput, BaseNodeOutput, VisualTag +class FunctionToolNode(BaseNode): + """Node class for function-based tools. + + This class is used to wrap Python functions as PySpur nodes. It handles parameter extraction, + template rendering, and function execution. + """ + + name: str + display_name: str + config_model: Type[BaseNodeConfig] + output_model: Type[BaseNodeOutput] + input_model: Type[BaseNodeInput] + function_param_names: Set[str] + is_output_model_defined: bool + _func: Callable[..., Any] + _visual_tag: Optional[Dict[str, str]] + + def __init__( + self, + name: str, + config: Optional[BaseNodeConfig] = None, + context: Optional[WorkflowExecutionContext] = None, + func: Optional[Callable[..., Any]] = None, + visual_tag: Optional[Dict[str, str]] = None, + ): + # Create default config if none provided + if config is None: + config = self.config_model() + + # Call parent init first + super().__init__(name=name, config=config, context=context) + + # Store the function and visual tag + if func is not None: + self._func = func + if visual_tag: + self.visual_tag = VisualTag(**visual_tag) + self._visual_tag = visual_tag + + async def run(self, input: BaseModel) -> BaseModel: + # Extract parameters from config directly using the stored parameter names + # This is more efficient than checking sig.parameters each time + kwargs: Dict[str, Any] = {} + + for param_name in self.function_param_names: + if hasattr(self.config, param_name): + kwargs[param_name] = getattr(self.config, param_name) + + # config values can be jinja2 templates so we need to render them + for param_name, param_value in kwargs.items(): + if isinstance(param_value, str): + template = Template(param_value) + kwargs[param_name] = template.render(input=input) + + # Call the original function + result = self._func(**kwargs) + + # Handle async functions + if hasattr(result, "__await__"): + result = await result + + if self.is_output_model_defined: + return self.output_model.model_validate(result) + else: + return self.output_model.model_validate({"output": result}) + + @runtime_checkable class ToolFunction(Protocol): """Protocol for functions decorated with @tool.""" @@ -28,6 +96,7 @@ class ToolFunction(Protocol): node_class: Type[BaseNode] config_model: Type[BaseNodeConfig] output_model: Type[BaseNodeOutput] + func_name: str def create_node( self, @@ -186,17 +255,16 @@ def decorator(func: Callable[..., Any]) -> ToolFunction: _is_output_model_defined = is_output_model_defined # Create a Node class for this function - class FunctionToolNode(BaseNode): - # Class attributes + class CustomFunctionToolNode(FunctionToolNode): name = func_name display_name = func_display_name - category = _category + category = _category or "FunctionTools" config_model = _config_model - # ignore the type check for output_model, we know it is a BaseNodeOutput output_model = _output_model # type: ignore input_model = _input_model function_param_names = _function_param_names is_output_model_defined = _is_output_model_defined + __doc__ = func_doc def __init__( self, @@ -204,67 +272,46 @@ def __init__( config: Optional[BaseNodeConfig] = None, context: Optional[WorkflowExecutionContext] = None, ): - # Create default config if none provided - if config is None: - config = _config_model() - - # Call parent init first - super().__init__(name=name, config=config, context=context) - - # Initialize with custom visual tag if provided after parent init - if visual_tag: - self.visual_tag = VisualTag(**visual_tag) - - async def run(self, input: BaseModel) -> BaseModel: - # Extract parameters from config directly using the stored parameter names - # This is more efficient than checking sig.parameters each time - kwargs: Dict[str, Any] = {} - - for param_name in self.function_param_names: - if hasattr(self.config, param_name): - kwargs[param_name] = getattr(self.config, param_name) - - # config values can be jinja2 templates so we need to render them - for param_name, param_value in kwargs.items(): - if isinstance(param_value, str): - template = Template(param_value) - kwargs[param_name] = template.render(input=input) - - # Call the original function - result = func(**kwargs) - - # Handle async functions - if hasattr(result, "__await__"): - result = await result - - if self.is_output_model_defined: - return self.output_model.model_validate(result) - else: - return self.output_model.model_validate({"output": result}) + super().__init__( + name=name, + config=config, + context=context, + func=func, + visual_tag=visual_tag, + ) - # Update the class with the function's docstring - FunctionToolNode.__doc__ = func_doc + # Change the name of the class to the function name and bind it to the module + new_class_name = type( + f"{func_name}", + (CustomFunctionToolNode,), + { + "__module__": func.__module__ # Set the module to match the decorated func's module + }, + ) # Set NodeClass attribute to the function - func.node_class = FunctionToolNode # type: ignore + func.node_class = new_class_name # type: ignore # Set the config model to the config_model - func.config_model = _config_model # type: ignore + func.config_model = config_model # type: ignore # Set the output model to the output_model func.output_model = _output_model # type: ignore + # Set the func_name attribute to the function name + func.func_name = func.__name__ # type: ignore + # Set the create_node function to the func def create_node( name: str = func_name, config: Optional[BaseNodeConfig] = None, context: Optional[WorkflowExecutionContext] = None, ) -> FunctionToolNode: - return FunctionToolNode(name=name, config=config, context=context) + return new_class_name(name=name, config=config, context=context) func.create_node = create_node # type: ignore - return func # type: ignore + return cast(ToolFunction, func) return decorator diff --git a/backend/pyspur/nodes/factory.py b/backend/pyspur/nodes/factory.py index aaf7520c..186f212e 100644 --- a/backend/pyspur/nodes/factory.py +++ b/backend/pyspur/nodes/factory.py @@ -12,8 +12,8 @@ class NodeFactory: - """ - Factory for creating node instances from a configuration. + """Create node instances from a configuration. + Supports both decorator-based registration and legacy configured registration. Conventions: @@ -30,8 +30,8 @@ class NodeFactory: @staticmethod def get_all_node_types() -> Dict[str, List[NodeTypeSchema]]: - """ - Returns a dictionary of all available node types grouped by category. + """Return a dictionary of all available node types grouped by category. + Combines both decorator-registered and configured nodes. """ # Get nodes from both sources @@ -45,9 +45,9 @@ def get_all_node_types() -> Dict[str, List[NodeTypeSchema]]: converted_nodes[category] = [] for node in nodes: schema = NodeTypeSchema( - node_type_name=node["node_type_name"], - module=node["module"], - class_name=node["class_name"], + node_type_name=node.node_type_name, + module=node.module, + class_name=node.class_name, ) converted_nodes[category].append(schema) @@ -65,8 +65,8 @@ def get_all_node_types() -> Dict[str, List[NodeTypeSchema]]: @staticmethod def create_node(node_name: str, node_type_name: str, config: Any) -> BaseNode: - """ - Creates a node instance from a configuration. + """Create a node instance from a configuration. + Checks both registration methods for the node type. """ if not is_valid_node_type(node_type_name): @@ -90,9 +90,9 @@ def create_node(node_name: str, node_type_name: str, config: Any) -> BaseNode: registered_nodes = NodeRegistry.get_registered_nodes() for nodes in registered_nodes.values(): for node in nodes: - if node["node_type_name"] == node_type_name: - module_name = node["module"] - class_name = node["class_name"] + if node.node_type_name == node_type_name: + module_name = node.module + class_name = node.class_name break if module_name and class_name: break diff --git a/backend/pyspur/nodes/node_types.py b/backend/pyspur/nodes/node_types.py index da64e847..ded151b1 100644 --- a/backend/pyspur/nodes/node_types.py +++ b/backend/pyspur/nodes/node_types.py @@ -236,7 +236,7 @@ def get_all_node_types() -> Dict[str, List[NodeTypeSchema]]: - """Returns a dictionary of all available node types grouped by category.""" + """Return a dictionary of all available node types grouped by category.""" node_type_groups: Dict[str, List[NodeTypeSchema]] = {} for group_name, node_types in SUPPORTED_NODE_TYPES.items(): node_type_groups[group_name] = [] @@ -247,7 +247,7 @@ def get_all_node_types() -> Dict[str, List[NodeTypeSchema]]: def is_valid_node_type(node_type_name: str) -> bool: - """Checks if a node type is valid (supported, deprecated, or registered via decorator).""" + """Check if a node type is valid (supported, deprecated, or registered via decorator).""" # Check configured nodes first for node_types in SUPPORTED_NODE_TYPES.values(): for node_type in node_types: diff --git a/backend/pyspur/nodes/registry.py b/backend/pyspur/nodes/registry.py index cc150605..c78c0b9e 100644 --- a/backend/pyspur/nodes/registry.py +++ b/backend/pyspur/nodes/registry.py @@ -2,16 +2,23 @@ import importlib import importlib.util import os +import traceback from pathlib import Path -from typing import Dict, List, Optional, Set, Type, Union +from typing import Any, Dict, List, Optional, Set, Type, Union from loguru import logger +from ..schemas.node_type_schemas import NodeTypeSchema from .base import BaseNode +from .decorator import FunctionToolNode, ToolFunction + + +class NodeInfo(NodeTypeSchema): + subcategory: Optional[str] = None class NodeRegistry: - _nodes: Dict[str, List[Dict[str, Union[str, Optional[str]]]]] = {} + _nodes: Dict[str, List[NodeInfo]] = {} _decorator_registered_classes: Set[Type[BaseNode]] = ( set() ) # Track classes registered via decorator @@ -53,7 +60,7 @@ def decorator(node_class: Type[BaseNode]) -> Type[BaseNode]: # Store subcategory as class attribute without type checking if subcategory: - node_class.subcategory = subcategory # type: ignore + node_class.subcategory = subcategory # Initialize category if not exists if category not in cls._nodes: @@ -65,12 +72,12 @@ def decorator(node_class: Type[BaseNode]) -> Type[BaseNode]: if module_path.startswith("pyspur."): module_path = module_path.replace("pyspur.", "", 1) - node_info: Dict[str, Union[str, Optional[str]]] = { - "node_type_name": node_class.__name__, - "module": f".{module_path}", - "class_name": node_class.__name__, - "subcategory": subcategory, - } + node_info = NodeInfo( + node_type_name=node_class.__name__, + module=f".{module_path}", + class_name=node_class.__name__, + subcategory=subcategory, + ) # Handle positioning nodes_list = cls._nodes[category] @@ -82,7 +89,7 @@ def decorator(node_class: Type[BaseNode]) -> Type[BaseNode]: elif position.startswith("after:"): target_node = position[6:] for i, n in enumerate(nodes_list): - if n["node_type_name"] == target_node: + if n.node_type_name == target_node: nodes_list.insert(i + 1, node_info) break else: @@ -90,7 +97,7 @@ def decorator(node_class: Type[BaseNode]) -> Type[BaseNode]: elif position.startswith("before:"): target_node = position[7:] for i, n in enumerate(nodes_list): - if n["node_type_name"] == target_node: + if n.node_type_name == target_node: nodes_list.insert(i, node_info) break else: @@ -99,7 +106,7 @@ def decorator(node_class: Type[BaseNode]) -> Type[BaseNode]: nodes_list.append(node_info) else: # Add to end if no position specified - if not any(n["node_type_name"] == node_class.__name__ for n in nodes_list): + if not any(n.node_type_name == node_class.__name__ for n in nodes_list): nodes_list.append(node_info) logger.debug(f"Registered node {node_class.__name__} in category {category}") cls._decorator_registered_classes.add(node_class) @@ -111,8 +118,9 @@ def decorator(node_class: Type[BaseNode]) -> Type[BaseNode]: @classmethod def get_registered_nodes( cls, - ) -> Dict[str, List[Dict[str, Union[str, Optional[str]]]]]: + ) -> Dict[str, List[NodeInfo]]: """Get all registered nodes.""" + cls.discover_nodes() return cls._nodes @classmethod @@ -159,6 +167,9 @@ def discover_nodes(cls, package_path: str = "pyspur.nodes") -> None: # Start recursive discovery cls._discover_in_directory(base_path, package_path) + # Also discover tool function nodes + cls.discover_tool_functions() + logger.info( "Node discovery complete." f" Found {len(cls._decorator_registered_classes)} decorated nodes." @@ -173,7 +184,7 @@ def discover_tool_functions(cls) -> None: This method searches recursively through Python files in the PROJECT_ROOT/tools directory for functions decorated with @tool_function and registers their node classes. - Works with both package (with __init__.py) and non-package Python files. + Only works with proper Python packages (directories with __init__.py). """ # Get PROJECT_ROOT from environment variable project_root = os.getenv("PROJECT_ROOT") @@ -194,105 +205,72 @@ def _is_package_dir(path: Path) -> bool: """Check if a directory is a Python package (has __init__.py).""" return (path / "__init__.py").exists() - def _get_module_path(file_path: Path, base_path: Path) -> str: - """Get the appropriate module path for importing. - - For files in a package (directory with __init__.py), returns the full package path. - For standalone files, returns the absolute file path. - """ - try: - rel_path = file_path.relative_to(base_path) - parts = list(rel_path.parts) - - # Build the module path by checking each parent directory - module_parts: List[str] = [] - current_path = base_path - - # Handle the directory parts - for part in parts[:-1]: # Exclude the file name - current_path = current_path / part - if _is_package_dir(current_path): - module_parts.append(part) - else: - # If we hit a non-package directory, we'll use absolute path - return str(file_path) - - # Add the file name without .py - module_parts.append(parts[-1][:-3]) # Remove .py extension - - # If we have a valid package path, return it with dots - if module_parts: - return ".".join(module_parts) + def _register_tool_function_node(func: ToolFunction, category: str) -> None: + """Register a tool function node in the NodeRegistry.""" + node_class = func.node_class + category = "Custom Tools" + if category not in cls._nodes: + cls._nodes[category] = [] - # Fallback to absolute path - return str(file_path) - except Exception: - # Fallback to absolute path if anything goes wrong - return str(file_path) + node_info = NodeInfo( + node_type_name=node_class.__name__, + module=node_class.__module__, + # Using dot notation for nested attribute + class_name=f"{func.func_name}.node_class", + subcategory=getattr(node_class, "subcategory", None), + ) - def _discover_tools_in_directory(path: Path) -> None: - nonlocal registered_tools + if not any(n.node_type_name == node_class.__name__ for n in cls._nodes[category]): + cls._nodes[category].append(node_info) + nonlocal registered_tools + registered_tools += 1 + logger.debug( + f"Registered tool function {node_class.__name__} in category {category}" + ) + + def _is_valid_tool_function(attr: Any) -> bool: + """Check if an attribute is a properly decorated tool function.""" + if not isinstance(attr, ToolFunction): + return False + if not issubclass(attr.node_class, FunctionToolNode): + return False # Skip regular functions + # Must have all required node attributes + required_attrs = {"display_name", "config_model", "input_model", "output_model"} + return all(hasattr(attr.node_class, attr_name) for attr_name in required_attrs) + + def _discover_tools_in_directory(path: Path, base_package: str = "tools") -> None: + """Recursively discover tool functions in package directories.""" + # Skip if not a package directory + if not _is_package_dir(path): + return for item in path.iterdir(): if item.is_file() and item.suffix == ".py" and not item.name.startswith("_"): try: - # Get the appropriate module path/name - module_path = _get_module_path(item, tools_dir) - - # Create a spec for the module - if module_path.endswith(".py"): - # For non-package files, use spec_from_file_location - spec = importlib.util.spec_from_file_location(item.stem, str(item)) - else: - # For package files, use find_spec - spec = importlib.util.find_spec(module_path) - - if spec is None or spec.loader is None: - logger.warning(f"Could not create module spec for {item}") - continue - - # Create and execute the module - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - - # Look for tool functions in module attributes + # Get the module path relative to project root + module_path = f"{base_package}.{item.stem}" + + # Import the module using standard import_module + module = importlib.import_module(module_path) + + # Register any valid tool functions found in the module for attr_name in dir(module): attr = getattr(module, attr_name) - # Check if attribute has node_class (indicating it's a tool function) - if hasattr(attr, "node_class"): + if _is_valid_tool_function(attr): node_class = attr.node_class - # Register the node class if it has a category category = getattr(node_class, "category", "Uncategorized") - if category not in cls._nodes: - cls._nodes[category] = [] - - # Create node registration info - node_info = { - "node_type_name": node_class.__name__, - "module": module_path, # Use the full module path - "class_name": node_class.__name__, - "subcategory": getattr(node_class, "subcategory", None), - } - - # Add to registry if not already present - if not any( - n["node_type_name"] == node_class.__name__ - for n in cls._nodes[category] - ): - cls._nodes[category].append(node_info) - registered_tools += 1 - logger.debug( - f"Registered tool function node {node_class.__name__}" - f" from {module_path} in category {category}" - ) + _register_tool_function_node(attr, category) except Exception as e: logger.error(f"Failed to load module {item}: {e}") + logger.error(traceback.format_exc()) # Recursively process subdirectories elif item.is_dir() and not item.name.startswith("_"): - _discover_tools_in_directory(item) + # Update the base package for the subdirectory + subpackage = f"{base_package}.{item.name}" + _discover_tools_in_directory(item, subpackage) - # Start recursive discovery + # Start recursive discovery from tools directory _discover_tools_in_directory(tools_dir) logger.info(f"Tool function discovery complete. Found {registered_tools} tool functions.") diff --git a/backend/pyspur/schemas/node_type_schemas.py b/backend/pyspur/schemas/node_type_schemas.py index 61ec87f1..f159c04a 100644 --- a/backend/pyspur/schemas/node_type_schemas.py +++ b/backend/pyspur/schemas/node_type_schemas.py @@ -10,8 +10,20 @@ class NodeTypeSchema(BaseModel): @property def node_class(self): + # Import the module module = importlib.import_module(name=f"{self.module}", package="pyspur") - return getattr(module, self.class_name) + + # Split the class name into parts for attribute traversal + parts = self.class_name.split(".") + + # Start with the module + obj = module + + # Traverse the attribute chain + for part in parts: + obj = getattr(obj, part) + + return obj @property def input_model(self): diff --git a/frontend/src/components/chat/MessageCard.tsx b/frontend/src/components/chat/MessageCard.tsx index fb199c5a..681c858c 100644 --- a/frontend/src/components/chat/MessageCard.tsx +++ b/frontend/src/components/chat/MessageCard.tsx @@ -4,11 +4,12 @@ import { Avatar, Badge, Button, Chip, cn, Link, Tooltip } from '@heroui/react' import { useClipboard } from '@heroui/use-clipboard' import { Icon } from '@iconify/react' import React from 'react' +import ReactMarkdown from 'react-markdown' export type MessageCardProps = React.HTMLAttributes & { avatar?: string showFeedback?: boolean - message?: React.ReactNode + message?: string | React.ReactNode currentAttempt?: number status?: 'success' | 'failed' | 'loading' attempts?: number @@ -54,8 +55,8 @@ const MessageCard = React.forwardRef( const failedMessage = (

Something went wrong, if the issue persists please contact us through our help center at  - - support@acmeai.com + + support@pyspur.dev

) @@ -135,7 +136,15 @@ const MessageCard = React.forwardRef( )} >
- {hasFailed ? failedMessage : isLoading ? loadingMessage : message} + {hasFailed ? ( + failedMessage + ) : isLoading ? ( + loadingMessage + ) : typeof message === 'string' ? ( + {message} + ) : ( + message + )}
{showFeedback && !hasFailed && !isLoading && (