channel.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523
  1. package controller
  2. import (
  3. "errors"
  4. "fmt"
  5. "maps"
  6. "net/http"
  7. "slices"
  8. "strconv"
  9. "strings"
  10. "github.com/bytedance/sonic/ast"
  11. "github.com/gin-gonic/gin"
  12. "github.com/labring/aiproxy/core/controller/utils"
  13. "github.com/labring/aiproxy/core/middleware"
  14. "github.com/labring/aiproxy/core/model"
  15. "github.com/labring/aiproxy/core/monitor"
  16. "github.com/labring/aiproxy/core/relay/adaptor"
  17. "github.com/labring/aiproxy/core/relay/adaptors"
  18. log "github.com/sirupsen/logrus"
  19. )
  20. // ChannelTypeMetas godoc
  21. //
  22. // @Summary Get channel type metadata
  23. // @Description Returns metadata for all channel types
  24. // @Tags channels
  25. // @Produce json
  26. // @Security ApiKeyAuth
  27. // @Success 200 {object} middleware.APIResponse{data=map[int]adaptors.AdaptorMeta}
  28. // @Router /api/channels/type_metas [get]
  29. func ChannelTypeMetas(c *gin.Context) {
  30. middleware.SuccessResponse(c, adaptors.ChannelMetas)
  31. }
  32. // GetChannels godoc
  33. //
  34. // @Summary Get channels with pagination
  35. // @Description Returns a paginated list of channels with optional filters
  36. // @Tags channels
  37. // @Produce json
  38. // @Security ApiKeyAuth
  39. // @Param page query int false "Page number"
  40. // @Param per_page query int false "Items per page"
  41. // @Param id query int false "Filter by id"
  42. // @Param name query string false "Filter by name"
  43. // @Param key query string false "Filter by key"
  44. // @Param channel_type query int false "Filter by channel type"
  45. // @Param base_url query string false "Filter by base URL"
  46. // @Param order query string false "Order by field"
  47. // @Success 200 {object} middleware.APIResponse{data=map[string]any{channels=[]model.Channel,total=int}}
  48. // @Router /api/channels/ [get]
  49. func GetChannels(c *gin.Context) {
  50. page, perPage := utils.ParsePageParams(c)
  51. id, _ := strconv.Atoi(c.Query("id"))
  52. name := c.Query("name")
  53. key := c.Query("key")
  54. channelType, _ := strconv.Atoi(c.Query("channel_type"))
  55. baseURL := c.Query("base_url")
  56. order := c.Query("order")
  57. channels, total, err := model.GetChannels(
  58. page,
  59. perPage,
  60. id,
  61. name,
  62. key,
  63. channelType,
  64. baseURL,
  65. order,
  66. )
  67. if err != nil {
  68. middleware.ErrorResponse(c, http.StatusInternalServerError, err.Error())
  69. return
  70. }
  71. middleware.SuccessResponse(c, gin.H{
  72. "channels": channels,
  73. "total": total,
  74. })
  75. }
  76. // GetAllChannels godoc
  77. //
  78. // @Summary Get all channels
  79. // @Description Returns a list of all channels without pagination
  80. // @Tags channels
  81. // @Produce json
  82. // @Security ApiKeyAuth
  83. // @Success 200 {object} middleware.APIResponse{data=[]model.Channel}
  84. // @Router /api/channels/all [get]
  85. func GetAllChannels(c *gin.Context) {
  86. channels, err := model.GetAllChannels()
  87. if err != nil {
  88. middleware.ErrorResponse(c, http.StatusInternalServerError, err.Error())
  89. return
  90. }
  91. middleware.SuccessResponse(c, channels)
  92. }
  93. // AddChannels godoc
  94. //
  95. // @Summary Add multiple channels
  96. // @Description Adds multiple channels in a batch operation
  97. // @Tags channels
  98. // @Accept json
  99. // @Produce json
  100. // @Security ApiKeyAuth
  101. // @Param channels body []AddChannelRequest true "Channel information"
  102. // @Success 200 {object} middleware.APIResponse
  103. // @Router /api/channels/ [post]
  104. func AddChannels(c *gin.Context) {
  105. channels := make([]*AddChannelRequest, 0)
  106. err := c.ShouldBindJSON(&channels)
  107. if err != nil {
  108. middleware.ErrorResponse(c, http.StatusBadRequest, err.Error())
  109. return
  110. }
  111. _channels := make([]*model.Channel, 0, len(channels))
  112. for _, channel := range channels {
  113. channels, err := channel.ToChannels()
  114. if err != nil {
  115. middleware.ErrorResponse(c, http.StatusBadRequest, err.Error())
  116. return
  117. }
  118. _channels = append(_channels, channels...)
  119. }
  120. err = model.BatchInsertChannels(_channels)
  121. if err != nil {
  122. middleware.ErrorResponse(c, http.StatusInternalServerError, err.Error())
  123. return
  124. }
  125. middleware.SuccessResponse(c, nil)
  126. }
  127. // SearchChannels godoc
  128. //
  129. // @Summary Search channels
  130. // @Description Search channels with keyword and optional filters
  131. // @Tags channels
  132. // @Produce json
  133. // @Security ApiKeyAuth
  134. // @Param keyword query string true "Search keyword"
  135. // @Param page query int false "Page number"
  136. // @Param per_page query int false "Items per page"
  137. // @Param id query int false "Filter by id"
  138. // @Param name query string false "Filter by name"
  139. // @Param key query string false "Filter by key"
  140. // @Param channel_type query int false "Filter by channel type"
  141. // @Param base_url query string false "Filter by base URL"
  142. // @Param order query string false "Order by field"
  143. // @Success 200 {object} middleware.APIResponse{data=map[string]any{channels=[]model.Channel,total=int}}
  144. // @Router /api/channels/search [get]
  145. func SearchChannels(c *gin.Context) {
  146. keyword := c.Query("keyword")
  147. page, perPage := utils.ParsePageParams(c)
  148. id, _ := strconv.Atoi(c.Query("id"))
  149. name := c.Query("name")
  150. key := c.Query("key")
  151. channelType, _ := strconv.Atoi(c.Query("channel_type"))
  152. baseURL := c.Query("base_url")
  153. order := c.Query("order")
  154. channels, total, err := model.SearchChannels(
  155. keyword,
  156. page,
  157. perPage,
  158. id,
  159. name,
  160. key,
  161. channelType,
  162. baseURL,
  163. order,
  164. )
  165. if err != nil {
  166. middleware.ErrorResponse(c, http.StatusInternalServerError, err.Error())
  167. return
  168. }
  169. middleware.SuccessResponse(c, gin.H{
  170. "channels": channels,
  171. "total": total,
  172. })
  173. }
  174. // GetChannel godoc
  175. //
  176. // @Summary Get a channel by ID
  177. // @Description Returns detailed information about a specific channel
  178. // @Tags channel
  179. // @Produce json
  180. // @Security ApiKeyAuth
  181. // @Param id path int true "Channel ID"
  182. // @Success 200 {object} middleware.APIResponse{data=model.Channel}
  183. // @Router /api/channel/{id} [get]
  184. func GetChannel(c *gin.Context) {
  185. id, err := strconv.Atoi(c.Param("id"))
  186. if err != nil {
  187. middleware.ErrorResponse(c, http.StatusBadRequest, err.Error())
  188. return
  189. }
  190. channel, err := model.GetChannelByID(id)
  191. if err != nil {
  192. middleware.ErrorResponse(c, http.StatusInternalServerError, err.Error())
  193. return
  194. }
  195. middleware.SuccessResponse(c, channel)
  196. }
  197. // AddChannelRequest represents the request body for adding a channel
  198. type AddChannelRequest struct {
  199. ModelMapping map[string]string `json:"model_mapping"`
  200. Config *model.ChannelConfig `json:"config"`
  201. Name string `json:"name"`
  202. Key string `json:"key"`
  203. BaseURL string `json:"base_url"`
  204. Models []string `json:"models"`
  205. Type model.ChannelType `json:"type"`
  206. Priority int32 `json:"priority"`
  207. Status int `json:"status"`
  208. Sets []string `json:"sets"`
  209. }
  210. func (r *AddChannelRequest) ToChannel() (*model.Channel, error) {
  211. a, ok := adaptors.GetAdaptor(r.Type)
  212. if !ok {
  213. return nil, fmt.Errorf("invalid channel type: %d", r.Type)
  214. }
  215. metadata := a.Metadata()
  216. if validator := adaptors.GetKeyValidator(a); validator != nil {
  217. err := validator.ValidateKey(r.Key)
  218. if err != nil {
  219. keyHelp := metadata.KeyHelp
  220. if keyHelp == "" {
  221. return nil, fmt.Errorf(
  222. "%s [%s(%d)] invalid key: %w",
  223. r.Name,
  224. r.Type.String(),
  225. r.Type,
  226. err,
  227. )
  228. }
  229. return nil, fmt.Errorf(
  230. "%s [%s(%d)] invalid key: %w, %s",
  231. r.Name,
  232. r.Type.String(),
  233. r.Type,
  234. err,
  235. keyHelp,
  236. )
  237. }
  238. }
  239. if r.Config != nil {
  240. for key, template := range metadata.Config {
  241. v, err := r.Config.Get(key)
  242. if err != nil {
  243. if errors.Is(err, ast.ErrNotExist) {
  244. if template.Required {
  245. return nil, fmt.Errorf("config %s is required: %w", key, err)
  246. }
  247. continue
  248. }
  249. return nil, fmt.Errorf("config %s is invalid: %w", key, err)
  250. }
  251. if !v.Exists() {
  252. if template.Required {
  253. return nil, fmt.Errorf("config %s is required: %w", key, err)
  254. }
  255. continue
  256. }
  257. if template.Validator != nil {
  258. i, err := v.Interface()
  259. if err != nil {
  260. return nil, fmt.Errorf("config %s is invalid: %w", key, err)
  261. }
  262. err = adaptor.ValidateConfigTemplateValue(template, i)
  263. if err != nil {
  264. return nil, fmt.Errorf("config %s is invalid: %w", key, err)
  265. }
  266. }
  267. }
  268. }
  269. return &model.Channel{
  270. Type: r.Type,
  271. Name: r.Name,
  272. Key: r.Key,
  273. BaseURL: r.BaseURL,
  274. Models: slices.Clone(r.Models),
  275. ModelMapping: maps.Clone(r.ModelMapping),
  276. Priority: r.Priority,
  277. Status: r.Status,
  278. Config: r.Config,
  279. Sets: slices.Clone(r.Sets),
  280. }, nil
  281. }
  282. func (r *AddChannelRequest) ToChannels() ([]*model.Channel, error) {
  283. keys := strings.Split(r.Key, "\n")
  284. channels := make([]*model.Channel, 0, len(keys))
  285. for _, key := range keys {
  286. if key == "" {
  287. continue
  288. }
  289. c, err := r.ToChannel()
  290. if err != nil {
  291. return nil, err
  292. }
  293. c.Key = key
  294. channels = append(channels, c)
  295. }
  296. if len(channels) == 0 {
  297. ch, err := r.ToChannel()
  298. if err != nil {
  299. return nil, err
  300. }
  301. return []*model.Channel{ch}, nil
  302. }
  303. return channels, nil
  304. }
  305. // AddChannel godoc
  306. //
  307. // @Summary Add a single channel
  308. // @Description Adds a new channel to the system
  309. // @Tags channel
  310. // @Accept json
  311. // @Produce json
  312. // @Security ApiKeyAuth
  313. // @Param channel body AddChannelRequest true "Channel information"
  314. // @Success 200 {object} middleware.APIResponse
  315. // @Router /api/channel/ [post]
  316. func AddChannel(c *gin.Context) {
  317. channel := AddChannelRequest{}
  318. err := c.ShouldBindJSON(&channel)
  319. if err != nil {
  320. middleware.ErrorResponse(c, http.StatusBadRequest, err.Error())
  321. return
  322. }
  323. channels, err := channel.ToChannels()
  324. if err != nil {
  325. middleware.ErrorResponse(c, http.StatusBadRequest, err.Error())
  326. return
  327. }
  328. err = model.BatchInsertChannels(channels)
  329. if err != nil {
  330. middleware.ErrorResponse(c, http.StatusInternalServerError, err.Error())
  331. return
  332. }
  333. middleware.SuccessResponse(c, nil)
  334. }
  335. // DeleteChannel godoc
  336. //
  337. // @Summary Delete a channel
  338. // @Description Deletes a channel by its ID
  339. // @Tags channel
  340. // @Produce json
  341. // @Security ApiKeyAuth
  342. // @Param id path int true "Channel ID"
  343. // @Success 200 {object} middleware.APIResponse
  344. // @Router /api/channel/{id} [delete]
  345. func DeleteChannel(c *gin.Context) {
  346. id, _ := strconv.Atoi(c.Param("id"))
  347. err := model.DeleteChannelByID(id)
  348. if err != nil {
  349. middleware.ErrorResponse(c, http.StatusInternalServerError, err.Error())
  350. return
  351. }
  352. middleware.SuccessResponse(c, nil)
  353. }
  354. // DeleteChannels godoc
  355. //
  356. // @Summary Delete multiple channels
  357. // @Description Deletes multiple channels by their IDs
  358. // @Tags channels
  359. // @Accept json
  360. // @Produce json
  361. // @Security ApiKeyAuth
  362. // @Param ids body []int true "Channel IDs"
  363. // @Success 200 {object} middleware.APIResponse
  364. // @Router /api/channels/batch_delete [post]
  365. func DeleteChannels(c *gin.Context) {
  366. ids := []int{}
  367. err := c.ShouldBindJSON(&ids)
  368. if err != nil {
  369. middleware.ErrorResponse(c, http.StatusBadRequest, err.Error())
  370. return
  371. }
  372. err = model.DeleteChannelsByIDs(ids)
  373. if err != nil {
  374. middleware.ErrorResponse(c, http.StatusInternalServerError, err.Error())
  375. return
  376. }
  377. middleware.SuccessResponse(c, nil)
  378. }
  379. // UpdateChannel godoc
  380. //
  381. // @Summary Update a channel
  382. // @Description Updates an existing channel by its ID
  383. // @Tags channel
  384. // @Accept json
  385. // @Produce json
  386. // @Security ApiKeyAuth
  387. // @Param id path int true "Channel ID"
  388. // @Param channel body AddChannelRequest true "Updated channel information"
  389. // @Success 200 {object} middleware.APIResponse{data=model.Channel}
  390. // @Router /api/channel/{id} [put]
  391. func UpdateChannel(c *gin.Context) {
  392. idStr := c.Param("id")
  393. if idStr == "" {
  394. middleware.ErrorResponse(c, http.StatusBadRequest, "id is required")
  395. return
  396. }
  397. id, err := strconv.Atoi(idStr)
  398. if err != nil {
  399. middleware.ErrorResponse(c, http.StatusBadRequest, err.Error())
  400. return
  401. }
  402. channel := AddChannelRequest{}
  403. err = c.ShouldBindJSON(&channel)
  404. if err != nil {
  405. middleware.ErrorResponse(c, http.StatusBadRequest, err.Error())
  406. return
  407. }
  408. ch, err := channel.ToChannel()
  409. if err != nil {
  410. middleware.ErrorResponse(c, http.StatusBadRequest, err.Error())
  411. return
  412. }
  413. ch.ID = id
  414. err = model.UpdateChannel(ch)
  415. if err != nil {
  416. middleware.ErrorResponse(c, http.StatusInternalServerError, err.Error())
  417. return
  418. }
  419. err = monitor.ClearChannelAllModelErrors(c.Request.Context(), id)
  420. if err != nil {
  421. log.Errorf("failed to clear channel all model errors: %+v", err)
  422. }
  423. middleware.SuccessResponse(c, ch)
  424. }
  425. // UpdateChannelStatusRequest represents the request body for updating a channel's status
  426. type UpdateChannelStatusRequest struct {
  427. Status int `json:"status"`
  428. }
  429. // UpdateChannelStatus godoc
  430. //
  431. // @Summary Update channel status
  432. // @Description Updates the status of a channel by its ID
  433. // @Tags channel
  434. // @Accept json
  435. // @Produce json
  436. // @Security ApiKeyAuth
  437. // @Param id path int true "Channel ID"
  438. // @Param status body UpdateChannelStatusRequest true "Status information"
  439. // @Success 200 {object} middleware.APIResponse
  440. // @Router /api/channel/{id}/status [post]
  441. func UpdateChannelStatus(c *gin.Context) {
  442. id, _ := strconv.Atoi(c.Param("id"))
  443. status := UpdateChannelStatusRequest{}
  444. err := c.ShouldBindJSON(&status)
  445. if err != nil {
  446. middleware.ErrorResponse(c, http.StatusBadRequest, err.Error())
  447. return
  448. }
  449. err = model.UpdateChannelStatusByID(id, status.Status)
  450. if err != nil {
  451. middleware.ErrorResponse(c, http.StatusInternalServerError, err.Error())
  452. return
  453. }
  454. err = monitor.ClearChannelAllModelErrors(c.Request.Context(), id)
  455. if err != nil {
  456. log.Errorf("failed to clear channel all model errors: %+v", err)
  457. }
  458. middleware.SuccessResponse(c, nil)
  459. }