channel.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622
  1. package controller
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "net/http"
  6. "one-api/common"
  7. "one-api/model"
  8. "strconv"
  9. "strings"
  10. "github.com/gin-gonic/gin"
  11. )
  12. type OpenAIModel struct {
  13. ID string `json:"id"`
  14. Object string `json:"object"`
  15. Created int64 `json:"created"`
  16. OwnedBy string `json:"owned_by"`
  17. Permission []struct {
  18. ID string `json:"id"`
  19. Object string `json:"object"`
  20. Created int64 `json:"created"`
  21. AllowCreateEngine bool `json:"allow_create_engine"`
  22. AllowSampling bool `json:"allow_sampling"`
  23. AllowLogprobs bool `json:"allow_logprobs"`
  24. AllowSearchIndices bool `json:"allow_search_indices"`
  25. AllowView bool `json:"allow_view"`
  26. AllowFineTuning bool `json:"allow_fine_tuning"`
  27. Organization string `json:"organization"`
  28. Group string `json:"group"`
  29. IsBlocking bool `json:"is_blocking"`
  30. } `json:"permission"`
  31. Root string `json:"root"`
  32. Parent string `json:"parent"`
  33. }
  34. type OpenAIModelsResponse struct {
  35. Data []OpenAIModel `json:"data"`
  36. Success bool `json:"success"`
  37. }
  38. func GetAllChannels(c *gin.Context) {
  39. p, _ := strconv.Atoi(c.Query("p"))
  40. pageSize, _ := strconv.Atoi(c.Query("page_size"))
  41. if p < 0 {
  42. p = 0
  43. }
  44. if pageSize < 0 {
  45. pageSize = common.ItemsPerPage
  46. }
  47. channelData := make([]*model.Channel, 0)
  48. idSort, _ := strconv.ParseBool(c.Query("id_sort"))
  49. enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
  50. if enableTagMode {
  51. tags, err := model.GetPaginatedTags(p*pageSize, pageSize)
  52. if err != nil {
  53. c.JSON(http.StatusOK, gin.H{
  54. "success": false,
  55. "message": err.Error(),
  56. })
  57. return
  58. }
  59. for _, tag := range tags {
  60. if tag != nil && *tag != "" {
  61. tagChannel, err := model.GetChannelsByTag(*tag, idSort)
  62. if err == nil {
  63. channelData = append(channelData, tagChannel...)
  64. }
  65. }
  66. }
  67. } else {
  68. channels, err := model.GetAllChannels(p*pageSize, pageSize, false, idSort)
  69. if err != nil {
  70. c.JSON(http.StatusOK, gin.H{
  71. "success": false,
  72. "message": err.Error(),
  73. })
  74. return
  75. }
  76. channelData = channels
  77. }
  78. c.JSON(http.StatusOK, gin.H{
  79. "success": true,
  80. "message": "",
  81. "data": channelData,
  82. })
  83. return
  84. }
  85. func FetchUpstreamModels(c *gin.Context) {
  86. id, err := strconv.Atoi(c.Param("id"))
  87. if err != nil {
  88. c.JSON(http.StatusOK, gin.H{
  89. "success": false,
  90. "message": err.Error(),
  91. })
  92. return
  93. }
  94. channel, err := model.GetChannelById(id, true)
  95. if err != nil {
  96. c.JSON(http.StatusOK, gin.H{
  97. "success": false,
  98. "message": err.Error(),
  99. })
  100. return
  101. }
  102. //if channel.Type != common.ChannelTypeOpenAI {
  103. // c.JSON(http.StatusOK, gin.H{
  104. // "success": false,
  105. // "message": "仅支持 OpenAI 类型渠道",
  106. // })
  107. // return
  108. //}
  109. baseURL := common.ChannelBaseURLs[channel.Type]
  110. if channel.GetBaseURL() != "" {
  111. baseURL = channel.GetBaseURL()
  112. }
  113. url := fmt.Sprintf("%s/v1/models", baseURL)
  114. if channel.Type == common.ChannelTypeGemini {
  115. url = fmt.Sprintf("%s/v1beta/openai/models", baseURL)
  116. }
  117. body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
  118. if err != nil {
  119. c.JSON(http.StatusOK, gin.H{
  120. "success": false,
  121. "message": err.Error(),
  122. })
  123. return
  124. }
  125. var result OpenAIModelsResponse
  126. if err = json.Unmarshal(body, &result); err != nil {
  127. c.JSON(http.StatusOK, gin.H{
  128. "success": false,
  129. "message": fmt.Sprintf("解析响应失败: %s", err.Error()),
  130. })
  131. return
  132. }
  133. var ids []string
  134. for _, model := range result.Data {
  135. id := model.ID
  136. if channel.Type == common.ChannelTypeGemini {
  137. id = strings.TrimPrefix(id, "models/")
  138. }
  139. ids = append(ids, id)
  140. }
  141. c.JSON(http.StatusOK, gin.H{
  142. "success": true,
  143. "message": "",
  144. "data": ids,
  145. })
  146. }
  147. func FixChannelsAbilities(c *gin.Context) {
  148. count, err := model.FixAbility()
  149. if err != nil {
  150. c.JSON(http.StatusOK, gin.H{
  151. "success": false,
  152. "message": err.Error(),
  153. })
  154. return
  155. }
  156. c.JSON(http.StatusOK, gin.H{
  157. "success": true,
  158. "message": "",
  159. "data": count,
  160. })
  161. }
  162. func SearchChannels(c *gin.Context) {
  163. keyword := c.Query("keyword")
  164. group := c.Query("group")
  165. modelKeyword := c.Query("model")
  166. idSort, _ := strconv.ParseBool(c.Query("id_sort"))
  167. enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
  168. channelData := make([]*model.Channel, 0)
  169. if enableTagMode {
  170. tags, err := model.SearchTags(keyword, group, modelKeyword, idSort)
  171. if err != nil {
  172. c.JSON(http.StatusOK, gin.H{
  173. "success": false,
  174. "message": err.Error(),
  175. })
  176. return
  177. }
  178. for _, tag := range tags {
  179. if tag != nil && *tag != "" {
  180. tagChannel, err := model.GetChannelsByTag(*tag, idSort)
  181. if err == nil {
  182. channelData = append(channelData, tagChannel...)
  183. }
  184. }
  185. }
  186. } else {
  187. channels, err := model.SearchChannels(keyword, group, modelKeyword, idSort)
  188. if err != nil {
  189. c.JSON(http.StatusOK, gin.H{
  190. "success": false,
  191. "message": err.Error(),
  192. })
  193. return
  194. }
  195. channelData = channels
  196. }
  197. c.JSON(http.StatusOK, gin.H{
  198. "success": true,
  199. "message": "",
  200. "data": channelData,
  201. })
  202. return
  203. }
  204. func GetChannel(c *gin.Context) {
  205. id, err := strconv.Atoi(c.Param("id"))
  206. if err != nil {
  207. c.JSON(http.StatusOK, gin.H{
  208. "success": false,
  209. "message": err.Error(),
  210. })
  211. return
  212. }
  213. channel, err := model.GetChannelById(id, false)
  214. if err != nil {
  215. c.JSON(http.StatusOK, gin.H{
  216. "success": false,
  217. "message": err.Error(),
  218. })
  219. return
  220. }
  221. c.JSON(http.StatusOK, gin.H{
  222. "success": true,
  223. "message": "",
  224. "data": channel,
  225. })
  226. return
  227. }
  228. func AddChannel(c *gin.Context) {
  229. channel := model.Channel{}
  230. err := c.ShouldBindJSON(&channel)
  231. if err != nil {
  232. c.JSON(http.StatusOK, gin.H{
  233. "success": false,
  234. "message": err.Error(),
  235. })
  236. return
  237. }
  238. channel.CreatedTime = common.GetTimestamp()
  239. keys := strings.Split(channel.Key, "\n")
  240. if channel.Type == common.ChannelTypeVertexAi {
  241. if channel.Other == "" {
  242. c.JSON(http.StatusOK, gin.H{
  243. "success": false,
  244. "message": "部署地区不能为空",
  245. })
  246. return
  247. } else {
  248. if common.IsJsonStr(channel.Other) {
  249. // must have default
  250. regionMap := common.StrToMap(channel.Other)
  251. if regionMap["default"] == nil {
  252. c.JSON(http.StatusOK, gin.H{
  253. "success": false,
  254. "message": "部署地区必须包含default字段",
  255. })
  256. return
  257. }
  258. }
  259. }
  260. keys = []string{channel.Key}
  261. }
  262. channels := make([]model.Channel, 0, len(keys))
  263. for _, key := range keys {
  264. if key == "" {
  265. continue
  266. }
  267. localChannel := channel
  268. localChannel.Key = key
  269. // Validate the length of the model name
  270. models := strings.Split(localChannel.Models, ",")
  271. for _, model := range models {
  272. if len(model) > 255 {
  273. c.JSON(http.StatusOK, gin.H{
  274. "success": false,
  275. "message": fmt.Sprintf("模型名称过长: %s", model),
  276. })
  277. return
  278. }
  279. }
  280. channels = append(channels, localChannel)
  281. }
  282. err = model.BatchInsertChannels(channels)
  283. if err != nil {
  284. c.JSON(http.StatusOK, gin.H{
  285. "success": false,
  286. "message": err.Error(),
  287. })
  288. return
  289. }
  290. c.JSON(http.StatusOK, gin.H{
  291. "success": true,
  292. "message": "",
  293. })
  294. return
  295. }
  296. func DeleteChannel(c *gin.Context) {
  297. id, _ := strconv.Atoi(c.Param("id"))
  298. channel := model.Channel{Id: id}
  299. err := channel.Delete()
  300. if err != nil {
  301. c.JSON(http.StatusOK, gin.H{
  302. "success": false,
  303. "message": err.Error(),
  304. })
  305. return
  306. }
  307. c.JSON(http.StatusOK, gin.H{
  308. "success": true,
  309. "message": "",
  310. })
  311. return
  312. }
  313. func DeleteDisabledChannel(c *gin.Context) {
  314. rows, err := model.DeleteDisabledChannel()
  315. if err != nil {
  316. c.JSON(http.StatusOK, gin.H{
  317. "success": false,
  318. "message": err.Error(),
  319. })
  320. return
  321. }
  322. c.JSON(http.StatusOK, gin.H{
  323. "success": true,
  324. "message": "",
  325. "data": rows,
  326. })
  327. return
  328. }
  329. type ChannelTag struct {
  330. Tag string `json:"tag"`
  331. NewTag *string `json:"new_tag"`
  332. Priority *int64 `json:"priority"`
  333. Weight *uint `json:"weight"`
  334. ModelMapping *string `json:"model_mapping"`
  335. Models *string `json:"models"`
  336. Groups *string `json:"groups"`
  337. }
  338. func DisableTagChannels(c *gin.Context) {
  339. channelTag := ChannelTag{}
  340. err := c.ShouldBindJSON(&channelTag)
  341. if err != nil || channelTag.Tag == "" {
  342. c.JSON(http.StatusOK, gin.H{
  343. "success": false,
  344. "message": "参数错误",
  345. })
  346. return
  347. }
  348. err = model.DisableChannelByTag(channelTag.Tag)
  349. if err != nil {
  350. c.JSON(http.StatusOK, gin.H{
  351. "success": false,
  352. "message": err.Error(),
  353. })
  354. return
  355. }
  356. c.JSON(http.StatusOK, gin.H{
  357. "success": true,
  358. "message": "",
  359. })
  360. return
  361. }
  362. func EnableTagChannels(c *gin.Context) {
  363. channelTag := ChannelTag{}
  364. err := c.ShouldBindJSON(&channelTag)
  365. if err != nil || channelTag.Tag == "" {
  366. c.JSON(http.StatusOK, gin.H{
  367. "success": false,
  368. "message": "参数错误",
  369. })
  370. return
  371. }
  372. err = model.EnableChannelByTag(channelTag.Tag)
  373. if err != nil {
  374. c.JSON(http.StatusOK, gin.H{
  375. "success": false,
  376. "message": err.Error(),
  377. })
  378. return
  379. }
  380. c.JSON(http.StatusOK, gin.H{
  381. "success": true,
  382. "message": "",
  383. })
  384. return
  385. }
  386. func EditTagChannels(c *gin.Context) {
  387. channelTag := ChannelTag{}
  388. err := c.ShouldBindJSON(&channelTag)
  389. if err != nil {
  390. c.JSON(http.StatusOK, gin.H{
  391. "success": false,
  392. "message": "参数错误",
  393. })
  394. return
  395. }
  396. if channelTag.Tag == "" {
  397. c.JSON(http.StatusOK, gin.H{
  398. "success": false,
  399. "message": "tag不能为空",
  400. })
  401. return
  402. }
  403. err = model.EditChannelByTag(channelTag.Tag, channelTag.NewTag, channelTag.ModelMapping, channelTag.Models, channelTag.Groups, channelTag.Priority, channelTag.Weight)
  404. if err != nil {
  405. c.JSON(http.StatusOK, gin.H{
  406. "success": false,
  407. "message": err.Error(),
  408. })
  409. return
  410. }
  411. c.JSON(http.StatusOK, gin.H{
  412. "success": true,
  413. "message": "",
  414. })
  415. return
  416. }
  417. type ChannelBatch struct {
  418. Ids []int `json:"ids"`
  419. Tag *string `json:"tag"`
  420. }
  421. func DeleteChannelBatch(c *gin.Context) {
  422. channelBatch := ChannelBatch{}
  423. err := c.ShouldBindJSON(&channelBatch)
  424. if err != nil || len(channelBatch.Ids) == 0 {
  425. c.JSON(http.StatusOK, gin.H{
  426. "success": false,
  427. "message": "参数错误",
  428. })
  429. return
  430. }
  431. err = model.BatchDeleteChannels(channelBatch.Ids)
  432. if err != nil {
  433. c.JSON(http.StatusOK, gin.H{
  434. "success": false,
  435. "message": err.Error(),
  436. })
  437. return
  438. }
  439. c.JSON(http.StatusOK, gin.H{
  440. "success": true,
  441. "message": "",
  442. "data": len(channelBatch.Ids),
  443. })
  444. return
  445. }
  446. func UpdateChannel(c *gin.Context) {
  447. channel := model.Channel{}
  448. err := c.ShouldBindJSON(&channel)
  449. if err != nil {
  450. c.JSON(http.StatusOK, gin.H{
  451. "success": false,
  452. "message": err.Error(),
  453. })
  454. return
  455. }
  456. if channel.Type == common.ChannelTypeVertexAi {
  457. if channel.Other == "" {
  458. c.JSON(http.StatusOK, gin.H{
  459. "success": false,
  460. "message": "部署地区不能为空",
  461. })
  462. return
  463. } else {
  464. if common.IsJsonStr(channel.Other) {
  465. // must have default
  466. regionMap := common.StrToMap(channel.Other)
  467. if regionMap["default"] == nil {
  468. c.JSON(http.StatusOK, gin.H{
  469. "success": false,
  470. "message": "部署地区必须包含default字段",
  471. })
  472. return
  473. }
  474. }
  475. }
  476. }
  477. err = channel.Update()
  478. if err != nil {
  479. c.JSON(http.StatusOK, gin.H{
  480. "success": false,
  481. "message": err.Error(),
  482. })
  483. return
  484. }
  485. c.JSON(http.StatusOK, gin.H{
  486. "success": true,
  487. "message": "",
  488. "data": channel,
  489. })
  490. return
  491. }
  492. func FetchModels(c *gin.Context) {
  493. var req struct {
  494. BaseURL string `json:"base_url"`
  495. Type int `json:"type"`
  496. Key string `json:"key"`
  497. }
  498. if err := c.ShouldBindJSON(&req); err != nil {
  499. c.JSON(http.StatusBadRequest, gin.H{
  500. "success": false,
  501. "message": "Invalid request",
  502. })
  503. return
  504. }
  505. baseURL := req.BaseURL
  506. if baseURL == "" {
  507. baseURL = common.ChannelBaseURLs[req.Type]
  508. }
  509. client := &http.Client{}
  510. url := fmt.Sprintf("%s/v1/models", baseURL)
  511. request, err := http.NewRequest("GET", url, nil)
  512. if err != nil {
  513. c.JSON(http.StatusInternalServerError, gin.H{
  514. "success": false,
  515. "message": err.Error(),
  516. })
  517. return
  518. }
  519. // remove line breaks and extra spaces.
  520. key := strings.TrimSpace(req.Key)
  521. // If the key contains a line break, only take the first part.
  522. key = strings.Split(key, "\n")[0]
  523. request.Header.Set("Authorization", "Bearer "+key)
  524. response, err := client.Do(request)
  525. if err != nil {
  526. c.JSON(http.StatusInternalServerError, gin.H{
  527. "success": false,
  528. "message": err.Error(),
  529. })
  530. return
  531. }
  532. //check status code
  533. if response.StatusCode != http.StatusOK {
  534. c.JSON(http.StatusInternalServerError, gin.H{
  535. "success": false,
  536. "message": "Failed to fetch models",
  537. })
  538. return
  539. }
  540. defer response.Body.Close()
  541. var result struct {
  542. Data []struct {
  543. ID string `json:"id"`
  544. } `json:"data"`
  545. }
  546. if err := json.NewDecoder(response.Body).Decode(&result); err != nil {
  547. c.JSON(http.StatusInternalServerError, gin.H{
  548. "success": false,
  549. "message": err.Error(),
  550. })
  551. return
  552. }
  553. var models []string
  554. for _, model := range result.Data {
  555. models = append(models, model.ID)
  556. }
  557. c.JSON(http.StatusOK, gin.H{
  558. "success": true,
  559. "data": models,
  560. })
  561. }
  562. func BatchSetChannelTag(c *gin.Context) {
  563. channelBatch := ChannelBatch{}
  564. err := c.ShouldBindJSON(&channelBatch)
  565. if err != nil || len(channelBatch.Ids) == 0 {
  566. c.JSON(http.StatusOK, gin.H{
  567. "success": false,
  568. "message": "参数错误",
  569. })
  570. return
  571. }
  572. err = model.BatchSetChannelTag(channelBatch.Ids, channelBatch.Tag)
  573. if err != nil {
  574. c.JSON(http.StatusOK, gin.H{
  575. "success": false,
  576. "message": err.Error(),
  577. })
  578. return
  579. }
  580. c.JSON(http.StatusOK, gin.H{
  581. "success": true,
  582. "message": "",
  583. "data": len(channelBatch.Ids),
  584. })
  585. return
  586. }