channel.go 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925
  1. package controller
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "net/http"
  6. "one-api/common"
  7. "one-api/constant"
  8. "one-api/model"
  9. "strconv"
  10. "strings"
  11. "github.com/gin-gonic/gin"
  12. )
  13. type OpenAIModel struct {
  14. ID string `json:"id"`
  15. Object string `json:"object"`
  16. Created int64 `json:"created"`
  17. OwnedBy string `json:"owned_by"`
  18. Permission []struct {
  19. ID string `json:"id"`
  20. Object string `json:"object"`
  21. Created int64 `json:"created"`
  22. AllowCreateEngine bool `json:"allow_create_engine"`
  23. AllowSampling bool `json:"allow_sampling"`
  24. AllowLogprobs bool `json:"allow_logprobs"`
  25. AllowSearchIndices bool `json:"allow_search_indices"`
  26. AllowView bool `json:"allow_view"`
  27. AllowFineTuning bool `json:"allow_fine_tuning"`
  28. Organization string `json:"organization"`
  29. Group string `json:"group"`
  30. IsBlocking bool `json:"is_blocking"`
  31. } `json:"permission"`
  32. Root string `json:"root"`
  33. Parent string `json:"parent"`
  34. }
  35. type OpenAIModelsResponse struct {
  36. Data []OpenAIModel `json:"data"`
  37. Success bool `json:"success"`
  38. }
  39. func parseStatusFilter(statusParam string) int {
  40. switch strings.ToLower(statusParam) {
  41. case "enabled", "1":
  42. return common.ChannelStatusEnabled
  43. case "disabled", "0":
  44. return 0
  45. default:
  46. return -1
  47. }
  48. }
  49. func GetAllChannels(c *gin.Context) {
  50. p, _ := strconv.Atoi(c.Query("p"))
  51. pageSize, _ := strconv.Atoi(c.Query("page_size"))
  52. if p < 1 {
  53. p = 1
  54. }
  55. if pageSize < 1 {
  56. pageSize = common.ItemsPerPage
  57. }
  58. channelData := make([]*model.Channel, 0)
  59. idSort, _ := strconv.ParseBool(c.Query("id_sort"))
  60. enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
  61. statusParam := c.Query("status")
  62. // statusFilter: -1 all, 1 enabled, 0 disabled (include auto & manual)
  63. statusFilter := parseStatusFilter(statusParam)
  64. // type filter
  65. typeStr := c.Query("type")
  66. typeFilter := -1
  67. if typeStr != "" {
  68. if t, err := strconv.Atoi(typeStr); err == nil {
  69. typeFilter = t
  70. }
  71. }
  72. var total int64
  73. if enableTagMode {
  74. tags, err := model.GetPaginatedTags((p-1)*pageSize, pageSize)
  75. if err != nil {
  76. c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
  77. return
  78. }
  79. for _, tag := range tags {
  80. if tag == nil || *tag == "" {
  81. continue
  82. }
  83. tagChannels, err := model.GetChannelsByTag(*tag, idSort)
  84. if err != nil {
  85. continue
  86. }
  87. filtered := make([]*model.Channel, 0)
  88. for _, ch := range tagChannels {
  89. if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
  90. continue
  91. }
  92. if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
  93. continue
  94. }
  95. if typeFilter >= 0 && ch.Type != typeFilter {
  96. continue
  97. }
  98. filtered = append(filtered, ch)
  99. }
  100. channelData = append(channelData, filtered...)
  101. }
  102. total, _ = model.CountAllTags()
  103. } else {
  104. baseQuery := model.DB.Model(&model.Channel{})
  105. if typeFilter >= 0 {
  106. baseQuery = baseQuery.Where("type = ?", typeFilter)
  107. }
  108. if statusFilter == common.ChannelStatusEnabled {
  109. baseQuery = baseQuery.Where("status = ?", common.ChannelStatusEnabled)
  110. } else if statusFilter == 0 {
  111. baseQuery = baseQuery.Where("status != ?", common.ChannelStatusEnabled)
  112. }
  113. baseQuery.Count(&total)
  114. order := "priority desc"
  115. if idSort {
  116. order = "id desc"
  117. }
  118. err := baseQuery.Order(order).Limit(pageSize).Offset((p - 1) * pageSize).Omit("key").Find(&channelData).Error
  119. if err != nil {
  120. c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
  121. return
  122. }
  123. }
  124. countQuery := model.DB.Model(&model.Channel{})
  125. if statusFilter == common.ChannelStatusEnabled {
  126. countQuery = countQuery.Where("status = ?", common.ChannelStatusEnabled)
  127. } else if statusFilter == 0 {
  128. countQuery = countQuery.Where("status != ?", common.ChannelStatusEnabled)
  129. }
  130. var results []struct {
  131. Type int64
  132. Count int64
  133. }
  134. _ = countQuery.Select("type, count(*) as count").Group("type").Find(&results).Error
  135. typeCounts := make(map[int64]int64)
  136. for _, r := range results {
  137. typeCounts[r.Type] = r.Count
  138. }
  139. c.JSON(http.StatusOK, gin.H{
  140. "success": true,
  141. "message": "",
  142. "data": gin.H{
  143. "items": channelData,
  144. "total": total,
  145. "page": p,
  146. "page_size": pageSize,
  147. "type_counts": typeCounts,
  148. },
  149. })
  150. return
  151. }
  152. func FetchUpstreamModels(c *gin.Context) {
  153. id, err := strconv.Atoi(c.Param("id"))
  154. if err != nil {
  155. c.JSON(http.StatusOK, gin.H{
  156. "success": false,
  157. "message": err.Error(),
  158. })
  159. return
  160. }
  161. channel, err := model.GetChannelById(id, true)
  162. if err != nil {
  163. c.JSON(http.StatusOK, gin.H{
  164. "success": false,
  165. "message": err.Error(),
  166. })
  167. return
  168. }
  169. baseURL := constant.ChannelBaseURLs[channel.Type]
  170. if channel.GetBaseURL() != "" {
  171. baseURL = channel.GetBaseURL()
  172. }
  173. url := fmt.Sprintf("%s/v1/models", baseURL)
  174. switch channel.Type {
  175. case constant.ChannelTypeGemini:
  176. url = fmt.Sprintf("%s/v1beta/openai/models", baseURL)
  177. case constant.ChannelTypeAli:
  178. url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
  179. }
  180. body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
  181. if err != nil {
  182. c.JSON(http.StatusOK, gin.H{
  183. "success": false,
  184. "message": err.Error(),
  185. })
  186. return
  187. }
  188. var result OpenAIModelsResponse
  189. if err = json.Unmarshal(body, &result); err != nil {
  190. c.JSON(http.StatusOK, gin.H{
  191. "success": false,
  192. "message": fmt.Sprintf("解析响应失败: %s", err.Error()),
  193. })
  194. return
  195. }
  196. var ids []string
  197. for _, model := range result.Data {
  198. id := model.ID
  199. if channel.Type == constant.ChannelTypeGemini {
  200. id = strings.TrimPrefix(id, "models/")
  201. }
  202. ids = append(ids, id)
  203. }
  204. c.JSON(http.StatusOK, gin.H{
  205. "success": true,
  206. "message": "",
  207. "data": ids,
  208. })
  209. }
  210. func FixChannelsAbilities(c *gin.Context) {
  211. success, fails, err := model.FixAbility()
  212. if err != nil {
  213. c.JSON(http.StatusOK, gin.H{
  214. "success": false,
  215. "message": err.Error(),
  216. })
  217. return
  218. }
  219. c.JSON(http.StatusOK, gin.H{
  220. "success": true,
  221. "message": "",
  222. "data": gin.H{
  223. "success": success,
  224. "fails": fails,
  225. },
  226. })
  227. }
  228. func SearchChannels(c *gin.Context) {
  229. keyword := c.Query("keyword")
  230. group := c.Query("group")
  231. modelKeyword := c.Query("model")
  232. statusParam := c.Query("status")
  233. statusFilter := parseStatusFilter(statusParam)
  234. idSort, _ := strconv.ParseBool(c.Query("id_sort"))
  235. enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
  236. channelData := make([]*model.Channel, 0)
  237. if enableTagMode {
  238. tags, err := model.SearchTags(keyword, group, modelKeyword, idSort)
  239. if err != nil {
  240. c.JSON(http.StatusOK, gin.H{
  241. "success": false,
  242. "message": err.Error(),
  243. })
  244. return
  245. }
  246. for _, tag := range tags {
  247. if tag != nil && *tag != "" {
  248. tagChannel, err := model.GetChannelsByTag(*tag, idSort)
  249. if err == nil {
  250. channelData = append(channelData, tagChannel...)
  251. }
  252. }
  253. }
  254. } else {
  255. channels, err := model.SearchChannels(keyword, group, modelKeyword, idSort)
  256. if err != nil {
  257. c.JSON(http.StatusOK, gin.H{
  258. "success": false,
  259. "message": err.Error(),
  260. })
  261. return
  262. }
  263. channelData = channels
  264. }
  265. if statusFilter == common.ChannelStatusEnabled || statusFilter == 0 {
  266. filtered := make([]*model.Channel, 0, len(channelData))
  267. for _, ch := range channelData {
  268. if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
  269. continue
  270. }
  271. if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
  272. continue
  273. }
  274. filtered = append(filtered, ch)
  275. }
  276. channelData = filtered
  277. }
  278. // calculate type counts for search results
  279. typeCounts := make(map[int64]int64)
  280. for _, channel := range channelData {
  281. typeCounts[int64(channel.Type)]++
  282. }
  283. typeParam := c.Query("type")
  284. typeFilter := -1
  285. if typeParam != "" {
  286. if tp, err := strconv.Atoi(typeParam); err == nil {
  287. typeFilter = tp
  288. }
  289. }
  290. if typeFilter >= 0 {
  291. filtered := make([]*model.Channel, 0, len(channelData))
  292. for _, ch := range channelData {
  293. if ch.Type == typeFilter {
  294. filtered = append(filtered, ch)
  295. }
  296. }
  297. channelData = filtered
  298. }
  299. page, _ := strconv.Atoi(c.DefaultQuery("p", "1"))
  300. pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
  301. if page < 1 {
  302. page = 1
  303. }
  304. if pageSize <= 0 {
  305. pageSize = 20
  306. }
  307. total := len(channelData)
  308. startIdx := (page - 1) * pageSize
  309. if startIdx > total {
  310. startIdx = total
  311. }
  312. endIdx := startIdx + pageSize
  313. if endIdx > total {
  314. endIdx = total
  315. }
  316. pagedData := channelData[startIdx:endIdx]
  317. c.JSON(http.StatusOK, gin.H{
  318. "success": true,
  319. "message": "",
  320. "data": gin.H{
  321. "items": pagedData,
  322. "total": total,
  323. "type_counts": typeCounts,
  324. },
  325. })
  326. return
  327. }
  328. func GetChannel(c *gin.Context) {
  329. id, err := strconv.Atoi(c.Param("id"))
  330. if err != nil {
  331. c.JSON(http.StatusOK, gin.H{
  332. "success": false,
  333. "message": err.Error(),
  334. })
  335. return
  336. }
  337. channel, err := model.GetChannelById(id, false)
  338. if err != nil {
  339. c.JSON(http.StatusOK, gin.H{
  340. "success": false,
  341. "message": err.Error(),
  342. })
  343. return
  344. }
  345. c.JSON(http.StatusOK, gin.H{
  346. "success": true,
  347. "message": "",
  348. "data": channel,
  349. })
  350. return
  351. }
  352. type AddChannelRequest struct {
  353. Mode string `json:"mode"`
  354. MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
  355. Channel *model.Channel `json:"channel"`
  356. }
  357. func getVertexArrayKeys(keys string) ([]string, error) {
  358. if keys == "" {
  359. return nil, nil
  360. }
  361. var keyArray []interface{}
  362. err := common.Unmarshal([]byte(keys), &keyArray)
  363. if err != nil {
  364. return nil, fmt.Errorf("批量添加 Vertex AI 必须使用标准的JsonArray格式,例如[{key1}, {key2}...],请检查输入: %w", err)
  365. }
  366. cleanKeys := make([]string, 0, len(keyArray))
  367. for _, key := range keyArray {
  368. var keyStr string
  369. switch v := key.(type) {
  370. case string:
  371. keyStr = strings.TrimSpace(v)
  372. default:
  373. bytes, err := json.Marshal(v)
  374. if err != nil {
  375. return nil, fmt.Errorf("Vertex AI key JSON 编码失败: %w", err)
  376. }
  377. keyStr = string(bytes)
  378. }
  379. if keyStr != "" {
  380. cleanKeys = append(cleanKeys, keyStr)
  381. }
  382. }
  383. if len(cleanKeys) == 0 {
  384. return nil, fmt.Errorf("批量添加 Vertex AI 的 keys 不能为空")
  385. }
  386. return cleanKeys, nil
  387. }
  388. func AddChannel(c *gin.Context) {
  389. addChannelRequest := AddChannelRequest{}
  390. err := c.ShouldBindJSON(&addChannelRequest)
  391. if err != nil {
  392. c.JSON(http.StatusOK, gin.H{
  393. "success": false,
  394. "message": err.Error(),
  395. })
  396. return
  397. }
  398. err = addChannelRequest.Channel.ValidateSettings()
  399. if err != nil {
  400. c.JSON(http.StatusOK, gin.H{
  401. "success": false,
  402. "message": "channel setting 格式错误:" + err.Error(),
  403. })
  404. return
  405. }
  406. if addChannelRequest.Channel == nil || addChannelRequest.Channel.Key == "" {
  407. c.JSON(http.StatusOK, gin.H{
  408. "success": false,
  409. "message": "channel cannot be empty",
  410. })
  411. return
  412. }
  413. // Validate the length of the model name
  414. for _, m := range addChannelRequest.Channel.GetModels() {
  415. if len(m) > 255 {
  416. c.JSON(http.StatusOK, gin.H{
  417. "success": false,
  418. "message": fmt.Sprintf("模型名称过长: %s", m),
  419. })
  420. return
  421. }
  422. }
  423. if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
  424. if addChannelRequest.Channel.Other == "" {
  425. c.JSON(http.StatusOK, gin.H{
  426. "success": false,
  427. "message": "部署地区不能为空",
  428. })
  429. return
  430. } else {
  431. regionMap, err := common.StrToMap(addChannelRequest.Channel.Other)
  432. if err != nil {
  433. c.JSON(http.StatusOK, gin.H{
  434. "success": false,
  435. "message": "部署地区必须是标准的Json格式,例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}",
  436. })
  437. return
  438. }
  439. if regionMap["default"] == nil {
  440. c.JSON(http.StatusOK, gin.H{
  441. "success": false,
  442. "message": "部署地区必须包含default字段",
  443. })
  444. return
  445. }
  446. }
  447. }
  448. addChannelRequest.Channel.CreatedTime = common.GetTimestamp()
  449. keys := make([]string, 0)
  450. switch addChannelRequest.Mode {
  451. case "multi_to_single":
  452. addChannelRequest.Channel.ChannelInfo.IsMultiKey = true
  453. addChannelRequest.Channel.ChannelInfo.MultiKeyMode = addChannelRequest.MultiKeyMode
  454. if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
  455. array, err := getVertexArrayKeys(addChannelRequest.Channel.Key)
  456. if err != nil {
  457. c.JSON(http.StatusOK, gin.H{
  458. "success": false,
  459. "message": err.Error(),
  460. })
  461. return
  462. }
  463. addChannelRequest.Channel.Key = strings.Join(array, "\n")
  464. } else {
  465. cleanKeys := make([]string, 0)
  466. for _, key := range strings.Split(addChannelRequest.Channel.Key, "\n") {
  467. if key == "" {
  468. continue
  469. }
  470. key = strings.TrimSpace(key)
  471. cleanKeys = append(cleanKeys, key)
  472. }
  473. addChannelRequest.Channel.Key = strings.Join(cleanKeys, "\n")
  474. }
  475. keys = []string{addChannelRequest.Channel.Key}
  476. case "batch":
  477. if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
  478. // multi json
  479. keys, err = getVertexArrayKeys(addChannelRequest.Channel.Key)
  480. if err != nil {
  481. c.JSON(http.StatusOK, gin.H{
  482. "success": false,
  483. "message": err.Error(),
  484. })
  485. return
  486. }
  487. } else {
  488. keys = strings.Split(addChannelRequest.Channel.Key, "\n")
  489. }
  490. case "single":
  491. keys = []string{addChannelRequest.Channel.Key}
  492. default:
  493. c.JSON(http.StatusOK, gin.H{
  494. "success": false,
  495. "message": "不支持的添加模式",
  496. })
  497. return
  498. }
  499. channels := make([]model.Channel, 0, len(keys))
  500. for _, key := range keys {
  501. if key == "" {
  502. continue
  503. }
  504. localChannel := addChannelRequest.Channel
  505. localChannel.Key = key
  506. channels = append(channels, *localChannel)
  507. }
  508. err = model.BatchInsertChannels(channels)
  509. if err != nil {
  510. c.JSON(http.StatusOK, gin.H{
  511. "success": false,
  512. "message": err.Error(),
  513. })
  514. return
  515. }
  516. c.JSON(http.StatusOK, gin.H{
  517. "success": true,
  518. "message": "",
  519. })
  520. return
  521. }
  522. func DeleteChannel(c *gin.Context) {
  523. id, _ := strconv.Atoi(c.Param("id"))
  524. channel := model.Channel{Id: id}
  525. err := channel.Delete()
  526. if err != nil {
  527. c.JSON(http.StatusOK, gin.H{
  528. "success": false,
  529. "message": err.Error(),
  530. })
  531. return
  532. }
  533. c.JSON(http.StatusOK, gin.H{
  534. "success": true,
  535. "message": "",
  536. })
  537. return
  538. }
  539. func DeleteDisabledChannel(c *gin.Context) {
  540. rows, err := model.DeleteDisabledChannel()
  541. if err != nil {
  542. c.JSON(http.StatusOK, gin.H{
  543. "success": false,
  544. "message": err.Error(),
  545. })
  546. return
  547. }
  548. c.JSON(http.StatusOK, gin.H{
  549. "success": true,
  550. "message": "",
  551. "data": rows,
  552. })
  553. return
  554. }
  555. type ChannelTag struct {
  556. Tag string `json:"tag"`
  557. NewTag *string `json:"new_tag"`
  558. Priority *int64 `json:"priority"`
  559. Weight *uint `json:"weight"`
  560. ModelMapping *string `json:"model_mapping"`
  561. Models *string `json:"models"`
  562. Groups *string `json:"groups"`
  563. }
  564. func DisableTagChannels(c *gin.Context) {
  565. channelTag := ChannelTag{}
  566. err := c.ShouldBindJSON(&channelTag)
  567. if err != nil || channelTag.Tag == "" {
  568. c.JSON(http.StatusOK, gin.H{
  569. "success": false,
  570. "message": "参数错误",
  571. })
  572. return
  573. }
  574. err = model.DisableChannelByTag(channelTag.Tag)
  575. if err != nil {
  576. c.JSON(http.StatusOK, gin.H{
  577. "success": false,
  578. "message": err.Error(),
  579. })
  580. return
  581. }
  582. c.JSON(http.StatusOK, gin.H{
  583. "success": true,
  584. "message": "",
  585. })
  586. return
  587. }
  588. func EnableTagChannels(c *gin.Context) {
  589. channelTag := ChannelTag{}
  590. err := c.ShouldBindJSON(&channelTag)
  591. if err != nil || channelTag.Tag == "" {
  592. c.JSON(http.StatusOK, gin.H{
  593. "success": false,
  594. "message": "参数错误",
  595. })
  596. return
  597. }
  598. err = model.EnableChannelByTag(channelTag.Tag)
  599. if err != nil {
  600. c.JSON(http.StatusOK, gin.H{
  601. "success": false,
  602. "message": err.Error(),
  603. })
  604. return
  605. }
  606. c.JSON(http.StatusOK, gin.H{
  607. "success": true,
  608. "message": "",
  609. })
  610. return
  611. }
  612. func EditTagChannels(c *gin.Context) {
  613. channelTag := ChannelTag{}
  614. err := c.ShouldBindJSON(&channelTag)
  615. if err != nil {
  616. c.JSON(http.StatusOK, gin.H{
  617. "success": false,
  618. "message": "参数错误",
  619. })
  620. return
  621. }
  622. if channelTag.Tag == "" {
  623. c.JSON(http.StatusOK, gin.H{
  624. "success": false,
  625. "message": "tag不能为空",
  626. })
  627. return
  628. }
  629. err = model.EditChannelByTag(channelTag.Tag, channelTag.NewTag, channelTag.ModelMapping, channelTag.Models, channelTag.Groups, channelTag.Priority, channelTag.Weight)
  630. if err != nil {
  631. c.JSON(http.StatusOK, gin.H{
  632. "success": false,
  633. "message": err.Error(),
  634. })
  635. return
  636. }
  637. c.JSON(http.StatusOK, gin.H{
  638. "success": true,
  639. "message": "",
  640. })
  641. return
  642. }
  643. type ChannelBatch struct {
  644. Ids []int `json:"ids"`
  645. Tag *string `json:"tag"`
  646. }
  647. func DeleteChannelBatch(c *gin.Context) {
  648. channelBatch := ChannelBatch{}
  649. err := c.ShouldBindJSON(&channelBatch)
  650. if err != nil || len(channelBatch.Ids) == 0 {
  651. c.JSON(http.StatusOK, gin.H{
  652. "success": false,
  653. "message": "参数错误",
  654. })
  655. return
  656. }
  657. err = model.BatchDeleteChannels(channelBatch.Ids)
  658. if err != nil {
  659. c.JSON(http.StatusOK, gin.H{
  660. "success": false,
  661. "message": err.Error(),
  662. })
  663. return
  664. }
  665. c.JSON(http.StatusOK, gin.H{
  666. "success": true,
  667. "message": "",
  668. "data": len(channelBatch.Ids),
  669. })
  670. return
  671. }
  672. func UpdateChannel(c *gin.Context) {
  673. channel := model.Channel{}
  674. err := c.ShouldBindJSON(&channel)
  675. if err != nil {
  676. c.JSON(http.StatusOK, gin.H{
  677. "success": false,
  678. "message": err.Error(),
  679. })
  680. return
  681. }
  682. err = channel.ValidateSettings()
  683. if err != nil {
  684. c.JSON(http.StatusOK, gin.H{
  685. "success": false,
  686. "message": "channel setting 格式错误:" + err.Error(),
  687. })
  688. return
  689. }
  690. if channel.Type == constant.ChannelTypeVertexAi {
  691. if channel.Other == "" {
  692. c.JSON(http.StatusOK, gin.H{
  693. "success": false,
  694. "message": "部署地区不能为空",
  695. })
  696. return
  697. } else {
  698. regionMap, err := common.StrToMap(channel.Other)
  699. if err != nil {
  700. c.JSON(http.StatusOK, gin.H{
  701. "success": false,
  702. "message": "部署地区必须是标准的Json格式,例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}",
  703. })
  704. return
  705. }
  706. if regionMap["default"] == nil {
  707. c.JSON(http.StatusOK, gin.H{
  708. "success": false,
  709. "message": "部署地区必须包含default字段",
  710. })
  711. return
  712. }
  713. }
  714. }
  715. err = channel.Update()
  716. if err != nil {
  717. c.JSON(http.StatusOK, gin.H{
  718. "success": false,
  719. "message": err.Error(),
  720. })
  721. return
  722. }
  723. channel.Key = ""
  724. c.JSON(http.StatusOK, gin.H{
  725. "success": true,
  726. "message": "",
  727. "data": channel,
  728. })
  729. return
  730. }
  731. func FetchModels(c *gin.Context) {
  732. var req struct {
  733. BaseURL string `json:"base_url"`
  734. Type int `json:"type"`
  735. Key string `json:"key"`
  736. }
  737. if err := c.ShouldBindJSON(&req); err != nil {
  738. c.JSON(http.StatusBadRequest, gin.H{
  739. "success": false,
  740. "message": "Invalid request",
  741. })
  742. return
  743. }
  744. baseURL := req.BaseURL
  745. if baseURL == "" {
  746. baseURL = constant.ChannelBaseURLs[req.Type]
  747. }
  748. client := &http.Client{}
  749. url := fmt.Sprintf("%s/v1/models", baseURL)
  750. request, err := http.NewRequest("GET", url, nil)
  751. if err != nil {
  752. c.JSON(http.StatusInternalServerError, gin.H{
  753. "success": false,
  754. "message": err.Error(),
  755. })
  756. return
  757. }
  758. // remove line breaks and extra spaces.
  759. key := strings.TrimSpace(req.Key)
  760. // If the key contains a line break, only take the first part.
  761. key = strings.Split(key, "\n")[0]
  762. request.Header.Set("Authorization", "Bearer "+key)
  763. response, err := client.Do(request)
  764. if err != nil {
  765. c.JSON(http.StatusInternalServerError, gin.H{
  766. "success": false,
  767. "message": err.Error(),
  768. })
  769. return
  770. }
  771. //check status code
  772. if response.StatusCode != http.StatusOK {
  773. c.JSON(http.StatusInternalServerError, gin.H{
  774. "success": false,
  775. "message": "Failed to fetch models",
  776. })
  777. return
  778. }
  779. defer response.Body.Close()
  780. var result struct {
  781. Data []struct {
  782. ID string `json:"id"`
  783. } `json:"data"`
  784. }
  785. if err := json.NewDecoder(response.Body).Decode(&result); err != nil {
  786. c.JSON(http.StatusInternalServerError, gin.H{
  787. "success": false,
  788. "message": err.Error(),
  789. })
  790. return
  791. }
  792. var models []string
  793. for _, model := range result.Data {
  794. models = append(models, model.ID)
  795. }
  796. c.JSON(http.StatusOK, gin.H{
  797. "success": true,
  798. "data": models,
  799. })
  800. }
  801. func BatchSetChannelTag(c *gin.Context) {
  802. channelBatch := ChannelBatch{}
  803. err := c.ShouldBindJSON(&channelBatch)
  804. if err != nil || len(channelBatch.Ids) == 0 {
  805. c.JSON(http.StatusOK, gin.H{
  806. "success": false,
  807. "message": "参数错误",
  808. })
  809. return
  810. }
  811. err = model.BatchSetChannelTag(channelBatch.Ids, channelBatch.Tag)
  812. if err != nil {
  813. c.JSON(http.StatusOK, gin.H{
  814. "success": false,
  815. "message": err.Error(),
  816. })
  817. return
  818. }
  819. c.JSON(http.StatusOK, gin.H{
  820. "success": true,
  821. "message": "",
  822. "data": len(channelBatch.Ids),
  823. })
  824. return
  825. }
  826. func GetTagModels(c *gin.Context) {
  827. tag := c.Query("tag")
  828. if tag == "" {
  829. c.JSON(http.StatusBadRequest, gin.H{
  830. "success": false,
  831. "message": "tag不能为空",
  832. })
  833. return
  834. }
  835. channels, err := model.GetChannelsByTag(tag, false) // Assuming false for idSort is fine here
  836. if err != nil {
  837. c.JSON(http.StatusInternalServerError, gin.H{
  838. "success": false,
  839. "message": err.Error(),
  840. })
  841. return
  842. }
  843. var longestModels string
  844. maxLength := 0
  845. // Find the longest models string among all channels with the given tag
  846. for _, channel := range channels {
  847. if channel.Models != "" {
  848. currentModels := strings.Split(channel.Models, ",")
  849. if len(currentModels) > maxLength {
  850. maxLength = len(currentModels)
  851. longestModels = channel.Models
  852. }
  853. }
  854. }
  855. c.JSON(http.StatusOK, gin.H{
  856. "success": true,
  857. "message": "",
  858. "data": longestModels,
  859. })
  860. return
  861. }