models.go 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. package dialog
  2. import (
  3. "fmt"
  4. "slices"
  5. "strings"
  6. "github.com/charmbracelet/bubbles/key"
  7. tea "github.com/charmbracelet/bubbletea"
  8. "github.com/charmbracelet/lipgloss"
  9. "github.com/sst/opencode/internal/config"
  10. "github.com/sst/opencode/internal/llm/models"
  11. "github.com/sst/opencode/internal/status"
  12. "github.com/sst/opencode/internal/tui/layout"
  13. "github.com/sst/opencode/internal/tui/styles"
  14. "github.com/sst/opencode/internal/tui/theme"
  15. "github.com/sst/opencode/internal/tui/util"
  16. )
  17. const (
  18. numVisibleModels = 10
  19. maxDialogWidth = 40
  20. )
  21. // ModelSelectedMsg is sent when a model is selected
  22. type ModelSelectedMsg struct {
  23. Model models.Model
  24. }
  25. // CloseModelDialogMsg is sent when a model is selected
  26. type CloseModelDialogMsg struct{}
  27. // ModelDialog interface for the model selection dialog
  28. type ModelDialog interface {
  29. tea.Model
  30. layout.Bindings
  31. }
  32. type modelDialogCmp struct {
  33. models []models.Model
  34. provider models.ModelProvider
  35. availableProviders []models.ModelProvider
  36. selectedIdx int
  37. width int
  38. height int
  39. scrollOffset int
  40. hScrollOffset int
  41. hScrollPossible bool
  42. }
  43. type modelKeyMap struct {
  44. Up key.Binding
  45. Down key.Binding
  46. Left key.Binding
  47. Right key.Binding
  48. Enter key.Binding
  49. Escape key.Binding
  50. J key.Binding
  51. K key.Binding
  52. H key.Binding
  53. L key.Binding
  54. }
  55. var modelKeys = modelKeyMap{
  56. Up: key.NewBinding(
  57. key.WithKeys("up"),
  58. key.WithHelp("↑", "previous model"),
  59. ),
  60. Down: key.NewBinding(
  61. key.WithKeys("down"),
  62. key.WithHelp("↓", "next model"),
  63. ),
  64. Left: key.NewBinding(
  65. key.WithKeys("left"),
  66. key.WithHelp("←", "scroll left"),
  67. ),
  68. Right: key.NewBinding(
  69. key.WithKeys("right"),
  70. key.WithHelp("→", "scroll right"),
  71. ),
  72. Enter: key.NewBinding(
  73. key.WithKeys("enter"),
  74. key.WithHelp("enter", "select model"),
  75. ),
  76. Escape: key.NewBinding(
  77. key.WithKeys("esc"),
  78. key.WithHelp("esc", "close"),
  79. ),
  80. J: key.NewBinding(
  81. key.WithKeys("j"),
  82. key.WithHelp("j", "next model"),
  83. ),
  84. K: key.NewBinding(
  85. key.WithKeys("k"),
  86. key.WithHelp("k", "previous model"),
  87. ),
  88. H: key.NewBinding(
  89. key.WithKeys("h"),
  90. key.WithHelp("h", "scroll left"),
  91. ),
  92. L: key.NewBinding(
  93. key.WithKeys("l"),
  94. key.WithHelp("l", "scroll right"),
  95. ),
  96. }
  97. func (m *modelDialogCmp) Init() tea.Cmd {
  98. m.setupModels()
  99. return nil
  100. }
  101. func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
  102. switch msg := msg.(type) {
  103. case tea.KeyMsg:
  104. switch {
  105. case key.Matches(msg, modelKeys.Up) || key.Matches(msg, modelKeys.K):
  106. m.moveSelectionUp()
  107. case key.Matches(msg, modelKeys.Down) || key.Matches(msg, modelKeys.J):
  108. m.moveSelectionDown()
  109. case key.Matches(msg, modelKeys.Left) || key.Matches(msg, modelKeys.H):
  110. if m.hScrollPossible {
  111. m.switchProvider(-1)
  112. }
  113. case key.Matches(msg, modelKeys.Right) || key.Matches(msg, modelKeys.L):
  114. if m.hScrollPossible {
  115. m.switchProvider(1)
  116. }
  117. case key.Matches(msg, modelKeys.Enter):
  118. status.Info(fmt.Sprintf("selected model: %s", m.models[m.selectedIdx].Name))
  119. return m, util.CmdHandler(ModelSelectedMsg{Model: m.models[m.selectedIdx]})
  120. case key.Matches(msg, modelKeys.Escape):
  121. return m, util.CmdHandler(CloseModelDialogMsg{})
  122. }
  123. case tea.WindowSizeMsg:
  124. m.width = msg.Width
  125. m.height = msg.Height
  126. }
  127. return m, nil
  128. }
  129. // moveSelectionUp moves the selection up or wraps to bottom
  130. func (m *modelDialogCmp) moveSelectionUp() {
  131. if m.selectedIdx > 0 {
  132. m.selectedIdx--
  133. } else {
  134. m.selectedIdx = len(m.models) - 1
  135. m.scrollOffset = max(0, len(m.models)-numVisibleModels)
  136. }
  137. // Keep selection visible
  138. if m.selectedIdx < m.scrollOffset {
  139. m.scrollOffset = m.selectedIdx
  140. }
  141. }
  142. // moveSelectionDown moves the selection down or wraps to top
  143. func (m *modelDialogCmp) moveSelectionDown() {
  144. if m.selectedIdx < len(m.models)-1 {
  145. m.selectedIdx++
  146. } else {
  147. m.selectedIdx = 0
  148. m.scrollOffset = 0
  149. }
  150. // Keep selection visible
  151. if m.selectedIdx >= m.scrollOffset+numVisibleModels {
  152. m.scrollOffset = m.selectedIdx - (numVisibleModels - 1)
  153. }
  154. }
  155. func (m *modelDialogCmp) switchProvider(offset int) {
  156. newOffset := m.hScrollOffset + offset
  157. // Ensure we stay within bounds
  158. if newOffset < 0 {
  159. newOffset = len(m.availableProviders) - 1
  160. }
  161. if newOffset >= len(m.availableProviders) {
  162. newOffset = 0
  163. }
  164. m.hScrollOffset = newOffset
  165. m.provider = m.availableProviders[m.hScrollOffset]
  166. m.setupModelsForProvider(m.provider)
  167. }
  168. func (m *modelDialogCmp) View() string {
  169. t := theme.CurrentTheme()
  170. baseStyle := styles.BaseStyle()
  171. // Capitalize first letter of provider name
  172. providerName := strings.ToUpper(string(m.provider)[:1]) + string(m.provider[1:])
  173. title := baseStyle.
  174. Foreground(t.Primary()).
  175. Bold(true).
  176. Width(maxDialogWidth).
  177. Padding(0, 0, 1).
  178. Render(fmt.Sprintf("Select %s Model", providerName))
  179. // Render visible models
  180. endIdx := min(m.scrollOffset+numVisibleModels, len(m.models))
  181. modelItems := make([]string, 0, endIdx-m.scrollOffset)
  182. for i := m.scrollOffset; i < endIdx; i++ {
  183. itemStyle := baseStyle.Width(maxDialogWidth)
  184. if i == m.selectedIdx {
  185. itemStyle = itemStyle.Background(t.Primary()).
  186. Foreground(t.Background()).Bold(true)
  187. }
  188. modelItems = append(modelItems, itemStyle.Render(m.models[i].Name))
  189. }
  190. scrollIndicator := m.getScrollIndicators(maxDialogWidth)
  191. content := lipgloss.JoinVertical(
  192. lipgloss.Left,
  193. title,
  194. baseStyle.Width(maxDialogWidth).Render(lipgloss.JoinVertical(lipgloss.Left, modelItems...)),
  195. scrollIndicator,
  196. )
  197. return baseStyle.Padding(1, 2).
  198. Border(lipgloss.RoundedBorder()).
  199. BorderBackground(t.Background()).
  200. BorderForeground(t.TextMuted()).
  201. Width(lipgloss.Width(content) + 4).
  202. Render(content)
  203. }
  204. func (m *modelDialogCmp) getScrollIndicators(maxWidth int) string {
  205. var indicator string
  206. if len(m.models) > numVisibleModels {
  207. if m.scrollOffset > 0 {
  208. indicator += "↑ "
  209. }
  210. if m.scrollOffset+numVisibleModels < len(m.models) {
  211. indicator += "↓ "
  212. }
  213. }
  214. if m.hScrollPossible {
  215. if m.hScrollOffset > 0 {
  216. indicator = "← " + indicator
  217. }
  218. if m.hScrollOffset < len(m.availableProviders)-1 {
  219. indicator += "→"
  220. }
  221. }
  222. if indicator == "" {
  223. return ""
  224. }
  225. t := theme.CurrentTheme()
  226. baseStyle := styles.BaseStyle()
  227. return baseStyle.
  228. Foreground(t.Primary()).
  229. Width(maxWidth).
  230. Align(lipgloss.Right).
  231. Bold(true).
  232. Render(indicator)
  233. }
  234. func (m *modelDialogCmp) BindingKeys() []key.Binding {
  235. return layout.KeyMapToSlice(modelKeys)
  236. }
  237. func (m *modelDialogCmp) setupModels() {
  238. cfg := config.Get()
  239. modelInfo := GetSelectedModel(cfg)
  240. m.availableProviders = getEnabledProviders(cfg)
  241. m.hScrollPossible = len(m.availableProviders) > 1
  242. m.provider = modelInfo.Provider
  243. m.hScrollOffset = findProviderIndex(m.availableProviders, m.provider)
  244. m.setupModelsForProvider(m.provider)
  245. }
  246. func GetSelectedModel(cfg *config.Config) models.Model {
  247. agentCfg := cfg.Agents[config.AgentPrimary]
  248. selectedModelId := agentCfg.Model
  249. return models.SupportedModels[selectedModelId]
  250. }
  251. func getEnabledProviders(cfg *config.Config) []models.ModelProvider {
  252. var providers []models.ModelProvider
  253. for providerId, provider := range cfg.Providers {
  254. if !provider.Disabled {
  255. providers = append(providers, providerId)
  256. }
  257. }
  258. // Sort by provider popularity
  259. slices.SortFunc(providers, func(a, b models.ModelProvider) int {
  260. rA := models.ProviderPopularity[a]
  261. rB := models.ProviderPopularity[b]
  262. // models not included in popularity ranking default to last
  263. if rA == 0 {
  264. rA = 999
  265. }
  266. if rB == 0 {
  267. rB = 999
  268. }
  269. return rA - rB
  270. })
  271. return providers
  272. }
  273. // findProviderIndex returns the index of the provider in the list, or -1 if not found
  274. func findProviderIndex(providers []models.ModelProvider, provider models.ModelProvider) int {
  275. for i, p := range providers {
  276. if p == provider {
  277. return i
  278. }
  279. }
  280. return -1
  281. }
  282. func (m *modelDialogCmp) setupModelsForProvider(provider models.ModelProvider) {
  283. cfg := config.Get()
  284. agentCfg := cfg.Agents[config.AgentPrimary]
  285. selectedModelId := agentCfg.Model
  286. m.provider = provider
  287. m.models = getModelsForProvider(provider)
  288. m.selectedIdx = 0
  289. m.scrollOffset = 0
  290. // Try to select the current model if it belongs to this provider
  291. if provider == models.SupportedModels[selectedModelId].Provider {
  292. for i, model := range m.models {
  293. if model.ID == selectedModelId {
  294. m.selectedIdx = i
  295. // Adjust scroll position to keep selected model visible
  296. if m.selectedIdx >= numVisibleModels {
  297. m.scrollOffset = m.selectedIdx - (numVisibleModels - 1)
  298. }
  299. break
  300. }
  301. }
  302. }
  303. }
  304. func getModelsForProvider(provider models.ModelProvider) []models.Model {
  305. var providerModels []models.Model
  306. for _, model := range models.SupportedModels {
  307. if model.Provider == provider {
  308. providerModels = append(providerModels, model)
  309. }
  310. }
  311. // reverse alphabetical order (if llm naming was consistent latest would appear first)
  312. slices.SortFunc(providerModels, func(a, b models.Model) int {
  313. if a.Name > b.Name {
  314. return -1
  315. } else if a.Name < b.Name {
  316. return 1
  317. }
  318. return 0
  319. })
  320. return providerModels
  321. }
  322. func NewModelDialogCmp() ModelDialog {
  323. return &modelDialogCmp{}
  324. }