channel-test.go 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677
  1. package controller
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "math"
  9. "net/http"
  10. "net/http/httptest"
  11. "net/url"
  12. "strconv"
  13. "strings"
  14. "sync"
  15. "time"
  16. "github.com/QuantumNous/new-api/common"
  17. "github.com/QuantumNous/new-api/constant"
  18. "github.com/QuantumNous/new-api/dto"
  19. "github.com/QuantumNous/new-api/middleware"
  20. "github.com/QuantumNous/new-api/model"
  21. "github.com/QuantumNous/new-api/relay"
  22. relaycommon "github.com/QuantumNous/new-api/relay/common"
  23. relayconstant "github.com/QuantumNous/new-api/relay/constant"
  24. "github.com/QuantumNous/new-api/relay/helper"
  25. "github.com/QuantumNous/new-api/service"
  26. "github.com/QuantumNous/new-api/setting/operation_setting"
  27. "github.com/QuantumNous/new-api/types"
  28. "github.com/bytedance/gopkg/util/gopool"
  29. "github.com/samber/lo"
  30. "github.com/gin-gonic/gin"
  31. )
  32. type testResult struct {
  33. context *gin.Context
  34. localErr error
  35. newAPIError *types.NewAPIError
  36. }
  37. // testChannel executes a test request against the given channel using the provided testModel and optional endpointType,
  38. // and returns a testResult containing the test context and any encountered error information.
  39. // It selects or derives a model when testModel is empty, auto-detects the request endpoint (chat, responses, embeddings, images, rerank) when endpointType is not specified,
  40. // converts and relays the request to the upstream adapter, and parses the upstream response to collect usage and pricing information.
  41. // On upstream responses that indicate the chat/completions `messages` parameter is unsupported and endpointType was not specified, it will retry the test using the Responses API.
  42. // The function records consumption logs and returns a testResult with a populated context on success, or with localErr/newAPIError set on failure;
  43. // for channel types that are not supported for testing it returns a localErr explaining that the channel test is not supported.
  44. func testChannel(channel *model.Channel, testModel string, endpointType string) testResult {
  45. tik := time.Now()
  46. var unsupportedTestChannelTypes = []int{
  47. constant.ChannelTypeMidjourney,
  48. constant.ChannelTypeMidjourneyPlus,
  49. constant.ChannelTypeSunoAPI,
  50. constant.ChannelTypeKling,
  51. constant.ChannelTypeJimeng,
  52. constant.ChannelTypeDoubaoVideo,
  53. constant.ChannelTypeVidu,
  54. }
  55. if lo.Contains(unsupportedTestChannelTypes, channel.Type) {
  56. channelTypeName := constant.GetChannelTypeName(channel.Type)
  57. return testResult{
  58. localErr: fmt.Errorf("%s channel test is not supported", channelTypeName),
  59. }
  60. }
  61. w := httptest.NewRecorder()
  62. c, _ := gin.CreateTestContext(w)
  63. testModel = strings.TrimSpace(testModel)
  64. if testModel == "" {
  65. if channel.TestModel != nil && *channel.TestModel != "" {
  66. testModel = strings.TrimSpace(*channel.TestModel)
  67. } else {
  68. models := channel.GetModels()
  69. if len(models) > 0 {
  70. testModel = strings.TrimSpace(models[0])
  71. }
  72. if testModel == "" {
  73. testModel = "gpt-4o-mini"
  74. }
  75. }
  76. }
  77. originTestModel := testModel
  78. requestPath := "/v1/chat/completions"
  79. // 如果指定了端点类型,使用指定的端点类型
  80. if endpointType != "" {
  81. if endpointInfo, ok := common.GetDefaultEndpointInfo(constant.EndpointType(endpointType)); ok {
  82. requestPath = endpointInfo.Path
  83. }
  84. } else {
  85. // 如果没有指定端点类型,使用原有的自动检测逻辑
  86. if common.IsOpenAIResponseOnlyModel(testModel) {
  87. requestPath = "/v1/responses"
  88. }
  89. // 先判断是否为 Embedding 模型
  90. if strings.Contains(strings.ToLower(testModel), "embedding") ||
  91. strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
  92. strings.Contains(testModel, "bge-") || // bge 系列模型
  93. strings.Contains(testModel, "embed") ||
  94. channel.Type == constant.ChannelTypeMokaAI { // 其他 embedding 模型
  95. requestPath = "/v1/embeddings" // 修改请求路径
  96. }
  97. // VolcEngine 图像生成模型
  98. if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") {
  99. requestPath = "/v1/images/generations"
  100. }
  101. }
  102. c.Request = &http.Request{
  103. Method: "POST",
  104. URL: &url.URL{Path: requestPath}, // 使用动态路径
  105. Body: nil,
  106. Header: make(http.Header),
  107. }
  108. cache, err := model.GetUserCache(1)
  109. if err != nil {
  110. return testResult{
  111. localErr: err,
  112. newAPIError: nil,
  113. }
  114. }
  115. cache.WriteContext(c)
  116. //c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
  117. c.Request.Header.Set("Content-Type", "application/json")
  118. c.Set("channel", channel.Type)
  119. c.Set("base_url", channel.GetBaseURL())
  120. group, _ := model.GetUserGroup(1, false)
  121. c.Set("group", group)
  122. newAPIError := middleware.SetupContextForSelectedChannel(c, channel, testModel)
  123. if newAPIError != nil {
  124. return testResult{
  125. context: c,
  126. localErr: newAPIError,
  127. newAPIError: newAPIError,
  128. }
  129. }
  130. // Determine relay format based on endpoint type or request path
  131. var relayFormat types.RelayFormat
  132. if endpointType != "" {
  133. // 根据指定的端点类型设置 relayFormat
  134. switch constant.EndpointType(endpointType) {
  135. case constant.EndpointTypeOpenAI:
  136. relayFormat = types.RelayFormatOpenAI
  137. case constant.EndpointTypeOpenAIResponse:
  138. relayFormat = types.RelayFormatOpenAIResponses
  139. case constant.EndpointTypeAnthropic:
  140. relayFormat = types.RelayFormatClaude
  141. case constant.EndpointTypeGemini:
  142. relayFormat = types.RelayFormatGemini
  143. case constant.EndpointTypeJinaRerank:
  144. relayFormat = types.RelayFormatRerank
  145. case constant.EndpointTypeImageGeneration:
  146. relayFormat = types.RelayFormatOpenAIImage
  147. case constant.EndpointTypeEmbeddings:
  148. relayFormat = types.RelayFormatEmbedding
  149. default:
  150. relayFormat = types.RelayFormatOpenAI
  151. }
  152. } else {
  153. // 根据请求路径自动检测
  154. relayFormat = types.RelayFormatOpenAI
  155. if c.Request.URL.Path == "/v1/embeddings" {
  156. relayFormat = types.RelayFormatEmbedding
  157. }
  158. if c.Request.URL.Path == "/v1/images/generations" {
  159. relayFormat = types.RelayFormatOpenAIImage
  160. }
  161. if c.Request.URL.Path == "/v1/messages" {
  162. relayFormat = types.RelayFormatClaude
  163. }
  164. if strings.Contains(c.Request.URL.Path, "/v1beta/models") {
  165. relayFormat = types.RelayFormatGemini
  166. }
  167. if c.Request.URL.Path == "/v1/rerank" || c.Request.URL.Path == "/rerank" {
  168. relayFormat = types.RelayFormatRerank
  169. }
  170. if c.Request.URL.Path == "/v1/responses" {
  171. relayFormat = types.RelayFormatOpenAIResponses
  172. }
  173. }
  174. request := buildTestRequest(testModel, endpointType)
  175. info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil)
  176. if err != nil {
  177. return testResult{
  178. context: c,
  179. localErr: err,
  180. newAPIError: types.NewError(err, types.ErrorCodeGenRelayInfoFailed),
  181. }
  182. }
  183. info.InitChannelMeta(c)
  184. err = helper.ModelMappedHelper(c, info, request)
  185. if err != nil {
  186. return testResult{
  187. context: c,
  188. localErr: err,
  189. newAPIError: types.NewError(err, types.ErrorCodeChannelModelMappedError),
  190. }
  191. }
  192. testModel = info.UpstreamModelName
  193. // 更新请求中的模型名称
  194. request.SetModelName(testModel)
  195. apiType, _ := common.ChannelType2APIType(channel.Type)
  196. adaptor := relay.GetAdaptor(apiType)
  197. if adaptor == nil {
  198. return testResult{
  199. context: c,
  200. localErr: fmt.Errorf("invalid api type: %d, adaptor is nil", apiType),
  201. newAPIError: types.NewError(fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.ErrorCodeInvalidApiType),
  202. }
  203. }
  204. //// 创建一个用于日志的 info 副本,移除 ApiKey
  205. //logInfo := info
  206. //logInfo.ApiKey = ""
  207. common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, info.ToString()))
  208. priceData, err := helper.ModelPriceHelper(c, info, 0, request.GetTokenCountMeta())
  209. if err != nil {
  210. return testResult{
  211. context: c,
  212. localErr: err,
  213. newAPIError: types.NewError(err, types.ErrorCodeModelPriceError),
  214. }
  215. }
  216. adaptor.Init(info)
  217. var convertedRequest any
  218. // 根据 RelayMode 选择正确的转换函数
  219. switch info.RelayMode {
  220. case relayconstant.RelayModeEmbeddings:
  221. // Embedding 请求 - request 已经是正确的类型
  222. if embeddingReq, ok := request.(*dto.EmbeddingRequest); ok {
  223. convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, *embeddingReq)
  224. } else {
  225. return testResult{
  226. context: c,
  227. localErr: errors.New("invalid embedding request type"),
  228. newAPIError: types.NewError(errors.New("invalid embedding request type"), types.ErrorCodeConvertRequestFailed),
  229. }
  230. }
  231. case relayconstant.RelayModeImagesGenerations:
  232. // 图像生成请求 - request 已经是正确的类型
  233. if imageReq, ok := request.(*dto.ImageRequest); ok {
  234. convertedRequest, err = adaptor.ConvertImageRequest(c, info, *imageReq)
  235. } else {
  236. return testResult{
  237. context: c,
  238. localErr: errors.New("invalid image request type"),
  239. newAPIError: types.NewError(errors.New("invalid image request type"), types.ErrorCodeConvertRequestFailed),
  240. }
  241. }
  242. case relayconstant.RelayModeRerank:
  243. // Rerank 请求 - request 已经是正确的类型
  244. if rerankReq, ok := request.(*dto.RerankRequest); ok {
  245. convertedRequest, err = adaptor.ConvertRerankRequest(c, info.RelayMode, *rerankReq)
  246. } else {
  247. return testResult{
  248. context: c,
  249. localErr: errors.New("invalid rerank request type"),
  250. newAPIError: types.NewError(errors.New("invalid rerank request type"), types.ErrorCodeConvertRequestFailed),
  251. }
  252. }
  253. case relayconstant.RelayModeResponses:
  254. // Response 请求 - request 已经是正确的类型
  255. if responseReq, ok := request.(*dto.OpenAIResponsesRequest); ok {
  256. convertedRequest, err = adaptor.ConvertOpenAIResponsesRequest(c, info, *responseReq)
  257. } else {
  258. return testResult{
  259. context: c,
  260. localErr: errors.New("invalid response request type"),
  261. newAPIError: types.NewError(errors.New("invalid response request type"), types.ErrorCodeConvertRequestFailed),
  262. }
  263. }
  264. default:
  265. // Chat/Completion 等其他请求类型
  266. if generalReq, ok := request.(*dto.GeneralOpenAIRequest); ok {
  267. convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, generalReq)
  268. } else {
  269. return testResult{
  270. context: c,
  271. localErr: errors.New("invalid general request type"),
  272. newAPIError: types.NewError(errors.New("invalid general request type"), types.ErrorCodeConvertRequestFailed),
  273. }
  274. }
  275. }
  276. if err != nil {
  277. return testResult{
  278. context: c,
  279. localErr: err,
  280. newAPIError: types.NewError(err, types.ErrorCodeConvertRequestFailed),
  281. }
  282. }
  283. jsonData, err := json.Marshal(convertedRequest)
  284. if err != nil {
  285. return testResult{
  286. context: c,
  287. localErr: err,
  288. newAPIError: types.NewError(err, types.ErrorCodeJsonMarshalFailed),
  289. }
  290. }
  291. requestBody := bytes.NewBuffer(jsonData)
  292. c.Request.Body = io.NopCloser(requestBody)
  293. resp, err := adaptor.DoRequest(c, info, requestBody)
  294. if err != nil {
  295. return testResult{
  296. context: c,
  297. localErr: err,
  298. newAPIError: types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError),
  299. }
  300. }
  301. var httpResp *http.Response
  302. if resp != nil {
  303. httpResp = resp.(*http.Response)
  304. if httpResp.StatusCode != http.StatusOK {
  305. err := service.RelayErrorHandler(c.Request.Context(), httpResp, true)
  306. // 自动检测模式下,如果上游不支持 chat.completions 的 messages 参数,尝试切换到 Responses API 再测一次。
  307. if endpointType == "" && requestPath == "/v1/chat/completions" && err != nil {
  308. lowerErr := strings.ToLower(err.Error())
  309. if strings.Contains(lowerErr, "unsupported parameter") && strings.Contains(lowerErr, "messages") {
  310. return testChannel(channel, originTestModel, string(constant.EndpointTypeOpenAIResponse))
  311. }
  312. }
  313. return testResult{
  314. context: c,
  315. localErr: err,
  316. newAPIError: types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError),
  317. }
  318. }
  319. }
  320. usageA, respErr := adaptor.DoResponse(c, httpResp, info)
  321. if respErr != nil {
  322. return testResult{
  323. context: c,
  324. localErr: respErr,
  325. newAPIError: respErr,
  326. }
  327. }
  328. if usageA == nil {
  329. return testResult{
  330. context: c,
  331. localErr: errors.New("usage is nil"),
  332. newAPIError: types.NewOpenAIError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
  333. }
  334. }
  335. usage := usageA.(*dto.Usage)
  336. result := w.Result()
  337. respBody, err := io.ReadAll(result.Body)
  338. if err != nil {
  339. return testResult{
  340. context: c,
  341. localErr: err,
  342. newAPIError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError),
  343. }
  344. }
  345. info.SetEstimatePromptTokens(usage.PromptTokens)
  346. quota := 0
  347. if !priceData.UsePrice {
  348. quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio))
  349. quota = int(math.Round(float64(quota) * priceData.ModelRatio))
  350. if priceData.ModelRatio != 0 && quota <= 0 {
  351. quota = 1
  352. }
  353. } else {
  354. quota = int(priceData.ModelPrice * common.QuotaPerUnit)
  355. }
  356. tok := time.Now()
  357. milliseconds := tok.Sub(tik).Milliseconds()
  358. consumedTime := float64(milliseconds) / 1000.0
  359. other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
  360. usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
  361. model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{
  362. ChannelId: channel.Id,
  363. PromptTokens: usage.PromptTokens,
  364. CompletionTokens: usage.CompletionTokens,
  365. ModelName: info.OriginModelName,
  366. TokenName: "模型测试",
  367. Quota: quota,
  368. Content: "模型测试",
  369. UseTimeSeconds: int(consumedTime),
  370. IsStream: info.IsStream,
  371. Group: info.UsingGroup,
  372. Other: other,
  373. })
  374. common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
  375. return testResult{
  376. context: c,
  377. localErr: nil,
  378. newAPIError: nil,
  379. }
  380. }
  381. // for embedding models, and otherwise a chat/completion request with model-specific token limit heuristics.
  382. func buildTestRequest(model string, endpointType string) dto.Request {
  383. // 根据端点类型构建不同的测试请求
  384. if endpointType != "" {
  385. switch constant.EndpointType(endpointType) {
  386. case constant.EndpointTypeEmbeddings:
  387. // 返回 EmbeddingRequest
  388. return &dto.EmbeddingRequest{
  389. Model: model,
  390. Input: []any{"hello world"},
  391. }
  392. case constant.EndpointTypeImageGeneration:
  393. // 返回 ImageRequest
  394. return &dto.ImageRequest{
  395. Model: model,
  396. Prompt: "a cute cat",
  397. N: 1,
  398. Size: "1024x1024",
  399. }
  400. case constant.EndpointTypeJinaRerank:
  401. // 返回 RerankRequest
  402. return &dto.RerankRequest{
  403. Model: model,
  404. Query: "What is Deep Learning?",
  405. Documents: []any{"Deep Learning is a subset of machine learning.", "Machine learning is a field of artificial intelligence."},
  406. TopN: 2,
  407. }
  408. case constant.EndpointTypeOpenAIResponse:
  409. // 返回 OpenAIResponsesRequest
  410. maxOutputTokens := uint(10)
  411. return &dto.OpenAIResponsesRequest{
  412. Model: model,
  413. Input: json.RawMessage(`[{"role":"user","content":"hi"}]`),
  414. MaxOutputTokens: maxOutputTokens,
  415. Stream: true,
  416. }
  417. case constant.EndpointTypeAnthropic, constant.EndpointTypeGemini, constant.EndpointTypeOpenAI:
  418. // 返回 GeneralOpenAIRequest
  419. maxTokens := uint(10)
  420. if constant.EndpointType(endpointType) == constant.EndpointTypeGemini {
  421. maxTokens = 3000
  422. }
  423. return &dto.GeneralOpenAIRequest{
  424. Model: model,
  425. Stream: false,
  426. Messages: []dto.Message{
  427. {
  428. Role: "user",
  429. Content: "hi",
  430. },
  431. },
  432. MaxTokens: maxTokens,
  433. }
  434. }
  435. }
  436. // 自动检测逻辑(保持原有行为)
  437. if common.IsOpenAIResponseOnlyModel(model) {
  438. maxOutputTokens := uint(10)
  439. return &dto.OpenAIResponsesRequest{
  440. Model: model,
  441. Input: json.RawMessage(`[{"role":"user","content":"hi"}]`),
  442. MaxOutputTokens: maxOutputTokens,
  443. Stream: true,
  444. }
  445. }
  446. // 先判断是否为 Embedding 模型
  447. if strings.Contains(strings.ToLower(model), "embedding") ||
  448. strings.HasPrefix(model, "m3e") ||
  449. strings.Contains(model, "bge-") {
  450. // 返回 EmbeddingRequest
  451. return &dto.EmbeddingRequest{
  452. Model: model,
  453. Input: []any{"hello world"},
  454. }
  455. }
  456. // Chat/Completion 请求 - 返回 GeneralOpenAIRequest
  457. testRequest := &dto.GeneralOpenAIRequest{
  458. Model: model,
  459. Stream: false,
  460. Messages: []dto.Message{
  461. {
  462. Role: "user",
  463. Content: "hi",
  464. },
  465. },
  466. }
  467. if strings.HasPrefix(model, "o") {
  468. testRequest.MaxCompletionTokens = 10
  469. } else if strings.Contains(model, "thinking") {
  470. if !strings.Contains(model, "claude") {
  471. testRequest.MaxTokens = 50
  472. }
  473. } else if strings.Contains(model, "gemini") {
  474. testRequest.MaxTokens = 3000
  475. } else {
  476. testRequest.MaxTokens = 10
  477. }
  478. return testRequest
  479. }
  480. func TestChannel(c *gin.Context) {
  481. channelId, err := strconv.Atoi(c.Param("id"))
  482. if err != nil {
  483. common.ApiError(c, err)
  484. return
  485. }
  486. channel, err := model.CacheGetChannel(channelId)
  487. if err != nil {
  488. channel, err = model.GetChannelById(channelId, true)
  489. if err != nil {
  490. common.ApiError(c, err)
  491. return
  492. }
  493. }
  494. //defer func() {
  495. // if channel.ChannelInfo.IsMultiKey {
  496. // go func() { _ = channel.SaveChannelInfo() }()
  497. // }
  498. //}()
  499. testModel := c.Query("model")
  500. endpointType := c.Query("endpoint_type")
  501. tik := time.Now()
  502. result := testChannel(channel, testModel, endpointType)
  503. if result.localErr != nil {
  504. c.JSON(http.StatusOK, gin.H{
  505. "success": false,
  506. "message": result.localErr.Error(),
  507. "time": 0.0,
  508. })
  509. return
  510. }
  511. tok := time.Now()
  512. milliseconds := tok.Sub(tik).Milliseconds()
  513. go channel.UpdateResponseTime(milliseconds)
  514. consumedTime := float64(milliseconds) / 1000.0
  515. if result.newAPIError != nil {
  516. c.JSON(http.StatusOK, gin.H{
  517. "success": false,
  518. "message": result.newAPIError.Error(),
  519. "time": consumedTime,
  520. })
  521. return
  522. }
  523. c.JSON(http.StatusOK, gin.H{
  524. "success": true,
  525. "message": "",
  526. "time": consumedTime,
  527. })
  528. }
  529. var testAllChannelsLock sync.Mutex
  530. var testAllChannelsRunning bool = false
  531. func testAllChannels(notify bool) error {
  532. testAllChannelsLock.Lock()
  533. if testAllChannelsRunning {
  534. testAllChannelsLock.Unlock()
  535. return errors.New("测试已在运行中")
  536. }
  537. testAllChannelsRunning = true
  538. testAllChannelsLock.Unlock()
  539. channels, getChannelErr := model.GetAllChannels(0, 0, true, false)
  540. if getChannelErr != nil {
  541. return getChannelErr
  542. }
  543. var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
  544. if disableThreshold == 0 {
  545. disableThreshold = 10000000 // a impossible value
  546. }
  547. gopool.Go(func() {
  548. // 使用 defer 确保无论如何都会重置运行状态,防止死锁
  549. defer func() {
  550. testAllChannelsLock.Lock()
  551. testAllChannelsRunning = false
  552. testAllChannelsLock.Unlock()
  553. }()
  554. for _, channel := range channels {
  555. isChannelEnabled := channel.Status == common.ChannelStatusEnabled
  556. tik := time.Now()
  557. result := testChannel(channel, "", "")
  558. tok := time.Now()
  559. milliseconds := tok.Sub(tik).Milliseconds()
  560. shouldBanChannel := false
  561. newAPIError := result.newAPIError
  562. // request error disables the channel
  563. if newAPIError != nil {
  564. shouldBanChannel = service.ShouldDisableChannel(channel.Type, result.newAPIError)
  565. }
  566. // 当错误检查通过,才检查响应时间
  567. if common.AutomaticDisableChannelEnabled && !shouldBanChannel {
  568. if milliseconds > disableThreshold {
  569. err := fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)
  570. newAPIError = types.NewOpenAIError(err, types.ErrorCodeChannelResponseTimeExceeded, http.StatusRequestTimeout)
  571. shouldBanChannel = true
  572. }
  573. }
  574. // disable channel
  575. if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
  576. processChannelError(result.context, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
  577. }
  578. // enable channel
  579. if !isChannelEnabled && service.ShouldEnableChannel(newAPIError, channel.Status) {
  580. service.EnableChannel(channel.Id, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.Name)
  581. }
  582. channel.UpdateResponseTime(milliseconds)
  583. time.Sleep(common.RequestInterval)
  584. }
  585. if notify {
  586. service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成")
  587. }
  588. })
  589. return nil
  590. }
  591. func TestAllChannels(c *gin.Context) {
  592. err := testAllChannels(true)
  593. if err != nil {
  594. common.ApiError(c, err)
  595. return
  596. }
  597. c.JSON(http.StatusOK, gin.H{
  598. "success": true,
  599. "message": "",
  600. })
  601. }
  602. var autoTestChannelsOnce sync.Once
  603. func AutomaticallyTestChannels() {
  604. // 只在Master节点定时测试渠道
  605. if !common.IsMasterNode {
  606. return
  607. }
  608. autoTestChannelsOnce.Do(func() {
  609. for {
  610. if !operation_setting.GetMonitorSetting().AutoTestChannelEnabled {
  611. time.Sleep(1 * time.Minute)
  612. continue
  613. }
  614. for {
  615. frequency := operation_setting.GetMonitorSetting().AutoTestChannelMinutes
  616. time.Sleep(time.Duration(int(math.Round(frequency))) * time.Minute)
  617. common.SysLog(fmt.Sprintf("automatically test channels with interval %f minutes", frequency))
  618. common.SysLog("automatically testing all channels")
  619. _ = testAllChannels(false)
  620. common.SysLog("automatically channel test finished")
  621. if !operation_setting.GetMonitorSetting().AutoTestChannelEnabled {
  622. break
  623. }
  624. }
  625. }
  626. })
  627. }