channel-test.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598
  1. package controller
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "math/rand/v2"
  8. "net/http"
  9. "net/http/httptest"
  10. "net/url"
  11. "slices"
  12. "strconv"
  13. "strings"
  14. "sync"
  15. "sync/atomic"
  16. "time"
  17. "github.com/gin-gonic/gin"
  18. "github.com/labring/aiproxy/core/common/conv"
  19. "github.com/labring/aiproxy/core/common/notify"
  20. "github.com/labring/aiproxy/core/common/trylock"
  21. "github.com/labring/aiproxy/core/middleware"
  22. "github.com/labring/aiproxy/core/model"
  23. "github.com/labring/aiproxy/core/monitor"
  24. "github.com/labring/aiproxy/core/relay/adaptors"
  25. "github.com/labring/aiproxy/core/relay/meta"
  26. "github.com/labring/aiproxy/core/relay/mode"
  27. "github.com/labring/aiproxy/core/relay/render"
  28. "github.com/labring/aiproxy/core/relay/utils"
  29. log "github.com/sirupsen/logrus"
  30. )
  31. const channelTestRequestID = "channel-test"
  32. var (
  33. modelConfigCache map[string]model.ModelConfig = make(map[string]model.ModelConfig)
  34. modelConfigCacheOnce sync.Once
  35. )
  36. func guessModelConfig(modelName string) model.ModelConfig {
  37. modelConfigCacheOnce.Do(func() {
  38. for _, c := range adaptors.ChannelAdaptor {
  39. for _, m := range c.Metadata().Models {
  40. if _, ok := modelConfigCache[m.Model]; !ok {
  41. modelConfigCache[m.Model] = m
  42. }
  43. }
  44. }
  45. })
  46. if cachedConfig, ok := modelConfigCache[modelName]; ok {
  47. return cachedConfig
  48. }
  49. return model.ModelConfig{}
  50. }
  51. // testSingleModel tests a single model in the channel
  52. func testSingleModel(
  53. mc *model.ModelCaches,
  54. channel *model.Channel,
  55. modelName string,
  56. ) (*model.ChannelTest, error) {
  57. modelConfig, ok := mc.ModelConfig.GetModelConfig(modelName)
  58. if !ok {
  59. return nil, errors.New(modelName + " model config not found")
  60. }
  61. if modelConfig.Type == mode.Unknown {
  62. newModelConfig := guessModelConfig(modelName)
  63. if newModelConfig.Type != mode.Unknown {
  64. modelConfig = newModelConfig
  65. }
  66. }
  67. if modelConfig.Type != mode.Unknown {
  68. a, ok := adaptors.GetAdaptor(channel.Type)
  69. if !ok {
  70. return nil, errors.New("adaptor not found")
  71. }
  72. if !a.SupportMode(modelConfig.Type) {
  73. return nil, fmt.Errorf("%s not supported by adaptor", modelConfig.Type)
  74. }
  75. }
  76. if modelConfig.ExcludeFromTests {
  77. return &model.ChannelTest{
  78. TestAt: time.Now(),
  79. Model: modelName,
  80. ActualModel: modelName,
  81. Success: true,
  82. Code: http.StatusOK,
  83. Mode: modelConfig.Type,
  84. ChannelName: channel.Name,
  85. ChannelType: channel.Type,
  86. ChannelID: channel.ID,
  87. }, nil
  88. }
  89. body, m, err := utils.BuildRequest(modelConfig)
  90. if err != nil {
  91. return nil, err
  92. }
  93. w := httptest.NewRecorder()
  94. newc, _ := gin.CreateTestContext(w)
  95. newc.Request = &http.Request{
  96. URL: &url.URL{},
  97. Body: io.NopCloser(body),
  98. Header: make(http.Header),
  99. }
  100. middleware.SetRequestID(newc, channelTestRequestID)
  101. meta := meta.NewMeta(
  102. channel,
  103. m,
  104. modelName,
  105. modelConfig,
  106. meta.WithRequestID(channelTestRequestID),
  107. )
  108. result := relayHandler(newc, meta, mc)
  109. success := result.Error == nil
  110. var (
  111. respStr string
  112. code int
  113. )
  114. if success {
  115. switch meta.Mode {
  116. case mode.AudioSpeech,
  117. mode.ImagesGenerations:
  118. respStr = ""
  119. default:
  120. respStr = w.Body.String()
  121. }
  122. code = w.Code
  123. } else {
  124. respBody, _ := result.Error.MarshalJSON()
  125. respStr = conv.BytesToString(respBody)
  126. code = result.Error.StatusCode()
  127. }
  128. return channel.UpdateModelTest(
  129. meta.RequestAt,
  130. meta.OriginModel,
  131. meta.ActualModel,
  132. meta.Mode,
  133. time.Since(meta.RequestAt).Seconds(),
  134. success,
  135. respStr,
  136. code,
  137. )
  138. }
  139. // TestChannel godoc
  140. //
  141. // @Summary Test channel model
  142. // @Description Tests a single model in the channel
  143. // @Tags channel
  144. // @Produce json
  145. // @Security ApiKeyAuth
  146. // @Param id path int true "Channel ID"
  147. // @Param model path string true "Model name"
  148. // @Success 200 {object} middleware.APIResponse{data=model.ChannelTest}
  149. // @Router /api/channel/{id}/{model} [get]
  150. //
  151. //nolint:goconst
  152. func TestChannel(c *gin.Context) {
  153. id, err := strconv.Atoi(c.Param("id"))
  154. if err != nil {
  155. c.JSON(http.StatusOK, middleware.APIResponse{
  156. Success: false,
  157. Message: err.Error(),
  158. })
  159. return
  160. }
  161. modelName := strings.TrimPrefix(c.Param("model"), "/")
  162. if modelName == "" {
  163. c.JSON(http.StatusOK, middleware.APIResponse{
  164. Success: false,
  165. Message: "model is required",
  166. })
  167. return
  168. }
  169. channel, err := model.LoadChannelByID(id)
  170. if err != nil {
  171. c.JSON(http.StatusOK, middleware.APIResponse{
  172. Success: false,
  173. Message: "channel not found",
  174. })
  175. return
  176. }
  177. if !slices.Contains(channel.Models, modelName) {
  178. c.JSON(http.StatusOK, middleware.APIResponse{
  179. Success: false,
  180. Message: "model not supported by channel",
  181. })
  182. return
  183. }
  184. ct, err := testSingleModel(model.LoadModelCaches(), channel, modelName)
  185. if err != nil {
  186. log.Errorf(
  187. "failed to test channel %s(%d) model %s: %s",
  188. channel.Name,
  189. channel.ID,
  190. modelName,
  191. err.Error(),
  192. )
  193. c.JSON(http.StatusOK, middleware.APIResponse{
  194. Success: false,
  195. Message: fmt.Sprintf(
  196. "failed to test channel %s(%d) model %s: %s",
  197. channel.Name,
  198. channel.ID,
  199. modelName,
  200. err.Error(),
  201. ),
  202. })
  203. return
  204. }
  205. if c.Query("success_body") != "true" && ct.Success {
  206. ct.Response = ""
  207. }
  208. c.JSON(http.StatusOK, middleware.APIResponse{
  209. Success: true,
  210. Data: ct,
  211. })
  212. }
  213. type TestResult struct {
  214. Data *model.ChannelTest `json:"data,omitempty"`
  215. Message string `json:"message,omitempty"`
  216. Success bool `json:"success"`
  217. }
  218. func processTestResult(
  219. mc *model.ModelCaches,
  220. channel *model.Channel,
  221. modelName string,
  222. returnSuccess, successResponseBody bool,
  223. ) *TestResult {
  224. ct, err := testSingleModel(mc, channel, modelName)
  225. e := &utils.UnsupportedModelTypeError{}
  226. if errors.As(err, &e) {
  227. log.Errorf("model %s not supported test: %s", modelName, err.Error())
  228. return nil
  229. }
  230. result := &TestResult{
  231. Success: err == nil,
  232. }
  233. if err != nil {
  234. result.Message = fmt.Sprintf(
  235. "failed to test channel %s(%d) model %s: %s",
  236. channel.Name,
  237. channel.ID,
  238. modelName,
  239. err.Error(),
  240. )
  241. return result
  242. }
  243. if !ct.Success {
  244. result.Data = ct
  245. return result
  246. }
  247. if !returnSuccess {
  248. return nil
  249. }
  250. if !successResponseBody {
  251. ct.Response = ""
  252. }
  253. result.Data = ct
  254. return result
  255. }
  256. // TestChannelModels godoc
  257. //
  258. // @Summary Test channel models
  259. // @Description Tests all models in the channel
  260. // @Tags channel
  261. // @Produce json
  262. // @Security ApiKeyAuth
  263. // @Param id path int true "Channel ID"
  264. // @Param return_success query bool false "Return success"
  265. // @Param success_body query bool false "Success body"
  266. // @Param stream query bool false "Stream"
  267. // @Success 200 {object} middleware.APIResponse{data=[]TestResult}
  268. // @Router /api/channel/{id}/test [get]
  269. func TestChannelModels(c *gin.Context) {
  270. id, err := strconv.Atoi(c.Param("id"))
  271. if err != nil {
  272. c.JSON(http.StatusOK, middleware.APIResponse{
  273. Success: false,
  274. Message: err.Error(),
  275. })
  276. return
  277. }
  278. channel, err := model.LoadChannelByID(id)
  279. if err != nil {
  280. c.JSON(http.StatusOK, middleware.APIResponse{
  281. Success: false,
  282. Message: "channel not found",
  283. })
  284. return
  285. }
  286. returnSuccess := c.Query("return_success") == "true"
  287. successResponseBody := c.Query("success_body") == "true"
  288. isStream := c.Query("stream") == "true"
  289. results := make([]*TestResult, 0)
  290. resultsMutex := sync.Mutex{}
  291. hasError := atomic.Bool{}
  292. var wg sync.WaitGroup
  293. semaphore := make(chan struct{}, 5)
  294. models := slices.Clone(channel.Models)
  295. rand.Shuffle(len(models), func(i, j int) {
  296. models[i], models[j] = models[j], models[i]
  297. })
  298. mc := model.LoadModelCaches()
  299. for _, modelName := range models {
  300. wg.Add(1)
  301. semaphore <- struct{}{}
  302. go func(model string) {
  303. defer wg.Done()
  304. defer func() { <-semaphore }()
  305. result := processTestResult(mc, channel, model, returnSuccess, successResponseBody)
  306. if result == nil {
  307. return
  308. }
  309. if !result.Success || (result.Data != nil && !result.Data.Success) {
  310. hasError.Store(true)
  311. }
  312. resultsMutex.Lock()
  313. if isStream {
  314. err := render.OpenaiObjectData(c, result)
  315. if err != nil {
  316. log.Errorf("failed to render result: %s", err.Error())
  317. }
  318. } else {
  319. results = append(results, result)
  320. }
  321. resultsMutex.Unlock()
  322. }(modelName)
  323. }
  324. wg.Wait()
  325. if !hasError.Load() {
  326. err := model.ClearLastTestErrorAt(channel.ID)
  327. if err != nil {
  328. log.Errorf(
  329. "failed to clear last test error at for channel %s(%d): %s",
  330. channel.Name,
  331. channel.ID,
  332. err.Error(),
  333. )
  334. }
  335. }
  336. if !isStream {
  337. c.JSON(http.StatusOK, middleware.APIResponse{
  338. Success: true,
  339. Data: results,
  340. })
  341. }
  342. }
  343. // TestAllChannels godoc
  344. //
  345. // @Summary Test all channels
  346. // @Description Tests all channels
  347. // @Tags channel
  348. // @Produce json
  349. // @Security ApiKeyAuth
  350. // @Param test_disabled query bool false "Test disabled"
  351. // @Param return_success query bool false "Return success"
  352. // @Param success_body query bool false "Success body"
  353. // @Param stream query bool false "Stream"
  354. // @Success 200 {object} middleware.APIResponse{data=[]TestResult}
  355. //
  356. // @Router /api/channels/test [get]
  357. func TestAllChannels(c *gin.Context) {
  358. testDisabled := c.Query("test_disabled") == "true"
  359. var (
  360. channels []*model.Channel
  361. err error
  362. )
  363. if testDisabled {
  364. channels, err = model.LoadChannels()
  365. } else {
  366. channels, err = model.LoadEnabledChannels()
  367. }
  368. if err != nil {
  369. c.JSON(http.StatusOK, middleware.APIResponse{
  370. Success: false,
  371. Message: err.Error(),
  372. })
  373. return
  374. }
  375. returnSuccess := c.Query("return_success") == "true"
  376. successResponseBody := c.Query("success_body") == "true"
  377. isStream := c.Query("stream") == "true"
  378. results := make([]*TestResult, 0)
  379. resultsMutex := sync.Mutex{}
  380. hasErrorMap := make(map[int]*atomic.Bool)
  381. var wg sync.WaitGroup
  382. semaphore := make(chan struct{}, 5)
  383. newChannels := slices.Clone(channels)
  384. rand.Shuffle(len(newChannels), func(i, j int) {
  385. newChannels[i], newChannels[j] = newChannels[j], newChannels[i]
  386. })
  387. mc := model.LoadModelCaches()
  388. for _, channel := range newChannels {
  389. channelHasError := &atomic.Bool{}
  390. hasErrorMap[channel.ID] = channelHasError
  391. models := slices.Clone(channel.Models)
  392. rand.Shuffle(len(models), func(i, j int) {
  393. models[i], models[j] = models[j], models[i]
  394. })
  395. for _, modelName := range models {
  396. wg.Add(1)
  397. semaphore <- struct{}{}
  398. go func(model string, ch *model.Channel, hasError *atomic.Bool) {
  399. defer wg.Done()
  400. defer func() { <-semaphore }()
  401. result := processTestResult(mc, ch, model, returnSuccess, successResponseBody)
  402. if result == nil {
  403. return
  404. }
  405. if !result.Success || (result.Data != nil && !result.Data.Success) {
  406. hasError.Store(true)
  407. }
  408. resultsMutex.Lock()
  409. if isStream {
  410. err := render.OpenaiObjectData(c, result)
  411. if err != nil {
  412. log.Errorf("failed to render result: %s", err.Error())
  413. }
  414. } else {
  415. results = append(results, result)
  416. }
  417. resultsMutex.Unlock()
  418. }(modelName, channel, channelHasError)
  419. }
  420. }
  421. wg.Wait()
  422. for id, hasError := range hasErrorMap {
  423. if !hasError.Load() {
  424. err := model.ClearLastTestErrorAt(id)
  425. if err != nil {
  426. log.Errorf("failed to clear last test error at for channel %d: %s", id, err.Error())
  427. }
  428. }
  429. }
  430. if !isStream {
  431. c.JSON(http.StatusOK, middleware.APIResponse{
  432. Success: true,
  433. Data: results,
  434. })
  435. }
  436. }
  437. func tryTestChannel(channelID int, modelName string) bool {
  438. return trylock.Lock(
  439. fmt.Sprintf("channel_test_lock:%d:%s", channelID, modelName),
  440. 30*time.Second,
  441. )
  442. }
  443. func AutoTestBannedModels() {
  444. log := log.WithFields(log.Fields{
  445. "auto_test_banned_models": "true",
  446. })
  447. channels, err := monitor.GetAllBannedModelChannels(context.Background())
  448. if err != nil {
  449. log.Errorf("failed to get banned channels: %s", err.Error())
  450. return
  451. }
  452. if len(channels) == 0 {
  453. return
  454. }
  455. mc := model.LoadModelCaches()
  456. for modelName, ids := range channels {
  457. for _, id := range ids {
  458. if !tryTestChannel(int(id), modelName) {
  459. continue
  460. }
  461. channel, err := model.LoadChannelByID(int(id))
  462. if err != nil {
  463. log.Errorf("failed to get channel by model %s: %s", modelName, err.Error())
  464. continue
  465. }
  466. result, err := testSingleModel(mc, channel, modelName)
  467. if err != nil {
  468. notify.Error(
  469. fmt.Sprintf(
  470. "channel %s (type: %d, id: %d) model %s test failed",
  471. channel.Name,
  472. channel.Type,
  473. channel.ID,
  474. modelName,
  475. ),
  476. err.Error(),
  477. )
  478. continue
  479. }
  480. if result.Success {
  481. notify.Info(
  482. fmt.Sprintf(
  483. "channel %s (type: %d, id: %d) model %s test success",
  484. channel.Name,
  485. channel.Type,
  486. channel.ID,
  487. modelName,
  488. ),
  489. "unban it",
  490. )
  491. err = monitor.ClearChannelModelErrors(context.Background(), modelName, channel.ID)
  492. if err != nil {
  493. log.Errorf("clear channel errors failed: %+v", err)
  494. }
  495. } else {
  496. notify.Error(fmt.Sprintf("channel %s (type: %d, id: %d) model %s test failed", channel.Name, channel.Type, channel.ID, modelName),
  497. fmt.Sprintf("code: %d, response: %s", result.Code, result.Response))
  498. }
  499. }
  500. }
  501. }