channel-test.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. package controller
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "math/rand/v2"
  8. "net/http"
  9. "net/http/httptest"
  10. "net/url"
  11. "slices"
  12. "strconv"
  13. "sync"
  14. "sync/atomic"
  15. "time"
  16. "github.com/gin-gonic/gin"
  17. "github.com/labring/aiproxy/common"
  18. "github.com/labring/aiproxy/common/notify"
  19. "github.com/labring/aiproxy/common/render"
  20. "github.com/labring/aiproxy/common/trylock"
  21. "github.com/labring/aiproxy/middleware"
  22. "github.com/labring/aiproxy/model"
  23. "github.com/labring/aiproxy/monitor"
  24. "github.com/labring/aiproxy/relay/channeltype"
  25. "github.com/labring/aiproxy/relay/meta"
  26. "github.com/labring/aiproxy/relay/relaymode"
  27. "github.com/labring/aiproxy/relay/utils"
  28. log "github.com/sirupsen/logrus"
  29. )
  30. const channelTestRequestID = "channel-test"
  31. var (
  32. modelTypeCache map[string]relaymode.Mode = make(map[string]relaymode.Mode)
  33. modelTypeCacheOnce sync.Once
  34. )
  35. func guessModelType(model string) relaymode.Mode {
  36. modelTypeCacheOnce.Do(func() {
  37. for _, c := range channeltype.ChannelAdaptor {
  38. for _, m := range c.GetModelList() {
  39. if _, ok := modelTypeCache[m.Model]; !ok {
  40. modelTypeCache[m.Model] = m.Type
  41. }
  42. }
  43. }
  44. })
  45. if cachedType, ok := modelTypeCache[model]; ok {
  46. return cachedType
  47. }
  48. return relaymode.Unknown
  49. }
  50. // testSingleModel tests a single model in the channel
  51. func testSingleModel(mc *model.ModelCaches, channel *model.Channel, modelName string) (*model.ChannelTest, error) {
  52. modelConfig, ok := mc.ModelConfig.GetModelConfig(modelName)
  53. if !ok {
  54. return nil, errors.New(modelName + " model config not found")
  55. }
  56. if modelConfig.Type == relaymode.Unknown {
  57. newModelConfig := *modelConfig
  58. newModelConfig.Type = guessModelType(modelName)
  59. modelConfig = &newModelConfig
  60. }
  61. body, mode, err := utils.BuildRequest(modelConfig)
  62. if err != nil {
  63. return nil, err
  64. }
  65. w := httptest.NewRecorder()
  66. newc, _ := gin.CreateTestContext(w)
  67. newc.Request = &http.Request{
  68. URL: &url.URL{},
  69. Body: io.NopCloser(body),
  70. Header: make(http.Header),
  71. }
  72. middleware.SetRequestID(newc, channelTestRequestID)
  73. meta := meta.NewMeta(
  74. channel,
  75. mode,
  76. modelName,
  77. modelConfig,
  78. meta.WithRequestID(channelTestRequestID),
  79. meta.WithChannelTest(true),
  80. )
  81. relayController, ok := relayController(mode)
  82. if !ok {
  83. return nil, fmt.Errorf("relay mode %d not implemented", mode)
  84. }
  85. bizErr := relayController(meta, newc)
  86. success := bizErr == nil
  87. var respStr string
  88. var code int
  89. if success {
  90. switch meta.Mode {
  91. case relaymode.AudioSpeech,
  92. relaymode.ImagesGenerations:
  93. respStr = ""
  94. default:
  95. respStr = w.Body.String()
  96. }
  97. code = w.Code
  98. } else {
  99. respStr = bizErr.Error.JSONOrEmpty()
  100. code = bizErr.StatusCode
  101. }
  102. return channel.UpdateModelTest(
  103. meta.RequestAt,
  104. meta.OriginModel,
  105. meta.ActualModel,
  106. meta.Mode,
  107. time.Since(meta.RequestAt).Seconds(),
  108. success,
  109. respStr,
  110. code,
  111. )
  112. }
  113. //nolint:goconst
  114. func TestChannel(c *gin.Context) {
  115. id, err := strconv.Atoi(c.Param("id"))
  116. if err != nil {
  117. c.JSON(http.StatusOK, middleware.APIResponse{
  118. Success: false,
  119. Message: err.Error(),
  120. })
  121. return
  122. }
  123. modelName := c.Param("model")
  124. if modelName == "" {
  125. c.JSON(http.StatusOK, middleware.APIResponse{
  126. Success: false,
  127. Message: "model is required",
  128. })
  129. return
  130. }
  131. channel, err := model.LoadChannelByID(id)
  132. if err != nil {
  133. c.JSON(http.StatusOK, middleware.APIResponse{
  134. Success: false,
  135. Message: "channel not found",
  136. })
  137. return
  138. }
  139. if !slices.Contains(channel.Models, modelName) {
  140. c.JSON(http.StatusOK, middleware.APIResponse{
  141. Success: false,
  142. Message: "model not supported by channel",
  143. })
  144. return
  145. }
  146. ct, err := testSingleModel(model.LoadModelCaches(), channel, modelName)
  147. if err != nil {
  148. log.Errorf("failed to test channel %s(%d) model %s: %s", channel.Name, channel.ID, modelName, err.Error())
  149. c.JSON(http.StatusOK, middleware.APIResponse{
  150. Success: false,
  151. Message: fmt.Sprintf("failed to test channel %s(%d) model %s: %s", channel.Name, channel.ID, modelName, err.Error()),
  152. })
  153. return
  154. }
  155. if c.Query("success_body") != "true" && ct.Success {
  156. ct.Response = ""
  157. }
  158. c.JSON(http.StatusOK, middleware.APIResponse{
  159. Success: true,
  160. Data: ct,
  161. })
  162. }
  163. type testResult struct {
  164. Data *model.ChannelTest `json:"data,omitempty"`
  165. Message string `json:"message,omitempty"`
  166. Success bool `json:"success"`
  167. }
  168. func processTestResult(mc *model.ModelCaches, channel *model.Channel, modelName string, returnSuccess bool, successResponseBody bool) *testResult {
  169. ct, err := testSingleModel(mc, channel, modelName)
  170. e := &utils.UnsupportedModelTypeError{}
  171. if errors.As(err, &e) {
  172. log.Errorf("model %s not supported test: %s", modelName, err.Error())
  173. return nil
  174. }
  175. result := &testResult{
  176. Success: err == nil,
  177. }
  178. if err != nil {
  179. result.Message = fmt.Sprintf("failed to test channel %s(%d) model %s: %s", channel.Name, channel.ID, modelName, err.Error())
  180. return result
  181. }
  182. if !ct.Success {
  183. result.Data = ct
  184. return result
  185. }
  186. if !returnSuccess {
  187. return nil
  188. }
  189. if !successResponseBody {
  190. ct.Response = ""
  191. }
  192. result.Data = ct
  193. return result
  194. }
  195. func TestChannelModels(c *gin.Context) {
  196. id, err := strconv.Atoi(c.Param("id"))
  197. if err != nil {
  198. c.JSON(http.StatusOK, middleware.APIResponse{
  199. Success: false,
  200. Message: err.Error(),
  201. })
  202. return
  203. }
  204. channel, err := model.LoadChannelByID(id)
  205. if err != nil {
  206. c.JSON(http.StatusOK, middleware.APIResponse{
  207. Success: false,
  208. Message: "channel not found",
  209. })
  210. return
  211. }
  212. returnSuccess := c.Query("return_success") == "true"
  213. successResponseBody := c.Query("success_body") == "true"
  214. isStream := c.Query("stream") == "true"
  215. if isStream {
  216. common.SetEventStreamHeaders(c)
  217. }
  218. results := make([]*testResult, 0)
  219. resultsMutex := sync.Mutex{}
  220. hasError := atomic.Bool{}
  221. var wg sync.WaitGroup
  222. semaphore := make(chan struct{}, 5)
  223. models := slices.Clone(channel.Models)
  224. rand.Shuffle(len(models), func(i, j int) {
  225. models[i], models[j] = models[j], models[i]
  226. })
  227. mc := model.LoadModelCaches()
  228. for _, modelName := range models {
  229. wg.Add(1)
  230. semaphore <- struct{}{}
  231. go func(model string) {
  232. defer wg.Done()
  233. defer func() { <-semaphore }()
  234. result := processTestResult(mc, channel, model, returnSuccess, successResponseBody)
  235. if result == nil {
  236. return
  237. }
  238. if !result.Success || (result.Data != nil && !result.Data.Success) {
  239. hasError.Store(true)
  240. }
  241. resultsMutex.Lock()
  242. if isStream {
  243. err := render.ObjectData(c, result)
  244. if err != nil {
  245. log.Errorf("failed to render result: %s", err.Error())
  246. }
  247. } else {
  248. results = append(results, result)
  249. }
  250. resultsMutex.Unlock()
  251. }(modelName)
  252. }
  253. wg.Wait()
  254. if !hasError.Load() {
  255. err := model.ClearLastTestErrorAt(channel.ID)
  256. if err != nil {
  257. log.Errorf("failed to clear last test error at for channel %s(%d): %s", channel.Name, channel.ID, err.Error())
  258. }
  259. }
  260. if !isStream {
  261. c.JSON(http.StatusOK, middleware.APIResponse{
  262. Success: true,
  263. Data: results,
  264. })
  265. }
  266. }
  267. func TestAllChannels(c *gin.Context) {
  268. testDisabled := c.Query("test_disabled") == "true"
  269. var channels []*model.Channel
  270. var err error
  271. if testDisabled {
  272. channels, err = model.LoadChannels()
  273. } else {
  274. channels, err = model.LoadEnabledChannels()
  275. }
  276. if err != nil {
  277. c.JSON(http.StatusOK, middleware.APIResponse{
  278. Success: false,
  279. Message: err.Error(),
  280. })
  281. return
  282. }
  283. returnSuccess := c.Query("return_success") == "true"
  284. successResponseBody := c.Query("success_body") == "true"
  285. isStream := c.Query("stream") == "true"
  286. if isStream {
  287. common.SetEventStreamHeaders(c)
  288. }
  289. results := make([]*testResult, 0)
  290. resultsMutex := sync.Mutex{}
  291. hasErrorMap := make(map[int]*atomic.Bool)
  292. var wg sync.WaitGroup
  293. semaphore := make(chan struct{}, 5)
  294. newChannels := slices.Clone(channels)
  295. rand.Shuffle(len(newChannels), func(i, j int) {
  296. newChannels[i], newChannels[j] = newChannels[j], newChannels[i]
  297. })
  298. mc := model.LoadModelCaches()
  299. for _, channel := range newChannels {
  300. channelHasError := &atomic.Bool{}
  301. hasErrorMap[channel.ID] = channelHasError
  302. models := slices.Clone(channel.Models)
  303. rand.Shuffle(len(models), func(i, j int) {
  304. models[i], models[j] = models[j], models[i]
  305. })
  306. for _, modelName := range models {
  307. wg.Add(1)
  308. semaphore <- struct{}{}
  309. go func(model string, ch *model.Channel, hasError *atomic.Bool) {
  310. defer wg.Done()
  311. defer func() { <-semaphore }()
  312. result := processTestResult(mc, ch, model, returnSuccess, successResponseBody)
  313. if result == nil {
  314. return
  315. }
  316. if !result.Success || (result.Data != nil && !result.Data.Success) {
  317. hasError.Store(true)
  318. }
  319. resultsMutex.Lock()
  320. if isStream {
  321. err := render.ObjectData(c, result)
  322. if err != nil {
  323. log.Errorf("failed to render result: %s", err.Error())
  324. }
  325. } else {
  326. results = append(results, result)
  327. }
  328. resultsMutex.Unlock()
  329. }(modelName, channel, channelHasError)
  330. }
  331. }
  332. wg.Wait()
  333. for id, hasError := range hasErrorMap {
  334. if !hasError.Load() {
  335. err := model.ClearLastTestErrorAt(id)
  336. if err != nil {
  337. log.Errorf("failed to clear last test error at for channel %d: %s", id, err.Error())
  338. }
  339. }
  340. }
  341. if !isStream {
  342. c.JSON(http.StatusOK, middleware.APIResponse{
  343. Success: true,
  344. Data: results,
  345. })
  346. }
  347. }
  348. func tryTestChannel(channelID int, modelName string) bool {
  349. return trylock.Lock(fmt.Sprintf("channel_test_lock:%d:%s", channelID, modelName), 30*time.Second)
  350. }
  351. func AutoTestBannedModels() {
  352. log := log.WithFields(log.Fields{
  353. "auto_test_banned_models": "true",
  354. })
  355. channels, err := monitor.GetAllBannedModelChannels(context.Background())
  356. if err != nil {
  357. log.Errorf("failed to get banned channels: %s", err.Error())
  358. return
  359. }
  360. if len(channels) == 0 {
  361. return
  362. }
  363. mc := model.LoadModelCaches()
  364. for modelName, ids := range channels {
  365. for _, id := range ids {
  366. if !tryTestChannel(int(id), modelName) {
  367. continue
  368. }
  369. channel, err := model.LoadChannelByID(int(id))
  370. if err != nil {
  371. log.Errorf("failed to get channel by model %s: %s", modelName, err.Error())
  372. continue
  373. }
  374. result, err := testSingleModel(mc, channel, modelName)
  375. if err != nil {
  376. notify.Error(fmt.Sprintf("channel[%d] %s(%d) model %s test failed", channel.Type, channel.Name, channel.ID, modelName), err.Error())
  377. continue
  378. }
  379. if result.Success {
  380. notify.Info(fmt.Sprintf("channel[%d] %s(%d) model %s test success", channel.Type, channel.Name, channel.ID, modelName), "unban it")
  381. err = monitor.ClearChannelModelErrors(context.Background(), modelName, channel.ID)
  382. if err != nil {
  383. log.Errorf("clear channel errors failed: %+v", err)
  384. }
  385. } else {
  386. notify.Error(fmt.Sprintf("channel[%d] %s(%d) model %s test failed", channel.Type, channel.Name, channel.ID, modelName),
  387. fmt.Sprintf("code: %d, response: %s", result.Code, result.Response))
  388. }
  389. }
  390. }
  391. }