hyper.go 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. package config
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "log/slog"
  8. "net/http"
  9. "sync"
  10. "sync/atomic"
  11. "time"
  12. "charm.land/catwalk/pkg/catwalk"
  13. "github.com/charmbracelet/crush/internal/agent/hyper"
  14. xetag "github.com/charmbracelet/x/etag"
  15. )
  16. type hyperClient interface {
  17. Get(context.Context, string) (catwalk.Provider, error)
  18. }
  19. var _ syncer[catwalk.Provider] = (*hyperSync)(nil)
  20. type hyperSync struct {
  21. once sync.Once
  22. result catwalk.Provider
  23. cache cache[catwalk.Provider]
  24. client hyperClient
  25. autoupdate bool
  26. init atomic.Bool
  27. }
  28. func (s *hyperSync) Init(client hyperClient, path string, autoupdate bool) {
  29. s.client = client
  30. s.cache = newCache[catwalk.Provider](path)
  31. s.autoupdate = autoupdate
  32. s.init.Store(true)
  33. }
  34. func (s *hyperSync) Get(ctx context.Context) (catwalk.Provider, error) {
  35. if !s.init.Load() {
  36. panic("called Get before Init")
  37. }
  38. var throwErr error
  39. s.once.Do(func() {
  40. if !s.autoupdate {
  41. slog.Info("Using embedded Hyper provider")
  42. s.result = hyper.Embedded()
  43. return
  44. }
  45. cached, etag, cachedErr := s.cache.Get()
  46. if cached.ID == "" || cachedErr != nil {
  47. // if cached file is empty, default to embedded provider
  48. cached = hyper.Embedded()
  49. }
  50. slog.Info("Fetching Hyper provider")
  51. result, err := s.client.Get(ctx, etag)
  52. if errors.Is(err, context.DeadlineExceeded) {
  53. slog.Warn("Hyper provider not updated in time")
  54. s.result = cached
  55. return
  56. }
  57. if errors.Is(err, catwalk.ErrNotModified) {
  58. slog.Info("Hyper provider not modified")
  59. s.result = cached
  60. return
  61. }
  62. if len(result.Models) == 0 {
  63. slog.Warn("Hyper did not return any models")
  64. s.result = cached
  65. return
  66. }
  67. s.result = result
  68. throwErr = s.cache.Store(result)
  69. })
  70. return s.result, throwErr
  71. }
  72. var _ hyperClient = realHyperClient{}
  73. type realHyperClient struct {
  74. baseURL string
  75. }
  76. // Get implements hyperClient.
  77. func (r realHyperClient) Get(ctx context.Context, etag string) (catwalk.Provider, error) {
  78. var result catwalk.Provider
  79. req, err := http.NewRequestWithContext(
  80. ctx,
  81. http.MethodGet,
  82. r.baseURL+"/api/v1/provider",
  83. nil,
  84. )
  85. if err != nil {
  86. return result, fmt.Errorf("could not create request: %w", err)
  87. }
  88. xetag.Request(req, etag)
  89. client := &http.Client{Timeout: 30 * time.Second}
  90. resp, err := client.Do(req)
  91. if err != nil {
  92. return result, fmt.Errorf("failed to make request: %w", err)
  93. }
  94. defer resp.Body.Close() //nolint:errcheck
  95. if resp.StatusCode == http.StatusNotModified {
  96. return result, catwalk.ErrNotModified
  97. }
  98. if resp.StatusCode != http.StatusOK {
  99. return result, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
  100. }
  101. if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
  102. return result, fmt.Errorf("failed to decode response: %w", err)
  103. }
  104. return result, nil
  105. }