|
|
@@ -20,6 +20,7 @@ import (
|
|
|
"one-api/relay/constant"
|
|
|
"one-api/service"
|
|
|
"strconv"
|
|
|
+ "strings"
|
|
|
"sync"
|
|
|
"time"
|
|
|
|
|
|
@@ -81,8 +82,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
|
|
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
|
|
|
}
|
|
|
|
|
|
- request := buildTestRequest()
|
|
|
- request.Model = testModel
|
|
|
+ request := buildTestRequest(testModel)
|
|
|
meta.UpstreamModelName = testModel
|
|
|
common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
|
|
|
|
|
|
@@ -141,17 +141,22 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
|
|
return nil, nil
|
|
|
}
|
|
|
|
|
|
-func buildTestRequest() *dto.GeneralOpenAIRequest {
|
|
|
+func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
|
|
|
testRequest := &dto.GeneralOpenAIRequest{
|
|
|
- Model: "", // this will be set later
|
|
|
- MaxTokens: 1,
|
|
|
- Stream: false,
|
|
|
+ Model: "", // this will be set later
|
|
|
+ Stream: false,
|
|
|
+ }
|
|
|
+ if strings.HasPrefix(model, "o1-") {
|
|
|
+ testRequest.MaxCompletionTokens = 1
|
|
|
+ } else {
|
|
|
+ testRequest.MaxTokens = 1
|
|
|
}
|
|
|
content, _ := json.Marshal("hi")
|
|
|
testMessage := dto.Message{
|
|
|
Role: "user",
|
|
|
Content: content,
|
|
|
}
|
|
|
+ testRequest.Model = model
|
|
|
testRequest.Messages = append(testRequest.Messages, testMessage)
|
|
|
return testRequest
|
|
|
}
|
|
|
@@ -226,26 +231,22 @@ func testAllChannels(notify bool) error {
|
|
|
tok := time.Now()
|
|
|
milliseconds := tok.Sub(tik).Milliseconds()
|
|
|
|
|
|
- ban := false
|
|
|
- if milliseconds > disableThreshold {
|
|
|
- err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
|
|
- ban = true
|
|
|
- }
|
|
|
+ shouldBanChannel := false
|
|
|
|
|
|
// request error disables the channel
|
|
|
if openaiWithStatusErr != nil {
|
|
|
oaiErr := openaiWithStatusErr.Error
|
|
|
err = errors.New(fmt.Sprintf("type %s, httpCode %d, code %v, message %s", oaiErr.Type, openaiWithStatusErr.StatusCode, oaiErr.Code, oaiErr.Message))
|
|
|
- ban = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr)
|
|
|
+ shouldBanChannel = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr)
|
|
|
}
|
|
|
|
|
|
- // parse *int to bool
|
|
|
- if !channel.GetAutoBan() {
|
|
|
- ban = false
|
|
|
+ if milliseconds > disableThreshold {
|
|
|
+ err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
|
|
+ shouldBanChannel = true
|
|
|
}
|
|
|
|
|
|
// disable channel
|
|
|
- if ban && isChannelEnabled {
|
|
|
+ if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
|
|
|
service.DisableChannel(channel.Id, channel.Name, err.Error())
|
|
|
}
|
|
|
|