2
0

models.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. package cmd
  2. import (
  3. "fmt"
  4. "os"
  5. "slices"
  6. "sort"
  7. "strings"
  8. "charm.land/lipgloss/v2/tree"
  9. "github.com/charmbracelet/catwalk/pkg/catwalk"
  10. "github.com/charmbracelet/crush/internal/config"
  11. "github.com/mattn/go-isatty"
  12. "github.com/spf13/cobra"
  13. )
  14. var modelsCmd = &cobra.Command{
  15. Use: "models",
  16. Short: "List all available models from configured providers",
  17. Long: `List all available models from configured providers. Shows provider name and model IDs.`,
  18. Example: `# List all available models
  19. crush models
  20. # Search models
  21. crush models gpt5`,
  22. Args: cobra.ArbitraryArgs,
  23. RunE: func(cmd *cobra.Command, args []string) error {
  24. cwd, err := ResolveCwd(cmd)
  25. if err != nil {
  26. return err
  27. }
  28. dataDir, _ := cmd.Flags().GetString("data-dir")
  29. debug, _ := cmd.Flags().GetBool("debug")
  30. cfg, err := config.Init(cwd, dataDir, debug)
  31. if err != nil {
  32. return err
  33. }
  34. if !cfg.IsConfigured() {
  35. return fmt.Errorf("no providers configured - please run 'crush' to set up a provider interactively")
  36. }
  37. term := strings.ToLower(strings.Join(args, " "))
  38. filter := func(p config.ProviderConfig, m catwalk.Model) bool {
  39. for _, s := range []string{p.ID, p.Name, m.ID, m.Name} {
  40. if term == "" || strings.Contains(strings.ToLower(s), term) {
  41. return true
  42. }
  43. }
  44. return false
  45. }
  46. var providerIDs []string
  47. providerModels := make(map[string][]string)
  48. for providerID, provider := range cfg.Providers.Seq2() {
  49. if provider.Disable {
  50. continue
  51. }
  52. var found bool
  53. for _, model := range provider.Models {
  54. if !filter(provider, model) {
  55. continue
  56. }
  57. providerModels[providerID] = append(providerModels[providerID], model.ID)
  58. found = true
  59. }
  60. if !found {
  61. continue
  62. }
  63. slices.Sort(providerModels[providerID])
  64. providerIDs = append(providerIDs, providerID)
  65. }
  66. sort.Strings(providerIDs)
  67. if len(providerIDs) == 0 && len(args) == 0 {
  68. return fmt.Errorf("no enabled providers found")
  69. }
  70. if len(providerIDs) == 0 {
  71. return fmt.Errorf("no enabled providers found matching %q", term)
  72. }
  73. if !isatty.IsTerminal(os.Stdout.Fd()) {
  74. for _, providerID := range providerIDs {
  75. for _, modelID := range providerModels[providerID] {
  76. fmt.Println(providerID + "/" + modelID)
  77. }
  78. }
  79. return nil
  80. }
  81. t := tree.New()
  82. for _, providerID := range providerIDs {
  83. providerNode := tree.Root(providerID)
  84. for _, modelID := range providerModels[providerID] {
  85. providerNode.Child(modelID)
  86. }
  87. t.Child(providerNode)
  88. }
  89. cmd.Println(t)
  90. return nil
  91. },
  92. }
  93. func init() {
  94. rootCmd.AddCommand(modelsCmd)
  95. }