Просмотр исходного кода

✨ feat: Enhance model listing and retrieval with support for Anthropic and Gemini models; refactor routes for better API key handling

nekohy 4 месяцев назад
Родитель
Сommit
fdb6a3ce16
4 измененных файлов с 120 добавлено и 18 удалено
  1. 49 7
      controller/model.go
  2. 24 0
      dto/pricing.go
  3. 9 7
      middleware/auth.go
  4. 38 4
      router/relay-router.go

+ 49 - 7
controller/model.go

@@ -16,6 +16,7 @@ import (
 	"one-api/relay/channel/moonshot"
 	relaycommon "one-api/relay/common"
 	"one-api/setting"
+	"time"
 )
 
 // https://platform.openai.com/docs/api-reference/models/list
@@ -102,7 +103,7 @@ func init() {
 	})
 }
 
-func ListModels(c *gin.Context) {
+func ListModels(c *gin.Context, modelType int) {
 	userOpenAiModels := make([]dto.OpenAIModels, 0)
 
 	modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
@@ -171,10 +172,41 @@ func ListModels(c *gin.Context) {
 			}
 		}
 	}
-	c.JSON(200, gin.H{
-		"success": true,
-		"data":    userOpenAiModels,
-	})
+	switch modelType {
+	case constant.ChannelTypeAnthropic:
+		useranthropicModels := make([]dto.AnthropicModel, len(userOpenAiModels))
+		for i, model := range userOpenAiModels {
+			useranthropicModels[i] = dto.AnthropicModel{
+				ID:          model.Id,
+				CreatedAt:   time.Unix(int64(model.Created), 0).UTC().Format(time.RFC3339),
+				DisplayName: model.Id,
+				Type:        "model",
+			}
+		}
+		c.JSON(200, gin.H{
+			"data":     useranthropicModels,
+			"first_id": useranthropicModels[0].ID,
+			"has_more": false,
+			"last_id":  useranthropicModels[len(useranthropicModels)-1].ID,
+		})
+	case constant.ChannelTypeGemini:
+		userGeminiModels := make([]dto.GeminiModel, len(userOpenAiModels))
+		for i, model := range userOpenAiModels {
+			userGeminiModels[i] = dto.GeminiModel{
+				Name:        model.Id,
+				DisplayName: model.Id,
+			}
+		}
+		c.JSON(200, gin.H{
+			"models":        userGeminiModels,
+			"nextPageToken": nil,
+		})
+	default:
+		c.JSON(200, gin.H{
+			"success": true,
+			"data":    userOpenAiModels,
+		})
+	}
 }
 
 func ChannelListModels(c *gin.Context) {
@@ -198,10 +230,20 @@ func EnabledListModels(c *gin.Context) {
 	})
 }
 
