relay-channel.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510
  1. package controller
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "math/rand/v2"
  7. "strconv"
  8. "github.com/gin-gonic/gin"
  9. "github.com/labring/aiproxy/core/common"
  10. "github.com/labring/aiproxy/core/middleware"
  11. "github.com/labring/aiproxy/core/model"
  12. "github.com/labring/aiproxy/core/monitor"
  13. "github.com/labring/aiproxy/core/relay/adaptors"
  14. "github.com/labring/aiproxy/core/relay/mode"
  15. )
  16. const (
  17. AIProxyChannelHeader = "Aiproxy-Channel"
  18. // maxRetryErrorRate is the maximum error rate threshold for channel retry selection
  19. // Channels with error rate higher than this will be filtered out during retry
  20. maxRetryErrorRate = 0.75
  21. )
  22. func GetChannelFromHeader(
  23. header string,
  24. mc *model.ModelCaches,
  25. availableSet []string,
  26. model string,
  27. m mode.Mode,
  28. ) (*model.Channel, error) {
  29. channelIDInt, err := strconv.ParseInt(header, 10, 64)
  30. if err != nil {
  31. return nil, err
  32. }
  33. for _, set := range availableSet {
  34. enabledChannels := mc.EnabledModel2ChannelsBySet[set][model]
  35. if len(enabledChannels) > 0 {
  36. for _, channel := range enabledChannels {
  37. if int64(channel.ID) == channelIDInt {
  38. a, ok := adaptors.GetAdaptor(channel.Type)
  39. if !ok {
  40. return nil, fmt.Errorf("adaptor not found for channel %d", channel.ID)
  41. }
  42. if !a.SupportMode(m) {
  43. return nil, fmt.Errorf("channel %d not supported by adaptor", channel.ID)
  44. }
  45. return channel, nil
  46. }
  47. }
  48. }
  49. disabledChannels := mc.DisabledModel2ChannelsBySet[set][model]
  50. if len(disabledChannels) > 0 {
  51. for _, channel := range disabledChannels {
  52. if int64(channel.ID) == channelIDInt {
  53. a, ok := adaptors.GetAdaptor(channel.Type)
  54. if !ok {
  55. return nil, fmt.Errorf("adaptor not found for channel %d", channel.ID)
  56. }
  57. if !a.SupportMode(m) {
  58. return nil, fmt.Errorf("channel %d not supported by adaptor", channel.ID)
  59. }
  60. return channel, nil
  61. }
  62. }
  63. }
  64. }
  65. return nil, fmt.Errorf("channel %d not found for model `%s`", channelIDInt, model)
  66. }
  67. func needPinChannel(m mode.Mode) bool {
  68. switch m {
  69. case mode.VideoGenerationsGetJobs,
  70. mode.VideoGenerationsContent,
  71. mode.ResponsesGet,
  72. mode.ResponsesDelete,
  73. mode.ResponsesCancel,
  74. mode.ResponsesInputItems:
  75. return true
  76. default:
  77. return false
  78. }
  79. }
  80. func GetChannelFromRequest(
  81. c *gin.Context,
  82. mc *model.ModelCaches,
  83. availableSet []string,
  84. modelName string,
  85. m mode.Mode,
  86. ) (*model.Channel, error) {
  87. channelID := middleware.GetChannelID(c)
  88. if channelID == 0 {
  89. if needPinChannel(m) {
  90. return nil, fmt.Errorf("%s need pinned channel", m)
  91. }
  92. return nil, nil
  93. }
  94. for _, set := range availableSet {
  95. enabledChannels := mc.EnabledModel2ChannelsBySet[set][modelName]
  96. if len(enabledChannels) > 0 {
  97. for _, channel := range enabledChannels {
  98. if channel.ID == channelID {
  99. a, ok := adaptors.GetAdaptor(channel.Type)
  100. if !ok {
  101. return nil, fmt.Errorf(
  102. "adaptor not found for pinned channel %d",
  103. channel.ID,
  104. )
  105. }
  106. if !a.SupportMode(m) {
  107. return nil, fmt.Errorf(
  108. "pinned channel %d not supported by adaptor",
  109. channel.ID,
  110. )
  111. }
  112. return channel, nil
  113. }
  114. }
  115. }
  116. }
  117. return nil, fmt.Errorf("pinned channel %d not found for model `%s`", channelID, modelName)
  118. }
  119. var (
  120. ErrChannelsNotFound = errors.New("channels not found")
  121. ErrChannelsExhausted = errors.New("channels exhausted")
  122. )
  123. func getRandomChannel(
  124. mc *model.ModelCaches,
  125. availableSet []string,
  126. modelName string,
  127. mode mode.Mode,
  128. errorRates map[int64]float64,
  129. maxErrorRate float64,
  130. ignoreChannelMap ...map[int64]struct{},
  131. ) (*model.Channel, []*model.Channel, error) {
  132. channelMap := make(map[int]*model.Channel)
  133. if len(availableSet) != 0 {
  134. for _, set := range availableSet {
  135. channels := mc.EnabledModel2ChannelsBySet[set][modelName]
  136. for _, channel := range channels {
  137. a, ok := adaptors.GetAdaptor(channel.Type)
  138. if !ok {
  139. continue
  140. }
  141. if !a.SupportMode(mode) {
  142. continue
  143. }
  144. channelMap[channel.ID] = channel
  145. }
  146. }
  147. } else {
  148. for _, sets := range mc.EnabledModel2ChannelsBySet {
  149. for _, channel := range sets[modelName] {
  150. a, ok := adaptors.GetAdaptor(channel.Type)
  151. if !ok {
  152. continue
  153. }
  154. if !a.SupportMode(mode) {
  155. continue
  156. }
  157. channelMap[channel.ID] = channel
  158. }
  159. }
  160. }
  161. migratedChannels := make([]*model.Channel, 0, len(channelMap))
  162. for _, channel := range channelMap {
  163. migratedChannels = append(migratedChannels, channel)
  164. }
  165. channel, err := ignoreChannel(
  166. migratedChannels,
  167. mode,
  168. errorRates,
  169. maxErrorRate,
  170. ignoreChannelMap...,
  171. )
  172. return channel, migratedChannels, err
  173. }
  174. func getPriority(channel *model.Channel, errorRate float64) int32 {
  175. priority := channel.GetPriority()
  176. if errorRate > 1 {
  177. errorRate = 1
  178. } else if errorRate < 0.1 {
  179. errorRate = 0.1
  180. }
  181. return int32(float64(priority) / errorRate)
  182. }
  183. func ignoreChannel(
  184. channels []*model.Channel,
  185. mode mode.Mode,
  186. errorRates map[int64]float64,
  187. maxErrorRate float64,
  188. ignoreChannelIDs ...map[int64]struct{},
  189. ) (*model.Channel, error) {
  190. if len(channels) == 0 {
  191. return nil, ErrChannelsNotFound
  192. }
  193. channels = filterChannels(channels, mode, errorRates, maxErrorRate, ignoreChannelIDs...)
  194. if len(channels) == 0 {
  195. return nil, ErrChannelsExhausted
  196. }
  197. if len(channels) == 1 {
  198. return channels[0], nil
  199. }
  200. var totalWeight int32
  201. cachedPrioritys := make([]int32, len(channels))
  202. for i, ch := range channels {
  203. priority := getPriority(ch, errorRates[int64(ch.ID)])
  204. totalWeight += priority
  205. cachedPrioritys[i] = priority
  206. }
  207. if totalWeight == 0 {
  208. return channels[rand.IntN(len(channels))], nil
  209. }
  210. r := rand.Int32N(totalWeight)
  211. for i, ch := range channels {
  212. r -= cachedPrioritys[i]
  213. if r < 0 {
  214. return ch, nil
  215. }
  216. }
  217. return channels[rand.IntN(len(channels))], nil
  218. }
  219. func getChannelWithFallback(
  220. cache *model.ModelCaches,
  221. availableSet []string,
  222. modelName string,
  223. mode mode.Mode,
  224. errorRates map[int64]float64,
  225. ignoreChannelIDs map[int64]struct{},
  226. ) (*model.Channel, []*model.Channel, error) {
  227. channel, migratedChannels, err := getRandomChannel(
  228. cache,
  229. availableSet,
  230. modelName,
  231. mode,
  232. errorRates,
  233. maxRetryErrorRate,
  234. ignoreChannelIDs,
  235. )
  236. if err == nil {
  237. return channel, migratedChannels, nil
  238. }
  239. if !errors.Is(err, ErrChannelsExhausted) {
  240. return nil, migratedChannels, err
  241. }
  242. return getRandomChannel(
  243. cache,
  244. availableSet,
  245. modelName,
  246. mode,
  247. errorRates,
  248. 0,
  249. )
  250. }
  251. type initialChannel struct {
  252. channel *model.Channel
  253. designatedChannel bool
  254. ignoreChannelIDs map[int64]struct{}
  255. errorRates map[int64]float64
  256. migratedChannels []*model.Channel
  257. }
  258. func getInitialChannel(c *gin.Context, modelName string, m mode.Mode) (*initialChannel, error) {
  259. log := common.GetLogger(c)
  260. group := middleware.GetGroup(c)
  261. availableSet := group.GetAvailableSets()
  262. if channelHeader := c.Request.Header.Get(AIProxyChannelHeader); channelHeader != "" {
  263. if group.Status != model.GroupStatusInternal {
  264. return nil, errors.New("channel header is not allowed in non-internal group")
  265. }
  266. channel, err := GetChannelFromHeader(
  267. channelHeader,
  268. middleware.GetModelCaches(c),
  269. availableSet,
  270. modelName,
  271. m,
  272. )
  273. if err != nil {
  274. return nil, err
  275. }
  276. log.Data["designated_channel"] = "true"
  277. return &initialChannel{channel: channel, designatedChannel: true}, nil
  278. }
  279. channel, err := GetChannelFromRequest(
  280. c,
  281. middleware.GetModelCaches(c),
  282. availableSet,
  283. modelName,
  284. m,
  285. )
  286. if err != nil {
  287. return nil, err
  288. }
  289. if channel != nil {
  290. return &initialChannel{channel: channel, designatedChannel: true}, nil
  291. }
  292. mc := middleware.GetModelCaches(c)
  293. ignoreChannelIDs, err := monitor.GetBannedChannelsMapWithModel(c.Request.Context(), modelName)
  294. if err != nil {
  295. log.Errorf("get %s auto banned channels failed: %+v", modelName, err)
  296. }
  297. log.Debugf("%s model banned channels: %+v", modelName, ignoreChannelIDs)
  298. errorRates, err := monitor.GetModelChannelErrorRate(c.Request.Context(), modelName)
  299. if err != nil {
  300. log.Errorf("get channel model error rates failed: %+v", err)
  301. }
  302. channel, migratedChannels, err := getChannelWithFallback(
  303. mc,
  304. availableSet,
  305. modelName,
  306. m,
  307. errorRates,
  308. ignoreChannelIDs,
  309. )
  310. if err != nil {
  311. return nil, err
  312. }
  313. return &initialChannel{
  314. channel: channel,
  315. ignoreChannelIDs: ignoreChannelIDs,
  316. errorRates: errorRates,
  317. migratedChannels: migratedChannels,
  318. }, nil
  319. }
  320. func getWebSearchChannel(
  321. ctx context.Context,
  322. mc *model.ModelCaches,
  323. modelName string,
  324. ) (*model.Channel, error) {
  325. ignoreChannelIDs, _ := monitor.GetBannedChannelsMapWithModel(ctx, modelName)
  326. errorRates, _ := monitor.GetModelChannelErrorRate(ctx, modelName)
  327. channel, _, err := getChannelWithFallback(
  328. mc,
  329. nil,
  330. modelName,
  331. mode.ChatCompletions,
  332. errorRates,
  333. ignoreChannelIDs)
  334. if err != nil {
  335. return nil, err
  336. }
  337. return channel, nil
  338. }
  339. func getRetryChannel(state *retryState, currentRetry, totalRetries int) (*model.Channel, error) {
  340. if state.exhausted {
  341. if state.lastHasPermissionChannel == nil {
  342. return nil, ErrChannelsExhausted
  343. }
  344. // Check if lastHasPermissionChannel has high error rate
  345. // If so, return exhausted to prevent retrying with a bad channel
  346. channelID := int64(state.lastHasPermissionChannel.ID)
  347. if errorRate, ok := state.errorRates[channelID]; ok && errorRate > maxRetryErrorRate {
  348. return nil, ErrChannelsExhausted
  349. }
  350. return state.lastHasPermissionChannel, nil
  351. }
  352. // For the last retry, filter out all previously failed channels if there are other options
  353. if currentRetry == totalRetries-1 && len(state.failedChannelIDs) > 0 {
  354. // Check if there are channels available after filtering out failed channels
  355. newChannel, err := ignoreChannel(
  356. state.migratedChannels,
  357. state.meta.Mode,
  358. state.errorRates,
  359. maxRetryErrorRate,
  360. state.ignoreChannelIDs,
  361. state.failedChannelIDs,
  362. )
  363. if err == nil {
  364. return newChannel, nil
  365. }
  366. // If no channels available after filtering, fall back to not using failed channels filter
  367. }
  368. newChannel, err := ignoreChannel(
  369. state.migratedChannels,
  370. state.meta.Mode,
  371. state.errorRates,
  372. maxRetryErrorRate,
  373. state.ignoreChannelIDs,
  374. )
  375. if err != nil {
  376. if !errors.Is(err, ErrChannelsExhausted) || state.lastHasPermissionChannel == nil {
  377. return nil, err
  378. }
  379. // Check if lastHasPermissionChannel has high error rate before using it
  380. channelID := int64(state.lastHasPermissionChannel.ID)
  381. if errorRate, ok := state.errorRates[channelID]; ok && errorRate > maxRetryErrorRate {
  382. return nil, ErrChannelsExhausted
  383. }
  384. state.exhausted = true
  385. return state.lastHasPermissionChannel, nil
  386. }
  387. return newChannel, nil
  388. }
  389. func filterChannels(
  390. channels []*model.Channel,
  391. mode mode.Mode,
  392. errorRates map[int64]float64,
  393. maxErrorRate float64,
  394. ignoreChannel ...map[int64]struct{},
  395. ) []*model.Channel {
  396. filtered := make([]*model.Channel, 0)
  397. for _, channel := range channels {
  398. if channel.Status != model.ChannelStatusEnabled {
  399. continue
  400. }
  401. a, ok := adaptors.GetAdaptor(channel.Type)
  402. if !ok {
  403. continue
  404. }
  405. if !a.SupportMode(mode) {
  406. continue
  407. }
  408. chid := int64(channel.ID)
  409. if maxErrorRate != 0 {
  410. // Filter out channels with error rate higher than threshold
  411. // This avoids amplifying attacks and retrying with bad channels
  412. if errorRate, ok := errorRates[chid]; ok && errorRate > maxErrorRate {
  413. continue
  414. }
  415. }
  416. needIgnore := false
  417. for _, ignores := range ignoreChannel {
  418. if ignores == nil {
  419. continue
  420. }
  421. _, needIgnore = ignores[chid]
  422. if needIgnore {
  423. break
  424. }
  425. }
  426. if needIgnore {
  427. continue
  428. }
  429. filtered = append(filtered, channel)
  430. }
  431. return filtered
  432. }