channel-test.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488
  1. package controller
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "math/rand/v2"
  8. "net/http"
  9. "net/http/httptest"
  10. "net/url"
  11. "slices"
  12. "strconv"
  13. "sync"
  14. "sync/atomic"
  15. "time"
  16. "github.com/gin-gonic/gin"
  17. "github.com/labring/aiproxy/core/common/notify"
  18. "github.com/labring/aiproxy/core/common/render"
  19. "github.com/labring/aiproxy/core/common/trylock"
  20. "github.com/labring/aiproxy/core/middleware"
  21. "github.com/labring/aiproxy/core/model"
  22. "github.com/labring/aiproxy/core/monitor"
  23. "github.com/labring/aiproxy/core/relay/channeltype"
  24. "github.com/labring/aiproxy/core/relay/meta"
  25. "github.com/labring/aiproxy/core/relay/mode"
  26. "github.com/labring/aiproxy/core/relay/utils"
  27. log "github.com/sirupsen/logrus"
  28. )
  29. const channelTestRequestID = "channel-test"
  30. var (
  31. modelConfigCache map[string]*model.ModelConfig = make(map[string]*model.ModelConfig)
  32. modelConfigCacheOnce sync.Once
  33. )
  34. func guessModelConfig(model string) *model.ModelConfig {
  35. modelConfigCacheOnce.Do(func() {
  36. for _, c := range channeltype.ChannelAdaptor {
  37. for _, m := range c.GetModelList() {
  38. if _, ok := modelConfigCache[m.Model]; !ok {
  39. modelConfigCache[m.Model] = m
  40. }
  41. }
  42. }
  43. })
  44. if cachedConfig, ok := modelConfigCache[model]; ok {
  45. return cachedConfig
  46. }
  47. return nil
  48. }
  49. // testSingleModel tests a single model in the channel
  50. func testSingleModel(mc *model.ModelCaches, channel *model.Channel, modelName string) (*model.ChannelTest, error) {
  51. modelConfig, ok := mc.ModelConfig.GetModelConfig(modelName)
  52. if !ok {
  53. return nil, errors.New(modelName + " model config not found")
  54. }
  55. if modelConfig.Type == mode.Unknown {
  56. newModelConfig := guessModelConfig(modelName)
  57. if newModelConfig != nil {
  58. modelConfig = newModelConfig
  59. }
  60. }
  61. if modelConfig.ExcludeFromTests {
  62. return &model.ChannelTest{
  63. TestAt: time.Now(),
  64. Model: modelName,
  65. ActualModel: modelName,
  66. Success: true,
  67. Code: http.StatusOK,
  68. Mode: modelConfig.Type,
  69. ChannelName: channel.Name,
  70. ChannelType: channel.Type,
  71. ChannelID: channel.ID,
  72. }, nil
  73. }
  74. body, m, err := utils.BuildRequest(modelConfig)
  75. if err != nil {
  76. return nil, err
  77. }
  78. w := httptest.NewRecorder()
  79. newc, _ := gin.CreateTestContext(w)
  80. newc.Request = &http.Request{
  81. URL: &url.URL{},
  82. Body: io.NopCloser(body),
  83. Header: make(http.Header),
  84. }
  85. middleware.SetRequestID(newc, channelTestRequestID)
  86. meta := meta.NewMeta(
  87. channel,
  88. m,
  89. modelName,
  90. modelConfig,
  91. meta.WithRequestID(channelTestRequestID),
  92. )
  93. result := relayHandler(meta, newc)
  94. success := result.Error == nil
  95. var respStr string
  96. var code int
  97. if success {
  98. switch meta.Mode {
  99. case mode.AudioSpeech,
  100. mode.ImagesGenerations:
  101. respStr = ""
  102. default:
  103. respStr = w.Body.String()
  104. }
  105. code = w.Code
  106. } else {
  107. respStr = result.Error.JSONOrEmpty()
  108. code = result.Error.StatusCode
  109. }
  110. return channel.UpdateModelTest(
  111. meta.RequestAt,
  112. meta.OriginModel,
  113. meta.ActualModel,
  114. meta.Mode,
  115. time.Since(meta.RequestAt).Seconds(),
  116. success,
  117. respStr,
  118. code,
  119. )
  120. }
  121. // TestChannel godoc
  122. //
  123. // @Summary Test channel model
  124. // @Description Tests a single model in the channel
  125. // @Tags channel
  126. // @Produce json
  127. // @Security ApiKeyAuth
  128. // @Param id path int true "Channel ID"
  129. // @Param model path string true "Model name"
  130. // @Success 200 {object} middleware.APIResponse{data=model.ChannelTest}
  131. // @Router /api/channel/{id}/{model} [get]
  132. //
  133. //nolint:goconst
  134. func TestChannel(c *gin.Context) {
  135. id, err := strconv.Atoi(c.Param("id"))
  136. if err != nil {
  137. c.JSON(http.StatusOK, middleware.APIResponse{
  138. Success: false,
  139. Message: err.Error(),
  140. })
  141. return
  142. }
  143. modelName := c.Param("model")
  144. if modelName == "" {
  145. c.JSON(http.StatusOK, middleware.APIResponse{
  146. Success: false,
  147. Message: "model is required",
  148. })
  149. return
  150. }
  151. channel, err := model.LoadChannelByID(id)
  152. if err != nil {
  153. c.JSON(http.StatusOK, middleware.APIResponse{
  154. Success: false,
  155. Message: "channel not found",
  156. })
  157. return
  158. }
  159. if !slices.Contains(channel.Models, modelName) {
  160. c.JSON(http.StatusOK, middleware.APIResponse{
  161. Success: false,
  162. Message: "model not supported by channel",
  163. })
  164. return
  165. }
  166. ct, err := testSingleModel(model.LoadModelCaches(), channel, modelName)
  167. if err != nil {
  168. log.Errorf("failed to test channel %s(%d) model %s: %s", channel.Name, channel.ID, modelName, err.Error())
  169. c.JSON(http.StatusOK, middleware.APIResponse{
  170. Success: false,
  171. Message: fmt.Sprintf("failed to test channel %s(%d) model %s: %s", channel.Name, channel.ID, modelName, err.Error()),
  172. })
  173. return
  174. }
  175. if c.Query("success_body") != "true" && ct.Success {
  176. ct.Response = ""
  177. }
  178. c.JSON(http.StatusOK, middleware.APIResponse{
  179. Success: true,
  180. Data: ct,
  181. })
  182. }
  183. type TestResult struct {
  184. Data *model.ChannelTest `json:"data,omitempty"`
  185. Message string `json:"message,omitempty"`
  186. Success bool `json:"success"`
  187. }
  188. func processTestResult(mc *model.ModelCaches, channel *model.Channel, modelName string, returnSuccess bool, successResponseBody bool) *TestResult {
  189. ct, err := testSingleModel(mc, channel, modelName)
  190. e := &utils.UnsupportedModelTypeError{}
  191. if errors.As(err, &e) {
  192. log.Errorf("model %s not supported test: %s", modelName, err.Error())
  193. return nil
  194. }
  195. result := &TestResult{
  196. Success: err == nil,
  197. }
  198. if err != nil {
  199. result.Message = fmt.Sprintf("failed to test channel %s(%d) model %s: %s", channel.Name, channel.ID, modelName, err.Error())
  200. return result
  201. }
  202. if !ct.Success {
  203. result.Data = ct
  204. return result
  205. }
  206. if !returnSuccess {
  207. return nil
  208. }
  209. if !successResponseBody {
  210. ct.Response = ""
  211. }
  212. result.Data = ct
  213. return result
  214. }
  215. // TestChannelModels godoc
  216. //
  217. // @Summary Test channel models
  218. // @Description Tests all models in the channel
  219. // @Tags channel
  220. // @Produce json
  221. // @Security ApiKeyAuth
  222. // @Param id path int true "Channel ID"
  223. // @Param return_success query bool false "Return success"
  224. // @Param success_body query bool false "Success body"
  225. // @Param stream query bool false "Stream"
  226. // @Success 200 {object} middleware.APIResponse{data=[]TestResult}
  227. // @Router /api/channel/{id}/models [get]
  228. func TestChannelModels(c *gin.Context) {
  229. id, err := strconv.Atoi(c.Param("id"))
  230. if err != nil {
  231. c.JSON(http.StatusOK, middleware.APIResponse{
  232. Success: false,
  233. Message: err.Error(),
  234. })
  235. return
  236. }
  237. channel, err := model.LoadChannelByID(id)
  238. if err != nil {
  239. c.JSON(http.StatusOK, middleware.APIResponse{
  240. Success: false,
  241. Message: "channel not found",
  242. })
  243. return
  244. }
  245. returnSuccess := c.Query("return_success") == "true"
  246. successResponseBody := c.Query("success_body") == "true"
  247. isStream := c.Query("stream") == "true"
  248. results := make([]*TestResult, 0)
  249. resultsMutex := sync.Mutex{}
  250. hasError := atomic.Bool{}
  251. var wg sync.WaitGroup
  252. semaphore := make(chan struct{}, 5)
  253. models := slices.Clone(channel.Models)
  254. rand.Shuffle(len(models), func(i, j int) {
  255. models[i], models[j] = models[j], models[i]
  256. })
  257. mc := model.LoadModelCaches()
  258. for _, modelName := range models {
  259. wg.Add(1)
  260. semaphore <- struct{}{}
  261. go func(model string) {
  262. defer wg.Done()
  263. defer func() { <-semaphore }()
  264. result := processTestResult(mc, channel, model, returnSuccess, successResponseBody)
  265. if result == nil {
  266. return
  267. }
  268. if !result.Success || (result.Data != nil && !result.Data.Success) {
  269. hasError.Store(true)
  270. }
  271. resultsMutex.Lock()
  272. if isStream {
  273. err := render.ObjectData(c, result)
  274. if err != nil {
  275. log.Errorf("failed to render result: %s", err.Error())
  276. }
  277. } else {
  278. results = append(results, result)
  279. }
  280. resultsMutex.Unlock()
  281. }(modelName)
  282. }
  283. wg.Wait()
  284. if !hasError.Load() {
  285. err := model.ClearLastTestErrorAt(channel.ID)
  286. if err != nil {
  287. log.Errorf("failed to clear last test error at for channel %s(%d): %s", channel.Name, channel.ID, err.Error())
  288. }
  289. }
  290. if !isStream {
  291. c.JSON(http.StatusOK, middleware.APIResponse{
  292. Success: true,
  293. Data: results,
  294. })
  295. }
  296. }
  297. // TestAllChannels godoc
  298. //
  299. // @Summary Test all channels
  300. // @Description Tests all channels
  301. // @Tags channel
  302. // @Produce json
  303. // @Security ApiKeyAuth
  304. // @Param test_disabled query bool false "Test disabled"
  305. // @Param return_success query bool false "Return success"
  306. // @Param success_body query bool false "Success body"
  307. // @Param stream query bool false "Stream"
  308. // @Success 200 {object} middleware.APIResponse{data=[]TestResult}
  309. //
  310. // @Router /api/channels/test [get]
  311. func TestAllChannels(c *gin.Context) {
  312. testDisabled := c.Query("test_disabled") == "true"
  313. var channels []*model.Channel
  314. var err error
  315. if testDisabled {
  316. channels, err = model.LoadChannels()
  317. } else {
  318. channels, err = model.LoadEnabledChannels()
  319. }
  320. if err != nil {
  321. c.JSON(http.StatusOK, middleware.APIResponse{
  322. Success: false,
  323. Message: err.Error(),
  324. })
  325. return
  326. }
  327. returnSuccess := c.Query("return_success") == "true"
  328. successResponseBody := c.Query("success_body") == "true"
  329. isStream := c.Query("stream") == "true"
  330. results := make([]*TestResult, 0)
  331. resultsMutex := sync.Mutex{}
  332. hasErrorMap := make(map[int]*atomic.Bool)
  333. var wg sync.WaitGroup
  334. semaphore := make(chan struct{}, 5)
  335. newChannels := slices.Clone(channels)
  336. rand.Shuffle(len(newChannels), func(i, j int) {
  337. newChannels[i], newChannels[j] = newChannels[j], newChannels[i]
  338. })
  339. mc := model.LoadModelCaches()
  340. for _, channel := range newChannels {
  341. channelHasError := &atomic.Bool{}
  342. hasErrorMap[channel.ID] = channelHasError
  343. models := slices.Clone(channel.Models)
  344. rand.Shuffle(len(models), func(i, j int) {
  345. models[i], models[j] = models[j], models[i]
  346. })
  347. for _, modelName := range models {
  348. wg.Add(1)
  349. semaphore <- struct{}{}
  350. go func(model string, ch *model.Channel, hasError *atomic.Bool) {
  351. defer wg.Done()
  352. defer func() { <-semaphore }()
  353. result := processTestResult(mc, ch, model, returnSuccess, successResponseBody)
  354. if result == nil {
  355. return
  356. }
  357. if !result.Success || (result.Data != nil && !result.Data.Success) {
  358. hasError.Store(true)
  359. }
  360. resultsMutex.Lock()
  361. if isStream {
  362. err := render.ObjectData(c, result)
  363. if err != nil {
  364. log.Errorf("failed to render result: %s", err.Error())
  365. }
  366. } else {
  367. results = append(results, result)
  368. }
  369. resultsMutex.Unlock()
  370. }(modelName, channel, channelHasError)
  371. }
  372. }
  373. wg.Wait()
  374. for id, hasError := range hasErrorMap {
  375. if !hasError.Load() {
  376. err := model.ClearLastTestErrorAt(id)
  377. if err != nil {
  378. log.Errorf("failed to clear last test error at for channel %d: %s", id, err.Error())
  379. }
  380. }
  381. }
  382. if !isStream {
  383. c.JSON(http.StatusOK, middleware.APIResponse{
  384. Success: true,
  385. Data: results,
  386. })
  387. }
  388. }
  389. func tryTestChannel(channelID int, modelName string) bool {
  390. return trylock.Lock(fmt.Sprintf("channel_test_lock:%d:%s", channelID, modelName), 30*time.Second)
  391. }
  392. func AutoTestBannedModels() {
  393. log := log.WithFields(log.Fields{
  394. "auto_test_banned_models": "true",
  395. })
  396. channels, err := monitor.GetAllBannedModelChannels(context.Background())
  397. if err != nil {
  398. log.Errorf("failed to get banned channels: %s", err.Error())
  399. return
  400. }
  401. if len(channels) == 0 {
  402. return
  403. }
  404. mc := model.LoadModelCaches()
  405. for modelName, ids := range channels {
  406. for _, id := range ids {
  407. if !tryTestChannel(int(id), modelName) {
  408. continue
  409. }
  410. channel, err := model.LoadChannelByID(int(id))
  411. if err != nil {
  412. log.Errorf("failed to get channel by model %s: %s", modelName, err.Error())
  413. continue
  414. }
  415. result, err := testSingleModel(mc, channel, modelName)
  416. if err != nil {
  417. notify.Error(fmt.Sprintf("channel %s (type: %d, id: %d) model %s test failed", channel.Name, channel.Type, channel.ID, modelName), err.Error())
  418. continue
  419. }
  420. if result.Success {
  421. notify.Info(fmt.Sprintf("channel %s (type: %d, id: %d) model %s test success", channel.Name, channel.Type, channel.ID, modelName), "unban it")
  422. err = monitor.ClearChannelModelErrors(context.Background(), modelName, channel.ID)
  423. if err != nil {
  424. log.Errorf("clear channel errors failed: %+v", err)
  425. }
  426. } else {
  427. notify.Error(fmt.Sprintf("channel %s (type: %d, id: %d) model %s test failed", channel.Name, channel.Type, channel.ID, modelName),
  428. fmt.Sprintf("code: %d, response: %s", result.Code, result.Response))
  429. }
  430. }
  431. }
  432. }