channel.go 50 KB


  1. package controller
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "net/http"
  6. "strconv"
  7. "strings"
  8. "github.com/QuantumNous/new-api/common"
  9. "github.com/QuantumNous/new-api/constant"
  10. "github.com/QuantumNous/new-api/dto"
  11. "github.com/QuantumNous/new-api/model"
  12. "github.com/QuantumNous/new-api/relay/channel/ollama"
  13. "github.com/QuantumNous/new-api/service"
  14. "github.com/gin-gonic/gin"
  15. )
  16. type OpenAIModel struct {
  17. ID string `json:"id"`
  18. Object string `json:"object"`
  19. Created int64 `json:"created"`
  20. OwnedBy string `json:"owned_by"`
  21. Metadata map[string]any `json:"metadata,omitempty"`
  22. Permission []struct {
  23. ID string `json:"id"`
  24. Object string `json:"object"`
  25. Created int64 `json:"created"`
  26. AllowCreateEngine bool `json:"allow_create_engine"`
  27. AllowSampling bool `json:"allow_sampling"`
  28. AllowLogprobs bool `json:"allow_logprobs"`
  29. AllowSearchIndices bool `json:"allow_search_indices"`
  30. AllowView bool `json:"allow_view"`
  31. AllowFineTuning bool `json:"allow_fine_tuning"`
  32. Organization string `json:"organization"`
  33. Group string `json:"group"`
  34. IsBlocking bool `json:"is_blocking"`
  35. } `json:"permission"`
  36. Root string `json:"root"`
  37. Parent string `json:"parent"`
  38. }
  39. type OpenAIModelsResponse struct {
  40. Data []OpenAIModel `json:"data"`
  41. Success bool `json:"success"`
  42. }
  43. func parseStatusFilter(statusParam string) int {
  44. switch strings.ToLower(statusParam) {
  45. case "enabled", "1":
  46. return common.ChannelStatusEnabled
  47. case "disabled", "0":
  48. return 0
  49. default:
  50. return -1
  51. }
  52. }
  53. func clearChannelInfo(channel *model.Channel) {
  54. if channel.ChannelInfo.IsMultiKey {
  55. channel.ChannelInfo.MultiKeyDisabledReason = nil
  56. channel.ChannelInfo.MultiKeyDisabledTime = nil
  57. }
  58. }
  59. func GetAllChannels(c *gin.Context) {
  60. pageInfo := common.GetPageQuery(c)
  61. channelData := make([]*model.Channel, 0)
  62. idSort, _ := strconv.ParseBool(c.Query("id_sort"))
  63. enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
  64. statusParam := c.Query("status")
  65. // statusFilter: -1 all, 1 enabled, 0 disabled (include auto & manual)
  66. statusFilter := parseStatusFilter(statusParam)
  67. // type filter
  68. typeStr := c.Query("type")
  69. typeFilter := -1
  70. if typeStr != "" {
  71. if t, err := strconv.Atoi(typeStr); err == nil {
  72. typeFilter = t
  73. }
  74. }
  75. var total int64
  76. if enableTagMode {
  77. tags, err := model.GetPaginatedTags(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
  78. if err != nil {
  79. c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
  80. return
  81. }
  82. for _, tag := range tags {
  83. if tag == nil || *tag == "" {
  84. continue
  85. }
  86. tagChannels, err := model.GetChannelsByTag(*tag, idSort, false)
  87. if err != nil {
  88. continue
  89. }
  90. filtered := make([]*model.Channel, 0)
  91. for _, ch := range tagChannels {
  92. if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
  93. continue
  94. }
  95. if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
  96. continue
  97. }
  98. if typeFilter >= 0 && ch.Type != typeFilter {
  99. continue
  100. }
  101. filtered = append(filtered, ch)
  102. }
  103. channelData = append(channelData, filtered...)
  104. }
  105. total, _ = model.CountAllTags()
  106. } else {
  107. baseQuery := model.DB.Model(&model.Channel{})
  108. if typeFilter >= 0 {
  109. baseQuery = baseQuery.Where("type = ?", typeFilter)
  110. }
  111. if statusFilter == common.ChannelStatusEnabled {
  112. baseQuery = baseQuery.Where("status = ?", common.ChannelStatusEnabled)
  113. } else if statusFilter == 0 {
  114. baseQuery = baseQuery.Where("status != ?", common.ChannelStatusEnabled)
  115. }
  116. baseQuery.Count(&total)
  117. order := "priority desc"
  118. if idSort {
  119. order = "id desc"
  120. }
  121. err := baseQuery.Order(order).Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("key").Find(&channelData).Error
  122. if err != nil {
  123. c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
  124. return
  125. }
  126. }
  127. for _, datum := range channelData {
  128. clearChannelInfo(datum)
  129. }
  130. countQuery := model.DB.Model(&model.Channel{})
  131. if statusFilter == common.ChannelStatusEnabled {
  132. countQuery = countQuery.Where("status = ?", common.ChannelStatusEnabled)
  133. } else if statusFilter == 0 {
  134. countQuery = countQuery.Where("status != ?", common.ChannelStatusEnabled)
  135. }
  136. var results []struct {
  137. Type int64
  138. Count int64
  139. }
  140. _ = countQuery.Select("type, count(*) as count").Group("type").Find(&results).Error
  141. typeCounts := make(map[int64]int64)
  142. for _, r := range results {
  143. typeCounts[r.Type] = r.Count
  144. }
  145. common.ApiSuccess(c, gin.H{
  146. "items": channelData,
  147. "total": total,
  148. "page": pageInfo.GetPage(),
  149. "page_size": pageInfo.GetPageSize(),
  150. "type_counts": typeCounts,
  151. })
  152. return
  153. }
  154. func buildFetchModelsHeaders(channel *model.Channel, key string) (http.Header, error) {
  155. var headers http.Header
  156. switch channel.Type {
  157. case constant.ChannelTypeAnthropic:
  158. headers = GetClaudeAuthHeader(key)
  159. default:
  160. headers = GetAuthHeader(key)
  161. }
  162. headerOverride := channel.GetHeaderOverride()
  163. for k, v := range headerOverride {
  164. str, ok := v.(string)
  165. if !ok {
  166. return nil, fmt.Errorf("invalid header override for key %s", k)
  167. }
  168. if strings.Contains(str, "{api_key}") {
  169. str = strings.ReplaceAll(str, "{api_key}", key)
  170. }
  171. headers.Set(k, str)
  172. }
  173. return headers, nil
  174. }
  175. func FetchUpstreamModels(c *gin.Context) {
  176. id, err := strconv.Atoi(c.Param("id"))
  177. if err != nil {
  178. common.ApiError(c, err)
  179. return
  180. }
  181. channel, err := model.GetChannelById(id, true)
  182. if err != nil {
  183. common.ApiError(c, err)
  184. return
  185. }
  186. baseURL := constant.ChannelBaseURLs[channel.Type]
  187. if channel.GetBaseURL() != "" {
  188. baseURL = channel.GetBaseURL()
  189. }
  190. // 对于 Ollama 渠道,使用特殊处理
  191. if channel.Type == constant.ChannelTypeOllama {
  192. key := strings.Split(channel.Key, "\n")[0]
  193. models, err := ollama.FetchOllamaModels(baseURL, key)
  194. if err != nil {
  195. c.JSON(http.StatusOK, gin.H{
  196. "success": false,
  197. "message": fmt.Sprintf("获取Ollama模型失败: %s", err.Error()),
  198. })
  199. return
  200. }
  201. result := OpenAIModelsResponse{
  202. Data: make([]OpenAIModel, 0, len(models)),
  203. }
  204. for _, modelInfo := range models {
  205. metadata := map[string]any{}
  206. if modelInfo.Size > 0 {
  207. metadata["size"] = modelInfo.Size
  208. }
  209. if modelInfo.Digest != "" {
  210. metadata["digest"] = modelInfo.Digest
  211. }
  212. if modelInfo.ModifiedAt != "" {
  213. metadata["modified_at"] = modelInfo.ModifiedAt
  214. }
  215. details := modelInfo.Details
  216. if details.ParentModel != "" || details.Format != "" || details.Family != "" || len(details.Families) > 0 || details.ParameterSize != "" || details.QuantizationLevel != "" {
  217. metadata["details"] = modelInfo.Details
  218. }
  219. if len(metadata) == 0 {
  220. metadata = nil
  221. }
  222. result.Data = append(result.Data, OpenAIModel{
  223. ID: modelInfo.Name,
  224. Object: "model",
  225. Created: 0,
  226. OwnedBy: "ollama",
  227. Metadata: metadata,
  228. })
  229. }
  230. c.JSON(http.StatusOK, gin.H{
  231. "success": true,
  232. "data": result.Data,
  233. })
  234. return
  235. }
  236. var url string
  237. switch channel.Type {
  238. case constant.ChannelTypeGemini:
  239. // curl https://example.com/v1beta/models?key=$GEMINI_API_KEY
  240. url = fmt.Sprintf("%s/v1beta/openai/models", baseURL) // Remove key in url since we need to use AuthHeader
  241. case constant.ChannelTypeAli:
  242. url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
  243. case constant.ChannelTypeZhipu_v4:
  244. if plan, ok := constant.ChannelSpecialBases[baseURL]; ok && plan.OpenAIBaseURL != "" {
  245. url = fmt.Sprintf("%s/models", plan.OpenAIBaseURL)
  246. } else {
  247. url = fmt.Sprintf("%s/api/paas/v4/models", baseURL)
  248. }
  249. case constant.ChannelTypeVolcEngine:
  250. if plan, ok := constant.ChannelSpecialBases[baseURL]; ok && plan.OpenAIBaseURL != "" {
  251. url = fmt.Sprintf("%s/v1/models", plan.OpenAIBaseURL)
  252. } else {
  253. url = fmt.Sprintf("%s/v1/models", baseURL)
  254. }
  255. case constant.ChannelTypeMoonshot:
  256. if plan, ok := constant.ChannelSpecialBases[baseURL]; ok && plan.OpenAIBaseURL != "" {
  257. url = fmt.Sprintf("%s/models", plan.OpenAIBaseURL)
  258. } else {
  259. url = fmt.Sprintf("%s/v1/models", baseURL)
  260. }
  261. default:
  262. url = fmt.Sprintf("%s/v1/models", baseURL)
  263. }
  264. // 获取用于请求的可用密钥(多密钥渠道优先使用启用状态的密钥)
  265. key, _, apiErr := channel.GetNextEnabledKey()
  266. if apiErr != nil {
  267. c.JSON(http.StatusOK, gin.H{
  268. "success": false,
  269. "message": fmt.Sprintf("获取渠道密钥失败: %s", apiErr.Error()),
  270. })
  271. return
  272. }
  273. key = strings.TrimSpace(key)
  274. headers, err := buildFetchModelsHeaders(channel, key)
  275. if err != nil {
  276. common.ApiError(c, err)
  277. return
  278. }
  279. body, err := GetResponseBody("GET", url, channel, headers)
  280. if err != nil {
  281. common.ApiError(c, err)
  282. return
  283. }
  284. var result OpenAIModelsResponse
  285. if err = json.Unmarshal(body, &result); err != nil {
  286. c.JSON(http.StatusOK, gin.H{
  287. "success": false,
  288. "message": fmt.Sprintf("解析响应失败: %s", err.Error()),
  289. })
  290. return
  291. }
  292. var ids []string
  293. for _, model := range result.Data {
  294. id := model.ID
  295. if channel.Type == constant.ChannelTypeGemini {
  296. id = strings.TrimPrefix(id, "models/")
  297. }
  298. ids = append(ids, id)
  299. }
  300. c.JSON(http.StatusOK, gin.H{
  301. "success": true,
  302. "message": "",
  303. "data": ids,
  304. })
  305. }
  306. func FixChannelsAbilities(c *gin.Context) {
  307. success, fails, err := model.FixAbility()
  308. if err != nil {
  309. common.ApiError(c, err)
  310. return
  311. }
  312. c.JSON(http.StatusOK, gin.H{
  313. "success": true,
  314. "message": "",
  315. "data": gin.H{
  316. "success": success,
  317. "fails": fails,
  318. },
  319. })
  320. }
  321. func SearchChannels(c *gin.Context) {
  322. keyword := c.Query("keyword")
  323. group := c.Query("group")
  324. modelKeyword := c.Query("model")
  325. statusParam := c.Query("status")
  326. statusFilter := parseStatusFilter(statusParam)
  327. idSort, _ := strconv.ParseBool(c.Query("id_sort"))
  328. enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
  329. channelData := make([]*model.Channel, 0)
  330. if enableTagMode {
  331. tags, err := model.SearchTags(keyword, group, modelKeyword, idSort)
  332. if err != nil {
  333. c.JSON(http.StatusOK, gin.H{
  334. "success": false,
  335. "message": err.Error(),
  336. })
  337. return
  338. }
  339. for _, tag := range tags {
  340. if tag != nil && *tag != "" {
  341. tagChannel, err := model.GetChannelsByTag(*tag, idSort, false)
  342. if err == nil {
  343. channelData = append(channelData, tagChannel...)
  344. }
  345. }
  346. }
  347. } else {
  348. channels, err := model.SearchChannels(keyword, group, modelKeyword, idSort)
  349. if err != nil {
  350. c.JSON(http.StatusOK, gin.H{
  351. "success": false,
  352. "message": err.Error(),
  353. })
  354. return
  355. }
  356. channelData = channels
  357. }
  358. if statusFilter == common.ChannelStatusEnabled || statusFilter == 0 {
  359. filtered := make([]*model.Channel, 0, len(channelData))
  360. for _, ch := range channelData {
  361. if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
  362. continue
  363. }
  364. if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
  365. continue
  366. }
  367. filtered = append(filtered, ch)
  368. }
  369. channelData = filtered
  370. }
  371. // calculate type counts for search results
  372. typeCounts := make(map[int64]int64)
  373. for _, channel := range channelData {
  374. typeCounts[int64(channel.Type)]++
  375. }
  376. typeParam := c.Query("type")
  377. typeFilter := -1
  378. if typeParam != "" {
  379. if tp, err := strconv.Atoi(typeParam); err == nil {
  380. typeFilter = tp
  381. }
  382. }
  383. if typeFilter >= 0 {
  384. filtered := make([]*model.Channel, 0, len(channelData))
  385. for _, ch := range channelData {
  386. if ch.Type == typeFilter {
  387. filtered = append(filtered, ch)
  388. }
  389. }
  390. channelData = filtered
  391. }
  392. page, _ := strconv.Atoi(c.DefaultQuery("p", "1"))
  393. pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
  394. if page < 1 {
  395. page = 1
  396. }
  397. if pageSize <= 0 {
  398. pageSize = 20
  399. }
  400. total := len(channelData)
  401. startIdx := (page - 1) * pageSize
  402. if startIdx > total {
  403. startIdx = total
  404. }
  405. endIdx := startIdx + pageSize
  406. if endIdx > total {
  407. endIdx = total
  408. }
  409. pagedData := channelData[startIdx:endIdx]
  410. for _, datum := range pagedData {
  411. clearChannelInfo(datum)
  412. }
  413. c.JSON(http.StatusOK, gin.H{
  414. "success": true,
  415. "message": "",
  416. "data": gin.H{
  417. "items": pagedData,
  418. "total": total,
  419. "type_counts": typeCounts,
  420. },
  421. })
  422. return
  423. }
  424. func GetChannel(c *gin.Context) {
  425. id, err := strconv.Atoi(c.Param("id"))
  426. if err != nil {
  427. common.ApiError(c, err)
  428. return
  429. }
  430. channel, err := model.GetChannelById(id, false)
  431. if err != nil {
  432. common.ApiError(c, err)
  433. return
  434. }
  435. if channel != nil {
  436. clearChannelInfo(channel)
  437. }
  438. c.JSON(http.StatusOK, gin.H{
  439. "success": true,
  440. "message": "",
  441. "data": channel,
  442. })
  443. return
  444. }
  445. // GetChannelKey 获取渠道密钥(需要通过安全验证中间件)
  446. // 此函数依赖 SecureVerificationRequired 中间件,确保用户已通过安全验证
  447. func GetChannelKey(c *gin.Context) {
  448. userId := c.GetInt("id")
  449. channelId, err := strconv.Atoi(c.Param("id"))
  450. if err != nil {
  451. common.ApiError(c, fmt.Errorf("渠道ID格式错误: %v", err))
  452. return
  453. }
  454. // 获取渠道信息(包含密钥)
  455. channel, err := model.GetChannelById(channelId, true)
  456. if err != nil {
  457. common.ApiError(c, fmt.Errorf("获取渠道信息失败: %v", err))
  458. return
  459. }
  460. if channel == nil {
  461. common.ApiError(c, fmt.Errorf("渠道不存在"))
  462. return
  463. }
  464. // 记录操作日志
  465. model.RecordLog(userId, model.LogTypeSystem, fmt.Sprintf("查看渠道密钥信息 (渠道ID: %d)", channelId))
  466. // 返回渠道密钥
  467. c.JSON(http.StatusOK, gin.H{
  468. "success": true,
  469. "message": "获取成功",
  470. "data": map[string]interface{}{
  471. "key": channel.Key,
  472. },
  473. })
  474. }
  475. // validateTwoFactorAuth 统一的2FA验证函数
  476. func validateTwoFactorAuth(twoFA *model.TwoFA, code string) bool {
  477. // 尝试验证TOTP
  478. if cleanCode, err := common.ValidateNumericCode(code); err == nil {
  479. if isValid, _ := twoFA.ValidateTOTPAndUpdateUsage(cleanCode); isValid {
  480. return true
  481. }
  482. }
  483. // 尝试验证备用码
  484. if isValid, err := twoFA.ValidateBackupCodeAndUpdateUsage(code); err == nil && isValid {
  485. return true
  486. }
  487. return false
  488. }
  489. // validateChannel 通用的渠道校验函数
  490. func validateChannel(channel *model.Channel, isAdd bool) error {
  491. // 校验 channel settings
  492. if err := channel.ValidateSettings(); err != nil {
  493. return fmt.Errorf("渠道额外设置[channel setting] 格式错误:%s", err.Error())
  494. }
  495. // 如果是添加操作,检查 channel 和 key 是否为空
  496. if isAdd {
  497. if channel == nil || channel.Key == "" {
  498. return fmt.Errorf("channel cannot be empty")
  499. }
  500. // 检查模型名称长度是否超过 255
  501. for _, m := range channel.GetModels() {
  502. if len(m) > 255 {
  503. return fmt.Errorf("模型名称过长: %s", m)
  504. }
  505. }
  506. }
  507. // VertexAI 特殊校验
  508. if channel.Type == constant.ChannelTypeVertexAi {
  509. if channel.Other == "" {
  510. return fmt.Errorf("部署地区不能为空")
  511. }
  512. regionMap, err := common.StrToMap(channel.Other)
  513. if err != nil {
  514. return fmt.Errorf("部署地区必须是标准的Json格式,例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}")
  515. }
  516. if regionMap["default"] == nil {
  517. return fmt.Errorf("部署地区必须包含default字段")
  518. }
  519. }
  520. return nil
  521. }
  522. type AddChannelRequest struct {
  523. Mode string `json:"mode"`
  524. MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
  525. BatchAddSetKeyPrefix2Name bool `json:"batch_add_set_key_prefix_2_name"`
  526. Channel *model.Channel `json:"channel"`
  527. }
  528. func getVertexArrayKeys(keys string) ([]string, error) {
  529. if keys == "" {
  530. return nil, nil
  531. }
  532. var keyArray []interface{}
  533. err := common.Unmarshal([]byte(keys), &keyArray)
  534. if err != nil {
  535. return nil, fmt.Errorf("批量添加 Vertex AI 必须使用标准的JsonArray格式,例如[{key1}, {key2}...],请检查输入: %w", err)
  536. }
  537. cleanKeys := make([]string, 0, len(keyArray))
  538. for _, key := range keyArray {
  539. var keyStr string
  540. switch v := key.(type) {
  541. case string:
  542. keyStr = strings.TrimSpace(v)
  543. default:
  544. bytes, err := json.Marshal(v)
  545. if err != nil {
  546. return nil, fmt.Errorf("Vertex AI key JSON 编码失败: %w", err)
  547. }
  548. keyStr = string(bytes)
  549. }
  550. if keyStr != "" {
  551. cleanKeys = append(cleanKeys, keyStr)
  552. }
  553. }
  554. if len(cleanKeys) == 0 {
  555. return nil, fmt.Errorf("批量添加 Vertex AI 的 keys 不能为空")
  556. }
  557. return cleanKeys, nil
  558. }
  559. func AddChannel(c *gin.Context) {
  560. addChannelRequest := AddChannelRequest{}
  561. err := c.ShouldBindJSON(&addChannelRequest)
  562. if err != nil {
  563. common.ApiError(c, err)
  564. return
  565. }
  566. // 使用统一的校验函数
  567. if err := validateChannel(addChannelRequest.Channel, true); err != nil {
  568. c.JSON(http.StatusOK, gin.H{
  569. "success": false,
  570. "message": err.Error(),
  571. })
  572. return
  573. }
  574. addChannelRequest.Channel.CreatedTime = common.GetTimestamp()
  575. keys := make([]string, 0)
  576. switch addChannelRequest.Mode {
  577. case "multi_to_single":
  578. addChannelRequest.Channel.ChannelInfo.IsMultiKey = true
  579. addChannelRequest.Channel.ChannelInfo.MultiKeyMode = addChannelRequest.MultiKeyMode
  580. if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi && addChannelRequest.Channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
  581. array, err := getVertexArrayKeys(addChannelRequest.Channel.Key)
  582. if err != nil {
  583. c.JSON(http.StatusOK, gin.H{
  584. "success": false,
  585. "message": err.Error(),
  586. })
  587. return
  588. }
  589. addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(array)
  590. addChannelRequest.Channel.Key = strings.Join(array, "\n")
  591. } else {
  592. cleanKeys := make([]string, 0)
  593. for _, key := range strings.Split(addChannelRequest.Channel.Key, "\n") {
  594. if key == "" {
  595. continue
  596. }
  597. key = strings.TrimSpace(key)
  598. cleanKeys = append(cleanKeys, key)
  599. }
  600. addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(cleanKeys)
  601. addChannelRequest.Channel.Key = strings.Join(cleanKeys, "\n")
  602. }
  603. keys = []string{addChannelRequest.Channel.Key}
  604. case "batch":
  605. if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi && addChannelRequest.Channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
  606. // multi json
  607. keys, err = getVertexArrayKeys(addChannelRequest.Channel.Key)
  608. if err != nil {
  609. c.JSON(http.StatusOK, gin.H{
  610. "success": false,
  611. "message": err.Error(),
  612. })
  613. return
  614. }
  615. } else {
  616. keys = strings.Split(addChannelRequest.Channel.Key, "\n")
  617. }
  618. case "single":
  619. keys = []string{addChannelRequest.Channel.Key}
  620. default:
  621. c.JSON(http.StatusOK, gin.H{
  622. "success": false,
  623. "message": "不支持的添加模式",
  624. })
  625. return
  626. }
  627. channels := make([]model.Channel, 0, len(keys))
  628. for _, key := range keys {
  629. if key == "" {
  630. continue
  631. }
  632. localChannel := addChannelRequest.Channel
  633. localChannel.Key = key
  634. if addChannelRequest.BatchAddSetKeyPrefix2Name && len(keys) > 1 {
  635. keyPrefix := localChannel.Key
  636. if len(localChannel.Key) > 8 {
  637. keyPrefix = localChannel.Key[:8]
  638. }
  639. localChannel.Name = fmt.Sprintf("%s %s", localChannel.Name, keyPrefix)
  640. }
  641. channels = append(channels, *localChannel)
  642. }
  643. err = model.BatchInsertChannels(channels)
  644. if err != nil {
  645. common.ApiError(c, err)
  646. return
  647. }
  648. service.ResetProxyClientCache()
  649. c.JSON(http.StatusOK, gin.H{
  650. "success": true,
  651. "message": "",
  652. })
  653. return
  654. }
  655. func DeleteChannel(c *gin.Context) {
  656. id, _ := strconv.Atoi(c.Param("id"))
  657. channel := model.Channel{Id: id}
  658. err := channel.Delete()
  659. if err != nil {
  660. common.ApiError(c, err)
  661. return
  662. }
  663. model.InitChannelCache()
  664. c.JSON(http.StatusOK, gin.H{
  665. "success": true,
  666. "message": "",
  667. })
  668. return
  669. }
  670. func DeleteDisabledChannel(c *gin.Context) {
  671. rows, err := model.DeleteDisabledChannel()
  672. if err != nil {
  673. common.ApiError(c, err)
  674. return
  675. }
  676. model.InitChannelCache()
  677. c.JSON(http.StatusOK, gin.H{
  678. "success": true,
  679. "message": "",
  680. "data": rows,
  681. })
  682. return
  683. }
  684. type ChannelTag struct {
  685. Tag string `json:"tag"`
  686. NewTag *string `json:"new_tag"`
  687. Priority *int64 `json:"priority"`
  688. Weight *uint `json:"weight"`
  689. ModelMapping *string `json:"model_mapping"`
  690. Models *string `json:"models"`
  691. Groups *string `json:"groups"`
  692. ParamOverride *string `json:"param_override"`
  693. HeaderOverride *string `json:"header_override"`
  694. }
  695. func DisableTagChannels(c *gin.Context) {
  696. channelTag := ChannelTag{}
  697. err := c.ShouldBindJSON(&channelTag)
  698. if err != nil || channelTag.Tag == "" {
  699. c.JSON(http.StatusOK, gin.H{
  700. "success": false,
  701. "message": "参数错误",
  702. })
  703. return
  704. }
  705. err = model.DisableChannelByTag(channelTag.Tag)
  706. if err != nil {
  707. common.ApiError(c, err)
  708. return
  709. }
  710. model.InitChannelCache()
  711. c.JSON(http.StatusOK, gin.H{
  712. "success": true,
  713. "message": "",
  714. })
  715. return
  716. }
  717. func EnableTagChannels(c *gin.Context) {
  718. channelTag := ChannelTag{}
  719. err := c.ShouldBindJSON(&channelTag)
  720. if err != nil || channelTag.Tag == "" {
  721. c.JSON(http.StatusOK, gin.H{
  722. "success": false,
  723. "message": "参数错误",
  724. })
  725. return
  726. }
  727. err = model.EnableChannelByTag(channelTag.Tag)
  728. if err != nil {
  729. common.ApiError(c, err)
  730. return
  731. }
  732. model.InitChannelCache()
  733. c.JSON(http.StatusOK, gin.H{
  734. "success": true,
  735. "message": "",
  736. })
  737. return
  738. }
  739. func EditTagChannels(c *gin.Context) {
  740. channelTag := ChannelTag{}
  741. err := c.ShouldBindJSON(&channelTag)
  742. if err != nil {
  743. c.JSON(http.StatusOK, gin.H{
  744. "success": false,
  745. "message": "参数错误",
  746. })
  747. return
  748. }
  749. if channelTag.Tag == "" {
  750. c.JSON(http.StatusOK, gin.H{
  751. "success": false,
  752. "message": "tag不能为空",
  753. })
  754. return
  755. }
  756. if channelTag.ParamOverride != nil {
  757. trimmed := strings.TrimSpace(*channelTag.ParamOverride)
  758. if trimmed != "" && !json.Valid([]byte(trimmed)) {
  759. c.JSON(http.StatusOK, gin.H{
  760. "success": false,
  761. "message": "参数覆盖必须是合法的 JSON 格式",
  762. })
  763. return
  764. }
  765. channelTag.ParamOverride = common.GetPointer[string](trimmed)
  766. }
  767. if channelTag.HeaderOverride != nil {
  768. trimmed := strings.TrimSpace(*channelTag.HeaderOverride)
  769. if trimmed != "" && !json.Valid([]byte(trimmed)) {
  770. c.JSON(http.StatusOK, gin.H{
  771. "success": false,
  772. "message": "请求头覆盖必须是合法的 JSON 格式",
  773. })
  774. return
  775. }
  776. channelTag.HeaderOverride = common.GetPointer[string](trimmed)
  777. }
  778. err = model.EditChannelByTag(channelTag.Tag, channelTag.NewTag, channelTag.ModelMapping, channelTag.Models, channelTag.Groups, channelTag.Priority, channelTag.Weight, channelTag.ParamOverride, channelTag.HeaderOverride)
  779. if err != nil {
  780. common.ApiError(c, err)
  781. return
  782. }
  783. model.InitChannelCache()
  784. c.JSON(http.StatusOK, gin.H{
  785. "success": true,
  786. "message": "",
  787. })
  788. return
  789. }
  790. type ChannelBatch struct {
  791. Ids []int `json:"ids"`
  792. Tag *string `json:"tag"`
  793. }
  794. func DeleteChannelBatch(c *gin.Context) {
  795. channelBatch := ChannelBatch{}
  796. err := c.ShouldBindJSON(&channelBatch)
  797. if err != nil || len(channelBatch.Ids) == 0 {
  798. c.JSON(http.StatusOK, gin.H{
  799. "success": false,
  800. "message": "参数错误",
  801. })
  802. return
  803. }
  804. err = model.BatchDeleteChannels(channelBatch.Ids)
  805. if err != nil {
  806. common.ApiError(c, err)
  807. return
  808. }
  809. model.InitChannelCache()
  810. c.JSON(http.StatusOK, gin.H{
  811. "success": true,
  812. "message": "",
  813. "data": len(channelBatch.Ids),
  814. })
  815. return
  816. }
  817. type PatchChannel struct {
  818. model.Channel
  819. MultiKeyMode *string `json:"multi_key_mode"`
  820. KeyMode *string `json:"key_mode"` // 多key模式下密钥覆盖或者追加
  821. }
  822. func UpdateChannel(c *gin.Context) {
  823. channel := PatchChannel{}
  824. err := c.ShouldBindJSON(&channel)
  825. if err != nil {
  826. common.ApiError(c, err)
  827. return
  828. }
  829. // 使用统一的校验函数
  830. if err := validateChannel(&channel.Channel, false); err != nil {
  831. c.JSON(http.StatusOK, gin.H{
  832. "success": false,
  833. "message": err.Error(),
  834. })
  835. return
  836. }
  837. // Preserve existing ChannelInfo to ensure multi-key channels keep correct state even if the client does not send ChannelInfo in the request.
  838. originChannel, err := model.GetChannelById(channel.Id, true)
  839. if err != nil {
  840. c.JSON(http.StatusOK, gin.H{
  841. "success": false,
  842. "message": err.Error(),
  843. })
  844. return
  845. }
  846. // Always copy the original ChannelInfo so that fields like IsMultiKey and MultiKeySize are retained.
  847. channel.ChannelInfo = originChannel.ChannelInfo
  848. // If the request explicitly specifies a new MultiKeyMode, apply it on top of the original info.
  849. if channel.MultiKeyMode != nil && *channel.MultiKeyMode != "" {
  850. channel.ChannelInfo.MultiKeyMode = constant.MultiKeyMode(*channel.MultiKeyMode)
  851. }
  852. // 处理多key模式下的密钥追加/覆盖逻辑
  853. if channel.KeyMode != nil && channel.ChannelInfo.IsMultiKey {
  854. switch *channel.KeyMode {
  855. case "append":
  856. // 追加模式:将新密钥添加到现有密钥列表
  857. if originChannel.Key != "" {
  858. var newKeys []string
  859. var existingKeys []string
  860. // 解析现有密钥
  861. if strings.HasPrefix(strings.TrimSpace(originChannel.Key), "[") {
  862. // JSON数组格式
  863. var arr []json.RawMessage
  864. if err := json.Unmarshal([]byte(strings.TrimSpace(originChannel.Key)), &arr); err == nil {
  865. existingKeys = make([]string, len(arr))
  866. for i, v := range arr {
  867. existingKeys[i] = string(v)
  868. }
  869. }
  870. } else {
  871. // 换行分隔格式
  872. existingKeys = strings.Split(strings.Trim(originChannel.Key, "\n"), "\n")
  873. }
  874. // 处理 Vertex AI 的特殊情况
  875. if channel.Type == constant.ChannelTypeVertexAi && channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
  876. // 尝试解析新密钥为JSON数组
  877. if strings.HasPrefix(strings.TrimSpace(channel.Key), "[") {
  878. array, err := getVertexArrayKeys(channel.Key)
  879. if err != nil {
  880. c.JSON(http.StatusOK, gin.H{
  881. "success": false,
  882. "message": "追加密钥解析失败: " + err.Error(),
  883. })
  884. return
  885. }
  886. newKeys = array
  887. } else {
  888. // 单个JSON密钥
  889. newKeys = []string{channel.Key}
  890. }
  891. // 合并密钥
  892. allKeys := append(existingKeys, newKeys...)
  893. channel.Key = strings.Join(allKeys, "\n")
  894. } else {
  895. // 普通渠道的处理
  896. inputKeys := strings.Split(channel.Key, "\n")
  897. for _, key := range inputKeys {
  898. key = strings.TrimSpace(key)
  899. if key != "" {
  900. newKeys = append(newKeys, key)
  901. }
  902. }
  903. // 合并密钥
  904. allKeys := append(existingKeys, newKeys...)
  905. channel.Key = strings.Join(allKeys, "\n")
  906. }
  907. }
  908. case "replace":
  909. // 覆盖模式:直接使用新密钥(默认行为,不需要特殊处理)
  910. }
  911. }
  912. err = channel.Update()
  913. if err != nil {
  914. common.ApiError(c, err)
  915. return
  916. }
  917. model.InitChannelCache()
  918. service.ResetProxyClientCache()
  919. channel.Key = ""
  920. clearChannelInfo(&channel.Channel)
  921. c.JSON(http.StatusOK, gin.H{
  922. "success": true,
  923. "message": "",
  924. "data": channel,
  925. })
  926. return
  927. }
  928. func FetchModels(c *gin.Context) {
  929. var req struct {
  930. BaseURL string `json:"base_url"`
  931. Type int `json:"type"`
  932. Key string `json:"key"`
  933. }
  934. if err := c.ShouldBindJSON(&req); err != nil {
  935. c.JSON(http.StatusBadRequest, gin.H{
  936. "success": false,
  937. "message": "Invalid request",
  938. })
  939. return
  940. }
  941. baseURL := req.BaseURL
  942. if baseURL == "" {
  943. baseURL = constant.ChannelBaseURLs[req.Type]
  944. }
  945. // remove line breaks and extra spaces.
  946. key := strings.TrimSpace(req.Key)
  947. key = strings.Split(key, "\n")[0]
  948. if req.Type == constant.ChannelTypeOllama {
  949. models, err := ollama.FetchOllamaModels(baseURL, key)
  950. if err != nil {
  951. c.JSON(http.StatusOK, gin.H{
  952. "success": false,
  953. "message": fmt.Sprintf("获取Ollama模型失败: %s", err.Error()),
  954. })
  955. return
  956. }
  957. names := make([]string, 0, len(models))
  958. for _, modelInfo := range models {
  959. names = append(names, modelInfo.Name)
  960. }
  961. c.JSON(http.StatusOK, gin.H{
  962. "success": true,
  963. "data": names,
  964. })
  965. return
  966. }
  967. client := &http.Client{}
  968. url := fmt.Sprintf("%s/v1/models", baseURL)
  969. request, err := http.NewRequest("GET", url, nil)
  970. if err != nil {
  971. c.JSON(http.StatusInternalServerError, gin.H{
  972. "success": false,
  973. "message": err.Error(),
  974. })
  975. return
  976. }
  977. request.Header.Set("Authorization", "Bearer "+key)
  978. response, err := client.Do(request)
  979. if err != nil {
  980. c.JSON(http.StatusInternalServerError, gin.H{
  981. "success": false,
  982. "message": err.Error(),
  983. })
  984. return
  985. }
  986. //check status code
  987. if response.StatusCode != http.StatusOK {
  988. c.JSON(http.StatusInternalServerError, gin.H{
  989. "success": false,
  990. "message": "Failed to fetch models",
  991. })
  992. return
  993. }
  994. defer response.Body.Close()
  995. var result struct {
  996. Data []struct {
  997. ID string `json:"id"`
  998. } `json:"data"`
  999. }
  1000. if err := json.NewDecoder(response.Body).Decode(&result); err != nil {
  1001. c.JSON(http.StatusInternalServerError, gin.H{
  1002. "success": false,
  1003. "message": err.Error(),
  1004. })
  1005. return
  1006. }
  1007. var models []string
  1008. for _, model := range result.Data {
  1009. models = append(models, model.ID)
  1010. }
  1011. c.JSON(http.StatusOK, gin.H{
  1012. "success": true,
  1013. "data": models,
  1014. })
  1015. }
  1016. func BatchSetChannelTag(c *gin.Context) {
  1017. channelBatch := ChannelBatch{}
  1018. err := c.ShouldBindJSON(&channelBatch)
  1019. if err != nil || len(channelBatch.Ids) == 0 {
  1020. c.JSON(http.StatusOK, gin.H{
  1021. "success": false,
  1022. "message": "参数错误",
  1023. })
  1024. return
  1025. }
  1026. err = model.BatchSetChannelTag(channelBatch.Ids, channelBatch.Tag)
  1027. if err != nil {
  1028. common.ApiError(c, err)
  1029. return
  1030. }
  1031. model.InitChannelCache()
  1032. c.JSON(http.StatusOK, gin.H{
  1033. "success": true,
  1034. "message": "",
  1035. "data": len(channelBatch.Ids),
  1036. })
  1037. return
  1038. }
  1039. func GetTagModels(c *gin.Context) {
  1040. tag := c.Query("tag")
  1041. if tag == "" {
  1042. c.JSON(http.StatusBadRequest, gin.H{
  1043. "success": false,
  1044. "message": "tag不能为空",
  1045. })
  1046. return
  1047. }
  1048. channels, err := model.GetChannelsByTag(tag, false, false) // idSort=false, selectAll=false
  1049. if err != nil {
  1050. c.JSON(http.StatusInternalServerError, gin.H{
  1051. "success": false,
  1052. "message": err.Error(),
  1053. })
  1054. return
  1055. }
  1056. var longestModels string
  1057. maxLength := 0
  1058. // Find the longest models string among all channels with the given tag
  1059. for _, channel := range channels {
  1060. if channel.Models != "" {
  1061. currentModels := strings.Split(channel.Models, ",")
  1062. if len(currentModels) > maxLength {
  1063. maxLength = len(currentModels)
  1064. longestModels = channel.Models
  1065. }
  1066. }
  1067. }
  1068. c.JSON(http.StatusOK, gin.H{
  1069. "success": true,
  1070. "message": "",
  1071. "data": longestModels,
  1072. })
  1073. return
  1074. }
  1075. // CopyChannel handles cloning an existing channel with its key.
  1076. // POST /api/channel/copy/:id
  1077. // Optional query params:
  1078. //
  1079. // suffix - string appended to the original name (default "_复制")
  1080. // reset_balance - bool, when true will reset balance & used_quota to 0 (default true)
  1081. func CopyChannel(c *gin.Context) {
  1082. id, err := strconv.Atoi(c.Param("id"))
  1083. if err != nil {
  1084. c.JSON(http.StatusOK, gin.H{"success": false, "message": "invalid id"})
  1085. return
  1086. }
  1087. suffix := c.DefaultQuery("suffix", "_复制")
  1088. resetBalance := true
  1089. if rbStr := c.DefaultQuery("reset_balance", "true"); rbStr != "" {
  1090. if v, err := strconv.ParseBool(rbStr); err == nil {
  1091. resetBalance = v
  1092. }
  1093. }
  1094. // fetch original channel with key
  1095. origin, err := model.GetChannelById(id, true)
  1096. if err != nil {
  1097. c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
  1098. return
  1099. }
  1100. // clone channel
  1101. clone := *origin // shallow copy is sufficient as we will overwrite primitives
  1102. clone.Id = 0 // let DB auto-generate
  1103. clone.CreatedTime = common.GetTimestamp()
  1104. clone.Name = origin.Name + suffix
  1105. clone.TestTime = 0
  1106. clone.ResponseTime = 0
  1107. if resetBalance {
  1108. clone.Balance = 0
  1109. clone.UsedQuota = 0
  1110. }
  1111. // insert
  1112. if err := model.BatchInsertChannels([]model.Channel{clone}); err != nil {
  1113. c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
  1114. return
  1115. }
  1116. model.InitChannelCache()
  1117. // success
  1118. c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": gin.H{"id": clone.Id}})
  1119. }
  1120. // MultiKeyManageRequest represents the request for multi-key management operations
  1121. type MultiKeyManageRequest struct {
  1122. ChannelId int `json:"channel_id"`
  1123. Action string `json:"action"` // "disable_key", "enable_key", "delete_key", "delete_disabled_keys", "get_key_status"
  1124. KeyIndex *int `json:"key_index,omitempty"` // for disable_key, enable_key, and delete_key actions
  1125. Page int `json:"page,omitempty"` // for get_key_status pagination
  1126. PageSize int `json:"page_size,omitempty"` // for get_key_status pagination
  1127. Status *int `json:"status,omitempty"` // for get_key_status filtering: 1=enabled, 2=manual_disabled, 3=auto_disabled, nil=all
  1128. }
  1129. // MultiKeyStatusResponse represents the response for key status query
  1130. type MultiKeyStatusResponse struct {
  1131. Keys []KeyStatus `json:"keys"`
  1132. Total int `json:"total"`
  1133. Page int `json:"page"`
  1134. PageSize int `json:"page_size"`
  1135. TotalPages int `json:"total_pages"`
  1136. // Statistics
  1137. EnabledCount int `json:"enabled_count"`
  1138. ManualDisabledCount int `json:"manual_disabled_count"`
  1139. AutoDisabledCount int `json:"auto_disabled_count"`
  1140. }
  1141. type KeyStatus struct {
  1142. Index int `json:"index"`
  1143. Status int `json:"status"` // 1: enabled, 2: disabled
  1144. DisabledTime int64 `json:"disabled_time,omitempty"`
  1145. Reason string `json:"reason,omitempty"`
  1146. KeyPreview string `json:"key_preview"` // first 10 chars of key for identification
  1147. }
  1148. // ManageMultiKeys handles multi-key management operations
  1149. func ManageMultiKeys(c *gin.Context) {
  1150. request := MultiKeyManageRequest{}
  1151. err := c.ShouldBindJSON(&request)
  1152. if err != nil {
  1153. common.ApiError(c, err)
  1154. return
  1155. }
  1156. channel, err := model.GetChannelById(request.ChannelId, true)
  1157. if err != nil {
  1158. c.JSON(http.StatusOK, gin.H{
  1159. "success": false,
  1160. "message": "渠道不存在",
  1161. })
  1162. return
  1163. }
  1164. if !channel.ChannelInfo.IsMultiKey {
  1165. c.JSON(http.StatusOK, gin.H{
  1166. "success": false,
  1167. "message": "该渠道不是多密钥模式",
  1168. })
  1169. return
  1170. }
  1171. lock := model.GetChannelPollingLock(channel.Id)
  1172. lock.Lock()
  1173. defer lock.Unlock()
  1174. switch request.Action {
  1175. case "get_key_status":
  1176. keys := channel.GetKeys()
  1177. // Default pagination parameters
  1178. page := request.Page
  1179. pageSize := request.PageSize
  1180. if page <= 0 {
  1181. page = 1
  1182. }
  1183. if pageSize <= 0 {
  1184. pageSize = 50 // Default page size
  1185. }
  1186. // Statistics for all keys (unchanged by filtering)
  1187. var enabledCount, manualDisabledCount, autoDisabledCount int
  1188. // Build all key status data first
  1189. var allKeyStatusList []KeyStatus
  1190. for i, key := range keys {
  1191. status := 1 // default enabled
  1192. var disabledTime int64
  1193. var reason string
  1194. if channel.ChannelInfo.MultiKeyStatusList != nil {
  1195. if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
  1196. status = s
  1197. }
  1198. }
  1199. // Count for statistics (all keys)
  1200. switch status {
  1201. case 1:
  1202. enabledCount++
  1203. case 2:
  1204. manualDisabledCount++
  1205. case 3:
  1206. autoDisabledCount++
  1207. }
  1208. if status != 1 {
  1209. if channel.ChannelInfo.MultiKeyDisabledTime != nil {
  1210. disabledTime = channel.ChannelInfo.MultiKeyDisabledTime[i]
  1211. }
  1212. if channel.ChannelInfo.MultiKeyDisabledReason != nil {
  1213. reason = channel.ChannelInfo.MultiKeyDisabledReason[i]
  1214. }
  1215. }
  1216. // Create key preview (first 10 chars)
  1217. keyPreview := key
  1218. if len(key) > 10 {
  1219. keyPreview = key[:10] + "..."
  1220. }
  1221. allKeyStatusList = append(allKeyStatusList, KeyStatus{
  1222. Index: i,
  1223. Status: status,
  1224. DisabledTime: disabledTime,
  1225. Reason: reason,
  1226. KeyPreview: keyPreview,
  1227. })
  1228. }
  1229. // Apply status filter if specified
  1230. var filteredKeyStatusList []KeyStatus
  1231. if request.Status != nil {
  1232. for _, keyStatus := range allKeyStatusList {
  1233. if keyStatus.Status == *request.Status {
  1234. filteredKeyStatusList = append(filteredKeyStatusList, keyStatus)
  1235. }
  1236. }
  1237. } else {
  1238. filteredKeyStatusList = allKeyStatusList
  1239. }
  1240. // Calculate pagination based on filtered results
  1241. filteredTotal := len(filteredKeyStatusList)
  1242. totalPages := (filteredTotal + pageSize - 1) / pageSize
  1243. if totalPages == 0 {
  1244. totalPages = 1
  1245. }
  1246. if page > totalPages {
  1247. page = totalPages
  1248. }
  1249. // Calculate range for current page
  1250. start := (page - 1) * pageSize
  1251. end := start + pageSize
  1252. if end > filteredTotal {
  1253. end = filteredTotal
  1254. }
  1255. // Get the page data
  1256. var pageKeyStatusList []KeyStatus
  1257. if start < filteredTotal {
  1258. pageKeyStatusList = filteredKeyStatusList[start:end]
  1259. }
  1260. c.JSON(http.StatusOK, gin.H{
  1261. "success": true,
  1262. "message": "",
  1263. "data": MultiKeyStatusResponse{
  1264. Keys: pageKeyStatusList,
  1265. Total: filteredTotal, // Total of filtered results
  1266. Page: page,
  1267. PageSize: pageSize,
  1268. TotalPages: totalPages,
  1269. EnabledCount: enabledCount, // Overall statistics
  1270. ManualDisabledCount: manualDisabledCount, // Overall statistics
  1271. AutoDisabledCount: autoDisabledCount, // Overall statistics
  1272. },
  1273. })
  1274. return
  1275. case "disable_key":
  1276. if request.KeyIndex == nil {
  1277. c.JSON(http.StatusOK, gin.H{
  1278. "success": false,
  1279. "message": "未指定要禁用的密钥索引",
  1280. })
  1281. return
  1282. }
  1283. keyIndex := *request.KeyIndex
  1284. if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
  1285. c.JSON(http.StatusOK, gin.H{
  1286. "success": false,
  1287. "message": "密钥索引超出范围",
  1288. })
  1289. return
  1290. }
  1291. if channel.ChannelInfo.MultiKeyStatusList == nil {
  1292. channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
  1293. }
  1294. if channel.ChannelInfo.MultiKeyDisabledTime == nil {
  1295. channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
  1296. }
  1297. if channel.ChannelInfo.MultiKeyDisabledReason == nil {
  1298. channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
  1299. }
  1300. channel.ChannelInfo.MultiKeyStatusList[keyIndex] = 2 // disabled
  1301. err = channel.Update()
  1302. if err != nil {
  1303. common.ApiError(c, err)
  1304. return
  1305. }
  1306. model.InitChannelCache()
  1307. c.JSON(http.StatusOK, gin.H{
  1308. "success": true,
  1309. "message": "密钥已禁用",
  1310. })
  1311. return
  1312. case "enable_key":
  1313. if request.KeyIndex == nil {
  1314. c.JSON(http.StatusOK, gin.H{
  1315. "success": false,
  1316. "message": "未指定要启用的密钥索引",
  1317. })
  1318. return
  1319. }
  1320. keyIndex := *request.KeyIndex
  1321. if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
  1322. c.JSON(http.StatusOK, gin.H{
  1323. "success": false,
  1324. "message": "密钥索引超出范围",
  1325. })
  1326. return
  1327. }
  1328. // 从状态列表中删除该密钥的记录,使其回到默认启用状态
  1329. if channel.ChannelInfo.MultiKeyStatusList != nil {
  1330. delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex)
  1331. }
  1332. if channel.ChannelInfo.MultiKeyDisabledTime != nil {
  1333. delete(channel.ChannelInfo.MultiKeyDisabledTime, keyIndex)
  1334. }
  1335. if channel.ChannelInfo.MultiKeyDisabledReason != nil {
  1336. delete(channel.ChannelInfo.MultiKeyDisabledReason, keyIndex)
  1337. }
  1338. err = channel.Update()
  1339. if err != nil {
  1340. common.ApiError(c, err)
  1341. return
  1342. }
  1343. model.InitChannelCache()
  1344. c.JSON(http.StatusOK, gin.H{
  1345. "success": true,
  1346. "message": "密钥已启用",
  1347. })
  1348. return
  1349. case "enable_all_keys":
  1350. // 清空所有禁用状态,使所有密钥回到默认启用状态
  1351. var enabledCount int
  1352. if channel.ChannelInfo.MultiKeyStatusList != nil {
  1353. enabledCount = len(channel.ChannelInfo.MultiKeyStatusList)
  1354. }
  1355. channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
  1356. channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
  1357. channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
  1358. err = channel.Update()
  1359. if err != nil {
  1360. common.ApiError(c, err)
  1361. return
  1362. }
  1363. model.InitChannelCache()
  1364. c.JSON(http.StatusOK, gin.H{
  1365. "success": true,
  1366. "message": fmt.Sprintf("已启用 %d 个密钥", enabledCount),
  1367. })
  1368. return
  1369. case "disable_all_keys":
  1370. // 禁用所有启用的密钥
  1371. if channel.ChannelInfo.MultiKeyStatusList == nil {
  1372. channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
  1373. }
  1374. if channel.ChannelInfo.MultiKeyDisabledTime == nil {
  1375. channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
  1376. }
  1377. if channel.ChannelInfo.MultiKeyDisabledReason == nil {
  1378. channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
  1379. }
  1380. var disabledCount int
  1381. for i := 0; i < channel.ChannelInfo.MultiKeySize; i++ {
  1382. status := 1 // default enabled
  1383. if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
  1384. status = s
  1385. }
  1386. // 只禁用当前启用的密钥
  1387. if status == 1 {
  1388. channel.ChannelInfo.MultiKeyStatusList[i] = 2 // disabled
  1389. disabledCount++
  1390. }
  1391. }
  1392. if disabledCount == 0 {
  1393. c.JSON(http.StatusOK, gin.H{
  1394. "success": false,
  1395. "message": "没有可禁用的密钥",
  1396. })
  1397. return
  1398. }
  1399. err = channel.Update()
  1400. if err != nil {
  1401. common.ApiError(c, err)
  1402. return
  1403. }
  1404. model.InitChannelCache()
  1405. c.JSON(http.StatusOK, gin.H{
  1406. "success": true,
  1407. "message": fmt.Sprintf("已禁用 %d 个密钥", disabledCount),
  1408. })
  1409. return
  1410. case "delete_key":
  1411. if request.KeyIndex == nil {
  1412. c.JSON(http.StatusOK, gin.H{
  1413. "success": false,
  1414. "message": "未指定要删除的密钥索引",
  1415. })
  1416. return
  1417. }
  1418. keyIndex := *request.KeyIndex
  1419. if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
  1420. c.JSON(http.StatusOK, gin.H{
  1421. "success": false,
  1422. "message": "密钥索引超出范围",
  1423. })
  1424. return
  1425. }
  1426. keys := channel.GetKeys()
  1427. var remainingKeys []string
  1428. var newStatusList = make(map[int]int)
  1429. var newDisabledTime = make(map[int]int64)
  1430. var newDisabledReason = make(map[int]string)
  1431. newIndex := 0
  1432. for i, key := range keys {
  1433. // 跳过要删除的密钥
  1434. if i == keyIndex {
  1435. continue
  1436. }
  1437. remainingKeys = append(remainingKeys, key)
  1438. // 保留其他密钥的状态信息,重新索引
  1439. if channel.ChannelInfo.MultiKeyStatusList != nil {
  1440. if status, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists && status != 1 {
  1441. newStatusList[newIndex] = status
  1442. }
  1443. }
  1444. if channel.ChannelInfo.MultiKeyDisabledTime != nil {
  1445. if t, exists := channel.ChannelInfo.MultiKeyDisabledTime[i]; exists {
  1446. newDisabledTime[newIndex] = t
  1447. }
  1448. }
  1449. if channel.ChannelInfo.MultiKeyDisabledReason != nil {
  1450. if r, exists := channel.ChannelInfo.MultiKeyDisabledReason[i]; exists {
  1451. newDisabledReason[newIndex] = r
  1452. }
  1453. }
  1454. newIndex++
  1455. }
  1456. if len(remainingKeys) == 0 {
  1457. c.JSON(http.StatusOK, gin.H{
  1458. "success": false,
  1459. "message": "不能删除最后一个密钥",
  1460. })
  1461. return
  1462. }
  1463. // Update channel with remaining keys
  1464. channel.Key = strings.Join(remainingKeys, "\n")
  1465. channel.ChannelInfo.MultiKeySize = len(remainingKeys)
  1466. channel.ChannelInfo.MultiKeyStatusList = newStatusList
  1467. channel.ChannelInfo.MultiKeyDisabledTime = newDisabledTime
  1468. channel.ChannelInfo.MultiKeyDisabledReason = newDisabledReason
  1469. err = channel.Update()
  1470. if err != nil {
  1471. common.ApiError(c, err)
  1472. return
  1473. }
  1474. model.InitChannelCache()
  1475. c.JSON(http.StatusOK, gin.H{
  1476. "success": true,
  1477. "message": "密钥已删除",
  1478. })
  1479. return
  1480. case "delete_disabled_keys":
  1481. keys := channel.GetKeys()
  1482. var remainingKeys []string
  1483. var deletedCount int
  1484. var newStatusList = make(map[int]int)
  1485. var newDisabledTime = make(map[int]int64)
  1486. var newDisabledReason = make(map[int]string)
  1487. newIndex := 0
  1488. for i, key := range keys {
  1489. status := 1 // default enabled
  1490. if channel.ChannelInfo.MultiKeyStatusList != nil {
  1491. if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
  1492. status = s
  1493. }
  1494. }
  1495. // 只删除自动禁用(status == 3)的密钥,保留启用(status == 1)和手动禁用(status == 2)的密钥
  1496. if status == 3 {
  1497. deletedCount++
  1498. } else {
  1499. remainingKeys = append(remainingKeys, key)
  1500. // 保留非自动禁用密钥的状态信息,重新索引
  1501. if status != 1 {
  1502. newStatusList[newIndex] = status
  1503. if channel.ChannelInfo.MultiKeyDisabledTime != nil {
  1504. if t, exists := channel.ChannelInfo.MultiKeyDisabledTime[i]; exists {
  1505. newDisabledTime[newIndex] = t
  1506. }
  1507. }
  1508. if channel.ChannelInfo.MultiKeyDisabledReason != nil {
  1509. if r, exists := channel.ChannelInfo.MultiKeyDisabledReason[i]; exists {
  1510. newDisabledReason[newIndex] = r
  1511. }
  1512. }
  1513. }
  1514. newIndex++
  1515. }
  1516. }
  1517. if deletedCount == 0 {
  1518. c.JSON(http.StatusOK, gin.H{
  1519. "success": false,
  1520. "message": "没有需要删除的自动禁用密钥",
  1521. })
  1522. return
  1523. }
  1524. // Update channel with remaining keys
  1525. channel.Key = strings.Join(remainingKeys, "\n")
  1526. channel.ChannelInfo.MultiKeySize = len(remainingKeys)
  1527. channel.ChannelInfo.MultiKeyStatusList = newStatusList
  1528. channel.ChannelInfo.MultiKeyDisabledTime = newDisabledTime
  1529. channel.ChannelInfo.MultiKeyDisabledReason = newDisabledReason
  1530. err = channel.Update()
  1531. if err != nil {
  1532. common.ApiError(c, err)
  1533. return
  1534. }
  1535. model.InitChannelCache()
  1536. c.JSON(http.StatusOK, gin.H{
  1537. "success": true,
  1538. "message": fmt.Sprintf("已删除 %d 个自动禁用的密钥", deletedCount),
  1539. "data": deletedCount,
  1540. })
  1541. return
  1542. default:
  1543. c.JSON(http.StatusOK, gin.H{
  1544. "success": false,
  1545. "message": "不支持的操作",
  1546. })
  1547. return
  1548. }
  1549. }
  1550. // OllamaPullModel 拉取 Ollama 模型
  1551. func OllamaPullModel(c *gin.Context) {
  1552. var req struct {
  1553. ChannelID int `json:"channel_id"`
  1554. ModelName string `json:"model_name"`
  1555. }
  1556. if err := c.ShouldBindJSON(&req); err != nil {
  1557. c.JSON(http.StatusBadRequest, gin.H{
  1558. "success": false,
  1559. "message": "Invalid request parameters",
  1560. })
  1561. return
  1562. }
  1563. if req.ChannelID == 0 || req.ModelName == "" {
  1564. c.JSON(http.StatusBadRequest, gin.H{
  1565. "success": false,
  1566. "message": "Channel ID and model name are required",
  1567. })
  1568. return
  1569. }
  1570. // 获取渠道信息
  1571. channel, err := model.GetChannelById(req.ChannelID, true)
  1572. if err != nil {
  1573. c.JSON(http.StatusNotFound, gin.H{
  1574. "success": false,
  1575. "message": "Channel not found",
  1576. })
  1577. return
  1578. }
  1579. // 检查是否是 Ollama 渠道
  1580. if channel.Type != constant.ChannelTypeOllama {
  1581. c.JSON(http.StatusBadRequest, gin.H{
  1582. "success": false,
  1583. "message": "This operation is only supported for Ollama channels",
  1584. })
  1585. return
  1586. }
  1587. baseURL := constant.ChannelBaseURLs[channel.Type]
  1588. if channel.GetBaseURL() != "" {
  1589. baseURL = channel.GetBaseURL()
  1590. }
  1591. key := strings.Split(channel.Key, "\n")[0]
  1592. err = ollama.PullOllamaModel(baseURL, key, req.ModelName)
  1593. if err != nil {
  1594. c.JSON(http.StatusInternalServerError, gin.H{
  1595. "success": false,
  1596. "message": fmt.Sprintf("Failed to pull model: %s", err.Error()),
  1597. })
  1598. return
  1599. }
  1600. c.JSON(http.StatusOK, gin.H{
  1601. "success": true,
  1602. "message": fmt.Sprintf("Model %s pulled successfully", req.ModelName),
  1603. })
  1604. }
  1605. // OllamaPullModelStream 流式拉取 Ollama 模型
  1606. func OllamaPullModelStream(c *gin.Context) {
  1607. var req struct {
  1608. ChannelID int `json:"channel_id"`
  1609. ModelName string `json:"model_name"`
  1610. }
  1611. if err := c.ShouldBindJSON(&req); err != nil {
  1612. c.JSON(http.StatusBadRequest, gin.H{
  1613. "success": false,
  1614. "message": "Invalid request parameters",
  1615. })
  1616. return
  1617. }
  1618. if req.ChannelID == 0 || req.ModelName == "" {
  1619. c.JSON(http.StatusBadRequest, gin.H{
  1620. "success": false,
  1621. "message": "Channel ID and model name are required",
  1622. })
  1623. return
  1624. }
  1625. // 获取渠道信息
  1626. channel, err := model.GetChannelById(req.ChannelID, true)
  1627. if err != nil {
  1628. c.JSON(http.StatusNotFound, gin.H{
  1629. "success": false,
  1630. "message": "Channel not found",
  1631. })
  1632. return
  1633. }
  1634. // 检查是否是 Ollama 渠道
  1635. if channel.Type != constant.ChannelTypeOllama {
  1636. c.JSON(http.StatusBadRequest, gin.H{
  1637. "success": false,
  1638. "message": "This operation is only supported for Ollama channels",
  1639. })
  1640. return
  1641. }
  1642. baseURL := constant.ChannelBaseURLs[channel.Type]
  1643. if channel.GetBaseURL() != "" {
  1644. baseURL = channel.GetBaseURL()
  1645. }
  1646. // 设置 SSE 头部
  1647. c.Header("Content-Type", "text/event-stream")
  1648. c.Header("Cache-Control", "no-cache")
  1649. c.Header("Connection", "keep-alive")
  1650. c.Header("Access-Control-Allow-Origin", "*")
  1651. key := strings.Split(channel.Key, "\n")[0]
  1652. // 创建进度回调函数
  1653. progressCallback := func(progress ollama.OllamaPullResponse) {
  1654. data, _ := json.Marshal(progress)
  1655. fmt.Fprintf(c.Writer, "data: %s\n\n", string(data))
  1656. c.Writer.Flush()
  1657. }
  1658. // 执行拉取
  1659. err = ollama.PullOllamaModelStream(baseURL, key, req.ModelName, progressCallback)
  1660. if err != nil {
  1661. errorData, _ := json.Marshal(gin.H{
  1662. "error": err.Error(),
  1663. })
  1664. fmt.Fprintf(c.Writer, "data: %s\n\n", string(errorData))
  1665. } else {
  1666. successData, _ := json.Marshal(gin.H{
  1667. "message": fmt.Sprintf("Model %s pulled successfully", req.ModelName),
  1668. })
  1669. fmt.Fprintf(c.Writer, "data: %s\n\n", string(successData))
  1670. }
  1671. // 发送结束标志
  1672. fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
  1673. c.Writer.Flush()
  1674. }
  1675. // OllamaDeleteModel 删除 Ollama 模型
  1676. func OllamaDeleteModel(c *gin.Context) {
  1677. var req struct {
  1678. ChannelID int `json:"channel_id"`
  1679. ModelName string `json:"model_name"`
  1680. }
  1681. if err := c.ShouldBindJSON(&req); err != nil {
  1682. c.JSON(http.StatusBadRequest, gin.H{
  1683. "success": false,
  1684. "message": "Invalid request parameters",
  1685. })
  1686. return
  1687. }
  1688. if req.ChannelID == 0 || req.ModelName == "" {
  1689. c.JSON(http.StatusBadRequest, gin.H{
  1690. "success": false,
  1691. "message": "Channel ID and model name are required",
  1692. })
  1693. return
  1694. }
  1695. // 获取渠道信息
  1696. channel, err := model.GetChannelById(req.ChannelID, true)
  1697. if err != nil {
  1698. c.JSON(http.StatusNotFound, gin.H{
  1699. "success": false,
  1700. "message": "Channel not found",
  1701. })
  1702. return
  1703. }
  1704. // 检查是否是 Ollama 渠道
  1705. if channel.Type != constant.ChannelTypeOllama {
  1706. c.JSON(http.StatusBadRequest, gin.H{
  1707. "success": false,
  1708. "message": "This operation is only supported for Ollama channels",
  1709. })
  1710. return
  1711. }
  1712. baseURL := constant.ChannelBaseURLs[channel.Type]
  1713. if channel.GetBaseURL() != "" {
  1714. baseURL = channel.GetBaseURL()
  1715. }
  1716. key := strings.Split(channel.Key, "\n")[0]
  1717. err = ollama.DeleteOllamaModel(baseURL, key, req.ModelName)
  1718. if err != nil {
  1719. c.JSON(http.StatusInternalServerError, gin.H{
  1720. "success": false,
  1721. "message": fmt.Sprintf("Failed to delete model: %s", err.Error()),
  1722. })
  1723. return
  1724. }
  1725. c.JSON(http.StatusOK, gin.H{
  1726. "success": true,
  1727. "message": fmt.Sprintf("Model %s deleted successfully", req.ModelName),
  1728. })
  1729. }
  1730. // OllamaVersion 获取 Ollama 服务版本信息
  1731. func OllamaVersion(c *gin.Context) {
  1732. id, err := strconv.Atoi(c.Param("id"))
  1733. if err != nil {
  1734. c.JSON(http.StatusBadRequest, gin.H{
  1735. "success": false,
  1736. "message": "Invalid channel id",
  1737. })
  1738. return
  1739. }
  1740. channel, err := model.GetChannelById(id, true)
  1741. if err != nil {
  1742. c.JSON(http.StatusNotFound, gin.H{
  1743. "success": false,
  1744. "message": "Channel not found",
  1745. })
  1746. return
  1747. }
  1748. if channel.Type != constant.ChannelTypeOllama {
  1749. c.JSON(http.StatusBadRequest, gin.H{
  1750. "success": false,
  1751. "message": "This operation is only supported for Ollama channels",
  1752. })
  1753. return
  1754. }
  1755. baseURL := constant.ChannelBaseURLs[channel.Type]
  1756. if channel.GetBaseURL() != "" {
  1757. baseURL = channel.GetBaseURL()
  1758. }
  1759. key := strings.Split(channel.Key, "\n")[0]
  1760. version, err := ollama.FetchOllamaVersion(baseURL, key)
  1761. if err != nil {
  1762. c.JSON(http.StatusOK, gin.H{
  1763. "success": false,
  1764. "message": fmt.Sprintf("获取Ollama版本失败: %s", err.Error()),
  1765. })
  1766. return
  1767. }
  1768. c.JSON(http.StatusOK, gin.H{
  1769. "success": true,
  1770. "data": gin.H{
  1771. "version": version,
  1772. },
  1773. })
  1774. }