api_key_input.go 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. package dialog
  2. import (
  3. "fmt"
  4. "strings"
  5. "time"
  6. "charm.land/bubbles/v2/help"
  7. "charm.land/bubbles/v2/key"
  8. "charm.land/bubbles/v2/spinner"
  9. "charm.land/bubbles/v2/textinput"
  10. tea "charm.land/bubbletea/v2"
  11. "charm.land/catwalk/pkg/catwalk"
  12. "github.com/charmbracelet/crush/internal/config"
  13. "github.com/charmbracelet/crush/internal/ui/common"
  14. "github.com/charmbracelet/crush/internal/ui/styles"
  15. "github.com/charmbracelet/crush/internal/ui/util"
  16. uv "github.com/charmbracelet/ultraviolet"
  17. "github.com/charmbracelet/x/exp/charmtone"
  18. )
  19. type APIKeyInputState int
  20. const (
  21. APIKeyInputStateInitial APIKeyInputState = iota
  22. APIKeyInputStateVerifying
  23. APIKeyInputStateVerified
  24. APIKeyInputStateError
  25. )
  26. // APIKeyInputID is the identifier for the model selection dialog.
  27. const APIKeyInputID = "api_key_input"
  28. // APIKeyInput represents a model selection dialog.
  29. type APIKeyInput struct {
  30. com *common.Common
  31. isOnboarding bool
  32. provider catwalk.Provider
  33. model config.SelectedModel
  34. modelType config.SelectedModelType
  35. width int
  36. state APIKeyInputState
  37. keyMap struct {
  38. Submit key.Binding
  39. Close key.Binding
  40. }
  41. input textinput.Model
  42. spinner spinner.Model
  43. help help.Model
  44. }
  45. var _ Dialog = (*APIKeyInput)(nil)
  46. // NewAPIKeyInput creates a new Models dialog.
  47. func NewAPIKeyInput(
  48. com *common.Common,
  49. isOnboarding bool,
  50. provider catwalk.Provider,
  51. model config.SelectedModel,
  52. modelType config.SelectedModelType,
  53. ) (*APIKeyInput, tea.Cmd) {
  54. t := com.Styles
  55. m := APIKeyInput{}
  56. m.com = com
  57. m.isOnboarding = isOnboarding
  58. m.provider = provider
  59. m.model = model
  60. m.modelType = modelType
  61. m.width = 60
  62. innerWidth := m.width - t.Dialog.View.GetHorizontalFrameSize() - 2
  63. m.input = textinput.New()
  64. m.input.SetVirtualCursor(false)
  65. m.input.Placeholder = "Enter your API key..."
  66. m.input.SetStyles(com.Styles.TextInput)
  67. m.input.Focus()
  68. m.input.SetWidth(max(0, innerWidth-t.Dialog.InputPrompt.GetHorizontalFrameSize()-1)) // (1) cursor padding
  69. m.spinner = spinner.New(
  70. spinner.WithSpinner(spinner.Dot),
  71. spinner.WithStyle(t.Base.Foreground(t.Green)),
  72. )
  73. m.help = help.New()
  74. m.help.Styles = t.DialogHelpStyles()
  75. m.keyMap.Submit = key.NewBinding(
  76. key.WithKeys("enter", "ctrl+y"),
  77. key.WithHelp("enter", "submit"),
  78. )
  79. m.keyMap.Close = CloseKey
  80. return &m, nil
  81. }
  82. // ID implements Dialog.
  83. func (m *APIKeyInput) ID() string {
  84. return APIKeyInputID
  85. }
  86. // HandleMsg implements [Dialog].
  87. func (m *APIKeyInput) HandleMsg(msg tea.Msg) Action {
  88. switch msg := msg.(type) {
  89. case ActionChangeAPIKeyState:
  90. m.state = msg.State
  91. switch m.state {
  92. case APIKeyInputStateVerifying:
  93. cmd := tea.Batch(m.spinner.Tick, m.verifyAPIKey)
  94. return ActionCmd{cmd}
  95. }
  96. case spinner.TickMsg:
  97. switch m.state {
  98. case APIKeyInputStateVerifying:
  99. var cmd tea.Cmd
  100. m.spinner, cmd = m.spinner.Update(msg)
  101. if cmd != nil {
  102. return ActionCmd{cmd}
  103. }
  104. }
  105. case tea.KeyPressMsg:
  106. switch {
  107. case m.state == APIKeyInputStateVerifying:
  108. // do nothing
  109. case key.Matches(msg, m.keyMap.Close):
  110. switch m.state {
  111. case APIKeyInputStateVerified:
  112. return m.saveKeyAndContinue()
  113. default:
  114. return ActionClose{}
  115. }
  116. case key.Matches(msg, m.keyMap.Submit):
  117. switch m.state {
  118. case APIKeyInputStateInitial, APIKeyInputStateError:
  119. return ActionChangeAPIKeyState{State: APIKeyInputStateVerifying}
  120. case APIKeyInputStateVerified:
  121. return m.saveKeyAndContinue()
  122. }
  123. default:
  124. var cmd tea.Cmd
  125. m.input, cmd = m.input.Update(msg)
  126. if cmd != nil {
  127. return ActionCmd{cmd}
  128. }
  129. }
  130. case tea.PasteMsg:
  131. var cmd tea.Cmd
  132. m.input, cmd = m.input.Update(msg)
  133. if cmd != nil {
  134. return ActionCmd{cmd}
  135. }
  136. }
  137. return nil
  138. }
  139. // Draw implements [Dialog].
  140. func (m *APIKeyInput) Draw(scr uv.Screen, area uv.Rectangle) *tea.Cursor {
  141. t := m.com.Styles
  142. textStyle := t.Dialog.SecondaryText
  143. helpStyle := t.Dialog.HelpView
  144. dialogStyle := t.Dialog.View.Width(m.width)
  145. inputStyle := t.Dialog.InputPrompt
  146. helpStyle = helpStyle.Width(m.width - dialogStyle.GetHorizontalFrameSize())
  147. m.input.Prompt = m.spinner.View()
  148. content := strings.Join([]string{
  149. m.headerView(),
  150. inputStyle.Render(m.inputView()),
  151. textStyle.Render("This will be written in your global configuration:"),
  152. textStyle.Render(config.GlobalConfigData()),
  153. "",
  154. helpStyle.Render(m.help.View(m)),
  155. }, "\n")
  156. cur := m.Cursor()
  157. if m.isOnboarding {
  158. view := content
  159. DrawOnboardingCursor(scr, area, view, cur)
  160. // FIXME(@andreynering): Figure it out how to properly fix this
  161. if cur != nil {
  162. cur.Y -= 1
  163. cur.X -= 1
  164. }
  165. } else {
  166. view := dialogStyle.Render(content)
  167. DrawCenterCursor(scr, area, view, cur)
  168. }
  169. return cur
  170. }
  171. func (m *APIKeyInput) headerView() string {
  172. var (
  173. t = m.com.Styles
  174. titleStyle = t.Dialog.Title
  175. textStyle = t.Dialog.PrimaryText
  176. dialogStyle = t.Dialog.View.Width(m.width)
  177. )
  178. if m.isOnboarding {
  179. return textStyle.Render(m.dialogTitle())
  180. }
  181. headerOffset := titleStyle.GetHorizontalFrameSize() + dialogStyle.GetHorizontalFrameSize()
  182. return common.DialogTitle(t, titleStyle.Render(m.dialogTitle()), m.width-headerOffset, m.com.Styles.Primary, m.com.Styles.Secondary)
  183. }
  184. func (m *APIKeyInput) dialogTitle() string {
  185. var (
  186. t = m.com.Styles
  187. textStyle = t.Dialog.TitleText
  188. errorStyle = t.Dialog.TitleError
  189. accentStyle = t.Dialog.TitleAccent
  190. )
  191. switch m.state {
  192. case APIKeyInputStateInitial:
  193. return textStyle.Render("Enter your ") + accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + textStyle.Render(".")
  194. case APIKeyInputStateVerifying:
  195. return textStyle.Render("Verifying your ") + accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + textStyle.Render("...")
  196. case APIKeyInputStateVerified:
  197. return accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + textStyle.Render(" validated.")
  198. case APIKeyInputStateError:
  199. return errorStyle.Render("Invalid ") + accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + errorStyle.Render(". Try again?")
  200. }
  201. return ""
  202. }
  203. func (m *APIKeyInput) inputView() string {
  204. t := m.com.Styles
  205. switch m.state {
  206. case APIKeyInputStateInitial:
  207. m.input.Prompt = "> "
  208. m.input.SetStyles(t.TextInput)
  209. m.input.Focus()
  210. case APIKeyInputStateVerifying:
  211. ts := t.TextInput
  212. ts.Blurred.Prompt = ts.Focused.Prompt
  213. m.input.Prompt = m.spinner.View()
  214. m.input.SetStyles(ts)
  215. m.input.Blur()
  216. case APIKeyInputStateVerified:
  217. ts := t.TextInput
  218. ts.Blurred.Prompt = ts.Focused.Prompt
  219. m.input.Prompt = styles.CheckIcon + " "
  220. m.input.SetStyles(ts)
  221. m.input.Blur()
  222. case APIKeyInputStateError:
  223. ts := t.TextInput
  224. ts.Focused.Prompt = ts.Focused.Prompt.Foreground(charmtone.Cherry)
  225. m.input.Prompt = styles.LSPErrorIcon + " "
  226. m.input.SetStyles(ts)
  227. m.input.Focus()
  228. }
  229. return m.input.View()
  230. }
  231. // Cursor returns the cursor position relative to the dialog.
  232. func (m *APIKeyInput) Cursor() *tea.Cursor {
  233. return InputCursor(m.com.Styles, m.input.Cursor())
  234. }
  235. // FullHelp returns the full help view.
  236. func (m *APIKeyInput) FullHelp() [][]key.Binding {
  237. return [][]key.Binding{
  238. {
  239. m.keyMap.Submit,
  240. m.keyMap.Close,
  241. },
  242. }
  243. }
  244. // ShortHelp returns the full help view.
  245. func (m *APIKeyInput) ShortHelp() []key.Binding {
  246. return []key.Binding{
  247. m.keyMap.Submit,
  248. m.keyMap.Close,
  249. }
  250. }
  251. func (m *APIKeyInput) verifyAPIKey() tea.Msg {
  252. start := time.Now()
  253. providerConfig := config.ProviderConfig{
  254. ID: string(m.provider.ID),
  255. Name: m.provider.Name,
  256. APIKey: m.input.Value(),
  257. Type: m.provider.Type,
  258. BaseURL: m.provider.APIEndpoint,
  259. }
  260. err := providerConfig.TestConnection(m.com.Store().Resolver())
  261. // intentionally wait for at least 750ms to make sure the user sees the spinner
  262. elapsed := time.Since(start)
  263. minimum := 750 * time.Millisecond
  264. if elapsed < minimum {
  265. time.Sleep(minimum - elapsed)
  266. }
  267. if err == nil {
  268. return ActionChangeAPIKeyState{APIKeyInputStateVerified}
  269. }
  270. return ActionChangeAPIKeyState{APIKeyInputStateError}
  271. }
  272. func (m *APIKeyInput) saveKeyAndContinue() Action {
  273. store := m.com.Store()
  274. err := store.SetProviderAPIKey(config.ScopeGlobal, string(m.provider.ID), m.input.Value())
  275. if err != nil {
  276. return ActionCmd{util.ReportError(fmt.Errorf("failed to save API key: %w", err))}
  277. }
  278. return ActionSelectModel{
  279. Provider: m.provider,
  280. Model: m.model,
  281. ModelType: m.modelType,
  282. }
  283. }