models.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. package models
  2. import (
  3. "context"
  4. "errors"
  5. "github.com/cloudwego/eino-ext/components/model/claude"
  6. "github.com/cloudwego/eino-ext/components/model/openai"
  7. "github.com/cloudwego/eino/components/model"
  8. "github.com/spf13/viper"
  9. )
  10. type (
  11. ModelID string
  12. ModelProvider string
  13. )
  14. type Model struct {
  15. ID ModelID `json:"id"`
  16. Name string `json:"name"`
  17. Provider ModelProvider `json:"provider"`
  18. APIModel string `json:"api_model"` // Actual value used when calling the API
  19. }
  20. const (
  21. DefaultBigModel = GPT4oMini
  22. DefaultLittleModel = GPT4oMini
  23. )
  24. // Model IDs
  25. const (
  26. // OpenAI
  27. GPT4o ModelID = "gpt-4o"
  28. GPT4oMini ModelID = "gpt-4o-mini"
  29. GPT45 ModelID = "gpt-4.5"
  30. O1 ModelID = "o1"
  31. O1Mini ModelID = "o1-mini"
  32. // Anthropic
  33. Claude35Sonnet ModelID = "claude-3.5-sonnet"
  34. Claude3Haiku ModelID = "claude-3-haiku"
  35. Claude37Sonnet ModelID = "claude-3.7-sonnet"
  36. // Google
  37. Gemini20Pro ModelID = "gemini-2.0-pro"
  38. Gemini15Flash ModelID = "gemini-1.5-flash"
  39. Gemini20Flash ModelID = "gemini-2.0-flash"
  40. // xAI
  41. Grok3 ModelID = "grok-3"
  42. Grok2Mini ModelID = "grok-2-mini"
  43. // DeepSeek
  44. DeepSeekR1 ModelID = "deepseek-r1"
  45. DeepSeekCoder ModelID = "deepseek-coder"
  46. // Meta
  47. Llama3 ModelID = "llama-3"
  48. Llama270B ModelID = "llama-2-70b"
  49. )
  50. const (
  51. ProviderOpenAI ModelProvider = "openai"
  52. ProviderAnthropic ModelProvider = "anthropic"
  53. ProviderGoogle ModelProvider = "google"
  54. ProviderXAI ModelProvider = "xai"
  55. ProviderDeepSeek ModelProvider = "deepseek"
  56. ProviderMeta ModelProvider = "meta"
  57. )
  58. var SupportedModels = map[ModelID]Model{
  59. // OpenAI
  60. GPT4o: {
  61. ID: GPT4o,
  62. Name: "GPT-4o",
  63. Provider: ProviderOpenAI,
  64. APIModel: "gpt-4o",
  65. },
  66. GPT4oMini: {
  67. ID: GPT4oMini,
  68. Name: "GPT-4o Mini",
  69. Provider: ProviderOpenAI,
  70. APIModel: "gpt-4o-mini",
  71. },
  72. GPT45: {
  73. ID: GPT45,
  74. Name: "GPT-4.5",
  75. Provider: ProviderOpenAI,
  76. APIModel: "gpt-4.5",
  77. },
  78. O1: {
  79. ID: O1,
  80. Name: "o1",
  81. Provider: ProviderOpenAI,
  82. APIModel: "o1",
  83. },
  84. O1Mini: {
  85. ID: O1Mini,
  86. Name: "o1 Mini",
  87. Provider: ProviderOpenAI,
  88. APIModel: "o1-mini",
  89. },
  90. // Anthropic
  91. Claude35Sonnet: {
  92. ID: Claude35Sonnet,
  93. Name: "Claude 3.5 Sonnet",
  94. Provider: ProviderAnthropic,
  95. APIModel: "claude-3.5-sonnet",
  96. },
  97. Claude3Haiku: {
  98. ID: Claude3Haiku,
  99. Name: "Claude 3 Haiku",
  100. Provider: ProviderAnthropic,
  101. APIModel: "claude-3-haiku",
  102. },
  103. Claude37Sonnet: {
  104. ID: Claude37Sonnet,
  105. Name: "Claude 3.7 Sonnet",
  106. Provider: ProviderAnthropic,
  107. APIModel: "claude-3-7-sonnet-20250219",
  108. },
  109. // Google
  110. Gemini20Pro: {
  111. ID: Gemini20Pro,
  112. Name: "Gemini 2.0 Pro",
  113. Provider: ProviderGoogle,
  114. APIModel: "gemini-2.0-pro",
  115. },
  116. Gemini15Flash: {
  117. ID: Gemini15Flash,
  118. Name: "Gemini 1.5 Flash",
  119. Provider: ProviderGoogle,
  120. APIModel: "gemini-1.5-flash",
  121. },
  122. Gemini20Flash: {
  123. ID: Gemini20Flash,
  124. Name: "Gemini 2.0 Flash",
  125. Provider: ProviderGoogle,
  126. APIModel: "gemini-2.0-flash",
  127. },
  128. // xAI
  129. Grok3: {
  130. ID: Grok3,
  131. Name: "Grok 3",
  132. Provider: ProviderXAI,
  133. APIModel: "grok-3",
  134. },
  135. Grok2Mini: {
  136. ID: Grok2Mini,
  137. Name: "Grok 2 Mini",
  138. Provider: ProviderXAI,
  139. APIModel: "grok-2-mini",
  140. },
  141. // DeepSeek
  142. DeepSeekR1: {
  143. ID: DeepSeekR1,
  144. Name: "DeepSeek R1",
  145. Provider: ProviderDeepSeek,
  146. APIModel: "deepseek-r1",
  147. },
  148. DeepSeekCoder: {
  149. ID: DeepSeekCoder,
  150. Name: "DeepSeek Coder",
  151. Provider: ProviderDeepSeek,
  152. APIModel: "deepseek-coder",
  153. },
  154. // Meta
  155. Llama3: {
  156. ID: Llama3,
  157. Name: "LLaMA 3",
  158. Provider: ProviderMeta,
  159. APIModel: "llama-3",
  160. },
  161. Llama270B: {
  162. ID: Llama270B,
  163. Name: "LLaMA 2 70B",
  164. Provider: ProviderMeta,
  165. APIModel: "llama-2-70b",
  166. },
  167. }
  168. func GetModel(ctx context.Context, model ModelID) (model.ChatModel, error) {
  169. provider := SupportedModels[model].Provider
  170. maxTokens := viper.GetInt("providers.common.max_tokens")
  171. switch provider {
  172. case ProviderOpenAI:
  173. return openai.NewChatModel(ctx, &openai.ChatModelConfig{
  174. APIKey: viper.GetString("providers.openai.key"),
  175. Model: string(SupportedModels[model].APIModel),
  176. MaxTokens: &maxTokens,
  177. })
  178. case ProviderAnthropic:
  179. return claude.NewChatModel(ctx, &claude.Config{
  180. APIKey: viper.GetString("providers.anthropic.key"),
  181. Model: string(SupportedModels[model].APIModel),
  182. MaxTokens: maxTokens,
  183. })
  184. }
  185. return nil, errors.New("unsupported provider")
  186. }