app_test.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. package app
  2. import (
  3. "testing"
  4. "github.com/sst/opencode-sdk-go"
  5. )
  6. // TestFindModelByFullID tests the findModelByFullID function
  7. func TestFindModelByFullID(t *testing.T) {
  8. // Create test providers with models
  9. providers := []opencode.Provider{
  10. {
  11. ID: "anthropic",
  12. Models: map[string]opencode.Model{
  13. "claude-3-opus-20240229": {ID: "claude-3-opus-20240229"},
  14. "claude-3-sonnet-20240229": {ID: "claude-3-sonnet-20240229"},
  15. },
  16. },
  17. {
  18. ID: "openai",
  19. Models: map[string]opencode.Model{
  20. "gpt-4": {ID: "gpt-4"},
  21. "gpt-3.5-turbo": {ID: "gpt-3.5-turbo"},
  22. },
  23. },
  24. }
  25. tests := []struct {
  26. name string
  27. fullModelID string
  28. expectedFound bool
  29. expectedProviderID string
  30. expectedModelID string
  31. }{
  32. {
  33. name: "valid full model ID",
  34. fullModelID: "anthropic/claude-3-opus-20240229",
  35. expectedFound: true,
  36. expectedProviderID: "anthropic",
  37. expectedModelID: "claude-3-opus-20240229",
  38. },
  39. {
  40. name: "valid full model ID with slash in model name",
  41. fullModelID: "openai/gpt-3.5-turbo",
  42. expectedFound: true,
  43. expectedProviderID: "openai",
  44. expectedModelID: "gpt-3.5-turbo",
  45. },
  46. {
  47. name: "invalid format - missing slash",
  48. fullModelID: "anthropic",
  49. expectedFound: false,
  50. },
  51. {
  52. name: "invalid format - empty string",
  53. fullModelID: "",
  54. expectedFound: false,
  55. },
  56. {
  57. name: "provider not found",
  58. fullModelID: "nonexistent/model",
  59. expectedFound: false,
  60. },
  61. {
  62. name: "model not found",
  63. fullModelID: "anthropic/nonexistent-model",
  64. expectedFound: false,
  65. },
  66. }
  67. for _, tt := range tests {
  68. t.Run(tt.name, func(t *testing.T) {
  69. provider, model := findModelByFullID(providers, tt.fullModelID)
  70. if tt.expectedFound {
  71. if provider == nil || model == nil {
  72. t.Errorf("Expected to find provider/model, but got nil")
  73. return
  74. }
  75. if provider.ID != tt.expectedProviderID {
  76. t.Errorf("Expected provider ID %s, got %s", tt.expectedProviderID, provider.ID)
  77. }
  78. if model.ID != tt.expectedModelID {
  79. t.Errorf("Expected model ID %s, got %s", tt.expectedModelID, model.ID)
  80. }
  81. } else {
  82. if provider != nil || model != nil {
  83. t.Errorf("Expected not to find provider/model, but got provider: %v, model: %v", provider, model)
  84. }
  85. }
  86. })
  87. }
  88. }
  89. // TestFindModelByProviderAndModelID tests the findModelByProviderAndModelID function
  90. func TestFindModelByProviderAndModelID(t *testing.T) {
  91. // Create test providers with models
  92. providers := []opencode.Provider{
  93. {
  94. ID: "anthropic",
  95. Models: map[string]opencode.Model{
  96. "claude-3-opus-20240229": {ID: "claude-3-opus-20240229"},
  97. "claude-3-sonnet-20240229": {ID: "claude-3-sonnet-20240229"},
  98. },
  99. },
  100. {
  101. ID: "openai",
  102. Models: map[string]opencode.Model{
  103. "gpt-4": {ID: "gpt-4"},
  104. "gpt-3.5-turbo": {ID: "gpt-3.5-turbo"},
  105. },
  106. },
  107. }
  108. tests := []struct {
  109. name string
  110. providerID string
  111. modelID string
  112. expectedFound bool
  113. expectedProviderID string
  114. expectedModelID string
  115. }{
  116. {
  117. name: "valid provider and model",
  118. providerID: "anthropic",
  119. modelID: "claude-3-opus-20240229",
  120. expectedFound: true,
  121. expectedProviderID: "anthropic",
  122. expectedModelID: "claude-3-opus-20240229",
  123. },
  124. {
  125. name: "provider not found",
  126. providerID: "nonexistent",
  127. modelID: "claude-3-opus-20240229",
  128. expectedFound: false,
  129. },
  130. {
  131. name: "model not found",
  132. providerID: "anthropic",
  133. modelID: "nonexistent-model",
  134. expectedFound: false,
  135. },
  136. {
  137. name: "both provider and model not found",
  138. providerID: "nonexistent",
  139. modelID: "nonexistent-model",
  140. expectedFound: false,
  141. },
  142. }
  143. for _, tt := range tests {
  144. t.Run(tt.name, func(t *testing.T) {
  145. provider, model := findModelByProviderAndModelID(providers, tt.providerID, tt.modelID)
  146. if tt.expectedFound {
  147. if provider == nil || model == nil {
  148. t.Errorf("Expected to find provider/model, but got nil")
  149. return
  150. }
  151. if provider.ID != tt.expectedProviderID {
  152. t.Errorf("Expected provider ID %s, got %s", tt.expectedProviderID, provider.ID)
  153. }
  154. if model.ID != tt.expectedModelID {
  155. t.Errorf("Expected model ID %s, got %s", tt.expectedModelID, model.ID)
  156. }
  157. } else {
  158. if provider != nil || model != nil {
  159. t.Errorf("Expected not to find provider/model, but got provider: %v, model: %v", provider, model)
  160. }
  161. }
  162. })
  163. }
  164. }
  165. // TestFindProviderByID tests the findProviderByID function
  166. func TestFindProviderByID(t *testing.T) {
  167. // Create test providers
  168. providers := []opencode.Provider{
  169. {ID: "anthropic"},
  170. {ID: "openai"},
  171. {ID: "google"},
  172. }
  173. tests := []struct {
  174. name string
  175. providerID string
  176. expectedFound bool
  177. expectedProviderID string
  178. }{
  179. {
  180. name: "provider found",
  181. providerID: "anthropic",
  182. expectedFound: true,
  183. expectedProviderID: "anthropic",
  184. },
  185. {
  186. name: "provider not found",
  187. providerID: "nonexistent",
  188. expectedFound: false,
  189. },
  190. }
  191. for _, tt := range tests {
  192. t.Run(tt.name, func(t *testing.T) {
  193. provider := findProviderByID(providers, tt.providerID)
  194. if tt.expectedFound {
  195. if provider == nil {
  196. t.Errorf("Expected to find provider, but got nil")
  197. return
  198. }
  199. if provider.ID != tt.expectedProviderID {
  200. t.Errorf("Expected provider ID %s, got %s", tt.expectedProviderID, provider.ID)
  201. }
  202. } else {
  203. if provider != nil {
  204. t.Errorf("Expected not to find provider, but got %v", provider)
  205. }
  206. }
  207. })
  208. }
  209. }