provider.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. package config
  2. import (
  3. "cmp"
  4. "encoding/json"
  5. "fmt"
  6. "log/slog"
  7. "os"
  8. "path/filepath"
  9. "runtime"
  10. "strings"
  11. "sync"
  12. "github.com/charmbracelet/catwalk/pkg/catwalk"
  13. "github.com/charmbracelet/catwalk/pkg/embedded"
  14. "github.com/charmbracelet/crush/internal/home"
  15. )
  16. type ProviderClient interface {
  17. GetProviders() ([]catwalk.Provider, error)
  18. }
  19. var (
  20. providerOnce sync.Once
  21. providerList []catwalk.Provider
  22. providerErr error
  23. )
  24. // file to cache provider data
  25. func providerCacheFileData() string {
  26. xdgDataHome := os.Getenv("XDG_DATA_HOME")
  27. if xdgDataHome != "" {
  28. return filepath.Join(xdgDataHome, appName, "providers.json")
  29. }
  30. // return the path to the main data directory
  31. // for windows, it should be in `%LOCALAPPDATA%/crush/`
  32. // for linux and macOS, it should be in `$HOME/.local/share/crush/`
  33. if runtime.GOOS == "windows" {
  34. localAppData := os.Getenv("LOCALAPPDATA")
  35. if localAppData == "" {
  36. localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
  37. }
  38. return filepath.Join(localAppData, appName, "providers.json")
  39. }
  40. return filepath.Join(home.Dir(), ".local", "share", appName, "providers.json")
  41. }
  42. func saveProvidersInCache(path string, providers []catwalk.Provider) error {
  43. slog.Info("Saving provider data to disk", "path", path)
  44. if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
  45. return fmt.Errorf("failed to create directory for provider cache: %w", err)
  46. }
  47. data, err := json.MarshalIndent(providers, "", " ")
  48. if err != nil {
  49. return fmt.Errorf("failed to marshal provider data: %w", err)
  50. }
  51. if err := os.WriteFile(path, data, 0o644); err != nil {
  52. return fmt.Errorf("failed to write provider data to cache: %w", err)
  53. }
  54. return nil
  55. }
  56. func loadProvidersFromCache(path string) ([]catwalk.Provider, error) {
  57. data, err := os.ReadFile(path)
  58. if err != nil {
  59. return nil, fmt.Errorf("failed to read provider cache file: %w", err)
  60. }
  61. var providers []catwalk.Provider
  62. if err := json.Unmarshal(data, &providers); err != nil {
  63. return nil, fmt.Errorf("failed to unmarshal provider data from cache: %w", err)
  64. }
  65. return providers, nil
  66. }
  67. func UpdateProviders(pathOrUrl string) error {
  68. var providers []catwalk.Provider
  69. pathOrUrl = cmp.Or(pathOrUrl, os.Getenv("CATWALK_URL"), defaultCatwalkURL)
  70. switch {
  71. case pathOrUrl == "embedded":
  72. providers = embedded.GetAll()
  73. case strings.HasPrefix(pathOrUrl, "http://") || strings.HasPrefix(pathOrUrl, "https://"):
  74. var err error
  75. providers, err = catwalk.NewWithURL(pathOrUrl).GetProviders()
  76. if err != nil {
  77. return fmt.Errorf("failed to fetch providers from Catwalk: %w", err)
  78. }
  79. default:
  80. content, err := os.ReadFile(pathOrUrl)
  81. if err != nil {
  82. return fmt.Errorf("failed to read file: %w", err)
  83. }
  84. if err := json.Unmarshal(content, &providers); err != nil {
  85. return fmt.Errorf("failed to unmarshal provider data: %w", err)
  86. }
  87. if len(providers) == 0 {
  88. return fmt.Errorf("no providers found in the provided source")
  89. }
  90. }
  91. cachePath := providerCacheFileData()
  92. if err := saveProvidersInCache(cachePath, providers); err != nil {
  93. return fmt.Errorf("failed to save providers to cache: %w", err)
  94. }
  95. slog.Info("Providers updated successfully", "count", len(providers), "from", pathOrUrl, "to", cachePath)
  96. return nil
  97. }
  98. func Providers(cfg *Config) ([]catwalk.Provider, error) {
  99. providerOnce.Do(func() {
  100. catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)
  101. client := catwalk.NewWithURL(catwalkURL)
  102. path := providerCacheFileData()
  103. autoUpdateDisabled := cfg.Options.DisableProviderAutoUpdate
  104. providerList, providerErr = loadProviders(autoUpdateDisabled, client, path)
  105. })
  106. return providerList, providerErr
  107. }
  108. func loadProviders(autoUpdateDisabled bool, client ProviderClient, path string) ([]catwalk.Provider, error) {
  109. catwalkGetAndSave := func() ([]catwalk.Provider, error) {
  110. providers, err := client.GetProviders()
  111. if err != nil {
  112. return nil, fmt.Errorf("failed to fetch providers from catwalk: %w", err)
  113. }
  114. if len(providers) == 0 {
  115. return nil, fmt.Errorf("empty providers list from catwalk")
  116. }
  117. if err := saveProvidersInCache(path, providers); err != nil {
  118. return nil, err
  119. }
  120. return providers, nil
  121. }
  122. switch {
  123. case autoUpdateDisabled:
  124. slog.Warn("Providers auto-update is disabled")
  125. if _, err := os.Stat(path); err == nil {
  126. slog.Warn("Using locally cached providers")
  127. return loadProvidersFromCache(path)
  128. }
  129. slog.Warn("Saving embedded providers to cache")
  130. providers := embedded.GetAll()
  131. if err := saveProvidersInCache(path, providers); err != nil {
  132. return nil, err
  133. }
  134. return providers, nil
  135. default:
  136. slog.Info("Fetching providers from Catwalk.", "path", path)
  137. providers, err := catwalkGetAndSave()
  138. if err != nil {
  139. catwalkUrl := fmt.Sprintf("%s/v2/providers", cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL))
  140. return nil, fmt.Errorf("Crush was unable to fetch an updated list of providers from %s. Consider setting CRUSH_DISABLE_PROVIDER_AUTO_UPDATE=1 to use the embedded providers bundled at the time of this Crush release. You can also update providers manually. For more info see crush update-providers --help. %w", catwalkUrl, err) //nolint:staticcheck
  141. }
  142. return providers, nil
  143. }
  144. }