models.go 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. package dialog
  2. import (
  3. "context"
  4. "fmt"
  5. "maps"
  6. "slices"
  7. "strings"
  8. "github.com/charmbracelet/bubbles/key"
  9. tea "github.com/charmbracelet/bubbletea"
  10. "github.com/charmbracelet/lipgloss"
  11. "github.com/sst/opencode/internal/app"
  12. "github.com/sst/opencode/internal/layout"
  13. "github.com/sst/opencode/internal/styles"
  14. "github.com/sst/opencode/internal/theme"
  15. "github.com/sst/opencode/internal/util"
  16. "github.com/sst/opencode/pkg/client"
  17. )
  18. const (
  19. numVisibleModels = 10
  20. maxDialogWidth = 40
  21. )
  22. // CloseModelDialogMsg is sent when a model is selected
  23. type CloseModelDialogMsg struct {
  24. Provider *client.ProviderInfo
  25. Model *client.ProviderModel
  26. }
  27. // ModelDialog interface for the model selection dialog
  28. type ModelDialog interface {
  29. tea.Model
  30. layout.Bindings
  31. SetProviders(providers []client.ProviderInfo)
  32. }
  33. type modelDialogCmp struct {
  34. app *app.App
  35. availableProviders []client.ProviderInfo
  36. provider client.ProviderInfo
  37. selectedIdx int
  38. width int
  39. height int
  40. scrollOffset int
  41. hScrollOffset int
  42. hScrollPossible bool
  43. }
  44. type modelKeyMap struct {
  45. Up key.Binding
  46. Down key.Binding
  47. Left key.Binding
  48. Right key.Binding
  49. Enter key.Binding
  50. Escape key.Binding
  51. J key.Binding
  52. K key.Binding
  53. H key.Binding
  54. L key.Binding
  55. }
  56. var modelKeys = modelKeyMap{
  57. Up: key.NewBinding(
  58. key.WithKeys("up"),
  59. key.WithHelp("↑", "previous model"),
  60. ),
  61. Down: key.NewBinding(
  62. key.WithKeys("down"),
  63. key.WithHelp("↓", "next model"),
  64. ),
  65. Left: key.NewBinding(
  66. key.WithKeys("left"),
  67. key.WithHelp("←", "scroll left"),
  68. ),
  69. Right: key.NewBinding(
  70. key.WithKeys("right"),
  71. key.WithHelp("→", "scroll right"),
  72. ),
  73. Enter: key.NewBinding(
  74. key.WithKeys("enter"),
  75. key.WithHelp("enter", "select model"),
  76. ),
  77. Escape: key.NewBinding(
  78. key.WithKeys("esc"),
  79. key.WithHelp("esc", "close"),
  80. ),
  81. J: key.NewBinding(
  82. key.WithKeys("j"),
  83. key.WithHelp("j", "next model"),
  84. ),
  85. K: key.NewBinding(
  86. key.WithKeys("k"),
  87. key.WithHelp("k", "previous model"),
  88. ),
  89. H: key.NewBinding(
  90. key.WithKeys("h"),
  91. key.WithHelp("h", "scroll left"),
  92. ),
  93. L: key.NewBinding(
  94. key.WithKeys("l"),
  95. key.WithHelp("l", "scroll right"),
  96. ),
  97. }
  98. func (m *modelDialogCmp) Init() tea.Cmd {
  99. // cfg := config.Get()
  100. // modelInfo := GetSelectedModel(cfg)
  101. // m.availableProviders = getEnabledProviders(cfg)
  102. // m.hScrollPossible = len(m.availableProviders) > 1
  103. // m.provider = modelInfo.Provider
  104. // m.hScrollOffset = findProviderIndex(m.availableProviders, m.provider)
  105. // m.setupModelsForProvider(m.provider)
  106. m.availableProviders, _ = m.app.ListProviders(context.Background())
  107. m.hScrollOffset = 0
  108. m.hScrollPossible = len(m.availableProviders) > 1
  109. m.provider = m.availableProviders[m.hScrollOffset]
  110. return nil
  111. }
  112. func (m *modelDialogCmp) SetProviders(providers []client.ProviderInfo) {
  113. m.availableProviders = providers
  114. }
  115. func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
  116. switch msg := msg.(type) {
  117. case tea.KeyMsg:
  118. switch {
  119. case key.Matches(msg, modelKeys.Up) || key.Matches(msg, modelKeys.K):
  120. m.moveSelectionUp()
  121. case key.Matches(msg, modelKeys.Down) || key.Matches(msg, modelKeys.J):
  122. m.moveSelectionDown()
  123. case key.Matches(msg, modelKeys.Left) || key.Matches(msg, modelKeys.H):
  124. if m.hScrollPossible {
  125. m.switchProvider(-1)
  126. }
  127. case key.Matches(msg, modelKeys.Right) || key.Matches(msg, modelKeys.L):
  128. if m.hScrollPossible {
  129. m.switchProvider(1)
  130. }
  131. case key.Matches(msg, modelKeys.Enter):
  132. models := m.models()
  133. return m, util.CmdHandler(CloseModelDialogMsg{Provider: &m.provider, Model: &models[m.selectedIdx]})
  134. case key.Matches(msg, modelKeys.Escape):
  135. return m, util.CmdHandler(CloseModelDialogMsg{})
  136. }
  137. case tea.WindowSizeMsg:
  138. m.width = msg.Width
  139. m.height = msg.Height
  140. }
  141. return m, nil
  142. }
  143. func (m *modelDialogCmp) models() []client.ProviderModel {
  144. models := slices.SortedFunc(maps.Values(m.provider.Models), func(a, b client.ProviderModel) int {
  145. return strings.Compare(*a.Name, *b.Name)
  146. })
  147. return models
  148. }
  149. // moveSelectionUp moves the selection up or wraps to bottom
  150. func (m *modelDialogCmp) moveSelectionUp() {
  151. if m.selectedIdx > 0 {
  152. m.selectedIdx--
  153. } else {
  154. m.selectedIdx = len(m.provider.Models) - 1
  155. m.scrollOffset = max(0, len(m.provider.Models)-numVisibleModels)
  156. }
  157. // Keep selection visible
  158. if m.selectedIdx < m.scrollOffset {
  159. m.scrollOffset = m.selectedIdx
  160. }
  161. }
  162. // moveSelectionDown moves the selection down or wraps to top
  163. func (m *modelDialogCmp) moveSelectionDown() {
  164. if m.selectedIdx < len(m.provider.Models)-1 {
  165. m.selectedIdx++
  166. } else {
  167. m.selectedIdx = 0
  168. m.scrollOffset = 0
  169. }
  170. // Keep selection visible
  171. if m.selectedIdx >= m.scrollOffset+numVisibleModels {
  172. m.scrollOffset = m.selectedIdx - (numVisibleModels - 1)
  173. }
  174. }
  175. func (m *modelDialogCmp) switchProvider(offset int) {
  176. newOffset := m.hScrollOffset + offset
  177. // Ensure we stay within bounds
  178. if newOffset < 0 {
  179. newOffset = len(m.availableProviders) - 1
  180. }
  181. if newOffset >= len(m.availableProviders) {
  182. newOffset = 0
  183. }
  184. m.hScrollOffset = newOffset
  185. m.provider = m.availableProviders[m.hScrollOffset]
  186. m.setupModelsForProvider(m.provider.Id)
  187. }
  188. func (m *modelDialogCmp) View() string {
  189. t := theme.CurrentTheme()
  190. baseStyle := styles.BaseStyle()
  191. // Capitalize first letter of provider name
  192. title := baseStyle.
  193. Foreground(t.Primary()).
  194. Bold(true).
  195. Width(maxDialogWidth).
  196. Padding(0, 0, 1).
  197. Render(fmt.Sprintf("Select %s Model", m.provider.Name))
  198. // Render visible models
  199. endIdx := min(m.scrollOffset+numVisibleModels, len(m.provider.Models))
  200. modelItems := make([]string, 0, endIdx-m.scrollOffset)
  201. models := m.models()
  202. for i := m.scrollOffset; i < endIdx; i++ {
  203. itemStyle := baseStyle.Width(maxDialogWidth)
  204. if i == m.selectedIdx {
  205. itemStyle = itemStyle.Background(t.Primary()).
  206. Foreground(t.Background()).Bold(true)
  207. }
  208. modelItems = append(modelItems, itemStyle.Render(*models[i].Name))
  209. }
  210. scrollIndicator := m.getScrollIndicators(maxDialogWidth)
  211. content := lipgloss.JoinVertical(
  212. lipgloss.Left,
  213. title,
  214. baseStyle.Width(maxDialogWidth).Render(lipgloss.JoinVertical(lipgloss.Left, modelItems...)),
  215. scrollIndicator,
  216. )
  217. return baseStyle.Padding(1, 2).
  218. Border(lipgloss.RoundedBorder()).
  219. BorderBackground(t.Background()).
  220. BorderForeground(t.TextMuted()).
  221. Width(lipgloss.Width(content) + 4).
  222. Render(content)
  223. }
  224. func (m *modelDialogCmp) getScrollIndicators(maxWidth int) string {
  225. var indicator string
  226. if len(m.provider.Models) > numVisibleModels {
  227. if m.scrollOffset > 0 {
  228. indicator += "↑ "
  229. }
  230. if m.scrollOffset+numVisibleModels < len(m.provider.Models) {
  231. indicator += "↓ "
  232. }
  233. }
  234. if m.hScrollPossible {
  235. if m.hScrollOffset > 0 {
  236. indicator = "← " + indicator
  237. }
  238. if m.hScrollOffset < len(m.availableProviders)-1 {
  239. indicator += "→"
  240. }
  241. }
  242. if indicator == "" {
  243. return ""
  244. }
  245. t := theme.CurrentTheme()
  246. baseStyle := styles.BaseStyle()
  247. return baseStyle.
  248. Foreground(t.Primary()).
  249. Width(maxWidth).
  250. Align(lipgloss.Right).
  251. Bold(true).
  252. Render(indicator)
  253. }
  254. func (m *modelDialogCmp) BindingKeys() []key.Binding {
  255. return layout.KeyMapToSlice(modelKeys)
  256. }
  257. // findProviderIndex returns the index of the provider in the list, or -1 if not found
  258. // func findProviderIndex(providers []string, provider string) int {
  259. // for i, p := range providers {
  260. // if p == provider {
  261. // return i
  262. // }
  263. // }
  264. // return -1
  265. // }
  266. func (m *modelDialogCmp) setupModelsForProvider(_ string) {
  267. m.selectedIdx = 0
  268. m.scrollOffset = 0
  269. // cfg := config.Get()
  270. // agentCfg := cfg.Agents[config.AgentPrimary]
  271. // selectedModelId := agentCfg.Model
  272. // m.provider = provider
  273. // m.models = getModelsForProvider(provider)
  274. // Try to select the current model if it belongs to this provider
  275. // if provider == models.SupportedModels[selectedModelId].Provider {
  276. // for i, model := range m.models {
  277. // if model.ID == selectedModelId {
  278. // m.selectedIdx = i
  279. // // Adjust scroll position to keep selected model visible
  280. // if m.selectedIdx >= numVisibleModels {
  281. // m.scrollOffset = m.selectedIdx - (numVisibleModels - 1)
  282. // }
  283. // break
  284. // }
  285. // }
  286. // }
  287. }
  288. func NewModelDialogCmp(app *app.App) ModelDialog {
  289. return &modelDialogCmp{
  290. app: app,
  291. }
  292. }