list.go 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. package models
  2. import (
  3. "cmp"
  4. "fmt"
  5. "slices"
  6. "strings"
  7. tea "github.com/charmbracelet/bubbletea/v2"
  8. "github.com/charmbracelet/catwalk/pkg/catwalk"
  9. "github.com/charmbracelet/crush/internal/config"
  10. "github.com/charmbracelet/crush/internal/tui/exp/list"
  11. "github.com/charmbracelet/crush/internal/tui/styles"
  12. "github.com/charmbracelet/crush/internal/tui/util"
  13. )
  14. type listModel = list.FilterableGroupList[list.CompletionItem[ModelOption]]
  15. type ModelListComponent struct {
  16. list listModel
  17. modelType int
  18. providers []catwalk.Provider
  19. }
  20. func NewModelListComponent(keyMap list.KeyMap, inputPlaceholder string, shouldResize bool) *ModelListComponent {
  21. t := styles.CurrentTheme()
  22. inputStyle := t.S().Base.PaddingLeft(1).PaddingBottom(1)
  23. options := []list.ListOption{
  24. list.WithKeyMap(keyMap),
  25. list.WithWrapNavigation(),
  26. }
  27. if shouldResize {
  28. options = append(options, list.WithResizeByList())
  29. }
  30. modelList := list.NewFilterableGroupedList(
  31. []list.Group[list.CompletionItem[ModelOption]]{},
  32. list.WithFilterInputStyle(inputStyle),
  33. list.WithFilterPlaceholder(inputPlaceholder),
  34. list.WithFilterListOptions(
  35. options...,
  36. ),
  37. )
  38. return &ModelListComponent{
  39. list: modelList,
  40. modelType: LargeModelType,
  41. }
  42. }
  43. func (m *ModelListComponent) Init() tea.Cmd {
  44. var cmds []tea.Cmd
  45. if len(m.providers) == 0 {
  46. cfg := config.Get()
  47. providers, err := config.Providers(cfg)
  48. filteredProviders := []catwalk.Provider{}
  49. for _, p := range providers {
  50. hasAPIKeyEnv := strings.HasPrefix(p.APIKey, "$")
  51. if hasAPIKeyEnv && p.ID != catwalk.InferenceProviderAzure {
  52. filteredProviders = append(filteredProviders, p)
  53. }
  54. }
  55. m.providers = filteredProviders
  56. if err != nil {
  57. cmds = append(cmds, util.ReportError(err))
  58. }
  59. }
  60. cmds = append(cmds, m.list.Init(), m.SetModelType(m.modelType))
  61. return tea.Batch(cmds...)
  62. }
  63. func (m *ModelListComponent) Update(msg tea.Msg) (*ModelListComponent, tea.Cmd) {
  64. u, cmd := m.list.Update(msg)
  65. m.list = u.(listModel)
  66. return m, cmd
  67. }
  68. func (m *ModelListComponent) View() string {
  69. return m.list.View()
  70. }
  71. func (m *ModelListComponent) Cursor() *tea.Cursor {
  72. return m.list.Cursor()
  73. }
  74. func (m *ModelListComponent) SetSize(width, height int) tea.Cmd {
  75. return m.list.SetSize(width, height)
  76. }
  77. func (m *ModelListComponent) SelectedModel() *ModelOption {
  78. s := m.list.SelectedItem()
  79. if s == nil {
  80. return nil
  81. }
  82. sv := *s
  83. model := sv.Value()
  84. return &model
  85. }
  86. func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
  87. t := styles.CurrentTheme()
  88. m.modelType = modelType
  89. var groups []list.Group[list.CompletionItem[ModelOption]]
  90. // first none section
  91. selectedItemID := ""
  92. cfg := config.Get()
  93. var currentModel config.SelectedModel
  94. if m.modelType == LargeModelType {
  95. currentModel = cfg.Models[config.SelectedModelTypeLarge]
  96. } else {
  97. currentModel = cfg.Models[config.SelectedModelTypeSmall]
  98. }
  99. configuredIcon := t.S().Base.Foreground(t.Success).Render(styles.CheckIcon)
  100. configured := fmt.Sprintf("%s %s", configuredIcon, t.S().Subtle.Render("Configured"))
  101. // Create a map to track which providers we've already added
  102. addedProviders := make(map[string]bool)
  103. // First, add any configured providers that are not in the known providers list
  104. // These should appear at the top of the list
  105. knownProviders, err := config.Providers(cfg)
  106. if err != nil {
  107. return util.ReportError(err)
  108. }
  109. for providerID, providerConfig := range cfg.Providers.Seq2() {
  110. if providerConfig.Disable {
  111. continue
  112. }
  113. // Check if this provider is not in the known providers list
  114. if !slices.ContainsFunc(knownProviders, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) ||
  115. !slices.ContainsFunc(m.providers, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) {
  116. // Convert config provider to provider.Provider format
  117. configProvider := catwalk.Provider{
  118. Name: providerConfig.Name,
  119. ID: catwalk.InferenceProvider(providerID),
  120. Models: make([]catwalk.Model, len(providerConfig.Models)),
  121. }
  122. // Convert models
  123. for i, model := range providerConfig.Models {
  124. configProvider.Models[i] = catwalk.Model{
  125. ID: model.ID,
  126. Name: model.Name,
  127. CostPer1MIn: model.CostPer1MIn,
  128. CostPer1MOut: model.CostPer1MOut,
  129. CostPer1MInCached: model.CostPer1MInCached,
  130. CostPer1MOutCached: model.CostPer1MOutCached,
  131. ContextWindow: model.ContextWindow,
  132. DefaultMaxTokens: model.DefaultMaxTokens,
  133. CanReason: model.CanReason,
  134. ReasoningLevels: model.ReasoningLevels,
  135. DefaultReasoningEffort: model.DefaultReasoningEffort,
  136. SupportsImages: model.SupportsImages,
  137. }
  138. }
  139. // Add this unknown provider to the list
  140. name := configProvider.Name
  141. if name == "" {
  142. name = string(configProvider.ID)
  143. }
  144. section := list.NewItemSection(name)
  145. section.SetInfo(configured)
  146. group := list.Group[list.CompletionItem[ModelOption]]{
  147. Section: section,
  148. }
  149. for _, model := range configProvider.Models {
  150. item := list.NewCompletionItem(model.Name, ModelOption{
  151. Provider: configProvider,
  152. Model: model,
  153. },
  154. list.WithCompletionID(
  155. fmt.Sprintf("%s:%s", providerConfig.ID, model.ID),
  156. ),
  157. )
  158. group.Items = append(group.Items, item)
  159. if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
  160. selectedItemID = item.ID()
  161. }
  162. }
  163. groups = append(groups, group)
  164. addedProviders[providerID] = true
  165. }
  166. }
  167. // Then add the known providers from the predefined list
  168. for _, provider := range m.providers {
  169. // Skip if we already added this provider as an unknown provider
  170. if addedProviders[string(provider.ID)] {
  171. continue
  172. }
  173. providerConfig, providerConfigured := cfg.Providers.Get(string(provider.ID))
  174. if providerConfigured && providerConfig.Disable {
  175. continue
  176. }
  177. displayProvider := provider
  178. if providerConfigured {
  179. displayProvider.Name = cmp.Or(providerConfig.Name, displayProvider.Name)
  180. modelIndex := make(map[string]int, len(displayProvider.Models))
  181. for i, model := range displayProvider.Models {
  182. modelIndex[model.ID] = i
  183. }
  184. for _, model := range providerConfig.Models {
  185. if model.ID == "" {
  186. continue
  187. }
  188. if idx, ok := modelIndex[model.ID]; ok {
  189. if model.Name != "" {
  190. displayProvider.Models[idx].Name = model.Name
  191. }
  192. continue
  193. }
  194. if model.Name == "" {
  195. model.Name = model.ID
  196. }
  197. displayProvider.Models = append(displayProvider.Models, model)
  198. modelIndex[model.ID] = len(displayProvider.Models) - 1
  199. }
  200. }
  201. name := displayProvider.Name
  202. if name == "" {
  203. name = string(displayProvider.ID)
  204. }
  205. section := list.NewItemSection(name)
  206. if providerConfigured {
  207. section.SetInfo(configured)
  208. }
  209. group := list.Group[list.CompletionItem[ModelOption]]{
  210. Section: section,
  211. }
  212. for _, model := range displayProvider.Models {
  213. item := list.NewCompletionItem(model.Name, ModelOption{
  214. Provider: displayProvider,
  215. Model: model,
  216. },
  217. list.WithCompletionID(
  218. fmt.Sprintf("%s:%s", displayProvider.ID, model.ID),
  219. ),
  220. )
  221. group.Items = append(group.Items, item)
  222. if model.ID == currentModel.Model && string(displayProvider.ID) == currentModel.Provider {
  223. selectedItemID = item.ID()
  224. }
  225. }
  226. groups = append(groups, group)
  227. }
  228. var cmds []tea.Cmd
  229. cmd := m.list.SetGroups(groups)
  230. if cmd != nil {
  231. cmds = append(cmds, cmd)
  232. }
  233. cmd = m.list.SetSelected(selectedItemID)
  234. if cmd != nil {
  235. cmds = append(cmds, cmd)
  236. }
  237. return tea.Sequence(cmds...)
  238. }
  239. // GetModelType returns the current model type
  240. func (m *ModelListComponent) GetModelType() int {
  241. return m.modelType
  242. }
  243. func (m *ModelListComponent) SetInputPlaceholder(placeholder string) {
  244. m.list.SetInputPlaceholder(placeholder)
  245. }