models.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. package models
  2. import (
  3. "fmt"
  4. "slices"
  5. "time"
  6. "github.com/charmbracelet/bubbles/v2/help"
  7. "github.com/charmbracelet/bubbles/v2/key"
  8. "github.com/charmbracelet/bubbles/v2/spinner"
  9. tea "github.com/charmbracelet/bubbletea/v2"
  10. "github.com/charmbracelet/catwalk/pkg/catwalk"
  11. "github.com/charmbracelet/crush/internal/config"
  12. "github.com/charmbracelet/crush/internal/tui/components/core"
  13. "github.com/charmbracelet/crush/internal/tui/components/dialogs"
  14. "github.com/charmbracelet/crush/internal/tui/exp/list"
  15. "github.com/charmbracelet/crush/internal/tui/styles"
  16. "github.com/charmbracelet/crush/internal/tui/util"
  17. "github.com/charmbracelet/lipgloss/v2"
  18. )
  19. const (
  20. ModelsDialogID dialogs.DialogID = "models"
  21. defaultWidth = 60
  22. )
  23. const (
  24. LargeModelType int = iota
  25. SmallModelType
  26. largeModelInputPlaceholder = "Choose a model for large, complex tasks"
  27. smallModelInputPlaceholder = "Choose a model for small, simple tasks"
  28. )
  29. // ModelSelectedMsg is sent when a model is selected
  30. type ModelSelectedMsg struct {
  31. Model config.SelectedModel
  32. ModelType config.SelectedModelType
  33. }
  34. // CloseModelDialogMsg is sent when a model is selected
  35. type CloseModelDialogMsg struct{}
  36. // ModelDialog interface for the model selection dialog
  37. type ModelDialog interface {
  38. dialogs.DialogModel
  39. }
  40. type ModelOption struct {
  41. Provider catwalk.Provider
  42. Model catwalk.Model
  43. }
  44. type modelDialogCmp struct {
  45. width int
  46. wWidth int
  47. wHeight int
  48. modelList *ModelListComponent
  49. keyMap KeyMap
  50. help help.Model
  51. // API key state
  52. needsAPIKey bool
  53. apiKeyInput *APIKeyInput
  54. selectedModel *ModelOption
  55. selectedModelType config.SelectedModelType
  56. isAPIKeyValid bool
  57. apiKeyValue string
  58. }
  59. func NewModelDialogCmp() ModelDialog {
  60. keyMap := DefaultKeyMap()
  61. listKeyMap := list.DefaultKeyMap()
  62. listKeyMap.Down.SetEnabled(false)
  63. listKeyMap.Up.SetEnabled(false)
  64. listKeyMap.DownOneItem = keyMap.Next
  65. listKeyMap.UpOneItem = keyMap.Previous
  66. t := styles.CurrentTheme()
  67. modelList := NewModelListComponent(listKeyMap, "Choose a model for large, complex tasks", true)
  68. apiKeyInput := NewAPIKeyInput()
  69. apiKeyInput.SetShowTitle(false)
  70. help := help.New()
  71. help.Styles = t.S().Help
  72. return &modelDialogCmp{
  73. modelList: modelList,
  74. apiKeyInput: apiKeyInput,
  75. width: defaultWidth,
  76. keyMap: DefaultKeyMap(),
  77. help: help,
  78. }
  79. }
  80. func (m *modelDialogCmp) Init() tea.Cmd {
  81. providers, err := config.Providers()
  82. if err == nil {
  83. filteredProviders := []catwalk.Provider{}
  84. simpleProviders := []string{
  85. "anthropic",
  86. "openai",
  87. "gemini",
  88. "xai",
  89. "groq",
  90. "openrouter",
  91. }
  92. for _, p := range providers {
  93. if slices.Contains(simpleProviders, string(p.ID)) {
  94. filteredProviders = append(filteredProviders, p)
  95. }
  96. }
  97. m.modelList.SetProviders(filteredProviders)
  98. }
  99. return tea.Batch(m.modelList.Init(), m.apiKeyInput.Init())
  100. }
  101. func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
  102. switch msg := msg.(type) {
  103. case tea.WindowSizeMsg:
  104. m.wWidth = msg.Width
  105. m.wHeight = msg.Height
  106. m.apiKeyInput.SetWidth(m.width - 2)
  107. m.help.Width = m.width - 2
  108. return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
  109. case APIKeyStateChangeMsg:
  110. u, cmd := m.apiKeyInput.Update(msg)
  111. m.apiKeyInput = u.(*APIKeyInput)
  112. return m, cmd
  113. case tea.KeyPressMsg:
  114. switch {
  115. case key.Matches(msg, m.keyMap.Select):
  116. if m.isAPIKeyValid {
  117. return m, m.saveAPIKeyAndContinue(m.apiKeyValue)
  118. }
  119. if m.needsAPIKey {
  120. // Handle API key submission
  121. m.apiKeyValue = m.apiKeyInput.Value()
  122. provider, err := m.getProvider(m.selectedModel.Provider.ID)
  123. if err != nil || provider == nil {
  124. return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID))
  125. }
  126. providerConfig := config.ProviderConfig{
  127. ID: string(m.selectedModel.Provider.ID),
  128. Name: m.selectedModel.Provider.Name,
  129. APIKey: m.apiKeyValue,
  130. Type: provider.Type,
  131. BaseURL: provider.APIEndpoint,
  132. }
  133. return m, tea.Sequence(
  134. util.CmdHandler(APIKeyStateChangeMsg{
  135. State: APIKeyInputStateVerifying,
  136. }),
  137. func() tea.Msg {
  138. start := time.Now()
  139. err := providerConfig.TestConnection(config.Get().Resolver())
  140. // intentionally wait for at least 750ms to make sure the user sees the spinner
  141. elapsed := time.Since(start)
  142. if elapsed < 750*time.Millisecond {
  143. time.Sleep(750*time.Millisecond - elapsed)
  144. }
  145. if err == nil {
  146. m.isAPIKeyValid = true
  147. return APIKeyStateChangeMsg{
  148. State: APIKeyInputStateVerified,
  149. }
  150. }
  151. return APIKeyStateChangeMsg{
  152. State: APIKeyInputStateError,
  153. }
  154. },
  155. )
  156. }
  157. // Normal model selection
  158. selectedItem := m.modelList.SelectedModel()
  159. var modelType config.SelectedModelType
  160. if m.modelList.GetModelType() == LargeModelType {
  161. modelType = config.SelectedModelTypeLarge
  162. } else {
  163. modelType = config.SelectedModelTypeSmall
  164. }
  165. // Check if provider is configured
  166. if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
  167. return m, tea.Sequence(
  168. util.CmdHandler(dialogs.CloseDialogMsg{}),
  169. util.CmdHandler(ModelSelectedMsg{
  170. Model: config.SelectedModel{
  171. Model: selectedItem.Model.ID,
  172. Provider: string(selectedItem.Provider.ID),
  173. },
  174. ModelType: modelType,
  175. }),
  176. )
  177. } else {
  178. // Provider not configured, show API key input
  179. m.needsAPIKey = true
  180. m.selectedModel = selectedItem
  181. m.selectedModelType = modelType
  182. m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
  183. return m, nil
  184. }
  185. case key.Matches(msg, m.keyMap.Tab):
  186. if m.needsAPIKey {
  187. u, cmd := m.apiKeyInput.Update(msg)
  188. m.apiKeyInput = u.(*APIKeyInput)
  189. return m, cmd
  190. }
  191. if m.modelList.GetModelType() == LargeModelType {
  192. m.modelList.SetInputPlaceholder(smallModelInputPlaceholder)
  193. return m, m.modelList.SetModelType(SmallModelType)
  194. } else {
  195. m.modelList.SetInputPlaceholder(largeModelInputPlaceholder)
  196. return m, m.modelList.SetModelType(LargeModelType)
  197. }
  198. case key.Matches(msg, m.keyMap.Close):
  199. if m.needsAPIKey {
  200. if m.isAPIKeyValid {
  201. return m, nil
  202. }
  203. // Go back to model selection
  204. m.needsAPIKey = false
  205. m.selectedModel = nil
  206. m.isAPIKeyValid = false
  207. m.apiKeyValue = ""
  208. m.apiKeyInput.Reset()
  209. return m, nil
  210. }
  211. return m, util.CmdHandler(dialogs.CloseDialogMsg{})
  212. default:
  213. if m.needsAPIKey {
  214. u, cmd := m.apiKeyInput.Update(msg)
  215. m.apiKeyInput = u.(*APIKeyInput)
  216. return m, cmd
  217. } else {
  218. u, cmd := m.modelList.Update(msg)
  219. m.modelList = u
  220. return m, cmd
  221. }
  222. }
  223. case tea.PasteMsg:
  224. if m.needsAPIKey {
  225. u, cmd := m.apiKeyInput.Update(msg)
  226. m.apiKeyInput = u.(*APIKeyInput)
  227. return m, cmd
  228. } else {
  229. var cmd tea.Cmd
  230. m.modelList, cmd = m.modelList.Update(msg)
  231. return m, cmd
  232. }
  233. case spinner.TickMsg:
  234. u, cmd := m.apiKeyInput.Update(msg)
  235. m.apiKeyInput = u.(*APIKeyInput)
  236. return m, cmd
  237. }
  238. return m, nil
  239. }
  240. func (m *modelDialogCmp) View() string {
  241. t := styles.CurrentTheme()
  242. if m.needsAPIKey {
  243. // Show API key input
  244. m.keyMap.isAPIKeyHelp = true
  245. m.keyMap.isAPIKeyValid = m.isAPIKeyValid
  246. apiKeyView := m.apiKeyInput.View()
  247. apiKeyView = t.S().Base.Width(m.width - 3).Height(lipgloss.Height(apiKeyView)).PaddingLeft(1).Render(apiKeyView)
  248. content := lipgloss.JoinVertical(
  249. lipgloss.Left,
  250. t.S().Base.Padding(0, 1, 1, 1).Render(core.Title(m.apiKeyInput.GetTitle(), m.width-4)),
  251. apiKeyView,
  252. "",
  253. t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
  254. )
  255. return m.style().Render(content)
  256. }
  257. // Show model selection
  258. listView := m.modelList.View()
  259. radio := m.modelTypeRadio()
  260. content := lipgloss.JoinVertical(
  261. lipgloss.Left,
  262. t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
  263. listView,
  264. "",
  265. t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
  266. )
  267. return m.style().Render(content)
  268. }
  269. func (m *modelDialogCmp) Cursor() *tea.Cursor {
  270. if m.needsAPIKey {
  271. cursor := m.apiKeyInput.Cursor()
  272. if cursor != nil {
  273. cursor = m.moveCursor(cursor)
  274. return cursor
  275. }
  276. } else {
  277. cursor := m.modelList.Cursor()
  278. if cursor != nil {
  279. cursor = m.moveCursor(cursor)
  280. return cursor
  281. }
  282. }
  283. return nil
  284. }
  285. func (m *modelDialogCmp) style() lipgloss.Style {
  286. t := styles.CurrentTheme()
  287. return t.S().Base.
  288. Width(m.width).
  289. Border(lipgloss.RoundedBorder()).
  290. BorderForeground(t.BorderFocus)
  291. }
  292. func (m *modelDialogCmp) listWidth() int {
  293. return m.width - 2
  294. }
  295. func (m *modelDialogCmp) listHeight() int {
  296. return m.wHeight / 2
  297. }
  298. func (m *modelDialogCmp) Position() (int, int) {
  299. row := m.wHeight/4 - 2 // just a bit above the center
  300. col := m.wWidth / 2
  301. col -= m.width / 2
  302. return row, col
  303. }
  304. func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
  305. row, col := m.Position()
  306. if m.needsAPIKey {
  307. offset := row + 3 // Border + title + API key input offset
  308. cursor.Y += offset
  309. cursor.X = cursor.X + col + 2
  310. } else {
  311. offset := row + 3 // Border + title
  312. cursor.Y += offset
  313. cursor.X = cursor.X + col + 2
  314. }
  315. return cursor
  316. }
  317. func (m *modelDialogCmp) ID() dialogs.DialogID {
  318. return ModelsDialogID
  319. }
  320. func (m *modelDialogCmp) modelTypeRadio() string {
  321. t := styles.CurrentTheme()
  322. choices := []string{"Large Task", "Small Task"}
  323. iconSelected := "◉"
  324. iconUnselected := "○"
  325. if m.modelList.GetModelType() == LargeModelType {
  326. return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + " " + iconUnselected + " " + choices[1])
  327. }
  328. return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + " " + iconSelected + " " + choices[1])
  329. }
  330. func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
  331. cfg := config.Get()
  332. if _, ok := cfg.Providers.Get(providerID); ok {
  333. return true
  334. }
  335. return false
  336. }
  337. func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
  338. providers, err := config.Providers()
  339. if err != nil {
  340. return nil, err
  341. }
  342. for _, p := range providers {
  343. if p.ID == providerID {
  344. return &p, nil
  345. }
  346. }
  347. return nil, nil
  348. }
  349. func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
  350. if m.selectedModel == nil {
  351. return util.ReportError(fmt.Errorf("no model selected"))
  352. }
  353. cfg := config.Get()
  354. err := cfg.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
  355. if err != nil {
  356. return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
  357. }
  358. // Reset API key state and continue with model selection
  359. selectedModel := *m.selectedModel
  360. return tea.Sequence(
  361. util.CmdHandler(dialogs.CloseDialogMsg{}),
  362. util.CmdHandler(ModelSelectedMsg{
  363. Model: config.SelectedModel{
  364. Model: selectedModel.Model.ID,
  365. Provider: string(selectedModel.Provider.ID),
  366. },
  367. ModelType: m.selectedModelType,
  368. }),
  369. )
  370. }