| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227 |
- package config
- import (
- "cmp"
- "context"
- "encoding/json"
- "errors"
- "fmt"
- "log/slog"
- "os"
- "path/filepath"
- "runtime"
- "slices"
- "strings"
- "sync"
- "time"
- "github.com/charmbracelet/catwalk/pkg/catwalk"
- "github.com/charmbracelet/catwalk/pkg/embedded"
- "github.com/charmbracelet/crush/internal/agent/hyper"
- "github.com/charmbracelet/crush/internal/csync"
- "github.com/charmbracelet/crush/internal/home"
- "github.com/charmbracelet/x/etag"
- )
- type syncer[T any] interface {
- Get(context.Context) (T, error)
- }
- var (
- providerOnce sync.Once
- providerList []catwalk.Provider
- providerErr error
- )
- // file to cache provider data
- func cachePathFor(name string) string {
- xdgDataHome := os.Getenv("XDG_DATA_HOME")
- if xdgDataHome != "" {
- return filepath.Join(xdgDataHome, appName, name+".json")
- }
- // return the path to the main data directory
- // for windows, it should be in `%LOCALAPPDATA%/crush/`
- // for linux and macOS, it should be in `$HOME/.local/share/crush/`
- if runtime.GOOS == "windows" {
- localAppData := os.Getenv("LOCALAPPDATA")
- if localAppData == "" {
- localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
- }
- return filepath.Join(localAppData, appName, name+".json")
- }
- return filepath.Join(home.Dir(), ".local", "share", appName, name+".json")
- }
- // UpdateProviders updates the Catwalk providers list from a specified source.
- func UpdateProviders(pathOrURL string) error {
- var providers []catwalk.Provider
- pathOrURL = cmp.Or(pathOrURL, os.Getenv("CATWALK_URL"), defaultCatwalkURL)
- switch {
- case pathOrURL == "embedded":
- providers = embedded.GetAll()
- case strings.HasPrefix(pathOrURL, "http://") || strings.HasPrefix(pathOrURL, "https://"):
- var err error
- providers, err = catwalk.NewWithURL(pathOrURL).GetProviders(context.Background(), "")
- if err != nil {
- return fmt.Errorf("failed to fetch providers from Catwalk: %w", err)
- }
- default:
- content, err := os.ReadFile(pathOrURL)
- if err != nil {
- return fmt.Errorf("failed to read file: %w", err)
- }
- if err := json.Unmarshal(content, &providers); err != nil {
- return fmt.Errorf("failed to unmarshal provider data: %w", err)
- }
- if len(providers) == 0 {
- return fmt.Errorf("no providers found in the provided source")
- }
- }
- if err := newCache[[]catwalk.Provider](cachePathFor("providers")).Store(providers); err != nil {
- return fmt.Errorf("failed to save providers to cache: %w", err)
- }
- slog.Info("Providers updated successfully", "count", len(providers), "from", pathOrURL, "to", cachePathFor)
- return nil
- }
- // UpdateHyper updates the Hyper provider information from a specified URL.
- func UpdateHyper(pathOrURL string) error {
- if !hyper.Enabled() {
- return fmt.Errorf("hyper not enabled")
- }
- var provider catwalk.Provider
- pathOrURL = cmp.Or(pathOrURL, hyper.BaseURL())
- switch {
- case pathOrURL == "embedded":
- provider = hyper.Embedded()
- case strings.HasPrefix(pathOrURL, "http://") || strings.HasPrefix(pathOrURL, "https://"):
- client := realHyperClient{baseURL: pathOrURL}
- var err error
- provider, err = client.Get(context.Background(), "")
- if err != nil {
- return fmt.Errorf("failed to fetch provider from Hyper: %w", err)
- }
- default:
- content, err := os.ReadFile(pathOrURL)
- if err != nil {
- return fmt.Errorf("failed to read file: %w", err)
- }
- if err := json.Unmarshal(content, &provider); err != nil {
- return fmt.Errorf("failed to unmarshal provider data: %w", err)
- }
- }
- if err := newCache[catwalk.Provider](cachePathFor("hyper")).Store(provider); err != nil {
- return fmt.Errorf("failed to save Hyper provider to cache: %w", err)
- }
- slog.Info("Hyper provider updated successfully", "from", pathOrURL, "to", cachePathFor("hyper"))
- return nil
- }
- var (
- catwalkSyncer = &catwalkSync{}
- hyperSyncer = &hyperSync{}
- )
- // Providers returns the list of providers, taking into account cached results
- // and whether or not auto update is enabled.
- //
- // It will:
- // 1. if auto update is disabled, it'll return the embedded providers at the
- // time of release.
- // 2. load the cached providers
- // 3. try to get the fresh list of providers, and return either this new list,
- // the cached list, or the embedded list if all others fail.
- func Providers(cfg *Config) ([]catwalk.Provider, error) {
- providerOnce.Do(func() {
- var wg sync.WaitGroup
- var errs []error
- providers := csync.NewSlice[catwalk.Provider]()
- autoupdate := !cfg.Options.DisableProviderAutoUpdate
- ctx, cancel := context.WithTimeout(context.Background(), 45*time.Second)
- defer cancel()
- wg.Go(func() {
- catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)
- client := catwalk.NewWithURL(catwalkURL)
- path := cachePathFor("providers")
- catwalkSyncer.Init(client, path, autoupdate)
- items, err := catwalkSyncer.Get(ctx)
- if err != nil {
- catwalkURL := fmt.Sprintf("%s/v2/providers", cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL))
- errs = append(errs, fmt.Errorf("Crush was unable to fetch an updated list of providers from %s. Consider setting CRUSH_DISABLE_PROVIDER_AUTO_UPDATE=1 to use the embedded providers bundled at the time of this Crush release. You can also update providers manually. For more info see crush update-providers --help.\n\nCause: %w", catwalkURL, providerErr)) //nolint:staticcheck
- return
- }
- providers.Append(items...)
- })
- wg.Go(func() {
- if !hyper.Enabled() {
- return
- }
- path := cachePathFor("hyper")
- hyperSyncer.Init(realHyperClient{baseURL: hyper.BaseURL()}, path, autoupdate)
- item, err := hyperSyncer.Get(ctx)
- if err != nil {
- errs = append(errs, fmt.Errorf("Crush was unable to fetch updated information from Hyper: %w", err)) //nolint:staticcheck
- return
- }
- providers.Append(item)
- })
- wg.Wait()
- providerList = slices.Collect(providers.Seq())
- providerErr = errors.Join(errs...)
- })
- return providerList, providerErr
- }
- type cache[T any] struct {
- path string
- }
- func newCache[T any](path string) cache[T] {
- return cache[T]{path: path}
- }
- func (c cache[T]) Get() (T, string, error) {
- var v T
- data, err := os.ReadFile(c.path)
- if err != nil {
- return v, "", fmt.Errorf("failed to read provider cache file: %w", err)
- }
- if err := json.Unmarshal(data, &v); err != nil {
- return v, "", fmt.Errorf("failed to unmarshal provider data from cache: %w", err)
- }
- return v, etag.Of(data), nil
- }
- func (c cache[T]) Store(v T) error {
- slog.Info("Saving provider data to disk", "path", c.path)
- if err := os.MkdirAll(filepath.Dir(c.path), 0o755); err != nil {
- return fmt.Errorf("failed to create directory for provider cache: %w", err)
- }
- data, err := json.Marshal(v)
- if err != nil {
- return fmt.Errorf("failed to marshal provider data: %w", err)
- }
- if err := os.WriteFile(c.path, data, 0o644); err != nil {
- return fmt.Errorf("failed to write provider data to cache: %w", err)
- }
- return nil
- }
|