-func RetrieveModel(c *gin.Context) {
+func RetrieveModel(c *gin.Context, modelType int) {
 	modelId := c.Param("model")
 	if aiModel, ok := openAIModelsMap[modelId]; ok {
-		c.JSON(200, aiModel)
+		switch modelType {
+		case constant.ChannelTypeAnthropic:
+			c.JSON(200, dto.AnthropicModel{
+				ID:          aiModel.Id,
+				CreatedAt:   time.Unix(int64(aiModel.Created), 0).UTC().Format(time.RFC3339),
+				DisplayName: aiModel.Id,
+				Type:        "model",
+			})
+		default:
+			c.JSON(200, aiModel)
+		}
 	} else {
 		openAIError := dto.OpenAIError{
 			Message: fmt.Sprintf("The model '%s' does not exist", modelId),

+ 24 - 0
dto/pricing.go

@@ -2,6 +2,7 @@ package dto
 
 import "one-api/constant"
 
+// 这里不好动就不动了,本来想独立出来的(
 type OpenAIModels struct {
 	Id                     string                  `json:"id"`
 	Object                 string                  `json:"object"`
@@ -9,3 +10,26 @@ type OpenAIModels struct {
 	OwnedBy                string                  `json:"owned_by"`
 	SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
 }
+
+type AnthropicModel struct {
+	ID          string `json:"id"`
+	CreatedAt   string `json:"created_at"`
+	DisplayName string `json:"display_name"`
+	Type        string `json:"type"`
+}
+
+type GeminiModel struct {
+	Name                       interface{}   `json:"name"`
+	BaseModelId                interface{}   `json:"baseModelId"`
+	Version                    interface{}   `json:"version"`
+	DisplayName                interface{}   `json:"displayName"`
+	Description                interface{}   `json:"description"`
+	InputTokenLimit            interface{}   `json:"inputTokenLimit"`
+	OutputTokenLimit           interface{}   `json:"outputTokenLimit"`
+	SupportedGenerationMethods []interface{} `json:"supportedGenerationMethods"`
+	Thinking                   interface{}   `json:"thinking"`
+	Temperature                interface{}   `json:"temperature"`
+	MaxTemperature             interface{}   `json:"maxTemperature"`
+	TopP                       interface{}   `json:"topP"`
+	TopK                       interface{}   `json:"topK"`
+}

+ 9 - 7
middleware/auth.go

@@ -192,16 +192,18 @@ func TokenAuth() func(c *gin.Context) {
 			}
 			c.Request.Header.Set("Authorization", "Bearer "+key)
 		}
+		anthropicKey := c.Request.Header.Get("x-api-key")
 		// 检查path包含/v1/messages
-		if strings.Contains(c.Request.URL.Path, "/v1/messages") {
-			// 从x-api-key中获取key
-			key := c.Request.Header.Get("x-api-key")
-			if key != "" {
-				c.Request.Header.Set("Authorization", "Bearer "+key)
-			}
+		// 或者是否 x-api-key 不为空且存在anthropic-version
+		// 谁知道有多少不符合规范没写anthropic-version的
+		// 所以就这样随它去吧(
+		if strings.Contains(c.Request.URL.Path, "/v1/messages") || (anthropicKey != "" && c.Request.Header.Get("anthropic-version") != "") {
+			c.Request.Header.Set("Authorization", "Bearer "+anthropicKey)
 		}
 		// gemini api 从query中获取key
-		if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") || strings.HasPrefix(c.Request.URL.Path, "/v1/models/") {
+		if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models") ||
+			strings.HasPrefix(c.Request.URL.Path, "/v1beta/openai/models") ||
+			strings.HasPrefix(c.Request.URL.Path, "/v1/models/") {
 			skKey := c.Query("key")
 			if skKey != "" {
 				c.Request.Header.Set("Authorization", "Bearer "+skKey)

+ 38 - 4
router/relay-router.go

@@ -1,11 +1,11 @@
 package router
 
 import (
+	"github.com/gin-gonic/gin"
+	"one-api/constant"
 	"one-api/controller"
 	"one-api/middleware"
 	"one-api/relay"
-
-	"github.com/gin-gonic/gin"
 )
 
 func SetRelayRouter(router *gin.Engine) {
@@ -16,9 +16,43 @@ func SetRelayRouter(router *gin.Engine) {
 	modelsRouter := router.Group("/v1/models")
 	modelsRouter.Use(middleware.TokenAuth())
 	{
-		modelsRouter.GET("", controller.ListModels)
-		modelsRouter.GET("/:model", controller.RetrieveModel)
+		modelsRouter.GET("", func(c *gin.Context) {
+			switch {
+			case c.GetHeader("x-api-key") != "" && c.GetHeader("anthropic-version") != "":
+				controller.ListModels(c, constant.ChannelTypeAnthropic)
+			case c.GetHeader("x-goog-api-key") != "" || c.Query("key") != "": // 单独的适配
+				controller.RetrieveModel(c, constant.ChannelTypeGemini)
+			default:
+				controller.ListModels(c, constant.ChannelTypeOpenAI)
+			}
+		})
+
+		modelsRouter.GET("/:model", func(c *gin.Context) {
+			switch {
+			case c.GetHeader("x-api-key") != "" && c.GetHeader("anthropic-version") != "":
+				controller.RetrieveModel(c, constant.ChannelTypeAnthropic)
+			default:
+				controller.RetrieveModel(c, constant.ChannelTypeOpenAI)
+			}
+		})
 	}
+
+	geminiRouter := router.Group("/v1beta/models")
+	geminiRouter.Use(middleware.TokenAuth())
+	{
+		geminiRouter.GET("", func(c *gin.Context) {
+			controller.ListModels(c, constant.ChannelTypeGemini)
+		})
+	}
+
+	geminiCompatibleRouter := router.Group("/v1beta/openai/models")
+	geminiCompatibleRouter.Use(middleware.TokenAuth())
+	{
+		geminiCompatibleRouter.GET("", func(c *gin.Context) {
+			controller.ListModels(c, constant.ChannelTypeOpenAI)
+		})
+	}
+
 	playgroundRouter := router.Group("/pg")
 	playgroundRouter.Use(middleware.UserAuth(), middleware.Distribute())
 	{