| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280 |
- package models
- import (
- "cmp"
- "fmt"
- "slices"
- "strings"
- tea "github.com/charmbracelet/bubbletea/v2"
- "github.com/charmbracelet/catwalk/pkg/catwalk"
- "github.com/charmbracelet/crush/internal/config"
- "github.com/charmbracelet/crush/internal/tui/exp/list"
- "github.com/charmbracelet/crush/internal/tui/styles"
- "github.com/charmbracelet/crush/internal/tui/util"
- )
- type listModel = list.FilterableGroupList[list.CompletionItem[ModelOption]]
- type ModelListComponent struct {
- list listModel
- modelType int
- providers []catwalk.Provider
- }
- func NewModelListComponent(keyMap list.KeyMap, inputPlaceholder string, shouldResize bool) *ModelListComponent {
- t := styles.CurrentTheme()
- inputStyle := t.S().Base.PaddingLeft(1).PaddingBottom(1)
- options := []list.ListOption{
- list.WithKeyMap(keyMap),
- list.WithWrapNavigation(),
- }
- if shouldResize {
- options = append(options, list.WithResizeByList())
- }
- modelList := list.NewFilterableGroupedList(
- []list.Group[list.CompletionItem[ModelOption]]{},
- list.WithFilterInputStyle(inputStyle),
- list.WithFilterPlaceholder(inputPlaceholder),
- list.WithFilterListOptions(
- options...,
- ),
- )
- return &ModelListComponent{
- list: modelList,
- modelType: LargeModelType,
- }
- }
- func (m *ModelListComponent) Init() tea.Cmd {
- var cmds []tea.Cmd
- if len(m.providers) == 0 {
- cfg := config.Get()
- providers, err := config.Providers(cfg)
- filteredProviders := []catwalk.Provider{}
- for _, p := range providers {
- hasAPIKeyEnv := strings.HasPrefix(p.APIKey, "$")
- if hasAPIKeyEnv && p.ID != catwalk.InferenceProviderAzure {
- filteredProviders = append(filteredProviders, p)
- }
- }
- m.providers = filteredProviders
- if err != nil {
- cmds = append(cmds, util.ReportError(err))
- }
- }
- cmds = append(cmds, m.list.Init(), m.SetModelType(m.modelType))
- return tea.Batch(cmds...)
- }
- func (m *ModelListComponent) Update(msg tea.Msg) (*ModelListComponent, tea.Cmd) {
- u, cmd := m.list.Update(msg)
- m.list = u.(listModel)
- return m, cmd
- }
- func (m *ModelListComponent) View() string {
- return m.list.View()
- }
- func (m *ModelListComponent) Cursor() *tea.Cursor {
- return m.list.Cursor()
- }
- func (m *ModelListComponent) SetSize(width, height int) tea.Cmd {
- return m.list.SetSize(width, height)
- }
- func (m *ModelListComponent) SelectedModel() *ModelOption {
- s := m.list.SelectedItem()
- if s == nil {
- return nil
- }
- sv := *s
- model := sv.Value()
- return &model
- }
- func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
- t := styles.CurrentTheme()
- m.modelType = modelType
- var groups []list.Group[list.CompletionItem[ModelOption]]
- // first none section
- selectedItemID := ""
- cfg := config.Get()
- var currentModel config.SelectedModel
- if m.modelType == LargeModelType {
- currentModel = cfg.Models[config.SelectedModelTypeLarge]
- } else {
- currentModel = cfg.Models[config.SelectedModelTypeSmall]
- }
- configuredIcon := t.S().Base.Foreground(t.Success).Render(styles.CheckIcon)
- configured := fmt.Sprintf("%s %s", configuredIcon, t.S().Subtle.Render("Configured"))
- // Create a map to track which providers we've already added
- addedProviders := make(map[string]bool)
- // First, add any configured providers that are not in the known providers list
- // These should appear at the top of the list
- knownProviders, err := config.Providers(cfg)
- if err != nil {
- return util.ReportError(err)
- }
- for providerID, providerConfig := range cfg.Providers.Seq2() {
- if providerConfig.Disable {
- continue
- }
- // Check if this provider is not in the known providers list
- if !slices.ContainsFunc(knownProviders, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) ||
- !slices.ContainsFunc(m.providers, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) {
- // Convert config provider to provider.Provider format
- configProvider := catwalk.Provider{
- Name: providerConfig.Name,
- ID: catwalk.InferenceProvider(providerID),
- Models: make([]catwalk.Model, len(providerConfig.Models)),
- }
- // Convert models
- for i, model := range providerConfig.Models {
- configProvider.Models[i] = catwalk.Model{
- ID: model.ID,
- Name: model.Name,
- CostPer1MIn: model.CostPer1MIn,
- CostPer1MOut: model.CostPer1MOut,
- CostPer1MInCached: model.CostPer1MInCached,
- CostPer1MOutCached: model.CostPer1MOutCached,
- ContextWindow: model.ContextWindow,
- DefaultMaxTokens: model.DefaultMaxTokens,
- CanReason: model.CanReason,
- ReasoningLevels: model.ReasoningLevels,
- DefaultReasoningEffort: model.DefaultReasoningEffort,
- SupportsImages: model.SupportsImages,
- }
- }
- // Add this unknown provider to the list
- name := configProvider.Name
- if name == "" {
- name = string(configProvider.ID)
- }
- section := list.NewItemSection(name)
- section.SetInfo(configured)
- group := list.Group[list.CompletionItem[ModelOption]]{
- Section: section,
- }
- for _, model := range configProvider.Models {
- item := list.NewCompletionItem(model.Name, ModelOption{
- Provider: configProvider,
- Model: model,
- },
- list.WithCompletionID(
- fmt.Sprintf("%s:%s", providerConfig.ID, model.ID),
- ),
- )
- group.Items = append(group.Items, item)
- if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
- selectedItemID = item.ID()
- }
- }
- groups = append(groups, group)
- addedProviders[providerID] = true
- }
- }
- // Then add the known providers from the predefined list
- for _, provider := range m.providers {
- // Skip if we already added this provider as an unknown provider
- if addedProviders[string(provider.ID)] {
- continue
- }
- providerConfig, providerConfigured := cfg.Providers.Get(string(provider.ID))
- if providerConfigured && providerConfig.Disable {
- continue
- }
- displayProvider := provider
- if providerConfigured {
- displayProvider.Name = cmp.Or(providerConfig.Name, displayProvider.Name)
- modelIndex := make(map[string]int, len(displayProvider.Models))
- for i, model := range displayProvider.Models {
- modelIndex[model.ID] = i
- }
- for _, model := range providerConfig.Models {
- if model.ID == "" {
- continue
- }
- if idx, ok := modelIndex[model.ID]; ok {
- if model.Name != "" {
- displayProvider.Models[idx].Name = model.Name
- }
- continue
- }
- if model.Name == "" {
- model.Name = model.ID
- }
- displayProvider.Models = append(displayProvider.Models, model)
- modelIndex[model.ID] = len(displayProvider.Models) - 1
- }
- }
- name := displayProvider.Name
- if name == "" {
- name = string(displayProvider.ID)
- }
- section := list.NewItemSection(name)
- if providerConfigured {
- section.SetInfo(configured)
- }
- group := list.Group[list.CompletionItem[ModelOption]]{
- Section: section,
- }
- for _, model := range displayProvider.Models {
- item := list.NewCompletionItem(model.Name, ModelOption{
- Provider: displayProvider,
- Model: model,
- },
- list.WithCompletionID(
- fmt.Sprintf("%s:%s", displayProvider.ID, model.ID),
- ),
- )
- group.Items = append(group.Items, item)
- if model.ID == currentModel.Model && string(displayProvider.ID) == currentModel.Provider {
- selectedItemID = item.ID()
- }
- }
- groups = append(groups, group)
- }
- var cmds []tea.Cmd
- cmd := m.list.SetGroups(groups)
- if cmd != nil {
- cmds = append(cmds, cmd)
- }
- cmd = m.list.SetSelected(selectedItemID)
- if cmd != nil {
- cmds = append(cmds, cmd)
- }
- return tea.Sequence(cmds...)
- }
- // GetModelType returns the current model type
- func (m *ModelListComponent) GetModelType() int {
- return m.modelType
- }
- func (m *ModelListComponent) SetInputPlaceholder(placeholder string) {
- m.list.SetInputPlaceholder(placeholder)
- }
|