| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196 |
- package models
- import (
- "context"
- "errors"
- "github.com/cloudwego/eino-ext/components/model/claude"
- "github.com/cloudwego/eino-ext/components/model/openai"
- "github.com/cloudwego/eino/components/model"
- "github.com/spf13/viper"
- )
- type (
- ModelID string
- ModelProvider string
- )
- type Model struct {
- ID ModelID `json:"id"`
- Name string `json:"name"`
- Provider ModelProvider `json:"provider"`
- APIModel string `json:"api_model"` // Actual value used when calling the API
- }
- const (
- DefaultBigModel = GPT4oMini
- DefaultLittleModel = GPT4oMini
- )
- // Model IDs
- const (
- // OpenAI
- GPT4o ModelID = "gpt-4o"
- GPT4oMini ModelID = "gpt-4o-mini"
- GPT45 ModelID = "gpt-4.5"
- O1 ModelID = "o1"
- O1Mini ModelID = "o1-mini"
- // Anthropic
- Claude35Sonnet ModelID = "claude-3.5-sonnet"
- Claude3Haiku ModelID = "claude-3-haiku"
- Claude37Sonnet ModelID = "claude-3.7-sonnet"
- // Google
- Gemini20Pro ModelID = "gemini-2.0-pro"
- Gemini15Flash ModelID = "gemini-1.5-flash"
- Gemini20Flash ModelID = "gemini-2.0-flash"
- // xAI
- Grok3 ModelID = "grok-3"
- Grok2Mini ModelID = "grok-2-mini"
- // DeepSeek
- DeepSeekR1 ModelID = "deepseek-r1"
- DeepSeekCoder ModelID = "deepseek-coder"
- // Meta
- Llama3 ModelID = "llama-3"
- Llama270B ModelID = "llama-2-70b"
- )
- const (
- ProviderOpenAI ModelProvider = "openai"
- ProviderAnthropic ModelProvider = "anthropic"
- ProviderGoogle ModelProvider = "google"
- ProviderXAI ModelProvider = "xai"
- ProviderDeepSeek ModelProvider = "deepseek"
- ProviderMeta ModelProvider = "meta"
- )
- var SupportedModels = map[ModelID]Model{
- // OpenAI
- GPT4o: {
- ID: GPT4o,
- Name: "GPT-4o",
- Provider: ProviderOpenAI,
- APIModel: "gpt-4o",
- },
- GPT4oMini: {
- ID: GPT4oMini,
- Name: "GPT-4o Mini",
- Provider: ProviderOpenAI,
- APIModel: "gpt-4o-mini",
- },
- GPT45: {
- ID: GPT45,
- Name: "GPT-4.5",
- Provider: ProviderOpenAI,
- APIModel: "gpt-4.5",
- },
- O1: {
- ID: O1,
- Name: "o1",
- Provider: ProviderOpenAI,
- APIModel: "o1",
- },
- O1Mini: {
- ID: O1Mini,
- Name: "o1 Mini",
- Provider: ProviderOpenAI,
- APIModel: "o1-mini",
- },
- // Anthropic
- Claude35Sonnet: {
- ID: Claude35Sonnet,
- Name: "Claude 3.5 Sonnet",
- Provider: ProviderAnthropic,
- APIModel: "claude-3.5-sonnet",
- },
- Claude3Haiku: {
- ID: Claude3Haiku,
- Name: "Claude 3 Haiku",
- Provider: ProviderAnthropic,
- APIModel: "claude-3-haiku",
- },
- Claude37Sonnet: {
- ID: Claude37Sonnet,
- Name: "Claude 3.7 Sonnet",
- Provider: ProviderAnthropic,
- APIModel: "claude-3-7-sonnet-20250219",
- },
- // Google
- Gemini20Pro: {
- ID: Gemini20Pro,
- Name: "Gemini 2.0 Pro",
- Provider: ProviderGoogle,
- APIModel: "gemini-2.0-pro",
- },
- Gemini15Flash: {
- ID: Gemini15Flash,
- Name: "Gemini 1.5 Flash",
- Provider: ProviderGoogle,
- APIModel: "gemini-1.5-flash",
- },
- Gemini20Flash: {
- ID: Gemini20Flash,
- Name: "Gemini 2.0 Flash",
- Provider: ProviderGoogle,
- APIModel: "gemini-2.0-flash",
- },
- // xAI
- Grok3: {
- ID: Grok3,
- Name: "Grok 3",
- Provider: ProviderXAI,
- APIModel: "grok-3",
- },
- Grok2Mini: {
- ID: Grok2Mini,
- Name: "Grok 2 Mini",
- Provider: ProviderXAI,
- APIModel: "grok-2-mini",
- },
- // DeepSeek
- DeepSeekR1: {
- ID: DeepSeekR1,
- Name: "DeepSeek R1",
- Provider: ProviderDeepSeek,
- APIModel: "deepseek-r1",
- },
- DeepSeekCoder: {
- ID: DeepSeekCoder,
- Name: "DeepSeek Coder",
- Provider: ProviderDeepSeek,
- APIModel: "deepseek-coder",
- },
- // Meta
- Llama3: {
- ID: Llama3,
- Name: "LLaMA 3",
- Provider: ProviderMeta,
- APIModel: "llama-3",
- },
- Llama270B: {
- ID: Llama270B,
- Name: "LLaMA 2 70B",
- Provider: ProviderMeta,
- APIModel: "llama-2-70b",
- },
- }
- func GetModel(ctx context.Context, model ModelID) (model.ChatModel, error) {
- provider := SupportedModels[model].Provider
- maxTokens := viper.GetInt("providers.common.max_tokens")
- switch provider {
- case ProviderOpenAI:
- return openai.NewChatModel(ctx, &openai.ChatModelConfig{
- APIKey: viper.GetString("providers.openai.key"),
- Model: string(SupportedModels[model].APIModel),
- MaxTokens: &maxTokens,
- })
- case ProviderAnthropic:
- return claude.NewChatModel(ctx, &claude.Config{
- APIKey: viper.GetString("providers.anthropic.key"),
- Model: string(SupportedModels[model].APIModel),
- MaxTokens: maxTokens,
- })
- }
- return nil, errors.New("unsupported provider")
- }
|