Просмотр исходного кода

feat: filter adaptor mode (#234)

* feat: ignore not support mode adaptor

* fix: test channel model need check adaptor mode support

* fix: filter case
zijiren 6 месяцев назад
Родитель
Сommit
bd8cc38991

+ 10 - 0
core/controller/channel-test.go

@@ -71,6 +71,16 @@ func testSingleModel(
 		}
 	}
 
+	if modelConfig.Type != mode.Unknown {
+		a, ok := adaptors.GetAdaptor(channel.Type)
+		if !ok {
+			return nil, errors.New("adaptor not found")
+		}
+		if !a.SupportMode(modelConfig.Type) {
+			return nil, fmt.Errorf("%s not supported by adaptor", modelConfig.Type)
+		}
+	}
+
 	if modelConfig.ExcludeFromTests {
 		return &model.ChannelTest{
 			TestAt:      time.Now(),

+ 84 - 7
core/controller/relay-controller.go

@@ -268,6 +268,7 @@ func GetChannelFromHeader(
 	mc *model.ModelCaches,
 	availableSet []string,
 	model string,
+	m mode.Mode,
 ) (*model.Channel, error) {
 	channelIDInt, err := strconv.ParseInt(header, 10, 64)
 	if err != nil {
@@ -279,6 +280,13 @@ func GetChannelFromHeader(
 		if len(enabledChannels) > 0 {
 			for _, channel := range enabledChannels {
 				if int64(channel.ID) == channelIDInt {
+					a, ok := adaptors.GetAdaptor(channel.Type)
+					if !ok {
+						return nil, fmt.Errorf("adaptor not found for channel %d", channel.ID)
+					}
+					if !a.SupportMode(m) {
+						return nil, fmt.Errorf("channel %d not supported by adaptor", channel.ID)
+					}
 					return channel, nil
 				}
 			}
@@ -288,6 +296,13 @@ func GetChannelFromHeader(
 		if len(disabledChannels) > 0 {
 			for _, channel := range disabledChannels {
 				if int64(channel.ID) == channelIDInt {
+					a, ok := adaptors.GetAdaptor(channel.Type)
+					if !ok {
+						return nil, fmt.Errorf("adaptor not found for channel %d", channel.ID)
+					}
+					if !a.SupportMode(m) {
+						return nil, fmt.Errorf("channel %d not supported by adaptor", channel.ID)
+					}
 					return channel, nil
 				}
 			}
@@ -316,6 +331,16 @@ func GetChannelFromRequest(
 			if len(enabledChannels) > 0 {
 				for _, channel := range enabledChannels {
 					if channel.ID == channelID {
+						a, ok := adaptors.GetAdaptor(channel.Type)
+						if !ok {
+							return nil, fmt.Errorf("adaptor not found for channel %d", channel.ID)
+						}
+						if !a.SupportMode(m) {
+							return nil, fmt.Errorf(
+								"channel %d not supported by adaptor",
+								channel.ID,
+							)
+						}
 						return channel, nil
 					}
 				}
@@ -332,6 +357,16 @@ func GetChannelFromRequest(
 			if len(enabledChannels) > 0 {
 				for _, channel := range enabledChannels {
 					if channel.ID == channelID {
+						a, ok := adaptors.GetAdaptor(channel.Type)
+						if !ok {
+							return nil, fmt.Errorf("adaptor not found for channel %d", channel.ID)
+						}
+						if !a.SupportMode(m) {
+							return nil, fmt.Errorf(
+								"channel %d not supported by adaptor",
+								channel.ID,
+							)
+						}
 						return channel, nil
 					}
 				}
@@ -454,12 +489,23 @@ func notifyChannelIssue(
 	)
 }
 
-func filterChannels(channels []*model.Channel, ignoreChannel ...int64) []*model.Channel {
+func filterChannels(
+	channels []*model.Channel,
+	mode mode.Mode,
+	ignoreChannel ...int64,
+) []*model.Channel {
 	filtered := make([]*model.Channel, 0)
 	for _, channel := range channels {
 		if channel.Status != model.ChannelStatusEnabled {
 			continue
 		}
+		a, ok := adaptors.GetAdaptor(channel.Type)
+		if !ok {
+			continue
+		}
+		if !a.SupportMode(mode) {
+			continue
+		}
 		if slices.Contains(ignoreChannel, int64(channel.ID)) {
 			continue
 		}
@@ -477,6 +523,7 @@ func GetRandomChannel(
 	mc *model.ModelCaches,
 	availableSet []string,
 	modelName string,
+	mode mode.Mode,
 	errorRates map[int64]float64,
 	ignoreChannel ...int64,
 ) (*model.Channel, []*model.Channel, error) {
@@ -484,12 +531,26 @@ func GetRandomChannel(
 	if len(availableSet) != 0 {
 		for _, set := range availableSet {
 			for _, channel := range mc.EnabledModel2ChannelsBySet[set][modelName] {
+				a, ok := adaptors.GetAdaptor(channel.Type)
+				if !ok {
+					continue
+				}
+				if !a.SupportMode(mode) {
+					continue
+				}
 				channelMap[channel.ID] = channel
 			}
 		}
 	} else {
 		for _, sets := range mc.EnabledModel2ChannelsBySet {
 			for _, channel := range sets[modelName] {
+				a, ok := adaptors.GetAdaptor(channel.Type)
+				if !ok {
+					continue
+				}
+				if !a.SupportMode(mode) {
+					continue
+				}
 				channelMap[channel.ID] = channel
 			}
 		}
@@ -498,7 +559,7 @@ func GetRandomChannel(
 	for _, channel := range channelMap {
 		migratedChannels = append(migratedChannels, channel)
 	}
-	channel, err := getRandomChannel(migratedChannels, errorRates, ignoreChannel...)
+	channel, err := getRandomChannel(migratedChannels, mode, errorRates, ignoreChannel...)
 	return channel, migratedChannels, err
 }
 
@@ -512,10 +573,9 @@ func getPriority(channel *model.Channel, errorRate float64) int32 {
 	return int32(float64(priority) / errorRate)
 }
 
-//
-
 func getRandomChannel(
 	channels []*model.Channel,
+	mode mode.Mode,
 	errorRates map[int64]float64,
 	ignoreChannel ...int64,
 ) (*model.Channel, error) {
@@ -523,7 +583,7 @@ func getRandomChannel(
 		return nil, ErrChannelsNotFound
 	}
 
-	channels = filterChannels(channels, ignoreChannel...)
+	channels = filterChannels(channels, mode, ignoreChannel...)
 	if len(channels) == 0 {
 		return nil, ErrChannelsExhausted
 	}
@@ -559,6 +619,7 @@ func getChannelWithFallback(
 	cache *model.ModelCaches,
 	availableSet []string,
 	modelName string,
+	mode mode.Mode,
 	errorRates map[int64]float64,
 	ignoreChannelIDs ...int64,
 ) (*model.Channel, []*model.Channel, error) {
@@ -566,6 +627,7 @@ func getChannelWithFallback(
 		cache,
 		availableSet,
 		modelName,
+		mode,
 		errorRates,
 		ignoreChannelIDs...)
 	if err == nil {
@@ -574,7 +636,13 @@ func getChannelWithFallback(
 	if !errors.Is(err, ErrChannelsExhausted) {
 		return nil, migratedChannels, err
 	}
-	channel, migratedChannels, err = GetRandomChannel(cache, availableSet, modelName, errorRates)
+	channel, migratedChannels, err = GetRandomChannel(
+		cache,
+		availableSet,
+		modelName,
+		mode,
+		errorRates,
+	)
 	return channel, migratedChannels, err
 }
 
@@ -776,6 +844,7 @@ func getInitialChannel(c *gin.Context, modelName string, m mode.Mode) (*initialC
 			middleware.GetModelCaches(c),
 			availableSet,
 			modelName,
+			m,
 		)
 		if err != nil {
 			return nil, err
@@ -815,6 +884,7 @@ func getInitialChannel(c *gin.Context, modelName string, m mode.Mode) (*initialC
 		mc,
 		availableSet,
 		modelName,
+		m,
 		errorRates,
 		ids...)
 	if err != nil {
@@ -844,7 +914,13 @@ func getWebSearchChannel(c *gin.Context, modelName string) (*model.Channel, erro
 		log.Errorf("get channel model error rates failed: %+v", err)
 	}
 
-	channel, _, err := getChannelWithFallback(mc, nil, modelName, errorRates, ids...)
+	channel, _, err := getChannelWithFallback(
+		mc,
+		nil,
+		modelName,
+		mode.ChatCompletions,
+		errorRates,
+		ids...)
 	if err != nil {
 		return nil, err
 	}
@@ -1006,6 +1082,7 @@ func getRetryChannel(state *retryState) (*model.Channel, error) {
 
 	newChannel, err := getRandomChannel(
 		state.migratedChannels,
+		state.meta.Mode,
 		state.errorRates,
 		state.ignoreChannelIDs...)
 	if err != nil {

+ 11 - 0
core/relay/adaptor/ali/adaptor.go

@@ -29,6 +29,17 @@ func (a *Adaptor) DefaultBaseURL() string {
 	return baseURL
 }
 
+func (a *Adaptor) SupportMode(m mode.Mode) bool {
+	return m == mode.ChatCompletions ||
+		m == mode.Completions ||
+		m == mode.Embeddings ||
+		m == mode.ImagesGenerations ||
+		m == mode.Rerank ||
+		m == mode.AudioSpeech ||
+		m == mode.AudioTranscription ||
+		m == mode.AudioTranslation
+}
+
 func (a *Adaptor) GetRequestURL(meta *meta.Meta, _ adaptor.Store) (adaptor.RequestURL, error) {
 	u := meta.Channel.BaseURL
 	if u == "" {

+ 4 - 0
core/relay/adaptor/anthropic/adaptor.go

@@ -25,6 +25,10 @@ func (a *Adaptor) DefaultBaseURL() string {
 	return baseURL
 }
 
+func (a *Adaptor) SupportMode(m mode.Mode) bool {
+	return m == mode.ChatCompletions || m == mode.Anthropic
+}
+
 func (a *Adaptor) GetRequestURL(meta *meta.Meta, _ adaptor.Store) (adaptor.RequestURL, error) {
 	return adaptor.RequestURL{
 		Method: http.MethodPost,

+ 5 - 0
core/relay/adaptor/aws/adaptor.go

@@ -10,6 +10,7 @@ import (
 	"github.com/labring/aiproxy/core/relay/adaptor"
 	"github.com/labring/aiproxy/core/relay/adaptor/aws/utils"
 	"github.com/labring/aiproxy/core/relay/meta"
+	"github.com/labring/aiproxy/core/relay/mode"
 	relaymodel "github.com/labring/aiproxy/core/relay/model"
 )
 
@@ -19,6 +20,10 @@ func (a *Adaptor) DefaultBaseURL() string {
 	return ""
 }
 
+func (a *Adaptor) SupportMode(m mode.Mode) bool {
+	return m == mode.ChatCompletions || m == mode.Completions
+}
+
 func (a *Adaptor) ConvertRequest(
 	meta *meta.Meta,
 	store adaptor.Store,

+ 7 - 0
core/relay/adaptor/baidu/adaptor.go

@@ -26,6 +26,13 @@ func (a *Adaptor) DefaultBaseURL() string {
 	return baseURL
 }
 
+func (a *Adaptor) SupportMode(m mode.Mode) bool {
+	return m == mode.ChatCompletions ||
+		m == mode.Embeddings ||
+		m == mode.Rerank ||
+		m == mode.ImagesGenerations
+}
+
 // Get model-specific endpoint using map
 var modelEndpointMap = map[string]string{
 	"ERNIE-4.0-8K":         "completions_pro",

+ 4 - 0
core/relay/adaptor/baiduv2/adaptor.go

@@ -26,6 +26,10 @@ func (a *Adaptor) DefaultBaseURL() string {
 	return baseURL
 }
 
+func (a *Adaptor) SupportMode(m mode.Mode) bool {
+	return m == mode.ChatCompletions || m == mode.Rerank
+}
+
 // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Fm2vrveyu
 var v2ModelMap = map[string]string{
 	"ERNIE-Character-8K":         "ernie-char-8k",

+ 8 - 10
core/relay/adaptor/cohere/adaptor.go

@@ -10,7 +10,6 @@ import (
 	"github.com/gin-gonic/gin"
 	"github.com/labring/aiproxy/core/model"
 	"github.com/labring/aiproxy/core/relay/adaptor"
-	"github.com/labring/aiproxy/core/relay/adaptor/openai"
 	"github.com/labring/aiproxy/core/relay/meta"
 	"github.com/labring/aiproxy/core/relay/mode"
 	"github.com/labring/aiproxy/core/relay/utils"
@@ -24,6 +23,10 @@ func (a *Adaptor) DefaultBaseURL() string {
 	return baseURL
 }
 
+func (a *Adaptor) SupportMode(m mode.Mode) bool {
+	return m == mode.ChatCompletions
+}
+
 func (a *Adaptor) GetRequestURL(meta *meta.Meta, _ adaptor.Store) (adaptor.RequestURL, error) {
 	return adaptor.RequestURL{
 		Method: http.MethodPost,
@@ -83,15 +86,10 @@ func (a *Adaptor) DoResponse(
 	c *gin.Context,
 	resp *http.Response,
 ) (usage model.Usage, err adaptor.Error) {
-	switch meta.Mode {
-	case mode.Rerank:
-		usage, err = openai.RerankHandler(meta, c, resp)
-	default:
-		if utils.IsStreamResponse(resp) {
-			usage, err = StreamHandler(meta, c, resp)
-		} else {
-			usage, err = Handler(meta, c, resp)
-		}
+	if utils.IsStreamResponse(resp) {
+		usage, err = StreamHandler(meta, c, resp)
+	} else {
+		usage, err = Handler(meta, c, resp)
 	}
 	return
 }

+ 4 - 0
core/relay/adaptor/coze/adaptor.go

@@ -24,6 +24,10 @@ func (a *Adaptor) DefaultBaseURL() string {
 	return baseURL
 }
 
+func (a *Adaptor) SupportMode(m mode.Mode) bool {
+	return m == mode.ChatCompletions
+}
+
 func (a *Adaptor) GetRequestURL(meta *meta.Meta, _ adaptor.Store) (adaptor.RequestURL, error) {
 	return adaptor.RequestURL{
 		Method: http.MethodPost,

+ 4 - 0
core/relay/adaptor/doc2x/adaptor.go

@@ -23,6 +23,10 @@ func (a *Adaptor) DefaultBaseURL() string {
 	return baseURL
 }
 
+func (a *Adaptor) SupportMode(m mode.Mode) bool {
+	return m == mode.ParsePdf
+}
+
 func (a *Adaptor) GetRequestURL(meta *meta.Meta, _ adaptor.Store) (adaptor.RequestURL, error) {
 	switch meta.Mode {
 	case mode.ParsePdf:

+ 4 - 0
core/relay/adaptor/doubaoaudio/main.go

@@ -33,6 +33,10 @@ func (a *Adaptor) DefaultBaseURL() string {
 	return baseURL
 }
 
+func (a *Adaptor) SupportMode(m mode.Mode) bool {
+	return m == mode.AudioSpeech
+}
+
 func (a *Adaptor) Metadata() adaptor.Metadata {
 	return adaptor.Metadata{
 		Features: []string{

+ 4 - 0
core/relay/adaptor/gemini/adaptor.go

@@ -21,6 +21,10 @@ func (a *Adaptor) DefaultBaseURL() string {
 	return baseURL
 }
 
+func (a *Adaptor) SupportMode(m mode.Mode) bool {
+	return m == mode.ChatCompletions || m == mode.Embeddings
+}
+
 var v1ModelMap = map[string]struct{}{}
 
 func getRequestURL(meta *meta.Meta, action string) adaptor.RequestURL {

+ 2 - 0
core/relay/adaptor/interface.go

@@ -10,6 +10,7 @@ import (
 	"github.com/gin-gonic/gin"
 	"github.com/labring/aiproxy/core/model"
 	"github.com/labring/aiproxy/core/relay/meta"
+	"github.com/labring/aiproxy/core/relay/mode"
 )
 
 type StoreCache struct {
@@ -70,6 +71,7 @@ type DoResponse interface {
 
 type Adaptor interface {
 	Metadata() Metadata
+	SupportMode(mode mode.Mode) bool
 	DefaultBaseURL() string
 	GetRequestURL
 	SetupRequestHeader

+ 4 - 0
core/relay/adaptor/ollama/adaptor.go

@@ -22,6 +22,10 @@ func (a *Adaptor) DefaultBaseURL() string {
 	return baseURL
 }
 
+func (a *Adaptor) SupportMode(m mode.Mode) bool {
+	return m == mode.Embeddings || m == mode.ChatCompletions || m == mode.Completions
+}
+
 func (a *Adaptor) GetRequestURL(meta *meta.Meta, _ adaptor.Store) (adaptor.RequestURL, error) {
 	// https://github.com/ollama/ollama/blob/main/docs/api.md
 	u := meta.Channel.BaseURL

+ 17 - 0
core/relay/adaptor/openai/adaptor.go

@@ -24,6 +24,23 @@ func (a *Adaptor) DefaultBaseURL() string {
 	return baseURL
 }
 
+func (a *Adaptor) SupportMode(m mode.Mode) bool {
+	return m == mode.ChatCompletions ||
+		m == mode.Completions ||
+		m == mode.Embeddings ||
+		m == mode.Moderations ||
+		m == mode.ImagesGenerations ||
+		m == mode.ImagesEdits ||
+		m == mode.AudioSpeech ||
+		m == mode.AudioTranscription ||
+		m == mode.AudioTranslation ||
+		m == mode.Rerank ||
+		m == mode.ParsePdf ||
+		m == mode.VideoGenerationsJobs ||
+		m == mode.VideoGenerationsGetJobs ||
+		m == mode.VideoGenerationsContent
+}
+
 func (a *Adaptor) GetRequestURL(meta *meta.Meta, _ adaptor.Store) (adaptor.RequestURL, error) {
 	u := meta.Channel.BaseURL
 

+ 4 - 0
core/relay/adaptor/text-embeddings-inference/adaptor.go

@@ -25,6 +25,10 @@ func (a *Adaptor) DefaultBaseURL() string {
 	return baseURL
 }
 
+func (a *Adaptor) SupportMode(m mode.Mode) bool {
+	return m == mode.Rerank || m == mode.Embeddings
+}
+
 func (a *Adaptor) Metadata() adaptor.Metadata {
 	return adaptor.Metadata{
 		Features: []string{

+ 5 - 0
core/relay/adaptor/vertexai/adaptor.go

@@ -11,6 +11,7 @@ import (
 	"github.com/labring/aiproxy/core/model"
 	"github.com/labring/aiproxy/core/relay/adaptor"
 	"github.com/labring/aiproxy/core/relay/meta"
+	"github.com/labring/aiproxy/core/relay/mode"
 	relaymodel "github.com/labring/aiproxy/core/relay/model"
 	"github.com/labring/aiproxy/core/relay/utils"
 )
@@ -21,6 +22,10 @@ func (a *Adaptor) DefaultBaseURL() string {
 	return ""
 }
 
+func (a *Adaptor) SupportMode(m mode.Mode) bool {
+	return m == mode.ChatCompletions || m == mode.Anthropic
+}
+
 type Config struct {
 	Region    string
 	ProjectID string