| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307 |
- package dialog
- import (
- "context"
- "fmt"
- "maps"
- "slices"
- "strings"
- "github.com/charmbracelet/bubbles/v2/key"
- tea "github.com/charmbracelet/bubbletea/v2"
- "github.com/charmbracelet/lipgloss/v2"
- "github.com/sst/opencode/internal/app"
- "github.com/sst/opencode/internal/components/modal"
- "github.com/sst/opencode/internal/layout"
- "github.com/sst/opencode/internal/styles"
- "github.com/sst/opencode/internal/theme"
- "github.com/sst/opencode/internal/util"
- "github.com/sst/opencode/pkg/client"
- )
- const (
- numVisibleModels = 6
- maxDialogWidth = 40
- )
- // ModelDialog interface for the model selection dialog
- type ModelDialog interface {
- layout.Modal
- }
- type modelDialog struct {
- app *app.App
- availableProviders []client.ProviderInfo
- provider client.ProviderInfo
- selectedIdx int
- width int
- height int
- scrollOffset int
- hScrollOffset int
- hScrollPossible bool
- modal *modal.Modal
- }
- type modelKeyMap struct {
- Up key.Binding
- Down key.Binding
- Left key.Binding
- Right key.Binding
- Enter key.Binding
- Escape key.Binding
- }
- var modelKeys = modelKeyMap{
- Up: key.NewBinding(
- key.WithKeys("up", "k"),
- key.WithHelp("↑", "previous model"),
- ),
- Down: key.NewBinding(
- key.WithKeys("down", "j"),
- key.WithHelp("↓", "next model"),
- ),
- Left: key.NewBinding(
- key.WithKeys("left", "h"),
- key.WithHelp("←", "scroll left"),
- ),
- Right: key.NewBinding(
- key.WithKeys("right", "l"),
- key.WithHelp("→", "scroll right"),
- ),
- Enter: key.NewBinding(
- key.WithKeys("enter"),
- key.WithHelp("enter", "select model"),
- ),
- Escape: key.NewBinding(
- key.WithKeys("esc"),
- key.WithHelp("esc", "close"),
- ),
- }
- func (m *modelDialog) Init() tea.Cmd {
- // cfg := config.Get()
- // modelInfo := GetSelectedModel(cfg)
- // m.availableProviders = getEnabledProviders(cfg)
- // m.hScrollPossible = len(m.availableProviders) > 1
- // m.provider = modelInfo.Provider
- // m.hScrollOffset = findProviderIndex(m.availableProviders, m.provider)
- // m.setupModelsForProvider(m.provider)
- return nil
- }
- func (m *modelDialog) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
- switch msg := msg.(type) {
- case tea.KeyMsg:
- switch {
- case key.Matches(msg, modelKeys.Up):
- m.moveSelectionUp()
- case key.Matches(msg, modelKeys.Down):
- m.moveSelectionDown()
- case key.Matches(msg, modelKeys.Left):
- if m.hScrollPossible {
- m.switchProvider(-1)
- }
- case key.Matches(msg, modelKeys.Right):
- if m.hScrollPossible {
- m.switchProvider(1)
- }
- case key.Matches(msg, modelKeys.Enter):
- models := m.models()
- return m, tea.Sequence(
- util.CmdHandler(modal.CloseModalMsg{}),
- util.CmdHandler(
- app.ModelSelectedMsg{
- Provider: m.provider,
- Model: models[m.selectedIdx],
- }),
- )
- case key.Matches(msg, modelKeys.Escape):
- return m, util.CmdHandler(modal.CloseModalMsg{})
- }
- case tea.WindowSizeMsg:
- m.width = msg.Width
- m.height = msg.Height
- }
- return m, nil
- }
- func (m *modelDialog) models() []client.ModelInfo {
- models := slices.SortedFunc(maps.Values(m.provider.Models), func(a, b client.ModelInfo) int {
- return strings.Compare(a.Name, b.Name)
- })
- return models
- }
- // moveSelectionUp moves the selection up or wraps to bottom
- func (m *modelDialog) moveSelectionUp() {
- if m.selectedIdx > 0 {
- m.selectedIdx--
- } else {
- m.selectedIdx = len(m.provider.Models) - 1
- m.scrollOffset = max(0, len(m.provider.Models)-numVisibleModels)
- }
- // Keep selection visible
- if m.selectedIdx < m.scrollOffset {
- m.scrollOffset = m.selectedIdx
- }
- }
- // moveSelectionDown moves the selection down or wraps to top
- func (m *modelDialog) moveSelectionDown() {
- if m.selectedIdx < len(m.provider.Models)-1 {
- m.selectedIdx++
- } else {
- m.selectedIdx = 0
- m.scrollOffset = 0
- }
- // Keep selection visible
- if m.selectedIdx >= m.scrollOffset+numVisibleModels {
- m.scrollOffset = m.selectedIdx - (numVisibleModels - 1)
- }
- }
- func (m *modelDialog) switchProvider(offset int) {
- newOffset := m.hScrollOffset + offset
- // Ensure we stay within bounds
- if newOffset < 0 {
- newOffset = len(m.availableProviders) - 1
- }
- if newOffset >= len(m.availableProviders) {
- newOffset = 0
- }
- m.hScrollOffset = newOffset
- m.provider = m.availableProviders[m.hScrollOffset]
- m.modal.SetTitle(fmt.Sprintf("Select %s Model", m.provider.Name))
- m.setupModelsForProvider(m.provider.Id)
- }
- func (m *modelDialog) View() string {
- t := theme.CurrentTheme()
- baseStyle := lipgloss.NewStyle().
- Background(t.BackgroundElement()).
- Foreground(t.Text())
- // Render visible models
- endIdx := min(m.scrollOffset+numVisibleModels, len(m.provider.Models))
- modelItems := make([]string, 0, endIdx-m.scrollOffset)
- models := m.models()
- for i := m.scrollOffset; i < endIdx; i++ {
- itemStyle := baseStyle.Width(maxDialogWidth)
- if i == m.selectedIdx {
- itemStyle = itemStyle.
- Background(t.Primary()).
- Foreground(t.BackgroundElement()).
- Bold(true)
- }
- modelItems = append(modelItems, itemStyle.Render(models[i].Name))
- }
- scrollIndicator := m.getScrollIndicators(maxDialogWidth)
- content := lipgloss.JoinVertical(
- lipgloss.Left,
- baseStyle.
- Width(maxDialogWidth).
- Render(lipgloss.JoinVertical(lipgloss.Left, modelItems...)),
- scrollIndicator,
- )
- return content
- }
- func (m *modelDialog) getScrollIndicators(maxWidth int) string {
- var indicator string
- if len(m.provider.Models) > numVisibleModels {
- if m.scrollOffset > 0 {
- indicator += "↑ "
- }
- if m.scrollOffset+numVisibleModels < len(m.provider.Models) {
- indicator += "↓ "
- }
- }
- if m.hScrollPossible {
- indicator = "← " + indicator + "→"
- }
- if indicator == "" {
- return ""
- }
- t := theme.CurrentTheme()
- baseStyle := styles.BaseStyle()
- return baseStyle.
- Foreground(t.Primary()).
- Width(maxWidth).
- Align(lipgloss.Right).
- Bold(true).
- Render(indicator)
- }
- // findProviderIndex returns the index of the provider in the list, or -1 if not found
- // func findProviderIndex(providers []string, provider string) int {
- // for i, p := range providers {
- // if p == provider {
- // return i
- // }
- // }
- // return -1
- // }
- func (m *modelDialog) setupModelsForProvider(_ string) {
- m.selectedIdx = 0
- m.scrollOffset = 0
- // cfg := config.Get()
- // agentCfg := cfg.Agents[config.AgentPrimary]
- // selectedModelId := agentCfg.Model
- // m.provider = provider
- // m.models = getModelsForProvider(provider)
- // Try to select the current model if it belongs to this provider
- // if provider == models.SupportedModels[selectedModelId].Provider {
- // for i, model := range m.models {
- // if model.ID == selectedModelId {
- // m.selectedIdx = i
- // // Adjust scroll position to keep selected model visible
- // if m.selectedIdx >= numVisibleModels {
- // m.scrollOffset = m.selectedIdx - (numVisibleModels - 1)
- // }
- // break
- // }
- // }
- // }
- }
- func (m *modelDialog) Render(background string) string {
- return m.modal.Render(m.View(), background)
- }
- func (s *modelDialog) Close() tea.Cmd {
- return nil
- }
- func NewModelDialog(app *app.App) ModelDialog {
- availableProviders, _ := app.ListProviders(context.Background())
- return &modelDialog{
- availableProviders: availableProviders,
- hScrollOffset: 0,
- hScrollPossible: len(availableProviders) > 1,
- provider: availableProviders[0],
- modal: modal.New(modal.WithTitle(fmt.Sprintf("Select %s Model", availableProviders[0].Name))),
- }
- }
|