diff --git a/cmd/gh-aw/main.go b/cmd/gh-aw/main.go index e8576e3307d..af2a16cbfcc 100644 --- a/cmd/gh-aw/main.go +++ b/cmd/gh-aw/main.go @@ -8,6 +8,7 @@ import ( "github.com/github/gh-aw/pkg/cli" "github.com/github/gh-aw/pkg/console" "github.com/github/gh-aw/pkg/constants" + "github.com/github/gh-aw/pkg/parser" "github.com/github/gh-aw/pkg/workflow" "github.com/spf13/cobra" ) @@ -24,8 +25,23 @@ var bannerFlag bool // validateEngine validates the engine flag value func validateEngine(engine string) error { - if engine != "" && engine != "claude" && engine != "codex" && engine != "copilot" && engine != "custom" { - return fmt.Errorf("invalid engine value '%s'. Must be 'claude', 'codex', 'copilot', or 'custom'", engine) + // Get the global engine registry + registry := workflow.GetGlobalEngineRegistry() + validEngines := registry.GetSupportedEngines() + + if engine != "" && !registry.IsValidEngine(engine) { + // Try to find close matches for "did you mean" suggestion + suggestions := parser.FindClosestMatches(engine, validEngines, 1) + + errMsg := fmt.Sprintf("invalid engine value '%s'. Must be '%s'", + engine, strings.Join(validEngines, "', '")) + + if len(suggestions) > 0 { + errMsg = fmt.Sprintf("invalid engine value '%s'. Must be '%s'.\n\nDid you mean: %s?", + engine, strings.Join(validEngines, "', '"), suggestions[0]) + } + + return fmt.Errorf("%s", errMsg) } return nil } diff --git a/pkg/workflow/engine_validation.go b/pkg/workflow/engine_validation.go index ea1762e1d27..99f4dfa109a 100644 --- a/pkg/workflow/engine_validation.go +++ b/pkg/workflow/engine_validation.go @@ -36,9 +36,11 @@ package workflow import ( "encoding/json" "fmt" + "strings" "github.com/github/gh-aw/pkg/constants" "github.com/github/gh-aw/pkg/logger" + "github.com/github/gh-aw/pkg/parser" ) var engineValidationLog = logger.New("workflow:engine_validation") @@ -66,8 +68,32 @@ func (c *Compiler) validateEngine(engineID string) error { } engineValidationLog.Printf("Engine ID %s not found: %v", engineID, err) - // Provide helpful error with valid options - return fmt.Errorf("invalid engine: %s. Valid engines are: copilot, claude, codex, custom.\n\nExample:\nengine: copilot\n\nSee: %s", engineID, constants.DocsEnginesURL) + + // Get list of valid engine IDs from the engine registry + validEngines := c.engineRegistry.GetSupportedEngines() + + // Try to find close matches for "did you mean" suggestion + suggestions := parser.FindClosestMatches(engineID, validEngines, 1) + + // Build comma-separated list of valid engines for error message + enginesStr := strings.Join(validEngines, ", ") + + // Build error message with helpful context + errMsg := fmt.Sprintf("invalid engine: %s. Valid engines are: %s.\n\nExample:\nengine: copilot\n\nSee: %s", + engineID, + enginesStr, + constants.DocsEnginesURL) + + // Add "did you mean" suggestion if we found a close match + if len(suggestions) > 0 { + errMsg = fmt.Sprintf("invalid engine: %s. Valid engines are: %s.\n\nDid you mean: %s?\n\nExample:\nengine: copilot\n\nSee: %s", + engineID, + enginesStr, + suggestions[0], + constants.DocsEnginesURL) + } + + return fmt.Errorf("%s", errMsg) } // validateSingleEngineSpecification validates that only one engine field exists across all files diff --git a/pkg/workflow/engine_validation_test.go b/pkg/workflow/engine_validation_test.go index 612b1058efc..1fae91338d2 100644 --- a/pkg/workflow/engine_validation_test.go +++ b/pkg/workflow/engine_validation_test.go @@ -300,3 +300,97 @@ func TestValidateSingleEngineSpecificationErrorMessageQuality(t *testing.T) { } }) } + +// TestValidateEngineDidYouMean tests the "did you mean" suggestion feature +func TestValidateEngineDidYouMean(t *testing.T) { + tests := []struct { + name string + invalidEngine string + expectedSuggestion string + shouldHaveSuggestion bool + }{ + { + name: "typo copiilot suggests copilot", + invalidEngine: "copiilot", + expectedSuggestion: "copilot", + shouldHaveSuggestion: true, + }, + { + name: "typo claud suggests claude", + invalidEngine: "claud", + expectedSuggestion: "claude", + shouldHaveSuggestion: true, + }, + { + name: "typo codec suggests codex", + invalidEngine: "codec", + expectedSuggestion: "codex", + shouldHaveSuggestion: true, + }, + { + name: "typo custon suggests custom", + invalidEngine: "custon", + expectedSuggestion: "custom", + shouldHaveSuggestion: true, + }, + { + name: "case difference no suggestion (case-insensitive match)", + invalidEngine: "Copilot", + expectedSuggestion: "", + shouldHaveSuggestion: false, + }, + { + name: "completely wrong gets no suggestion", + invalidEngine: "gpt4", + expectedSuggestion: "", + shouldHaveSuggestion: false, + }, + { + name: "totally different gets no suggestion", + invalidEngine: "xyz", + expectedSuggestion: "", + shouldHaveSuggestion: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + compiler := NewCompiler() + err := compiler.validateEngine(tt.invalidEngine) + + if err == nil { + t.Fatal("Expected validation to fail for invalid engine") + } + + errorMsg := err.Error() + + if tt.shouldHaveSuggestion { + // Should have "Did you mean: X?" suggestion + if !strings.Contains(errorMsg, "Did you mean:") { + t.Errorf("Expected 'Did you mean:' in error message, got: %s", errorMsg) + } + + if !strings.Contains(errorMsg, tt.expectedSuggestion) { + t.Errorf("Expected suggestion '%s' in error message, got: %s", + tt.expectedSuggestion, errorMsg) + } + } else { + // Should NOT have "Did you mean:" suggestion + if strings.Contains(errorMsg, "Did you mean:") { + t.Errorf("Should not suggest anything for '%s', but got: %s", + tt.invalidEngine, errorMsg) + } + } + + // All errors should still list valid engines + if !strings.Contains(errorMsg, "copilot") { + t.Errorf("Error should always list valid engines, got: %s", errorMsg) + } + + // All errors should still include an example + if !strings.Contains(errorMsg, "Example:") { + t.Errorf("Error should always include an example, got: %s", errorMsg) + } + }) + } +} diff --git a/pkg/workflow/github_tool_to_toolset.go b/pkg/workflow/github_tool_to_toolset.go index 0383a4c6c20..2fae897bf66 100644 --- a/pkg/workflow/github_tool_to_toolset.go +++ b/pkg/workflow/github_tool_to_toolset.go @@ -3,8 +3,11 @@ package workflow import ( _ "embed" "encoding/json" + "fmt" + "sort" "github.com/github/gh-aw/pkg/logger" + "github.com/github/gh-aw/pkg/parser" ) var githubToolToToolsetLog = logger.New("workflow:github_tool_to_toolset") @@ -45,10 +48,37 @@ func ValidateGitHubToolsAgainstToolsets(allowedTools []string, enabledToolsets [ // Track missing toolsets and which tools need them missingToolsets := make(map[string][]string) // toolset -> list of tools that need it + // Track unknown tools for suggestions + var unknownTools []string + var suggestions []string + for _, tool := range allowedTools { + // Skip wildcard - it means "allow all tools" + if tool == "*" { + continue + } + requiredToolset, exists := GitHubToolToToolsetMap[tool] if !exists { - githubToolToToolsetLog.Printf("Tool %s not found in mapping, skipping validation", tool) + githubToolToToolsetLog.Printf("Tool %s not found in mapping, checking for typo", tool) + + // Get all valid tool names for suggestion + validTools := make([]string, 0, len(GitHubToolToToolsetMap)) + for validTool := range GitHubToolToToolsetMap { + validTools = append(validTools, validTool) + } + sort.Strings(validTools) + + // Try to find close matches + matches := parser.FindClosestMatches(tool, validTools, 1) + if len(matches) > 0 { + githubToolToToolsetLog.Printf("Found suggestion for unknown tool %s: %s", tool, matches[0]) + unknownTools = append(unknownTools, tool) + suggestions = append(suggestions, fmt.Sprintf("%s → %s", tool, matches[0])) + } else { + githubToolToToolsetLog.Printf("No suggestion found for unknown tool: %s", tool) + unknownTools = append(unknownTools, tool) + } // Tool not in our mapping - this could be a new tool or a typo // We'll skip validation for unknown tools to avoid false positives continue @@ -60,6 +90,36 @@ func ValidateGitHubToolsAgainstToolsets(allowedTools []string, enabledToolsets [ } } + // Report unknown tools with suggestions if any were found + if len(unknownTools) > 0 { + githubToolToToolsetLog.Printf("Found %d unknown tools", len(unknownTools)) + errMsg := fmt.Sprintf("Unknown GitHub tool(s): %s\n\n", formatList(unknownTools)) + + if len(suggestions) > 0 { + errMsg += "Did you mean:\n" + for _, s := range suggestions { + errMsg += fmt.Sprintf(" %s\n", s) + } + errMsg += "\n" + } + + // Show a few examples of valid tools + validTools := make([]string, 0, len(GitHubToolToToolsetMap)) + for tool := range GitHubToolToToolsetMap { + validTools = append(validTools, tool) + } + sort.Strings(validTools) + + exampleCount := 10 + if len(validTools) < exampleCount { + exampleCount = len(validTools) + } + errMsg += fmt.Sprintf("Valid GitHub tools include: %s\n\n", formatList(validTools[:exampleCount])) + errMsg += "See all tools: https://github.com/github/gh-aw/blob/main/pkg/workflow/data/github_tool_to_toolset.json" + + return fmt.Errorf("%s", errMsg) + } + if len(missingToolsets) > 0 { githubToolToToolsetLog.Printf("Validation failed: missing %d toolsets", len(missingToolsets)) return NewGitHubToolsetValidationError(missingToolsets) @@ -68,3 +128,17 @@ func ValidateGitHubToolsAgainstToolsets(allowedTools []string, enabledToolsets [ githubToolToToolsetLog.Print("Validation successful: all tools have required toolsets") return nil } + +// formatList formats a list of strings as a comma-separated list +func formatList(items []string) string { + if len(items) == 0 { + return "" + } + if len(items) == 1 { + return items[0] + } + if len(items) == 2 { + return items[0] + " and " + items[1] + } + return fmt.Sprintf("%s, and %s", formatList(items[:len(items)-1]), items[len(items)-1]) +} diff --git a/pkg/workflow/github_tool_to_toolset_test.go b/pkg/workflow/github_tool_to_toolset_test.go index b9a32189d88..8815d1082da 100644 --- a/pkg/workflow/github_tool_to_toolset_test.go +++ b/pkg/workflow/github_tool_to_toolset_test.go @@ -76,14 +76,15 @@ func TestValidateGitHubToolsAgainstToolsets(t *testing.T) { name: "Unknown tool is ignored", allowedTools: []string{"get_repository", "unknown_tool_xyz"}, enabledToolsets: []string{"repos"}, - expectError: false, + expectError: true, + errorContains: []string{"Unknown GitHub tool", "unknown_tool_xyz"}, }, { name: "Mix of known and unknown tools", allowedTools: []string{"get_repository", "unknown_tool", "list_issues"}, enabledToolsets: []string{"repos"}, // issues missing expectError: true, - errorContains: []string{"issues", "list_issues"}, + errorContains: []string{"Unknown GitHub tool", "unknown_tool"}, }, { name: "Actions toolset tools", @@ -308,3 +309,129 @@ func expandToolsetsForTesting(toolsets []string) []string { return expanded } + +// TestValidateGitHubToolsDidYouMean tests the "did you mean" suggestion feature for GitHub tools +func TestValidateGitHubToolsDidYouMean(t *testing.T) { + tests := []struct { + name string + invalidTool string + expectedSuggestion string + shouldHaveSuggestion bool + }{ + { + name: "typo issue_raed suggests issue_read", + invalidTool: "issue_raed", + expectedSuggestion: "issue_read", + shouldHaveSuggestion: true, + }, + { + name: "typo crate_issue suggests create_issue", + invalidTool: "crate_issue", + expectedSuggestion: "create_issue", + shouldHaveSuggestion: true, + }, + { + name: "typo get_repositry suggests get_repository", + invalidTool: "get_repositry", + expectedSuggestion: "get_repository", + shouldHaveSuggestion: true, + }, + { + name: "typo list_workflos suggests list_workflows", + invalidTool: "list_workflos", + expectedSuggestion: "list_workflows", + shouldHaveSuggestion: true, + }, + { + name: "typo serch_code suggests search_code", + invalidTool: "serch_code", + expectedSuggestion: "search_code", + shouldHaveSuggestion: true, + }, + { + name: "completely wrong tool gets no suggestion", + invalidTool: "xyz_abc_123", + expectedSuggestion: "", + shouldHaveSuggestion: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test with the invalid tool + allowedTools := []string{"get_repository", tt.invalidTool} + enabledToolsets := []string{"repos"} + + err := ValidateGitHubToolsAgainstToolsets(allowedTools, enabledToolsets) + + if err == nil { + t.Fatal("Expected validation to fail for unknown tool") + } + + errorMsg := err.Error() + + // Should mention the unknown tool + if !strings.Contains(errorMsg, "Unknown GitHub tool") { + t.Errorf("Expected 'Unknown GitHub tool' in error message, got: %s", errorMsg) + } + + if !strings.Contains(errorMsg, tt.invalidTool) { + t.Errorf("Expected invalid tool '%s' in error message, got: %s", + tt.invalidTool, errorMsg) + } + + if tt.shouldHaveSuggestion { + // Should have "Did you mean:" suggestion + if !strings.Contains(errorMsg, "Did you mean:") { + t.Errorf("Expected 'Did you mean:' in error message, got: %s", errorMsg) + } + + if !strings.Contains(errorMsg, tt.expectedSuggestion) { + t.Errorf("Expected suggestion '%s' in error message, got: %s", + tt.expectedSuggestion, errorMsg) + } + } else { + // Should NOT have "Did you mean:" suggestion + if strings.Contains(errorMsg, "Did you mean:") { + t.Errorf("Should not suggest anything for '%s', but got: %s", + tt.invalidTool, errorMsg) + } + } + + // All errors should list some valid tools + if !strings.Contains(errorMsg, "Valid GitHub tools") { + t.Errorf("Error should list valid GitHub tools, got: %s", errorMsg) + } + }) + } +} + +// TestValidateGitHubToolsMultipleUnknown tests error message when multiple unknown tools are used +func TestValidateGitHubToolsMultipleUnknown(t *testing.T) { + allowedTools := []string{"get_repository", "issue_raed", "crate_issue", "unknown_xyz"} + enabledToolsets := []string{"repos", "issues"} + + err := ValidateGitHubToolsAgainstToolsets(allowedTools, enabledToolsets) + + if err == nil { + t.Fatal("Expected validation to fail for unknown tools") + } + + errorMsg := err.Error() + + // Should mention all unknown tools + if !strings.Contains(errorMsg, "issue_raed") { + t.Errorf("Expected 'issue_raed' in error message, got: %s", errorMsg) + } + if !strings.Contains(errorMsg, "crate_issue") { + t.Errorf("Expected 'crate_issue' in error message, got: %s", errorMsg) + } + if !strings.Contains(errorMsg, "unknown_xyz") { + t.Errorf("Expected 'unknown_xyz' in error message, got: %s", errorMsg) + } + + // Should have suggestions section + if !strings.Contains(errorMsg, "Did you mean:") { + t.Errorf("Expected 'Did you mean:' in error message, got: %s", errorMsg) + } +}