relay-channel.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. package controller
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "math/rand/v2"
  7. "strconv"
  8. "github.com/gin-gonic/gin"
  9. "github.com/labring/aiproxy/core/common"
  10. "github.com/labring/aiproxy/core/middleware"
  11. "github.com/labring/aiproxy/core/model"
  12. "github.com/labring/aiproxy/core/monitor"
  13. "github.com/labring/aiproxy/core/relay/adaptors"
  14. "github.com/labring/aiproxy/core/relay/mode"
  15. )
  16. const (
  17. AIProxyChannelHeader = "Aiproxy-Channel"
  18. )
  19. func GetChannelFromHeader(
  20. header string,
  21. mc *model.ModelCaches,
  22. availableSet []string,
  23. model string,
  24. m mode.Mode,
  25. ) (*model.Channel, error) {
  26. channelIDInt, err := strconv.ParseInt(header, 10, 64)
  27. if err != nil {
  28. return nil, err
  29. }
  30. for _, set := range availableSet {
  31. enabledChannels := mc.EnabledModel2ChannelsBySet[set][model]
  32. if len(enabledChannels) > 0 {
  33. for _, channel := range enabledChannels {
  34. if int64(channel.ID) == channelIDInt {
  35. a, ok := adaptors.GetAdaptor(channel.Type)
  36. if !ok {
  37. return nil, fmt.Errorf("adaptor not found for channel %d", channel.ID)
  38. }
  39. if !a.SupportMode(m) {
  40. return nil, fmt.Errorf("channel %d not supported by adaptor", channel.ID)
  41. }
  42. return channel, nil
  43. }
  44. }
  45. }
  46. disabledChannels := mc.DisabledModel2ChannelsBySet[set][model]
  47. if len(disabledChannels) > 0 {
  48. for _, channel := range disabledChannels {
  49. if int64(channel.ID) == channelIDInt {
  50. a, ok := adaptors.GetAdaptor(channel.Type)
  51. if !ok {
  52. return nil, fmt.Errorf("adaptor not found for channel %d", channel.ID)
  53. }
  54. if !a.SupportMode(m) {
  55. return nil, fmt.Errorf("channel %d not supported by adaptor", channel.ID)
  56. }
  57. return channel, nil
  58. }
  59. }
  60. }
  61. }
  62. return nil, fmt.Errorf("channel %d not found for model `%s`", channelIDInt, model)
  63. }
  64. func needPinChannel(m mode.Mode) bool {
  65. switch m {
  66. case mode.VideoGenerationsGetJobs,
  67. mode.VideoGenerationsContent,
  68. mode.ResponsesGet,
  69. mode.ResponsesDelete,
  70. mode.ResponsesCancel,
  71. mode.ResponsesInputItems:
  72. return true
  73. default:
  74. return false
  75. }
  76. }
  77. func GetChannelFromRequest(
  78. c *gin.Context,
  79. mc *model.ModelCaches,
  80. availableSet []string,
  81. modelName string,
  82. m mode.Mode,
  83. ) (*model.Channel, error) {
  84. channelID := middleware.GetChannelID(c)
  85. if channelID == 0 {
  86. if needPinChannel(m) {
  87. return nil, fmt.Errorf("%s need pinned channel", m)
  88. }
  89. return nil, nil
  90. }
  91. for _, set := range availableSet {
  92. enabledChannels := mc.EnabledModel2ChannelsBySet[set][modelName]
  93. if len(enabledChannels) > 0 {
  94. for _, channel := range enabledChannels {
  95. if channel.ID == channelID {
  96. a, ok := adaptors.GetAdaptor(channel.Type)
  97. if !ok {
  98. return nil, fmt.Errorf(
  99. "adaptor not found for pinned channel %d",
  100. channel.ID,
  101. )
  102. }
  103. if !a.SupportMode(m) {
  104. return nil, fmt.Errorf(
  105. "pinned channel %d not supported by adaptor",
  106. channel.ID,
  107. )
  108. }
  109. return channel, nil
  110. }
  111. }
  112. }
  113. }
  114. return nil, fmt.Errorf("pinned channel %d not found for model `%s`", channelID, modelName)
  115. }
  116. var (
  117. ErrChannelsNotFound = errors.New("channels not found")
  118. ErrChannelsExhausted = errors.New("channels exhausted")
  119. )
  120. func GetRandomChannel(
  121. mc *model.ModelCaches,
  122. availableSet []string,
  123. modelName string,
  124. mode mode.Mode,
  125. errorRates map[int64]float64,
  126. ignoreChannelMap map[int64]struct{},
  127. ) (*model.Channel, []*model.Channel, error) {
  128. channelMap := make(map[int]*model.Channel)
  129. if len(availableSet) != 0 {
  130. for _, set := range availableSet {
  131. for _, channel := range mc.EnabledModel2ChannelsBySet[set][modelName] {
  132. a, ok := adaptors.GetAdaptor(channel.Type)
  133. if !ok {
  134. continue
  135. }
  136. if !a.SupportMode(mode) {
  137. continue
  138. }
  139. channelMap[channel.ID] = channel
  140. }
  141. }
  142. } else {
  143. for _, sets := range mc.EnabledModel2ChannelsBySet {
  144. for _, channel := range sets[modelName] {
  145. a, ok := adaptors.GetAdaptor(channel.Type)
  146. if !ok {
  147. continue
  148. }
  149. if !a.SupportMode(mode) {
  150. continue
  151. }
  152. channelMap[channel.ID] = channel
  153. }
  154. }
  155. }
  156. migratedChannels := make([]*model.Channel, 0, len(channelMap))
  157. for _, channel := range channelMap {
  158. migratedChannels = append(migratedChannels, channel)
  159. }
  160. channel, err := ignoreChannel(migratedChannels, mode, errorRates, ignoreChannelMap, nil)
  161. return channel, migratedChannels, err
  162. }
  163. func getPriority(channel *model.Channel, errorRate float64) int32 {
  164. priority := channel.GetPriority()
  165. if errorRate > 1 {
  166. errorRate = 1
  167. } else if errorRate < 0.1 {
  168. errorRate = 0.1
  169. }
  170. return int32(float64(priority) / errorRate)
  171. }
  172. func ignoreChannel(
  173. channels []*model.Channel,
  174. mode mode.Mode,
  175. errorRates map[int64]float64,
  176. ignoreChannelIDs ...map[int64]struct{},
  177. ) (*model.Channel, error) {
  178. if len(channels) == 0 {
  179. return nil, ErrChannelsNotFound
  180. }
  181. channels = filterChannels(channels, mode, ignoreChannelIDs...)
  182. if len(channels) == 0 {
  183. return nil, ErrChannelsExhausted
  184. }
  185. if len(channels) == 1 {
  186. return channels[0], nil
  187. }
  188. var totalWeight int32
  189. cachedPrioritys := make([]int32, len(channels))
  190. for i, ch := range channels {
  191. priority := getPriority(ch, errorRates[int64(ch.ID)])
  192. totalWeight += priority
  193. cachedPrioritys[i] = priority
  194. }
  195. if totalWeight == 0 {
  196. return channels[rand.IntN(len(channels))], nil
  197. }
  198. r := rand.Int32N(totalWeight)
  199. for i, ch := range channels {
  200. r -= cachedPrioritys[i]
  201. if r < 0 {
  202. return ch, nil
  203. }
  204. }
  205. return channels[rand.IntN(len(channels))], nil
  206. }
  207. func getChannelWithFallback(
  208. cache *model.ModelCaches,
  209. availableSet []string,
  210. modelName string,
  211. mode mode.Mode,
  212. errorRates map[int64]float64,
  213. ignoreChannelIDs map[int64]struct{},
  214. ) (*model.Channel, []*model.Channel, error) {
  215. channel, migratedChannels, err := GetRandomChannel(
  216. cache,
  217. availableSet,
  218. modelName,
  219. mode,
  220. errorRates,
  221. ignoreChannelIDs)
  222. if err == nil {
  223. return channel, migratedChannels, nil
  224. }
  225. if !errors.Is(err, ErrChannelsExhausted) {
  226. return nil, migratedChannels, err
  227. }
  228. channel, migratedChannels, err = GetRandomChannel(
  229. cache,
  230. availableSet,
  231. modelName,
  232. mode,
  233. errorRates,
  234. nil,
  235. )
  236. return channel, migratedChannels, err
  237. }
  238. type initialChannel struct {
  239. channel *model.Channel
  240. designatedChannel bool
  241. ignoreChannelIDs map[int64]struct{}
  242. errorRates map[int64]float64
  243. migratedChannels []*model.Channel
  244. }
  245. func getInitialChannel(c *gin.Context, modelName string, m mode.Mode) (*initialChannel, error) {
  246. log := common.GetLogger(c)
  247. group := middleware.GetGroup(c)
  248. availableSet := group.GetAvailableSets()
  249. if channelHeader := c.Request.Header.Get(AIProxyChannelHeader); channelHeader != "" {
  250. if group.Status != model.GroupStatusInternal {
  251. return nil, errors.New("channel header is not allowed in non-internal group")
  252. }
  253. channel, err := GetChannelFromHeader(
  254. channelHeader,
  255. middleware.GetModelCaches(c),
  256. availableSet,
  257. modelName,
  258. m,
  259. )
  260. if err != nil {
  261. return nil, err
  262. }
  263. log.Data["designated_channel"] = "true"
  264. return &initialChannel{channel: channel, designatedChannel: true}, nil
  265. }
  266. channel, err := GetChannelFromRequest(
  267. c,
  268. middleware.GetModelCaches(c),
  269. availableSet,
  270. modelName,
  271. m,
  272. )
  273. if err != nil {
  274. return nil, err
  275. }
  276. if channel != nil {
  277. return &initialChannel{channel: channel, designatedChannel: true}, nil
  278. }
  279. mc := middleware.GetModelCaches(c)
  280. ignoreChannelIDs, err := monitor.GetBannedChannelsMapWithModel(c.Request.Context(), modelName)
  281. if err != nil {
  282. log.Errorf("get %s auto banned channels failed: %+v", modelName, err)
  283. }
  284. log.Debugf("%s model banned channels: %+v", modelName, ignoreChannelIDs)
  285. errorRates, err := monitor.GetModelChannelErrorRate(c.Request.Context(), modelName)
  286. if err != nil {
  287. log.Errorf("get channel model error rates failed: %+v", err)
  288. }
  289. channel, migratedChannels, err := getChannelWithFallback(
  290. mc,
  291. availableSet,
  292. modelName,
  293. m,
  294. errorRates,
  295. ignoreChannelIDs,
  296. )
  297. if err != nil {
  298. return nil, err
  299. }
  300. return &initialChannel{
  301. channel: channel,
  302. ignoreChannelIDs: ignoreChannelIDs,
  303. errorRates: errorRates,
  304. migratedChannels: migratedChannels,
  305. }, nil
  306. }
  307. func getWebSearchChannel(
  308. ctx context.Context,
  309. mc *model.ModelCaches,
  310. modelName string,
  311. ) (*model.Channel, error) {
  312. ignoreChannelIDs, _ := monitor.GetBannedChannelsMapWithModel(ctx, modelName)
  313. errorRates, _ := monitor.GetModelChannelErrorRate(ctx, modelName)
  314. channel, _, err := getChannelWithFallback(
  315. mc,
  316. nil,
  317. modelName,
  318. mode.ChatCompletions,
  319. errorRates,
  320. ignoreChannelIDs)
  321. if err != nil {
  322. return nil, err
  323. }
  324. return channel, nil
  325. }
  326. func getRetryChannel(state *retryState, currentRetry, totalRetries int) (*model.Channel, error) {
  327. if state.exhausted {
  328. if state.lastHasPermissionChannel == nil {
  329. return nil, ErrChannelsExhausted
  330. }
  331. return state.lastHasPermissionChannel, nil
  332. }
  333. // For the last retry, filter out all previously failed channels if there are other options
  334. if currentRetry == totalRetries-1 && len(state.failedChannelIDs) > 0 {
  335. // Check if there are channels available after filtering out failed channels
  336. newChannel, err := ignoreChannel(
  337. state.migratedChannels,
  338. state.meta.Mode,
  339. state.errorRates,
  340. state.ignoreChannelIDs,
  341. state.failedChannelIDs,
  342. )
  343. if err == nil {
  344. return newChannel, nil
  345. }
  346. // If no channels available after filtering, fall back to not using failed channels filter
  347. }
  348. newChannel, err := ignoreChannel(
  349. state.migratedChannels,
  350. state.meta.Mode,
  351. state.errorRates,
  352. state.ignoreChannelIDs,
  353. )
  354. if err != nil {
  355. if !errors.Is(err, ErrChannelsExhausted) || state.lastHasPermissionChannel == nil {
  356. return nil, err
  357. }
  358. state.exhausted = true
  359. return state.lastHasPermissionChannel, nil
  360. }
  361. return newChannel, nil
  362. }
  363. func filterChannels(
  364. channels []*model.Channel,
  365. mode mode.Mode,
  366. ignoreChannel ...map[int64]struct{},
  367. ) []*model.Channel {
  368. filtered := make([]*model.Channel, 0)
  369. for _, channel := range channels {
  370. if channel.Status != model.ChannelStatusEnabled {
  371. continue
  372. }
  373. a, ok := adaptors.GetAdaptor(channel.Type)
  374. if !ok {
  375. continue
  376. }
  377. if !a.SupportMode(mode) {
  378. continue
  379. }
  380. chid := int64(channel.ID)
  381. needIgnore := false
  382. for _, ignores := range ignoreChannel {
  383. if ignores == nil {
  384. continue
  385. }
  386. _, needIgnore = ignores[chid]
  387. if needIgnore {
  388. break
  389. }
  390. }
  391. if needIgnore {
  392. continue
  393. }
  394. filtered = append(filtered, channel)
  395. }
  396. return filtered
  397. }