provider_test.go 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. package app
  2. import (
  3. "testing"
  4. "charm.land/catwalk/pkg/catwalk"
  5. "github.com/charmbracelet/crush/internal/config"
  6. "github.com/stretchr/testify/require"
  7. )
  8. func TestParseModelStr(t *testing.T) {
  9. tests := []struct {
  10. name string
  11. modelStr string
  12. expectedFilter string
  13. expectedModelID string
  14. setupProviders func() map[string]config.ProviderConfig
  15. }{
  16. {
  17. name: "simple model with no slashes",
  18. modelStr: "gpt-4o",
  19. expectedFilter: "",
  20. expectedModelID: "gpt-4o",
  21. setupProviders: setupMockProviders,
  22. },
  23. {
  24. name: "valid provider and model",
  25. modelStr: "openai/gpt-4o",
  26. expectedFilter: "openai",
  27. expectedModelID: "gpt-4o",
  28. setupProviders: setupMockProviders,
  29. },
  30. {
  31. name: "model with multiple slashes and first part is invalid provider",
  32. modelStr: "moonshot/kimi-k2",
  33. expectedFilter: "",
  34. expectedModelID: "moonshot/kimi-k2",
  35. setupProviders: setupMockProviders,
  36. },
  37. {
  38. name: "full path with valid provider and model with slashes",
  39. modelStr: "synthetic/moonshot/kimi-k2",
  40. expectedFilter: "synthetic",
  41. expectedModelID: "moonshot/kimi-k2",
  42. setupProviders: setupMockProvidersWithSlashes,
  43. },
  44. {
  45. name: "empty model string",
  46. modelStr: "",
  47. expectedFilter: "",
  48. expectedModelID: "",
  49. setupProviders: setupMockProviders,
  50. },
  51. {
  52. name: "model with trailing slash but valid provider",
  53. modelStr: "openai/",
  54. expectedFilter: "openai",
  55. expectedModelID: "",
  56. setupProviders: setupMockProviders,
  57. },
  58. }
  59. for _, tt := range tests {
  60. t.Run(tt.name, func(t *testing.T) {
  61. providers := tt.setupProviders()
  62. filter, modelID := parseModelStr(providers, tt.modelStr)
  63. require.Equal(t, tt.expectedFilter, filter, "provider filter mismatch")
  64. require.Equal(t, tt.expectedModelID, modelID, "model ID mismatch")
  65. })
  66. }
  67. }
  68. func setupMockProviders() map[string]config.ProviderConfig {
  69. return map[string]config.ProviderConfig{
  70. "openai": {
  71. ID: "openai",
  72. Name: "OpenAI",
  73. Models: []catwalk.Model{{ID: "gpt-4o"}, {ID: "gpt-4o-mini"}},
  74. },
  75. "anthropic": {
  76. ID: "anthropic",
  77. Name: "Anthropic",
  78. Models: []catwalk.Model{{ID: "claude-3-sonnet"}, {ID: "claude-3-opus"}},
  79. },
  80. }
  81. }
  82. func setupMockProvidersWithSlashes() map[string]config.ProviderConfig {
  83. return map[string]config.ProviderConfig{
  84. "synthetic": {
  85. ID: "synthetic",
  86. Name: "Synthetic",
  87. Models: []catwalk.Model{
  88. {ID: "moonshot/kimi-k2"},
  89. {ID: "deepseek/deepseek-chat"},
  90. },
  91. },
  92. "openai": {
  93. ID: "openai",
  94. Name: "OpenAI",
  95. Models: []catwalk.Model{{ID: "gpt-4o"}},
  96. },
  97. }
  98. }
  99. func TestFindModels(t *testing.T) {
  100. tests := []struct {
  101. name string
  102. modelStr string
  103. expectedProvider string
  104. expectedModelID string
  105. expectError bool
  106. errorContains string
  107. setupProviders func() map[string]config.ProviderConfig
  108. }{
  109. {
  110. name: "simple model found in one provider",
  111. modelStr: "gpt-4o",
  112. expectedProvider: "openai",
  113. expectedModelID: "gpt-4o",
  114. expectError: false,
  115. setupProviders: setupMockProviders,
  116. },
  117. {
  118. name: "model with slashes in ID",
  119. modelStr: "moonshot/kimi-k2",
  120. expectedProvider: "synthetic",
  121. expectedModelID: "moonshot/kimi-k2",
  122. expectError: false,
  123. setupProviders: setupMockProvidersWithSlashes,
  124. },
  125. {
  126. name: "provider and model with slashes in ID",
  127. modelStr: "synthetic/moonshot/kimi-k2",
  128. expectedProvider: "synthetic",
  129. expectedModelID: "moonshot/kimi-k2",
  130. expectError: false,
  131. setupProviders: setupMockProvidersWithSlashes,
  132. },
  133. {
  134. name: "model not found",
  135. modelStr: "nonexistent-model",
  136. expectError: true,
  137. errorContains: "not found",
  138. setupProviders: setupMockProviders,
  139. },
  140. {
  141. name: "invalid provider specified",
  142. modelStr: "nonexistent-provider/gpt-4o",
  143. expectError: true,
  144. errorContains: "provider",
  145. setupProviders: setupMockProviders,
  146. },
  147. {
  148. name: "model found in multiple providers without provider filter",
  149. modelStr: "shared-model",
  150. expectError: true,
  151. errorContains: "multiple providers",
  152. setupProviders: func() map[string]config.ProviderConfig {
  153. return map[string]config.ProviderConfig{
  154. "openai": {
  155. ID: "openai",
  156. Models: []catwalk.Model{{ID: "shared-model"}},
  157. },
  158. "anthropic": {
  159. ID: "anthropic",
  160. Models: []catwalk.Model{{ID: "shared-model"}},
  161. },
  162. }
  163. },
  164. },
  165. {
  166. name: "empty model string",
  167. modelStr: "",
  168. expectError: true,
  169. errorContains: "not found",
  170. setupProviders: setupMockProviders,
  171. },
  172. }
  173. for _, tt := range tests {
  174. t.Run(tt.name, func(t *testing.T) {
  175. providers := tt.setupProviders()
  176. // Use findModels with the model as "large" and empty "small".
  177. matches, _, err := findModels(providers, tt.modelStr, "")
  178. if err != nil {
  179. if tt.expectError {
  180. require.Contains(t, err.Error(), tt.errorContains)
  181. } else {
  182. require.NoError(t, err)
  183. }
  184. return
  185. }
  186. // Validate the matches.
  187. match, err := validateMatches(matches, tt.modelStr, "large")
  188. if tt.expectError {
  189. require.Error(t, err)
  190. require.Contains(t, err.Error(), tt.errorContains)
  191. } else {
  192. require.NoError(t, err)
  193. require.Equal(t, tt.expectedProvider, match.provider)
  194. require.Equal(t, tt.expectedModelID, match.modelID)
  195. }
  196. })
  197. }
  198. }