Просмотр исходного кода

feat(tui): handle --model and --prompt flags

adamdottv 7 месяцев назад
Родитель
Сommit
ea96ead346

+ 6 - 1
packages/tui/cmd/opencode/main.go

@@ -9,6 +9,7 @@ import (
 	"strings"
 
 	tea "github.com/charmbracelet/bubbletea/v2"
+	flag "github.com/spf13/pflag"
 	"github.com/sst/opencode-sdk-go"
 	"github.com/sst/opencode-sdk-go/option"
 	"github.com/sst/opencode/internal/app"
@@ -23,6 +24,10 @@ func main() {
 		version = "v" + Version
 	}
 
+	var model *string = flag.String("model", "", "model to begin with")
+	var prompt *string = flag.String("prompt", "", "prompt to begin with")
+	flag.Parse()
+
 	url := os.Getenv("OPENCODE_SERVER")
 
 	appInfoStr := os.Getenv("OPENCODE_APP_INFO")
@@ -65,7 +70,7 @@ func main() {
 	ctx, cancel := context.WithCancel(context.Background())
 	defer cancel()
 
-	app_, err := app.New(ctx, version, appInfo, httpClient)
+	app_, err := app.New(ctx, version, appInfo, httpClient, model, prompt)
 	if err != nil {
 		panic(err)
 	}

+ 1 - 1
packages/tui/go.mod

@@ -17,7 +17,6 @@ require (
 	github.com/muesli/termenv v0.16.0
 	github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3
 	github.com/sst/opencode-sdk-go v0.1.0-alpha.8
-	github.com/tidwall/gjson v1.14.4
 	rsc.io/qr v0.2.0
 )
 
@@ -50,6 +49,7 @@ require (
 	github.com/sosodev/duration v1.3.1 // indirect
 	github.com/speakeasy-api/openapi-overlay v0.9.0 // indirect
 	github.com/spf13/cobra v1.9.1 // indirect
+	github.com/tidwall/gjson v1.14.4 // indirect
 	github.com/tidwall/match v1.1.1 // indirect
 	github.com/tidwall/pretty v1.2.1 // indirect
 	github.com/tidwall/sjson v1.2.5 // indirect

+ 93 - 63
packages/tui/internal/app/app.go

@@ -21,17 +21,19 @@ import (
 )
 
 type App struct {
-	Info      opencode.App
-	Version   string
-	StatePath string
-	Config    *opencode.Config
-	Client    *opencode.Client
-	State     *config.State
-	Provider  *opencode.Provider
-	Model     *opencode.Model
-	Session   *opencode.Session
-	Messages  []opencode.MessageUnion
-	Commands  commands.CommandRegistry
+	Info          opencode.App
+	Version       string
+	StatePath     string
+	Config        *opencode.Config
+	Client        *opencode.Client
+	State         *config.State
+	Provider      *opencode.Provider
+	Model         *opencode.Model
+	Session       *opencode.Session
+	Messages      []opencode.MessageUnion
+	Commands      commands.CommandRegistry
+	InitialModel  *string
+	InitialPrompt *string
 }
 
 type SessionSelectedMsg = *opencode.Session
@@ -58,6 +60,8 @@ func New(
 	version string,
 	appInfo opencode.App,
 	httpClient *opencode.Client,
+	model *string,
+	prompt *string,
 ) (*App, error) {
 	util.RootPath = appInfo.Path.Root
 	util.CwdPath = appInfo.Path.Cwd
@@ -109,15 +113,17 @@ func New(
 	slog.Debug("Loaded config", "config", configInfo)
 
 	app := &App{
-		Info:      appInfo,
-		Version:   version,
-		StatePath: appStatePath,
-		Config:    configInfo,
-		State:     appState,
-		Client:    httpClient,
-		Session:   &opencode.Session{},
-		Messages:  []opencode.MessageUnion{},
-		Commands:  commands.LoadFromConfig(configInfo),
+		Info:          appInfo,
+		Version:       version,
+		StatePath:     appStatePath,
+		Config:        configInfo,
+		State:         appState,
+		Client:        httpClient,
+		Session:       &opencode.Session{},
+		Messages:      []opencode.MessageUnion{},
+		Commands:      commands.LoadFromConfig(configInfo),
+		InitialModel:  model,
+		InitialPrompt: prompt,
 	}
 
 	return app, nil
@@ -141,65 +147,89 @@ func (a *App) Key(commandName commands.CommandName) string {
 }
 
 func (a *App) InitializeProvider() tea.Cmd {
-	return func() tea.Msg {
-		providersResponse, err := a.Client.Config.Providers(context.Background())
-		if err != nil {
-			slog.Error("Failed to list providers", "error", err)
-			// TODO: notify user
-			return nil
+	providersResponse, err := a.Client.Config.Providers(context.Background())
+	if err != nil {
+		slog.Error("Failed to list providers", "error", err)
+		// TODO: notify user
+		return nil
+	}
+	providers := providersResponse.Providers
+	var defaultProvider *opencode.Provider
+	var defaultModel *opencode.Model
+
+	var anthropic *opencode.Provider
+	for _, provider := range providers {
+		if provider.ID == "anthropic" {
+			anthropic = &provider
 		}
-		providers := providersResponse.Providers
-		var defaultProvider *opencode.Provider
-		var defaultModel *opencode.Model
+	}
 
-		var anthropic *opencode.Provider
-		for _, provider := range providers {
-			if provider.ID == "anthropic" {
-				anthropic = &provider
-			}
-		}
+	// default to anthropic if available
+	if anthropic != nil {
+		defaultProvider = anthropic
+		defaultModel = getDefaultModel(providersResponse, *anthropic)
+	}
 
-		// default to anthropic if available
-		if anthropic != nil {
-			defaultProvider = anthropic
-			defaultModel = getDefaultModel(providersResponse, *anthropic)
+	for _, provider := range providers {
+		if defaultProvider == nil || defaultModel == nil {
+			defaultProvider = &provider
+			defaultModel = getDefaultModel(providersResponse, provider)
 		}
+		providers = append(providers, provider)
+	}
+	if len(providers) == 0 {
+		slog.Error("No providers configured")
+		return nil
+	}
 
-		for _, provider := range providers {
-			if defaultProvider == nil || defaultModel == nil {
-				defaultProvider = &provider
-				defaultModel = getDefaultModel(providersResponse, provider)
+	var currentProvider *opencode.Provider
+	var currentModel *opencode.Model
+	for _, provider := range providers {
+		if provider.ID == a.State.Provider {
+			currentProvider = &provider
+
+			for _, model := range provider.Models {
+				if model.ID == a.State.Model {
+					currentModel = &model
+				}
 			}
-			providers = append(providers, provider)
-		}
-		if len(providers) == 0 {
-			slog.Error("No providers configured")
-			return nil
 		}
+	}
+	if currentProvider == nil || currentModel == nil {
+		currentProvider = defaultProvider
+		currentModel = defaultModel
+	}
 
-		var currentProvider *opencode.Provider
-		var currentModel *opencode.Model
+	var initialProvider *opencode.Provider
+	var initialModel *opencode.Model
+	if a.InitialModel != nil && *a.InitialModel != "" {
+		splits := strings.Split(*a.InitialModel, "/")
 		for _, provider := range providers {
-			if provider.ID == a.State.Provider {
-				currentProvider = &provider
-
+			if provider.ID == splits[0] {
+				initialProvider = &provider
 				for _, model := range provider.Models {
-					if model.ID == a.State.Model {
-						currentModel = &model
+					if model.ID == splits[1] {
+						initialModel = &model
 					}
 				}
 			}
 		}
-		if currentProvider == nil || currentModel == nil {
-			currentProvider = defaultProvider
-			currentModel = defaultModel
-		}
+	}
 
-		return ModelSelectedMsg{
-			Provider: *currentProvider,
-			Model:    *currentModel,
-		}
+	if initialProvider != nil && initialModel != nil {
+		currentProvider = initialProvider
+		currentModel = initialModel
+	}
+
+	var cmds []tea.Cmd
+	cmds = append(cmds, util.CmdHandler(ModelSelectedMsg{
+		Provider: *currentProvider,
+		Model:    *currentModel,
+	}))
+	if a.InitialPrompt != nil && *a.InitialPrompt != "" {
+		cmds = append(cmds, util.CmdHandler(SendMsg{Text: *a.InitialPrompt}))
 	}
+	return tea.Sequence(cmds...)
 }
 
 func getDefaultModel(

+ 12 - 16
packages/tui/internal/components/chat/editor.go

@@ -64,7 +64,7 @@ func (m *editorComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 			return m, tea.Batch(cmds...)
 		}
 	case dialog.ThemeSelectedMsg:
-		m.textarea = createTextArea(&m.textarea)
+		m.textarea = m.resetTextareaStyles()
 		m.spinner = createSpinner()
 		return m, tea.Batch(m.spinner.Tick, m.textarea.Focus())
 	case dialog.CompletionSelectedMsg:
@@ -306,13 +306,13 @@ func (m *editorComponent) getSubmitKeyText() string {
 	return m.app.Commands[commands.InputSubmitCommand].Keys()[0]
 }
 
-func createTextArea(existing *textarea.Model) textarea.Model {
+func (m *editorComponent) resetTextareaStyles() textarea.Model {
 	t := theme.CurrentTheme()
 	bgColor := t.BackgroundElement()
 	textColor := t.Text()
 	textMutedColor := t.TextMuted()
 
-	ta := textarea.New()
+	ta := m.textarea
 
 	ta.Styles.Blurred.Base = styles.NewStyle().Foreground(textColor).Background(bgColor).Lipgloss()
 	ta.Styles.Blurred.CursorLine = styles.NewStyle().Background(bgColor).Lipgloss()
@@ -337,17 +337,6 @@ func createTextArea(existing *textarea.Model) textarea.Model {
 		Background(t.Secondary()).
 		Lipgloss()
 	ta.Styles.Cursor.Color = t.Primary()
-
-	ta.Prompt = " "
-	ta.ShowLineNumbers = false
-	ta.CharLimit = -1
-
-	if existing != nil {
-		ta.SetValue(existing.Value())
-		// ta.SetWidth(existing.Width())
-		ta.SetHeight(existing.Height())
-	}
-
 	return ta
 }
 
@@ -367,12 +356,19 @@ func createSpinner() spinner.Model {
 
 func NewEditorComponent(app *app.App) EditorComponent {
 	s := createSpinner()
-	ta := createTextArea(nil)
 
-	return &editorComponent{
+	ta := textarea.New()
+	ta.Prompt = " "
+	ta.ShowLineNumbers = false
+	ta.CharLimit = -1
+
+	m := &editorComponent{
 		app:                    app,
 		textarea:               ta,
 		spinner:                s,
 		interruptKeyInDebounce: false,
 	}
+	m.resetTextareaStyles()
+
+	return m
 }