models.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. package dialog
  2. import (
  3. "context"
  4. "fmt"
  5. "sort"
  6. "time"
  7. "github.com/charmbracelet/bubbles/v2/key"
  8. tea "github.com/charmbracelet/bubbletea/v2"
  9. "github.com/lithammer/fuzzysearch/fuzzy"
  10. "github.com/sst/opencode-sdk-go"
  11. "github.com/sst/opencode/internal/app"
  12. "github.com/sst/opencode/internal/components/list"
  13. "github.com/sst/opencode/internal/components/modal"
  14. "github.com/sst/opencode/internal/layout"
  15. "github.com/sst/opencode/internal/styles"
  16. "github.com/sst/opencode/internal/theme"
  17. "github.com/sst/opencode/internal/util"
  18. )
  19. const (
  20. numVisibleModels = 10
  21. minDialogWidth = 40
  22. maxDialogWidth = 80
  23. maxRecentModels = 5
  24. )
  25. // ModelDialog interface for the model selection dialog
  26. type ModelDialog interface {
  27. layout.Modal
  28. }
  29. type modelDialog struct {
  30. app *app.App
  31. allModels []ModelWithProvider
  32. width int
  33. height int
  34. modal *modal.Modal
  35. searchDialog *SearchDialog
  36. dialogWidth int
  37. }
  38. type ModelWithProvider struct {
  39. Model opencode.Model
  40. Provider opencode.Provider
  41. }
  42. // modelItem is a custom list item for model selections
  43. type modelItem struct {
  44. model ModelWithProvider
  45. }
  46. func (m modelItem) Render(
  47. selected bool,
  48. width int,
  49. baseStyle styles.Style,
  50. ) string {
  51. t := theme.CurrentTheme()
  52. itemStyle := baseStyle.
  53. Background(t.BackgroundPanel()).
  54. Foreground(t.Text())
  55. if selected {
  56. itemStyle = itemStyle.Foreground(t.Primary())
  57. }
  58. providerStyle := baseStyle.
  59. Foreground(t.TextMuted()).
  60. Background(t.BackgroundPanel())
  61. modelPart := itemStyle.Render(m.model.Model.Name)
  62. providerPart := providerStyle.Render(fmt.Sprintf(" %s", m.model.Provider.Name))
  63. combinedText := modelPart + providerPart
  64. return baseStyle.
  65. Background(t.BackgroundPanel()).
  66. PaddingLeft(1).
  67. Render(combinedText)
  68. }
  69. func (m modelItem) Selectable() bool {
  70. return true
  71. }
  72. type modelKeyMap struct {
  73. Enter key.Binding
  74. Escape key.Binding
  75. }
  76. var modelKeys = modelKeyMap{
  77. Enter: key.NewBinding(
  78. key.WithKeys("enter"),
  79. key.WithHelp("enter", "select model"),
  80. ),
  81. Escape: key.NewBinding(
  82. key.WithKeys("esc"),
  83. key.WithHelp("esc", "close"),
  84. ),
  85. }
  86. func (m *modelDialog) Init() tea.Cmd {
  87. m.setupAllModels()
  88. return m.searchDialog.Init()
  89. }
  90. func (m *modelDialog) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
  91. switch msg := msg.(type) {
  92. case SearchSelectionMsg:
  93. // Handle selection from search dialog
  94. if item, ok := msg.Item.(modelItem); ok {
  95. return m, tea.Sequence(
  96. util.CmdHandler(modal.CloseModalMsg{}),
  97. util.CmdHandler(
  98. app.ModelSelectedMsg{
  99. Provider: item.model.Provider,
  100. Model: item.model.Model,
  101. }),
  102. )
  103. }
  104. return m, util.CmdHandler(modal.CloseModalMsg{})
  105. case SearchCancelledMsg:
  106. return m, util.CmdHandler(modal.CloseModalMsg{})
  107. case SearchRemoveItemMsg:
  108. if item, ok := msg.Item.(modelItem); ok {
  109. if m.isModelInRecentSection(item.model, msg.Index) {
  110. m.app.State.RemoveModelFromRecentlyUsed(item.model.Provider.ID, item.model.Model.ID)
  111. items := m.buildDisplayList(m.searchDialog.GetQuery())
  112. m.searchDialog.SetItems(items)
  113. return m, m.app.SaveState()
  114. }
  115. }
  116. return m, nil
  117. case SearchQueryChangedMsg:
  118. // Update the list based on search query
  119. items := m.buildDisplayList(msg.Query)
  120. m.searchDialog.SetItems(items)
  121. return m, nil
  122. case tea.WindowSizeMsg:
  123. m.width = msg.Width
  124. m.height = msg.Height
  125. m.searchDialog.SetWidth(m.dialogWidth)
  126. m.searchDialog.SetHeight(msg.Height)
  127. }
  128. updatedDialog, cmd := m.searchDialog.Update(msg)
  129. m.searchDialog = updatedDialog.(*SearchDialog)
  130. return m, cmd
  131. }
  132. func (m *modelDialog) View() string {
  133. return m.searchDialog.View()
  134. }
  135. func (m *modelDialog) calculateOptimalWidth(models []ModelWithProvider) int {
  136. maxWidth := minDialogWidth
  137. for _, model := range models {
  138. // Calculate the width needed for this item: "ModelName (ProviderName)"
  139. // Add 4 for the parentheses, space, and some padding
  140. itemWidth := len(model.Model.Name) + len(model.Provider.Name) + 4
  141. if itemWidth > maxWidth {
  142. maxWidth = itemWidth
  143. }
  144. }
  145. if maxWidth > maxDialogWidth {
  146. maxWidth = maxDialogWidth
  147. }
  148. return maxWidth
  149. }
  150. func (m *modelDialog) setupAllModels() {
  151. providers, _ := m.app.ListProviders(context.Background())
  152. m.allModels = make([]ModelWithProvider, 0)
  153. for _, provider := range providers {
  154. for _, model := range provider.Models {
  155. m.allModels = append(m.allModels, ModelWithProvider{
  156. Model: model,
  157. Provider: provider,
  158. })
  159. }
  160. }
  161. m.sortModels()
  162. // Calculate optimal width based on all models
  163. m.dialogWidth = m.calculateOptimalWidth(m.allModels)
  164. // Initialize search dialog
  165. m.searchDialog = NewSearchDialog("Search models...", numVisibleModels)
  166. m.searchDialog.SetWidth(m.dialogWidth)
  167. // Build initial display list (empty query shows grouped view)
  168. items := m.buildDisplayList("")
  169. m.searchDialog.SetItems(items)
  170. }
  171. func (m *modelDialog) sortModels() {
  172. sort.Slice(m.allModels, func(i, j int) bool {
  173. modelA := m.allModels[i]
  174. modelB := m.allModels[j]
  175. usageA := m.getModelUsageTime(modelA.Provider.ID, modelA.Model.ID)
  176. usageB := m.getModelUsageTime(modelB.Provider.ID, modelB.Model.ID)
  177. // If both have usage times, sort by most recent first
  178. if !usageA.IsZero() && !usageB.IsZero() {
  179. return usageA.After(usageB)
  180. }
  181. // If only one has usage time, it goes first
  182. if !usageA.IsZero() && usageB.IsZero() {
  183. return true
  184. }
  185. if usageA.IsZero() && !usageB.IsZero() {
  186. return false
  187. }
  188. // If neither has usage time, sort by release date desc if available
  189. if modelA.Model.ReleaseDate != "" && modelB.Model.ReleaseDate != "" {
  190. dateA := m.parseReleaseDate(modelA.Model.ReleaseDate)
  191. dateB := m.parseReleaseDate(modelB.Model.ReleaseDate)
  192. if !dateA.IsZero() && !dateB.IsZero() {
  193. return dateA.After(dateB)
  194. }
  195. }
  196. // If only one has release date, it goes first
  197. if modelA.Model.ReleaseDate != "" && modelB.Model.ReleaseDate == "" {
  198. return true
  199. }
  200. if modelA.Model.ReleaseDate == "" && modelB.Model.ReleaseDate != "" {
  201. return false
  202. }
  203. // If neither has usage time nor release date, fall back to alphabetical sorting
  204. return modelA.Model.Name < modelB.Model.Name
  205. })
  206. }
  207. func (m *modelDialog) parseReleaseDate(dateStr string) time.Time {
  208. if parsed, err := time.Parse("2006-01-02", dateStr); err == nil {
  209. return parsed
  210. }
  211. return time.Time{}
  212. }
  213. func (m *modelDialog) getModelUsageTime(providerID, modelID string) time.Time {
  214. for _, usage := range m.app.State.RecentlyUsedModels {
  215. if usage.ProviderID == providerID && usage.ModelID == modelID {
  216. return usage.LastUsed
  217. }
  218. }
  219. return time.Time{}
  220. }
  221. // buildDisplayList creates the list items based on search query
  222. func (m *modelDialog) buildDisplayList(query string) []list.Item {
  223. if query != "" {
  224. // Search mode: use fuzzy matching
  225. return m.buildSearchResults(query)
  226. } else {
  227. // Grouped mode: show Recent section and provider groups
  228. return m.buildGroupedResults()
  229. }
  230. }
  231. // buildSearchResults creates a flat list of search results using fuzzy matching
  232. func (m *modelDialog) buildSearchResults(query string) []list.Item {
  233. type modelMatch struct {
  234. model ModelWithProvider
  235. score int
  236. }
  237. modelNames := []string{}
  238. modelMap := make(map[string]ModelWithProvider)
  239. // Create search strings and perform fuzzy matching
  240. for _, model := range m.allModels {
  241. searchStr := fmt.Sprintf("%s %s", model.Model.Name, model.Provider.Name)
  242. modelNames = append(modelNames, searchStr)
  243. modelMap[searchStr] = model
  244. searchStr = fmt.Sprintf("%s %s", model.Provider.Name, model.Model.Name)
  245. modelNames = append(modelNames, searchStr)
  246. modelMap[searchStr] = model
  247. }
  248. matches := fuzzy.RankFindFold(query, modelNames)
  249. sort.Sort(matches)
  250. items := []list.Item{}
  251. seenModels := make(map[string]bool)
  252. for _, match := range matches {
  253. model := modelMap[match.Target]
  254. // Create a unique key to avoid duplicates
  255. key := fmt.Sprintf("%s:%s", model.Provider.ID, model.Model.ID)
  256. if seenModels[key] {
  257. continue
  258. }
  259. seenModels[key] = true
  260. items = append(items, modelItem{model: model})
  261. }
  262. return items
  263. }
  264. // buildGroupedResults creates a grouped list with Recent section and provider groups
  265. func (m *modelDialog) buildGroupedResults() []list.Item {
  266. var items []list.Item
  267. // Add Recent section
  268. recentModels := m.getRecentModels(maxRecentModels)
  269. if len(recentModels) > 0 {
  270. items = append(items, list.HeaderItem("Recent"))
  271. for _, model := range recentModels {
  272. items = append(items, modelItem{model: model})
  273. }
  274. }
  275. // Group models by provider
  276. providerGroups := make(map[string][]ModelWithProvider)
  277. for _, model := range m.allModels {
  278. providerName := model.Provider.Name
  279. providerGroups[providerName] = append(providerGroups[providerName], model)
  280. }
  281. // Get sorted provider names for consistent order
  282. var providerNames []string
  283. for name := range providerGroups {
  284. providerNames = append(providerNames, name)
  285. }
  286. sort.Strings(providerNames)
  287. // Add provider groups
  288. for _, providerName := range providerNames {
  289. models := providerGroups[providerName]
  290. // Sort models within provider group
  291. sort.Slice(models, func(i, j int) bool {
  292. modelA := models[i]
  293. modelB := models[j]
  294. usageA := m.getModelUsageTime(modelA.Provider.ID, modelA.Model.ID)
  295. usageB := m.getModelUsageTime(modelB.Provider.ID, modelB.Model.ID)
  296. // Sort by usage time first, then by release date, then alphabetically
  297. if !usageA.IsZero() && !usageB.IsZero() {
  298. return usageA.After(usageB)
  299. }
  300. if !usageA.IsZero() && usageB.IsZero() {
  301. return true
  302. }
  303. if usageA.IsZero() && !usageB.IsZero() {
  304. return false
  305. }
  306. // Sort by release date if available
  307. if modelA.Model.ReleaseDate != "" && modelB.Model.ReleaseDate != "" {
  308. dateA := m.parseReleaseDate(modelA.Model.ReleaseDate)
  309. dateB := m.parseReleaseDate(modelB.Model.ReleaseDate)
  310. if !dateA.IsZero() && !dateB.IsZero() {
  311. return dateA.After(dateB)
  312. }
  313. }
  314. return modelA.Model.Name < modelB.Model.Name
  315. })
  316. // Add provider header
  317. items = append(items, list.HeaderItem(providerName))
  318. // Add models in this provider group
  319. for _, model := range models {
  320. items = append(items, modelItem{model: model})
  321. }
  322. }
  323. return items
  324. }
  325. // getRecentModels returns the most recently used models
  326. func (m *modelDialog) getRecentModels(limit int) []ModelWithProvider {
  327. var recentModels []ModelWithProvider
  328. // Get recent models from app state
  329. for _, usage := range m.app.State.RecentlyUsedModels {
  330. if len(recentModels) >= limit {
  331. break
  332. }
  333. // Find the corresponding model
  334. for _, model := range m.allModels {
  335. if model.Provider.ID == usage.ProviderID && model.Model.ID == usage.ModelID {
  336. recentModels = append(recentModels, model)
  337. break
  338. }
  339. }
  340. }
  341. return recentModels
  342. }
  343. func (m *modelDialog) isModelInRecentSection(model ModelWithProvider, index int) bool {
  344. // Only check if we're in grouped mode (no search query)
  345. if m.searchDialog.GetQuery() != "" {
  346. return false
  347. }
  348. recentModels := m.getRecentModels(maxRecentModels)
  349. if len(recentModels) == 0 {
  350. return false
  351. }
  352. // Index 0 is the "Recent" header, so recent models are at indices 1 to len(recentModels)
  353. if index >= 1 && index <= len(recentModels) {
  354. if index-1 < len(recentModels) {
  355. recentModel := recentModels[index-1]
  356. return recentModel.Provider.ID == model.Provider.ID &&
  357. recentModel.Model.ID == model.Model.ID
  358. }
  359. }
  360. return false
  361. }
  362. func (m *modelDialog) Render(background string) string {
  363. return m.modal.Render(m.View(), background)
  364. }
  365. func (s *modelDialog) Close() tea.Cmd {
  366. return nil
  367. }
  368. func NewModelDialog(app *app.App) ModelDialog {
  369. dialog := &modelDialog{
  370. app: app,
  371. }
  372. dialog.setupAllModels()
  373. dialog.modal = modal.New(
  374. modal.WithTitle("Select Model"),
  375. modal.WithMaxWidth(dialog.dialogWidth+4),
  376. )
  377. return dialog
  378. }