Просмотр исходного кода

feat: Add tool restriction flags for non-interactive mode (#29)

Ed Zynda 9 месяцев назад
Родитель
Сommit
307982a099
7 измененных файлов с 372 добавлено и 105 удалено
  1. 2 1
      .gitignore
  2. 29 8
      README.md
  3. 292 0
      cmd/non_interactive_mode.go
  4. 17 96
      cmd/root.go
  5. 5 0
      go.mod
  6. 2 0
      go.sum
  7. 25 0
      internal/tui/components/spinner/spinner.go

+ 2 - 1
.gitignore

@@ -42,4 +42,5 @@ Thumbs.db
 .env.local
 
 .opencode/
-
+# ignore locally built binary
+opencode*

+ 29 - 8
README.md

@@ -225,10 +225,28 @@ opencode -p "Explain the use of context in Go" -f json
 
 # Run without showing the spinner
 opencode -p "Explain the use of context in Go" -q
+
+# Enable verbose logging to stderr
+opencode -p "Explain the use of context in Go" --verbose
+
+# Restrict the agent to only use specific tools
+opencode -p "Explain the use of context in Go" --allowedTools=view,ls,glob
+
+# Prevent the agent from using specific tools
+opencode -p "Explain the use of context in Go" --excludedTools=bash,edit
 ```
 
 In this mode, OpenCode will process your prompt, print the result to standard output, and then exit. All permissions are auto-approved for the session.
 
+### Tool Restrictions
+
+You can control which tools the AI assistant has access to in non-interactive mode:
+
+- `--allowedTools`: Comma-separated list of tools that the agent is allowed to use. Only these tools will be available.
+- `--excludedTools`: Comma-separated list of tools that the agent is not allowed to use. All other tools will be available.
+
+These flags are mutually exclusive - you can use either `--allowedTools` or `--excludedTools`, but not both at the same time.
+
 ### Output Formats
 
 OpenCode supports the following output formats in non-interactive mode:
@@ -242,14 +260,17 @@ The output format is implemented as a strongly-typed `OutputFormat` in the codeb
 
 ## Command-line Flags
 
-| Flag              | Short | Description                                            |
-| ----------------- | ----- | ------------------------------------------------------ |
-| `--help`          | `-h`  | Display help information                               |
-| `--debug`         | `-d`  | Enable debug mode                                      |
-| `--cwd`           | `-c`  | Set current working directory                          |
-| `--prompt`        | `-p`  | Run a single prompt in non-interactive mode            |
-| `--output-format` | `-f`  | Output format for non-interactive mode (text, json)    |
-| `--quiet`         | `-q`  | Hide spinner in non-interactive mode                   |
+| Flag              | Short | Description                                                |
+| ----------------- | ----- | ---------------------------------------------------------- |
+| `--help`          | `-h`  | Display help information                                   |
+| `--debug`         | `-d`  | Enable debug mode                                          |
+| `--cwd`           | `-c`  | Set current working directory                              |
+| `--prompt`        | `-p`  | Run a single prompt in non-interactive mode                |
+| `--output-format` | `-f`  | Output format for non-interactive mode (text, json)        |
+| `--quiet`         | `-q`  | Hide spinner in non-interactive mode                       |
+| `--verbose`       |       | Display logs to stderr in non-interactive mode             |
+| `--allowedTools`  |       | Restrict the agent to only use specified tools             |
+| `--excludedTools` |       | Prevent the agent from using specified tools               |
 
 ## Keyboard Shortcuts
 

+ 292 - 0
cmd/non_interactive_mode.go

@@ -0,0 +1,292 @@
+package cmd
+
+import (
+	"context"
+	"fmt"
+	"io"
+	"os"
+	"sync"
+	"time"
+
+	"log/slog"
+
+	charmlog "github.com/charmbracelet/log"
+	"github.com/sst/opencode/internal/app"
+	"github.com/sst/opencode/internal/config"
+	"github.com/sst/opencode/internal/db"
+	"github.com/sst/opencode/internal/format"
+	"github.com/sst/opencode/internal/llm/agent"
+	"github.com/sst/opencode/internal/llm/tools"
+	"github.com/sst/opencode/internal/message"
+	"github.com/sst/opencode/internal/permission"
+	"github.com/sst/opencode/internal/tui/components/spinner"
+	"github.com/sst/opencode/internal/tui/theme"
+)
+
+// syncWriter is a thread-safe writer that prevents interleaved output
+type syncWriter struct {
+	w  io.Writer
+	mu sync.Mutex
+}
+
+// Write implements io.Writer
+func (sw *syncWriter) Write(p []byte) (n int, err error) {
+	sw.mu.Lock()
+	defer sw.mu.Unlock()
+	return sw.w.Write(p)
+}
+
+// newSyncWriter creates a new synchronized writer
+func newSyncWriter(w io.Writer) io.Writer {
+	return &syncWriter{w: w}
+}
+
+// filterTools filters the provided tools based on allowed or excluded tool names
+func filterTools(allTools []tools.BaseTool, allowedTools, excludedTools []string) []tools.BaseTool {
+	// If neither allowed nor excluded tools are specified, return all tools
+	if len(allowedTools) == 0 && len(excludedTools) == 0 {
+		return allTools
+	}
+
+	// Create a map for faster lookups
+	allowedMap := make(map[string]bool)
+	for _, name := range allowedTools {
+		allowedMap[name] = true
+	}
+
+	excludedMap := make(map[string]bool)
+	for _, name := range excludedTools {
+		excludedMap[name] = true
+	}
+
+	var filteredTools []tools.BaseTool
+
+	for _, tool := range allTools {
+		toolName := tool.Info().Name
+
+		// If we have an allowed list, only include tools in that list
+		if len(allowedTools) > 0 {
+			if allowedMap[toolName] {
+				filteredTools = append(filteredTools, tool)
+			}
+		} else if len(excludedTools) > 0 {
+			// If we have an excluded list, include all tools except those in the list
+			if !excludedMap[toolName] {
+				filteredTools = append(filteredTools, tool)
+			}
+		}
+	}
+
+	return filteredTools
+}
+
+// handleNonInteractiveMode processes a single prompt in non-interactive mode
+func handleNonInteractiveMode(ctx context.Context, prompt string, outputFormat format.OutputFormat, quiet bool, verbose bool, allowedTools, excludedTools []string) error {
+	// Initial log message using standard slog
+	slog.Info("Running in non-interactive mode", "prompt", prompt, "format", outputFormat, "quiet", quiet, "verbose", verbose,
+		"allowedTools", allowedTools, "excludedTools", excludedTools)
+
+	// Sanity check for mutually exclusive flags
+	if quiet && verbose {
+		return fmt.Errorf("--quiet and --verbose flags cannot be used together")
+	}
+
+	// Set up logging to stderr if verbose mode is enabled
+	if verbose {
+		// Create a synchronized writer to prevent interleaved output
+		syncWriter := newSyncWriter(os.Stderr)
+
+		// Create a charmbracelet/log logger that writes to the synchronized writer
+		charmLogger := charmlog.NewWithOptions(syncWriter, charmlog.Options{
+			Level:           charmlog.DebugLevel,
+			ReportCaller:    true,
+			ReportTimestamp: true,
+			TimeFormat:      time.RFC3339,
+			Prefix:          "OpenCode",
+		})
+
+		// Set the global logger for charmbracelet/log
+		charmlog.SetDefault(charmLogger)
+
+		// Create a slog handler that uses charmbracelet/log
+		// This will forward all slog logs to charmbracelet/log
+		slog.SetDefault(slog.New(charmLogger))
+
+		// Log a message to confirm verbose logging is enabled
+		charmLogger.Info("Verbose logging enabled")
+	}
+
+	// Start spinner if not in quiet mode
+	var s *spinner.Spinner
+	if !quiet {
+		// Get the current theme to style the spinner
+		currentTheme := theme.CurrentTheme()
+
+		// Create a themed spinner
+		if currentTheme != nil {
+			// Use the primary color from the theme
+			s = spinner.NewThemedSpinner("Thinking...", currentTheme.Primary())
+		} else {
+			// Fallback to default spinner if no theme is available
+			s = spinner.NewSpinner("Thinking...")
+		}
+
+		s.Start()
+		defer s.Stop()
+	}
+
+	// Connect DB, this will also run migrations
+	conn, err := db.Connect()
+	if err != nil {
+		return err
+	}
+
+	// Create a context with cancellation
+	ctx, cancel := context.WithCancel(ctx)
+	defer cancel()
+
+	// Create the app
+	app, err := app.New(ctx, conn)
+	if err != nil {
+		slog.Error("Failed to create app", "error", err)
+		return err
+	}
+
+	// Create a new session for this prompt
+	session, err := app.Sessions.Create(ctx, "Non-interactive prompt")
+	if err != nil {
+		return fmt.Errorf("failed to create session: %w", err)
+	}
+
+	// Set the session as current
+	app.CurrentSession = &session
+
+	// Auto-approve all permissions for this session
+	permission.AutoApproveSession(ctx, session.ID)
+
+	// Create the user message
+	_, err = app.Messages.Create(ctx, session.ID, message.CreateMessageParams{
+		Role:  message.User,
+		Parts: []message.ContentPart{message.TextContent{Text: prompt}},
+	})
+	if err != nil {
+		return fmt.Errorf("failed to create message: %w", err)
+	}
+
+	// If tool restrictions are specified, create a new agent with filtered tools
+	if len(allowedTools) > 0 || len(excludedTools) > 0 {
+		// Initialize MCP tools synchronously to ensure they're included in filtering
+		mcpCtx, mcpCancel := context.WithTimeout(ctx, 10*time.Second)
+		agent.GetMcpTools(mcpCtx, app.Permissions)
+		mcpCancel()
+
+		// Get all available tools including MCP tools
+		allTools := agent.PrimaryAgentTools(
+			app.Permissions,
+			app.Sessions,
+			app.Messages,
+			app.History,
+			app.LSPClients,
+		)
+
+		// Filter tools based on allowed/excluded lists
+		filteredTools := filterTools(allTools, allowedTools, excludedTools)
+
+		// Log the filtered tools for debugging
+		var toolNames []string
+		for _, tool := range filteredTools {
+			toolNames = append(toolNames, tool.Info().Name)
+		}
+		slog.Debug("Using filtered tools", "count", len(filteredTools), "tools", toolNames)
+
+		// Create a new agent with the filtered tools
+		restrictedAgent, err := agent.NewAgent(
+			config.AgentPrimary,
+			app.Sessions,
+			app.Messages,
+			filteredTools,
+		)
+		if err != nil {
+			return fmt.Errorf("failed to create restricted agent: %w", err)
+		}
+
+		// Use the restricted agent for this request
+		eventCh, err := restrictedAgent.Run(ctx, session.ID, prompt)
+		if err != nil {
+			return fmt.Errorf("failed to run restricted agent: %w", err)
+		}
+
+		// Wait for the response
+		var response message.Message
+		for event := range eventCh {
+			if event.Err() != nil {
+				return fmt.Errorf("agent error: %w", event.Err())
+			}
+			response = event.Response()
+		}
+
+		// Format and print the output
+		content := ""
+		if textContent := response.Content(); textContent != nil {
+			content = textContent.Text
+		}
+
+		formattedOutput, err := format.FormatOutput(content, outputFormat)
+		if err != nil {
+			return fmt.Errorf("failed to format output: %w", err)
+		}
+
+		// Stop spinner before printing output
+		if !quiet && s != nil {
+			s.Stop()
+		}
+
+		// Print the formatted output to stdout
+		fmt.Println(formattedOutput)
+
+		// Shutdown the app
+		app.Shutdown()
+
+		return nil
+	}
+
+	// Run the default agent if no tool restrictions
+	eventCh, err := app.PrimaryAgent.Run(ctx, session.ID, prompt)
+	if err != nil {
+		return fmt.Errorf("failed to run agent: %w", err)
+	}
+
+	// Wait for the response
+	var response message.Message
+	for event := range eventCh {
+		if event.Err() != nil {
+			return fmt.Errorf("agent error: %w", event.Err())
+		}
+		response = event.Response()
+	}
+
+	// Get the text content from the response
+	content := ""
+	if textContent := response.Content(); textContent != nil {
+		content = textContent.Text
+	}
+
+	// Format the output according to the specified format
+	formattedOutput, err := format.FormatOutput(content, outputFormat)
+	if err != nil {
+		return fmt.Errorf("failed to format output: %w", err)
+	}
+
+	// Stop spinner before printing output
+	if !quiet && s != nil {
+		s.Stop()
+	}
+
+	// Print the formatted output to stdout
+	fmt.Println(formattedOutput)
+
+	// Shutdown the app
+	app.Shutdown()
+
+	return nil
+}

+ 17 - 96
cmd/root.go

@@ -19,11 +19,8 @@ import (
 	"github.com/sst/opencode/internal/llm/agent"
 	"github.com/sst/opencode/internal/logging"
 	"github.com/sst/opencode/internal/lsp/discovery"
-	"github.com/sst/opencode/internal/message"
-	"github.com/sst/opencode/internal/permission"
 	"github.com/sst/opencode/internal/pubsub"
 	"github.com/sst/opencode/internal/tui"
-	"github.com/sst/opencode/internal/tui/components/spinner"
 	"github.com/sst/opencode/internal/version"
 )
 
@@ -100,9 +97,15 @@ to assist developers in writing, debugging, and understanding code directly from
 			if !outputFormat.IsValid() {
 				return fmt.Errorf("invalid output format: %s", outputFormatStr)
 			}
-			
+
 			quiet, _ := cmd.Flags().GetBool("quiet")
-			return handleNonInteractiveMode(cmd.Context(), prompt, outputFormat, quiet)
+			verbose, _ := cmd.Flags().GetBool("verbose")
+
+			// Get tool restriction flags
+			allowedTools, _ := cmd.Flags().GetStringSlice("allowedTools")
+			excludedTools, _ := cmd.Flags().GetStringSlice("excludedTools")
+
+			return handleNonInteractiveMode(cmd.Context(), prompt, outputFormat, quiet, verbose, allowedTools, excludedTools)
 		}
 
 		// Run LSP auto-discovery
@@ -222,97 +225,6 @@ func initMCPTools(ctx context.Context, app *app.App) {
 	}()
 }
 
-// handleNonInteractiveMode processes a single prompt in non-interactive mode
-func handleNonInteractiveMode(ctx context.Context, prompt string, outputFormat format.OutputFormat, quiet bool) error {
-	slog.Info("Running in non-interactive mode", "prompt", prompt, "format", outputFormat, "quiet", quiet)
-	
-	// Start spinner if not in quiet mode
-	var s *spinner.Spinner
-	if !quiet {
-		s = spinner.NewSpinner("Thinking...")
-		s.Start()
-		defer s.Stop()
-	}
-	
-	// Connect DB, this will also run migrations
-	conn, err := db.Connect()
-	if err != nil {
-		return err
-	}
-	
-	// Create a context with cancellation
-	ctx, cancel := context.WithCancel(ctx)
-	defer cancel()
-	
-	// Create the app
-	app, err := app.New(ctx, conn)
-	if err != nil {
-		slog.Error("Failed to create app", "error", err)
-		return err
-	}
-	
-	// Auto-approve all permissions for non-interactive mode
-	permission.AutoApproveSession(ctx, "non-interactive")
-	
-	// Create a new session for this prompt
-	session, err := app.Sessions.Create(ctx, "Non-interactive prompt")
-	if err != nil {
-		return fmt.Errorf("failed to create session: %w", err)
-	}
-	
-	// Set the session as current
-	app.CurrentSession = &session
-	
-	// Create the user message
-	_, err = app.Messages.Create(ctx, session.ID, message.CreateMessageParams{
-		Role:  message.User,
-		Parts: []message.ContentPart{message.TextContent{Text: prompt}},
-	})
-	if err != nil {
-		return fmt.Errorf("failed to create message: %w", err)
-	}
-	
-	// Run the agent to get a response
-	eventCh, err := app.PrimaryAgent.Run(ctx, session.ID, prompt)
-	if err != nil {
-		return fmt.Errorf("failed to run agent: %w", err)
-	}
-	
-	// Wait for the response
-	var response message.Message
-	for event := range eventCh {
-		if event.Err() != nil {
-			return fmt.Errorf("agent error: %w", event.Err())
-		}
-		response = event.Response()
-	}
-	
-	// Get the text content from the response
-	content := ""
-	if textContent := response.Content(); textContent != nil {
-		content = textContent.Text
-	}
-	
-	// Format the output according to the specified format
-	formattedOutput, err := format.FormatOutput(content, outputFormat)
-	if err != nil {
-		return fmt.Errorf("failed to format output: %w", err)
-	}
-	
-	// Stop spinner before printing output
-	if !quiet && s != nil {
-		s.Stop()
-	}
-	
-	// Print the formatted output to stdout
-	fmt.Println(formattedOutput)
-	
-	// Shutdown the app
-	app.Shutdown()
-	
-	return nil
-}
-
 func setupSubscriber[T any](
 	ctx context.Context,
 	wg *sync.WaitGroup,
@@ -407,4 +319,13 @@ func init() {
 	rootCmd.Flags().StringP("prompt", "p", "", "Run a single prompt in non-interactive mode")
 	rootCmd.Flags().StringP("output-format", "f", "text", "Output format for non-interactive mode (text, json)")
 	rootCmd.Flags().BoolP("quiet", "q", false, "Hide spinner in non-interactive mode")
+	rootCmd.Flags().BoolP("verbose", "", false, "Display logs to stderr in non-interactive mode")
+	rootCmd.Flags().StringSlice("allowedTools", nil, "Restrict the agent to only use the specified tools in non-interactive mode (comma-separated list)")
+	rootCmd.Flags().StringSlice("excludedTools", nil, "Prevent the agent from using the specified tools in non-interactive mode (comma-separated list)")
+
+	// Make allowedTools and excludedTools mutually exclusive
+	rootCmd.MarkFlagsMutuallyExclusive("allowedTools", "excludedTools")
+
+	// Make quiet and verbose mutually exclusive
+	rootCmd.MarkFlagsMutuallyExclusive("quiet", "verbose")
 }

+ 5 - 0
go.mod

@@ -34,6 +34,11 @@ require (
 	github.com/stretchr/testify v1.10.0
 )
 
+require (
+	github.com/charmbracelet/log v0.4.2 // indirect
+	golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 // indirect
+)
+
 require (
 	cloud.google.com/go v0.116.0 // indirect
 	cloud.google.com/go/auth v0.13.0 // indirect

+ 2 - 0
go.sum

@@ -80,6 +80,8 @@ github.com/charmbracelet/glamour v0.9.1 h1:11dEfiGP8q1BEqvGoIjivuc2rBk+5qEXdPtaQ
 github.com/charmbracelet/glamour v0.9.1/go.mod h1:+SHvIS8qnwhgTpVMiXwn7OfGomSqff1cHBCI8jLOetk=
 github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
 github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
+github.com/charmbracelet/log v0.4.2 h1:hYt8Qj6a8yLnvR+h7MwsJv/XvmBJXiueUcI3cIxsyig=
+github.com/charmbracelet/log v0.4.2/go.mod h1:qifHGX/tc7eluv2R6pWIpyHDDrrb/AG71Pf2ysQu5nw=
 github.com/charmbracelet/x/ansi v0.8.0 h1:9GTq3xq9caJW8ZrBTe0LIe2fvfLR/bYXKTx2llXn7xE=
 github.com/charmbracelet/x/ansi v0.8.0/go.mod h1:wdYl/ONOLHLIVmQaxbIYEC/cRKOQyjTkowiI4blgS9Q=
 github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8=

+ 25 - 0
internal/tui/components/spinner/spinner.go

@@ -7,6 +7,7 @@ import (
 
 	"github.com/charmbracelet/bubbles/spinner"
 	tea "github.com/charmbracelet/bubbletea"
+	"github.com/charmbracelet/lipgloss"
 )
 
 // Spinner wraps the bubbles spinner for both interactive and non-interactive mode
@@ -80,6 +81,30 @@ func NewSpinner(message string) *Spinner {
 	}
 }
 
+// NewThemedSpinner creates a new spinner with the given message and color
+func NewThemedSpinner(message string, color lipgloss.AdaptiveColor) *Spinner {
+	s := spinner.New()
+	s.Spinner = spinner.Dot
+	s.Style = s.Style.Foreground(color)
+
+	ctx, cancel := context.WithCancel(context.Background())
+
+	model := spinnerModel{
+		spinner: s,
+		message: message,
+	}
+
+	prog := tea.NewProgram(model, tea.WithOutput(os.Stderr), tea.WithoutCatchPanics())
+
+	return &Spinner{
+		model:  s,
+		done:   make(chan struct{}),
+		prog:   prog,
+		ctx:    ctx,
+		cancel: cancel,
+	}
+}
+
 // Start begins the spinner animation
 func (s *Spinner) Start() {
 	go func() {