models.go 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. package dialog
  2. import (
  3. "fmt"
  4. "slices"
  5. "strings"
  6. "github.com/charmbracelet/bubbles/key"
  7. tea "github.com/charmbracelet/bubbletea"
  8. "github.com/charmbracelet/lipgloss"
  9. "github.com/sst/opencode/internal/config"
  10. "github.com/sst/opencode/internal/llm/models"
  11. "github.com/sst/opencode/internal/tui/layout"
  12. "github.com/sst/opencode/internal/tui/styles"
  13. "github.com/sst/opencode/internal/tui/theme"
  14. "github.com/sst/opencode/internal/tui/util"
  15. )
  16. const (
  17. numVisibleModels = 10
  18. maxDialogWidth = 40
  19. )
  20. // ModelSelectedMsg is sent when a model is selected
  21. type ModelSelectedMsg struct {
  22. Model models.Model
  23. }
  24. // CloseModelDialogMsg is sent when a model is selected
  25. type CloseModelDialogMsg struct{}
  26. // ModelDialog interface for the model selection dialog
  27. type ModelDialog interface {
  28. tea.Model
  29. layout.Bindings
  30. }
  31. type modelDialogCmp struct {
  32. models []models.Model
  33. provider models.ModelProvider
  34. availableProviders []models.ModelProvider
  35. selectedIdx int
  36. width int
  37. height int
  38. scrollOffset int
  39. hScrollOffset int
  40. hScrollPossible bool
  41. }
  42. type modelKeyMap struct {
  43. Up key.Binding
  44. Down key.Binding
  45. Left key.Binding
  46. Right key.Binding
  47. Enter key.Binding
  48. Escape key.Binding
  49. J key.Binding
  50. K key.Binding
  51. H key.Binding
  52. L key.Binding
  53. }
  54. var modelKeys = modelKeyMap{
  55. Up: key.NewBinding(
  56. key.WithKeys("up"),
  57. key.WithHelp("↑", "previous model"),
  58. ),
  59. Down: key.NewBinding(
  60. key.WithKeys("down"),
  61. key.WithHelp("↓", "next model"),
  62. ),
  63. Left: key.NewBinding(
  64. key.WithKeys("left"),
  65. key.WithHelp("←", "scroll left"),
  66. ),
  67. Right: key.NewBinding(
  68. key.WithKeys("right"),
  69. key.WithHelp("→", "scroll right"),
  70. ),
  71. Enter: key.NewBinding(
  72. key.WithKeys("enter"),
  73. key.WithHelp("enter", "select model"),
  74. ),
  75. Escape: key.NewBinding(
  76. key.WithKeys("esc"),
  77. key.WithHelp("esc", "close"),
  78. ),
  79. J: key.NewBinding(
  80. key.WithKeys("j"),
  81. key.WithHelp("j", "next model"),
  82. ),
  83. K: key.NewBinding(
  84. key.WithKeys("k"),
  85. key.WithHelp("k", "previous model"),
  86. ),
  87. H: key.NewBinding(
  88. key.WithKeys("h"),
  89. key.WithHelp("h", "scroll left"),
  90. ),
  91. L: key.NewBinding(
  92. key.WithKeys("l"),
  93. key.WithHelp("l", "scroll right"),
  94. ),
  95. }
  96. func (m *modelDialogCmp) Init() tea.Cmd {
  97. m.setupModels()
  98. return nil
  99. }
  100. func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
  101. switch msg := msg.(type) {
  102. case tea.KeyMsg:
  103. switch {
  104. case key.Matches(msg, modelKeys.Up) || key.Matches(msg, modelKeys.K):
  105. m.moveSelectionUp()
  106. case key.Matches(msg, modelKeys.Down) || key.Matches(msg, modelKeys.J):
  107. m.moveSelectionDown()
  108. case key.Matches(msg, modelKeys.Left) || key.Matches(msg, modelKeys.H):
  109. if m.hScrollPossible {
  110. m.switchProvider(-1)
  111. }
  112. case key.Matches(msg, modelKeys.Right) || key.Matches(msg, modelKeys.L):
  113. if m.hScrollPossible {
  114. m.switchProvider(1)
  115. }
  116. case key.Matches(msg, modelKeys.Enter):
  117. return m, util.CmdHandler(ModelSelectedMsg{Model: m.models[m.selectedIdx]})
  118. case key.Matches(msg, modelKeys.Escape):
  119. return m, util.CmdHandler(CloseModelDialogMsg{})
  120. }
  121. case tea.WindowSizeMsg:
  122. m.width = msg.Width
  123. m.height = msg.Height
  124. }
  125. return m, nil
  126. }
  127. // moveSelectionUp moves the selection up or wraps to bottom
  128. func (m *modelDialogCmp) moveSelectionUp() {
  129. if m.selectedIdx > 0 {
  130. m.selectedIdx--
  131. } else {
  132. m.selectedIdx = len(m.models) - 1
  133. m.scrollOffset = max(0, len(m.models)-numVisibleModels)
  134. }
  135. // Keep selection visible
  136. if m.selectedIdx < m.scrollOffset {
  137. m.scrollOffset = m.selectedIdx
  138. }
  139. }
  140. // moveSelectionDown moves the selection down or wraps to top
  141. func (m *modelDialogCmp) moveSelectionDown() {
  142. if m.selectedIdx < len(m.models)-1 {
  143. m.selectedIdx++
  144. } else {
  145. m.selectedIdx = 0
  146. m.scrollOffset = 0
  147. }
  148. // Keep selection visible
  149. if m.selectedIdx >= m.scrollOffset+numVisibleModels {
  150. m.scrollOffset = m.selectedIdx - (numVisibleModels - 1)
  151. }
  152. }
  153. func (m *modelDialogCmp) switchProvider(offset int) {
  154. newOffset := m.hScrollOffset + offset
  155. // Ensure we stay within bounds
  156. if newOffset < 0 {
  157. newOffset = len(m.availableProviders) - 1
  158. }
  159. if newOffset >= len(m.availableProviders) {
  160. newOffset = 0
  161. }
  162. m.hScrollOffset = newOffset
  163. m.provider = m.availableProviders[m.hScrollOffset]
  164. m.setupModelsForProvider(m.provider)
  165. }
  166. func (m *modelDialogCmp) View() string {
  167. t := theme.CurrentTheme()
  168. baseStyle := styles.BaseStyle()
  169. // Capitalize first letter of provider name
  170. providerName := strings.ToUpper(string(m.provider)[:1]) + string(m.provider[1:])
  171. title := baseStyle.
  172. Foreground(t.Primary()).
  173. Bold(true).
  174. Width(maxDialogWidth).
  175. Padding(0, 0, 1).
  176. Render(fmt.Sprintf("Select %s Model", providerName))
  177. // Render visible models
  178. endIdx := min(m.scrollOffset+numVisibleModels, len(m.models))
  179. modelItems := make([]string, 0, endIdx-m.scrollOffset)
  180. for i := m.scrollOffset; i < endIdx; i++ {
  181. itemStyle := baseStyle.Width(maxDialogWidth)
  182. if i == m.selectedIdx {
  183. itemStyle = itemStyle.Background(t.Primary()).
  184. Foreground(t.Background()).Bold(true)
  185. }
  186. modelItems = append(modelItems, itemStyle.Render(m.models[i].Name))
  187. }
  188. scrollIndicator := m.getScrollIndicators(maxDialogWidth)
  189. content := lipgloss.JoinVertical(
  190. lipgloss.Left,
  191. title,
  192. baseStyle.Width(maxDialogWidth).Render(lipgloss.JoinVertical(lipgloss.Left, modelItems...)),
  193. scrollIndicator,
  194. )
  195. return baseStyle.Padding(1, 2).
  196. Border(lipgloss.RoundedBorder()).
  197. BorderBackground(t.Background()).
  198. BorderForeground(t.TextMuted()).
  199. Width(lipgloss.Width(content) + 4).
  200. Render(content)
  201. }
  202. func (m *modelDialogCmp) getScrollIndicators(maxWidth int) string {
  203. var indicator string
  204. if len(m.models) > numVisibleModels {
  205. if m.scrollOffset > 0 {
  206. indicator += "↑ "
  207. }
  208. if m.scrollOffset+numVisibleModels < len(m.models) {
  209. indicator += "↓ "
  210. }
  211. }
  212. if m.hScrollPossible {
  213. if m.hScrollOffset > 0 {
  214. indicator = "← " + indicator
  215. }
  216. if m.hScrollOffset < len(m.availableProviders)-1 {
  217. indicator += "→"
  218. }
  219. }
  220. if indicator == "" {
  221. return ""
  222. }
  223. t := theme.CurrentTheme()
  224. baseStyle := styles.BaseStyle()
  225. return baseStyle.
  226. Foreground(t.Primary()).
  227. Width(maxWidth).
  228. Align(lipgloss.Right).
  229. Bold(true).
  230. Render(indicator)
  231. }
  232. func (m *modelDialogCmp) BindingKeys() []key.Binding {
  233. return layout.KeyMapToSlice(modelKeys)
  234. }
  235. func (m *modelDialogCmp) setupModels() {
  236. cfg := config.Get()
  237. modelInfo := GetSelectedModel(cfg)
  238. m.availableProviders = getEnabledProviders(cfg)
  239. m.hScrollPossible = len(m.availableProviders) > 1
  240. m.provider = modelInfo.Provider
  241. m.hScrollOffset = findProviderIndex(m.availableProviders, m.provider)
  242. m.setupModelsForProvider(m.provider)
  243. }
  244. func GetSelectedModel(cfg *config.Config) models.Model {
  245. agentCfg := cfg.Agents[config.AgentPrimary]
  246. selectedModelId := agentCfg.Model
  247. return models.SupportedModels[selectedModelId]
  248. }
  249. func getEnabledProviders(cfg *config.Config) []models.ModelProvider {
  250. var providers []models.ModelProvider
  251. for providerId, provider := range cfg.Providers {
  252. if !provider.Disabled {
  253. providers = append(providers, providerId)
  254. }
  255. }
  256. // Sort by provider popularity
  257. slices.SortFunc(providers, func(a, b models.ModelProvider) int {
  258. rA := models.ProviderPopularity[a]
  259. rB := models.ProviderPopularity[b]
  260. // models not included in popularity ranking default to last
  261. if rA == 0 {
  262. rA = 999
  263. }
  264. if rB == 0 {
  265. rB = 999
  266. }
  267. return rA - rB
  268. })
  269. return providers
  270. }
  271. // findProviderIndex returns the index of the provider in the list, or -1 if not found
  272. func findProviderIndex(providers []models.ModelProvider, provider models.ModelProvider) int {
  273. for i, p := range providers {
  274. if p == provider {
  275. return i
  276. }
  277. }
  278. return -1
  279. }
  280. func (m *modelDialogCmp) setupModelsForProvider(provider models.ModelProvider) {
  281. cfg := config.Get()
  282. agentCfg := cfg.Agents[config.AgentPrimary]
  283. selectedModelId := agentCfg.Model
  284. m.provider = provider
  285. m.models = getModelsForProvider(provider)
  286. m.selectedIdx = 0
  287. m.scrollOffset = 0
  288. // Try to select the current model if it belongs to this provider
  289. if provider == models.SupportedModels[selectedModelId].Provider {
  290. for i, model := range m.models {
  291. if model.ID == selectedModelId {
  292. m.selectedIdx = i
  293. // Adjust scroll position to keep selected model visible
  294. if m.selectedIdx >= numVisibleModels {
  295. m.scrollOffset = m.selectedIdx - (numVisibleModels - 1)
  296. }
  297. break
  298. }
  299. }
  300. }
  301. }
  302. func getModelsForProvider(provider models.ModelProvider) []models.Model {
  303. var providerModels []models.Model
  304. for _, model := range models.SupportedModels {
  305. if model.Provider == provider {
  306. providerModels = append(providerModels, model)
  307. }
  308. }
  309. // reverse alphabetical order (if llm naming was consistent latest would appear first)
  310. slices.SortFunc(providerModels, func(a, b models.Model) int {
  311. if a.Name > b.Name {
  312. return -1
  313. } else if a.Name < b.Name {
  314. return 1
  315. }
  316. return 0
  317. })
  318. return providerModels
  319. }
  320. func NewModelDialogCmp() ModelDialog {
  321. return &modelDialogCmp{}
  322. }