models.go 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. package dialog
  2. import (
  3. "fmt"
  4. "slices"
  5. "strings"
  6. "github.com/charmbracelet/bubbles/key"
  7. tea "github.com/charmbracelet/bubbletea"
  8. "github.com/charmbracelet/lipgloss"
  9. "github.com/opencode-ai/opencode/internal/config"
  10. "github.com/opencode-ai/opencode/internal/llm/models"
  11. "github.com/opencode-ai/opencode/internal/tui/layout"
  12. "github.com/opencode-ai/opencode/internal/tui/styles"
  13. "github.com/opencode-ai/opencode/internal/tui/theme"
  14. "github.com/opencode-ai/opencode/internal/tui/util"
  15. )
  16. const (
  17. numVisibleModels = 10
  18. maxDialogWidth = 40
  19. )
  20. // ModelSelectedMsg is sent when a model is selected
  21. type ModelSelectedMsg struct {
  22. Model models.Model
  23. }
  24. // CloseModelDialogMsg is sent when a model is selected
  25. type CloseModelDialogMsg struct{}
  26. // ModelDialog interface for the model selection dialog
  27. type ModelDialog interface {
  28. tea.Model
  29. layout.Bindings
  30. }
  31. type modelDialogCmp struct {
  32. models []models.Model
  33. provider models.ModelProvider
  34. availableProviders []models.ModelProvider
  35. selectedIdx int
  36. width int
  37. height int
  38. scrollOffset int
  39. hScrollOffset int
  40. hScrollPossible bool
  41. }
  42. type modelKeyMap struct {
  43. Up key.Binding
  44. Down key.Binding
  45. Left key.Binding
  46. Right key.Binding
  47. Enter key.Binding
  48. Escape key.Binding
  49. J key.Binding
  50. K key.Binding
  51. H key.Binding
  52. L key.Binding
  53. }
  54. var modelKeys = modelKeyMap{
  55. Up: key.NewBinding(
  56. key.WithKeys("up"),
  57. key.WithHelp("↑", "previous model"),
  58. ),
  59. Down: key.NewBinding(
  60. key.WithKeys("down"),
  61. key.WithHelp("↓", "next model"),
  62. ),
  63. Left: key.NewBinding(
  64. key.WithKeys("left"),
  65. key.WithHelp("←", "scroll left"),
  66. ),
  67. Right: key.NewBinding(
  68. key.WithKeys("right"),
  69. key.WithHelp("→", "scroll right"),
  70. ),
  71. Enter: key.NewBinding(
  72. key.WithKeys("enter"),
  73. key.WithHelp("enter", "select model"),
  74. ),
  75. Escape: key.NewBinding(
  76. key.WithKeys("esc"),
  77. key.WithHelp("esc", "close"),
  78. ),
  79. J: key.NewBinding(
  80. key.WithKeys("j"),
  81. key.WithHelp("j", "next model"),
  82. ),
  83. K: key.NewBinding(
  84. key.WithKeys("k"),
  85. key.WithHelp("k", "previous model"),
  86. ),
  87. H: key.NewBinding(
  88. key.WithKeys("h"),
  89. key.WithHelp("h", "scroll left"),
  90. ),
  91. L: key.NewBinding(
  92. key.WithKeys("l"),
  93. key.WithHelp("l", "scroll right"),
  94. ),
  95. }
  96. func (m *modelDialogCmp) Init() tea.Cmd {
  97. m.setupModels()
  98. return nil
  99. }
  100. func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
  101. switch msg := msg.(type) {
  102. case tea.KeyMsg:
  103. switch {
  104. case key.Matches(msg, modelKeys.Up) || key.Matches(msg, modelKeys.K):
  105. m.moveSelectionUp()
  106. case key.Matches(msg, modelKeys.Down) || key.Matches(msg, modelKeys.J):
  107. m.moveSelectionDown()
  108. case key.Matches(msg, modelKeys.Left) || key.Matches(msg, modelKeys.H):
  109. if m.hScrollPossible {
  110. m.switchProvider(-1)
  111. }
  112. case key.Matches(msg, modelKeys.Right) || key.Matches(msg, modelKeys.L):
  113. if m.hScrollPossible {
  114. m.switchProvider(1)
  115. }
  116. case key.Matches(msg, modelKeys.Enter):
  117. util.ReportInfo(fmt.Sprintf("selected model: %s", m.models[m.selectedIdx].Name))
  118. return m, util.CmdHandler(ModelSelectedMsg{Model: m.models[m.selectedIdx]})
  119. case key.Matches(msg, modelKeys.Escape):
  120. return m, util.CmdHandler(CloseModelDialogMsg{})
  121. }
  122. case tea.WindowSizeMsg:
  123. m.width = msg.Width
  124. m.height = msg.Height
  125. }
  126. return m, nil
  127. }
  128. // moveSelectionUp moves the selection up or wraps to bottom
  129. func (m *modelDialogCmp) moveSelectionUp() {
  130. if m.selectedIdx > 0 {
  131. m.selectedIdx--
  132. } else {
  133. m.selectedIdx = len(m.models) - 1
  134. m.scrollOffset = max(0, len(m.models)-numVisibleModels)
  135. }
  136. // Keep selection visible
  137. if m.selectedIdx < m.scrollOffset {
  138. m.scrollOffset = m.selectedIdx
  139. }
  140. }
  141. // moveSelectionDown moves the selection down or wraps to top
  142. func (m *modelDialogCmp) moveSelectionDown() {
  143. if m.selectedIdx < len(m.models)-1 {
  144. m.selectedIdx++
  145. } else {
  146. m.selectedIdx = 0
  147. m.scrollOffset = 0
  148. }
  149. // Keep selection visible
  150. if m.selectedIdx >= m.scrollOffset+numVisibleModels {
  151. m.scrollOffset = m.selectedIdx - (numVisibleModels - 1)
  152. }
  153. }
  154. func (m *modelDialogCmp) switchProvider(offset int) {
  155. newOffset := m.hScrollOffset + offset
  156. // Ensure we stay within bounds
  157. if newOffset < 0 {
  158. newOffset = len(m.availableProviders) - 1
  159. }
  160. if newOffset >= len(m.availableProviders) {
  161. newOffset = 0
  162. }
  163. m.hScrollOffset = newOffset
  164. m.provider = m.availableProviders[m.hScrollOffset]
  165. m.setupModelsForProvider(m.provider)
  166. }
  167. func (m *modelDialogCmp) View() string {
  168. t := theme.CurrentTheme()
  169. baseStyle := styles.BaseStyle()
  170. // Capitalize first letter of provider name
  171. providerName := strings.ToUpper(string(m.provider)[:1]) + string(m.provider[1:])
  172. title := baseStyle.
  173. Foreground(t.Primary()).
  174. Bold(true).
  175. Width(maxDialogWidth).
  176. Padding(0, 0, 1).
  177. Render(fmt.Sprintf("Select %s Model", providerName))
  178. // Render visible models
  179. endIdx := min(m.scrollOffset+numVisibleModels, len(m.models))
  180. modelItems := make([]string, 0, endIdx-m.scrollOffset)
  181. for i := m.scrollOffset; i < endIdx; i++ {
  182. itemStyle := baseStyle.Width(maxDialogWidth)
  183. if i == m.selectedIdx {
  184. itemStyle = itemStyle.Background(t.Primary()).
  185. Foreground(t.Background()).Bold(true)
  186. }
  187. modelItems = append(modelItems, itemStyle.Render(m.models[i].Name))
  188. }
  189. scrollIndicator := m.getScrollIndicators(maxDialogWidth)
  190. content := lipgloss.JoinVertical(
  191. lipgloss.Left,
  192. title,
  193. baseStyle.Width(maxDialogWidth).Render(lipgloss.JoinVertical(lipgloss.Left, modelItems...)),
  194. scrollIndicator,
  195. )
  196. return baseStyle.Padding(1, 2).
  197. Border(lipgloss.RoundedBorder()).
  198. BorderBackground(t.Background()).
  199. BorderForeground(t.TextMuted()).
  200. Width(lipgloss.Width(content) + 4).
  201. Render(content)
  202. }
  203. func (m *modelDialogCmp) getScrollIndicators(maxWidth int) string {
  204. var indicator string
  205. if len(m.models) > numVisibleModels {
  206. if m.scrollOffset > 0 {
  207. indicator += "↑ "
  208. }
  209. if m.scrollOffset+numVisibleModels < len(m.models) {
  210. indicator += "↓ "
  211. }
  212. }
  213. if m.hScrollPossible {
  214. if m.hScrollOffset > 0 {
  215. indicator = "← " + indicator
  216. }
  217. if m.hScrollOffset < len(m.availableProviders)-1 {
  218. indicator += "→"
  219. }
  220. }
  221. if indicator == "" {
  222. return ""
  223. }
  224. t := theme.CurrentTheme()
  225. baseStyle := styles.BaseStyle()
  226. return baseStyle.
  227. Foreground(t.Primary()).
  228. Width(maxWidth).
  229. Align(lipgloss.Right).
  230. Bold(true).
  231. Render(indicator)
  232. }
  233. func (m *modelDialogCmp) BindingKeys() []key.Binding {
  234. return layout.KeyMapToSlice(modelKeys)
  235. }
  236. func (m *modelDialogCmp) setupModels() {
  237. cfg := config.Get()
  238. m.availableProviders = getEnabledProviders(cfg)
  239. m.hScrollPossible = len(m.availableProviders) > 1
  240. agentCfg := cfg.Agents[config.AgentCoder]
  241. selectedModelId := agentCfg.Model
  242. modelInfo := models.SupportedModels[selectedModelId]
  243. m.provider = modelInfo.Provider
  244. m.hScrollOffset = findProviderIndex(m.availableProviders, m.provider)
  245. m.setupModelsForProvider(m.provider)
  246. }
  247. func getEnabledProviders(cfg *config.Config) []models.ModelProvider {
  248. var providers []models.ModelProvider
  249. for providerId, provider := range cfg.Providers {
  250. if !provider.Disabled {
  251. providers = append(providers, providerId)
  252. }
  253. }
  254. // Sort by provider popularity
  255. slices.SortFunc(providers, func(a, b models.ModelProvider) int {
  256. rA := models.ProviderPopularity[a]
  257. rB := models.ProviderPopularity[b]
  258. // models not included in popularity ranking default to last
  259. if rA == 0 {
  260. rA = 999
  261. }
  262. if rB == 0 {
  263. rB = 999
  264. }
  265. return rA - rB
  266. })
  267. return providers
  268. }
  269. // findProviderIndex returns the index of the provider in the list, or -1 if not found
  270. func findProviderIndex(providers []models.ModelProvider, provider models.ModelProvider) int {
  271. for i, p := range providers {
  272. if p == provider {
  273. return i
  274. }
  275. }
  276. return -1
  277. }
  278. func (m *modelDialogCmp) setupModelsForProvider(provider models.ModelProvider) {
  279. cfg := config.Get()
  280. agentCfg := cfg.Agents[config.AgentCoder]
  281. selectedModelId := agentCfg.Model
  282. m.provider = provider
  283. m.models = getModelsForProvider(provider)
  284. m.selectedIdx = 0
  285. m.scrollOffset = 0
  286. // Try to select the current model if it belongs to this provider
  287. if provider == models.SupportedModels[selectedModelId].Provider {
  288. for i, model := range m.models {
  289. if model.ID == selectedModelId {
  290. m.selectedIdx = i
  291. // Adjust scroll position to keep selected model visible
  292. if m.selectedIdx >= numVisibleModels {
  293. m.scrollOffset = m.selectedIdx - (numVisibleModels - 1)
  294. }
  295. break
  296. }
  297. }
  298. }
  299. }
  300. func getModelsForProvider(provider models.ModelProvider) []models.Model {
  301. var providerModels []models.Model
  302. for _, model := range models.SupportedModels {
  303. if model.Provider == provider {
  304. providerModels = append(providerModels, model)
  305. }
  306. }
  307. // reverse alphabetical order (if llm naming was consistent latest would appear first)
  308. slices.SortFunc(providerModels, func(a, b models.Model) int {
  309. if a.Name > b.Name {
  310. return -1
  311. } else if a.Name < b.Name {
  312. return 1
  313. }
  314. return 0
  315. })
  316. return providerModels
  317. }
  318. func NewModelDialogCmp() ModelDialog {
  319. return &modelDialogCmp{}
  320. }