Преглед изворни кода

feat: support o1 channel test

CalciumIon пре 1 година
родитељ
комит
cb73889353
1 измењених фајлова са 17 додато и 16 уклоњено
  1. 17 16
      controller/channel-test.go

+ 17 - 16
controller/channel-test.go

@@ -20,6 +20,7 @@ import (
 	"one-api/relay/constant"
 	"one-api/service"
 	"strconv"
+	"strings"
 	"sync"
 	"time"
 
@@ -81,8 +82,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
 		return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
 	}
 
-	request := buildTestRequest()
-	request.Model = testModel
+	request := buildTestRequest(testModel)
 	meta.UpstreamModelName = testModel
 	common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
 
@@ -141,17 +141,22 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
 	return nil, nil
 }
 
-func buildTestRequest() *dto.GeneralOpenAIRequest {
+func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
 	testRequest := &dto.GeneralOpenAIRequest{
-		Model:     "", // this will be set later
-		MaxTokens: 1,
-		Stream:    false,
+		Model:  "", // this will be set later
+		Stream: false,
+	}
+	if strings.HasPrefix(model, "o1-") {
+		testRequest.MaxCompletionTokens = 1
+	} else {
+		testRequest.MaxTokens = 1
 	}
 	content, _ := json.Marshal("hi")
 	testMessage := dto.Message{
 		Role:    "user",
 		Content: content,
 	}
+	testRequest.Model = model
 	testRequest.Messages = append(testRequest.Messages, testMessage)
 	return testRequest
 }
@@ -226,26 +231,22 @@ func testAllChannels(notify bool) error {
 			tok := time.Now()
 			milliseconds := tok.Sub(tik).Milliseconds()
 
-			ban := false
-			if milliseconds > disableThreshold {
-				err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
-				ban = true
-			}
+			shouldBanChannel := false
 
 			// request error disables the channel
 			if openaiWithStatusErr != nil {
 				oaiErr := openaiWithStatusErr.Error
 				err = errors.New(fmt.Sprintf("type %s, httpCode %d, code %v, message %s", oaiErr.Type, openaiWithStatusErr.StatusCode, oaiErr.Code, oaiErr.Message))
-				ban = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr)
+				shouldBanChannel = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr)
 			}
 
-			// parse *int to bool
-			if !channel.GetAutoBan() {
-				ban = false
+			if milliseconds > disableThreshold {
+				err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
+				shouldBanChannel = true
 			}
 
 			// disable channel
-			if ban && isChannelEnabled {
+			if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
 				service.DisableChannel(channel.Id, channel.Name, err.Error())
 			}