models.go 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. package dialog
  2. import (
  3. "context"
  4. "fmt"
  5. "maps"
  6. "slices"
  7. "strings"
  8. "github.com/charmbracelet/bubbles/v2/key"
  9. tea "github.com/charmbracelet/bubbletea/v2"
  10. "github.com/charmbracelet/lipgloss/v2"
  11. "github.com/sst/opencode/internal/app"
  12. "github.com/sst/opencode/internal/components/modal"
  13. "github.com/sst/opencode/internal/layout"
  14. "github.com/sst/opencode/internal/styles"
  15. "github.com/sst/opencode/internal/theme"
  16. "github.com/sst/opencode/internal/util"
  17. "github.com/sst/opencode/pkg/client"
  18. )
  19. const (
  20. numVisibleModels = 6
  21. maxDialogWidth = 40
  22. )
  23. // ModelDialog interface for the model selection dialog
  24. type ModelDialog interface {
  25. layout.Modal
  26. }
  27. type modelDialog struct {
  28. app *app.App
  29. availableProviders []client.ProviderInfo
  30. provider client.ProviderInfo
  31. selectedIdx int
  32. width int
  33. height int
  34. scrollOffset int
  35. hScrollOffset int
  36. hScrollPossible bool
  37. modal *modal.Modal
  38. }
  39. type modelKeyMap struct {
  40. Up key.Binding
  41. Down key.Binding
  42. Left key.Binding
  43. Right key.Binding
  44. Enter key.Binding
  45. Escape key.Binding
  46. }
  47. var modelKeys = modelKeyMap{
  48. Up: key.NewBinding(
  49. key.WithKeys("up", "k"),
  50. key.WithHelp("↑", "previous model"),
  51. ),
  52. Down: key.NewBinding(
  53. key.WithKeys("down", "j"),
  54. key.WithHelp("↓", "next model"),
  55. ),
  56. Left: key.NewBinding(
  57. key.WithKeys("left", "h"),
  58. key.WithHelp("←", "scroll left"),
  59. ),
  60. Right: key.NewBinding(
  61. key.WithKeys("right", "l"),
  62. key.WithHelp("→", "scroll right"),
  63. ),
  64. Enter: key.NewBinding(
  65. key.WithKeys("enter"),
  66. key.WithHelp("enter", "select model"),
  67. ),
  68. Escape: key.NewBinding(
  69. key.WithKeys("esc"),
  70. key.WithHelp("esc", "close"),
  71. ),
  72. }
  73. func (m *modelDialog) Init() tea.Cmd {
  74. // cfg := config.Get()
  75. // modelInfo := GetSelectedModel(cfg)
  76. // m.availableProviders = getEnabledProviders(cfg)
  77. // m.hScrollPossible = len(m.availableProviders) > 1
  78. // m.provider = modelInfo.Provider
  79. // m.hScrollOffset = findProviderIndex(m.availableProviders, m.provider)
  80. // m.setupModelsForProvider(m.provider)
  81. return nil
  82. }
  83. func (m *modelDialog) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
  84. switch msg := msg.(type) {
  85. case tea.KeyMsg:
  86. switch {
  87. case key.Matches(msg, modelKeys.Up):
  88. m.moveSelectionUp()
  89. case key.Matches(msg, modelKeys.Down):
  90. m.moveSelectionDown()
  91. case key.Matches(msg, modelKeys.Left):
  92. if m.hScrollPossible {
  93. m.switchProvider(-1)
  94. }
  95. case key.Matches(msg, modelKeys.Right):
  96. if m.hScrollPossible {
  97. m.switchProvider(1)
  98. }
  99. case key.Matches(msg, modelKeys.Enter):
  100. models := m.models()
  101. return m, tea.Sequence(
  102. util.CmdHandler(modal.CloseModalMsg{}),
  103. util.CmdHandler(
  104. app.ModelSelectedMsg{
  105. Provider: m.provider,
  106. Model: models[m.selectedIdx],
  107. }),
  108. )
  109. case key.Matches(msg, modelKeys.Escape):
  110. return m, util.CmdHandler(modal.CloseModalMsg{})
  111. }
  112. case tea.WindowSizeMsg:
  113. m.width = msg.Width
  114. m.height = msg.Height
  115. }
  116. return m, nil
  117. }
  118. func (m *modelDialog) models() []client.ModelInfo {
  119. models := slices.SortedFunc(maps.Values(m.provider.Models), func(a, b client.ModelInfo) int {
  120. return strings.Compare(a.Name, b.Name)
  121. })
  122. return models
  123. }
  124. // moveSelectionUp moves the selection up or wraps to bottom
  125. func (m *modelDialog) moveSelectionUp() {
  126. if m.selectedIdx > 0 {
  127. m.selectedIdx--
  128. } else {
  129. m.selectedIdx = len(m.provider.Models) - 1
  130. m.scrollOffset = max(0, len(m.provider.Models)-numVisibleModels)
  131. }
  132. // Keep selection visible
  133. if m.selectedIdx < m.scrollOffset {
  134. m.scrollOffset = m.selectedIdx
  135. }
  136. }
  137. // moveSelectionDown moves the selection down or wraps to top
  138. func (m *modelDialog) moveSelectionDown() {
  139. if m.selectedIdx < len(m.provider.Models)-1 {
  140. m.selectedIdx++
  141. } else {
  142. m.selectedIdx = 0
  143. m.scrollOffset = 0
  144. }
  145. // Keep selection visible
  146. if m.selectedIdx >= m.scrollOffset+numVisibleModels {
  147. m.scrollOffset = m.selectedIdx - (numVisibleModels - 1)
  148. }
  149. }
  150. func (m *modelDialog) switchProvider(offset int) {
  151. newOffset := m.hScrollOffset + offset
  152. // Ensure we stay within bounds
  153. if newOffset < 0 {
  154. newOffset = len(m.availableProviders) - 1
  155. }
  156. if newOffset >= len(m.availableProviders) {
  157. newOffset = 0
  158. }
  159. m.hScrollOffset = newOffset
  160. m.provider = m.availableProviders[m.hScrollOffset]
  161. m.modal.SetTitle(fmt.Sprintf("Select %s Model", m.provider.Name))
  162. m.setupModelsForProvider(m.provider.Id)
  163. }
  164. func (m *modelDialog) View() string {
  165. t := theme.CurrentTheme()
  166. baseStyle := lipgloss.NewStyle().
  167. Background(t.BackgroundElement()).
  168. Foreground(t.Text())
  169. // Render visible models
  170. endIdx := min(m.scrollOffset+numVisibleModels, len(m.provider.Models))
  171. modelItems := make([]string, 0, endIdx-m.scrollOffset)
  172. models := m.models()
  173. for i := m.scrollOffset; i < endIdx; i++ {
  174. itemStyle := baseStyle.Width(maxDialogWidth)
  175. if i == m.selectedIdx {
  176. itemStyle = itemStyle.
  177. Background(t.Primary()).
  178. Foreground(t.BackgroundElement()).
  179. Bold(true)
  180. }
  181. modelItems = append(modelItems, itemStyle.Render(models[i].Name))
  182. }
  183. scrollIndicator := m.getScrollIndicators(maxDialogWidth)
  184. content := lipgloss.JoinVertical(
  185. lipgloss.Left,
  186. baseStyle.
  187. Width(maxDialogWidth).
  188. Render(lipgloss.JoinVertical(lipgloss.Left, modelItems...)),
  189. scrollIndicator,
  190. )
  191. return content
  192. }
  193. func (m *modelDialog) getScrollIndicators(maxWidth int) string {
  194. var indicator string
  195. if len(m.provider.Models) > numVisibleModels {
  196. if m.scrollOffset > 0 {
  197. indicator += "↑ "
  198. }
  199. if m.scrollOffset+numVisibleModels < len(m.provider.Models) {
  200. indicator += "↓ "
  201. }
  202. }
  203. if m.hScrollPossible {
  204. indicator = "← " + indicator + "→"
  205. }
  206. if indicator == "" {
  207. return ""
  208. }
  209. t := theme.CurrentTheme()
  210. baseStyle := styles.BaseStyle()
  211. return baseStyle.
  212. Foreground(t.Primary()).
  213. Width(maxWidth).
  214. Align(lipgloss.Right).
  215. Bold(true).
  216. Render(indicator)
  217. }
  218. // findProviderIndex returns the index of the provider in the list, or -1 if not found
  219. // func findProviderIndex(providers []string, provider string) int {
  220. // for i, p := range providers {
  221. // if p == provider {
  222. // return i
  223. // }
  224. // }
  225. // return -1
  226. // }
  227. func (m *modelDialog) setupModelsForProvider(_ string) {
  228. m.selectedIdx = 0
  229. m.scrollOffset = 0
  230. // cfg := config.Get()
  231. // agentCfg := cfg.Agents[config.AgentPrimary]
  232. // selectedModelId := agentCfg.Model
  233. // m.provider = provider
  234. // m.models = getModelsForProvider(provider)
  235. // Try to select the current model if it belongs to this provider
  236. // if provider == models.SupportedModels[selectedModelId].Provider {
  237. // for i, model := range m.models {
  238. // if model.ID == selectedModelId {
  239. // m.selectedIdx = i
  240. // // Adjust scroll position to keep selected model visible
  241. // if m.selectedIdx >= numVisibleModels {
  242. // m.scrollOffset = m.selectedIdx - (numVisibleModels - 1)
  243. // }
  244. // break
  245. // }
  246. // }
  247. // }
  248. }
  249. func (m *modelDialog) Render(background string) string {
  250. return m.modal.Render(m.View(), background)
  251. }
  252. func (s *modelDialog) Close() tea.Cmd {
  253. return nil
  254. }
  255. func NewModelDialog(app *app.App) ModelDialog {
  256. availableProviders, _ := app.ListProviders(context.Background())
  257. return &modelDialog{
  258. availableProviders: availableProviders,
  259. hScrollOffset: 0,
  260. hScrollPossible: len(availableProviders) > 1,
  261. provider: availableProviders[0],
  262. modal: modal.New(modal.WithTitle(fmt.Sprintf("Select %s Model", availableProviders[0].Name))),
  263. }
  264. }