hyper_test.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. package config
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "os"
  7. "testing"
  8. "github.com/charmbracelet/catwalk/pkg/catwalk"
  9. "github.com/stretchr/testify/require"
  10. )
  11. type mockHyperClient struct {
  12. provider catwalk.Provider
  13. err error
  14. callCount int
  15. }
  16. func (m *mockHyperClient) Get(ctx context.Context, etag string) (catwalk.Provider, error) {
  17. m.callCount++
  18. return m.provider, m.err
  19. }
  20. func TestHyperSync_Init(t *testing.T) {
  21. t.Parallel()
  22. syncer := &hyperSync{}
  23. client := &mockHyperClient{}
  24. path := "/tmp/hyper.json"
  25. syncer.Init(client, path, true)
  26. require.True(t, syncer.init.Load())
  27. require.Equal(t, client, syncer.client)
  28. require.Equal(t, path, syncer.cache.path)
  29. }
  30. func TestHyperSync_GetPanicIfNotInit(t *testing.T) {
  31. t.Parallel()
  32. syncer := &hyperSync{}
  33. require.Panics(t, func() {
  34. _, _ = syncer.Get(t.Context())
  35. })
  36. }
  37. func TestHyperSync_GetFreshProvider(t *testing.T) {
  38. t.Parallel()
  39. syncer := &hyperSync{}
  40. client := &mockHyperClient{
  41. provider: catwalk.Provider{
  42. Name: "Hyper",
  43. ID: "hyper",
  44. Models: []catwalk.Model{
  45. {ID: "model-1", Name: "Model 1"},
  46. },
  47. },
  48. }
  49. path := t.TempDir() + "/hyper.json"
  50. syncer.Init(client, path, true)
  51. provider, err := syncer.Get(t.Context())
  52. require.NoError(t, err)
  53. require.Equal(t, "Hyper", provider.Name)
  54. require.Equal(t, 1, client.callCount)
  55. // Verify cache was written.
  56. fileInfo, err := os.Stat(path)
  57. require.NoError(t, err)
  58. require.False(t, fileInfo.IsDir())
  59. }
  60. func TestHyperSync_GetNotModifiedUsesCached(t *testing.T) {
  61. t.Parallel()
  62. tmpDir := t.TempDir()
  63. path := tmpDir + "/hyper.json"
  64. // Create cache file.
  65. cachedProvider := catwalk.Provider{
  66. Name: "Cached Hyper",
  67. ID: "hyper",
  68. }
  69. data, err := json.Marshal(cachedProvider)
  70. require.NoError(t, err)
  71. require.NoError(t, os.WriteFile(path, data, 0o644))
  72. syncer := &hyperSync{}
  73. client := &mockHyperClient{
  74. err: catwalk.ErrNotModified,
  75. }
  76. syncer.Init(client, path, true)
  77. provider, err := syncer.Get(t.Context())
  78. require.NoError(t, err)
  79. require.Equal(t, "Cached Hyper", provider.Name)
  80. require.Equal(t, 1, client.callCount)
  81. }
  82. func TestHyperSync_GetClientError(t *testing.T) {
  83. t.Parallel()
  84. tmpDir := t.TempDir()
  85. path := tmpDir + "/hyper.json"
  86. syncer := &hyperSync{}
  87. client := &mockHyperClient{
  88. err: errors.New("network error"),
  89. }
  90. syncer.Init(client, path, true)
  91. provider, err := syncer.Get(t.Context())
  92. require.NoError(t, err) // Should fall back to embedded.
  93. require.Equal(t, "Charm Hyper", provider.Name)
  94. require.Equal(t, catwalk.InferenceProvider("hyper"), provider.ID)
  95. }
  96. func TestHyperSync_GetEmptyCache(t *testing.T) {
  97. t.Parallel()
  98. tmpDir := t.TempDir()
  99. path := tmpDir + "/hyper.json"
  100. syncer := &hyperSync{}
  101. client := &mockHyperClient{
  102. provider: catwalk.Provider{
  103. Name: "Fresh Hyper",
  104. ID: "hyper",
  105. Models: []catwalk.Model{
  106. {ID: "model-1", Name: "Model 1"},
  107. },
  108. },
  109. }
  110. syncer.Init(client, path, true)
  111. provider, err := syncer.Get(t.Context())
  112. require.NoError(t, err)
  113. require.Equal(t, "Fresh Hyper", provider.Name)
  114. }
  115. func TestHyperSync_GetCalledMultipleTimesUsesOnce(t *testing.T) {
  116. t.Parallel()
  117. syncer := &hyperSync{}
  118. client := &mockHyperClient{
  119. provider: catwalk.Provider{
  120. Name: "Hyper",
  121. ID: "hyper",
  122. Models: []catwalk.Model{
  123. {ID: "model-1", Name: "Model 1"},
  124. },
  125. },
  126. }
  127. path := t.TempDir() + "/hyper.json"
  128. syncer.Init(client, path, true)
  129. // Call Get multiple times.
  130. provider1, err1 := syncer.Get(t.Context())
  131. require.NoError(t, err1)
  132. require.Equal(t, "Hyper", provider1.Name)
  133. provider2, err2 := syncer.Get(t.Context())
  134. require.NoError(t, err2)
  135. require.Equal(t, "Hyper", provider2.Name)
  136. // Client should only be called once due to sync.Once.
  137. require.Equal(t, 1, client.callCount)
  138. }
  139. func TestHyperSync_GetCacheStoreError(t *testing.T) {
  140. t.Parallel()
  141. // Create a file where we want a directory, causing mkdir to fail.
  142. tmpDir := t.TempDir()
  143. blockingFile := tmpDir + "/blocking"
  144. require.NoError(t, os.WriteFile(blockingFile, []byte("block"), 0o644))
  145. // Try to create cache in a subdirectory under the blocking file.
  146. path := blockingFile + "/subdir/hyper.json"
  147. syncer := &hyperSync{}
  148. client := &mockHyperClient{
  149. provider: catwalk.Provider{
  150. Name: "Hyper",
  151. ID: "hyper",
  152. Models: []catwalk.Model{
  153. {ID: "model-1", Name: "Model 1"},
  154. },
  155. },
  156. }
  157. syncer.Init(client, path, true)
  158. provider, err := syncer.Get(t.Context())
  159. require.Error(t, err)
  160. require.Contains(t, err.Error(), "failed to create directory for provider cache")
  161. require.Equal(t, "Hyper", provider.Name) // Provider is still returned.
  162. }