channel.go 53 KB


  1. package controller
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "net/http"
  7. "strconv"
  8. "strings"
  9. "time"
  10. "github.com/QuantumNous/new-api/common"
  11. "github.com/QuantumNous/new-api/constant"
  12. "github.com/QuantumNous/new-api/dto"
  13. "github.com/QuantumNous/new-api/model"
  14. "github.com/QuantumNous/new-api/relay/channel/gemini"
  15. "github.com/QuantumNous/new-api/relay/channel/ollama"
  16. "github.com/QuantumNous/new-api/service"
  17. "github.com/gin-gonic/gin"
  18. )
  19. type OpenAIModel struct {
  20. ID string `json:"id"`
  21. Object string `json:"object"`
  22. Created int64 `json:"created"`
  23. OwnedBy string `json:"owned_by"`
  24. Metadata map[string]any `json:"metadata,omitempty"`
  25. Permission []struct {
  26. ID string `json:"id"`
  27. Object string `json:"object"`
  28. Created int64 `json:"created"`
  29. AllowCreateEngine bool `json:"allow_create_engine"`
  30. AllowSampling bool `json:"allow_sampling"`
  31. AllowLogprobs bool `json:"allow_logprobs"`
  32. AllowSearchIndices bool `json:"allow_search_indices"`
  33. AllowView bool `json:"allow_view"`
  34. AllowFineTuning bool `json:"allow_fine_tuning"`
  35. Organization string `json:"organization"`
  36. Group string `json:"group"`
  37. IsBlocking bool `json:"is_blocking"`
  38. } `json:"permission"`
  39. Root string `json:"root"`
  40. Parent string `json:"parent"`
  41. }
  42. type OpenAIModelsResponse struct {
  43. Data []OpenAIModel `json:"data"`
  44. Success bool `json:"success"`
  45. }
  46. func parseStatusFilter(statusParam string) int {
  47. switch strings.ToLower(statusParam) {
  48. case "enabled", "1":
  49. return common.ChannelStatusEnabled
  50. case "disabled", "0":
  51. return 0
  52. default:
  53. return -1
  54. }
  55. }
  56. func clearChannelInfo(channel *model.Channel) {
  57. if channel.ChannelInfo.IsMultiKey {
  58. channel.ChannelInfo.MultiKeyDisabledReason = nil
  59. channel.ChannelInfo.MultiKeyDisabledTime = nil
  60. }
  61. }
  62. func GetAllChannels(c *gin.Context) {
  63. pageInfo := common.GetPageQuery(c)
  64. channelData := make([]*model.Channel, 0)
  65. idSort, _ := strconv.ParseBool(c.Query("id_sort"))
  66. enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
  67. statusParam := c.Query("status")
  68. // statusFilter: -1 all, 1 enabled, 0 disabled (include auto & manual)
  69. statusFilter := parseStatusFilter(statusParam)
  70. // type filter
  71. typeStr := c.Query("type")
  72. typeFilter := -1
  73. if typeStr != "" {
  74. if t, err := strconv.Atoi(typeStr); err == nil {
  75. typeFilter = t
  76. }
  77. }
  78. var total int64
  79. if enableTagMode {
  80. tags, err := model.GetPaginatedTags(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
  81. if err != nil {
  82. c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
  83. return
  84. }
  85. for _, tag := range tags {
  86. if tag == nil || *tag == "" {
  87. continue
  88. }
  89. tagChannels, err := model.GetChannelsByTag(*tag, idSort, false)
  90. if err != nil {
  91. continue
  92. }
  93. filtered := make([]*model.Channel, 0)
  94. for _, ch := range tagChannels {
  95. if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
  96. continue
  97. }
  98. if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
  99. continue
  100. }
  101. if typeFilter >= 0 && ch.Type != typeFilter {
  102. continue
  103. }
  104. filtered = append(filtered, ch)
  105. }
  106. channelData = append(channelData, filtered...)
  107. }
  108. total, _ = model.CountAllTags()
  109. } else {
  110. baseQuery := model.DB.Model(&model.Channel{})
  111. if typeFilter >= 0 {
  112. baseQuery = baseQuery.Where("type = ?", typeFilter)
  113. }
  114. if statusFilter == common.ChannelStatusEnabled {
  115. baseQuery = baseQuery.Where("status = ?", common.ChannelStatusEnabled)
  116. } else if statusFilter == 0 {
  117. baseQuery = baseQuery.Where("status != ?", common.ChannelStatusEnabled)
  118. }
  119. baseQuery.Count(&total)
  120. order := "priority desc"
  121. if idSort {
  122. order = "id desc"
  123. }
  124. err := baseQuery.Order(order).Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("key").Find(&channelData).Error
  125. if err != nil {
  126. c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
  127. return
  128. }
  129. }
  130. for _, datum := range channelData {
  131. clearChannelInfo(datum)
  132. }
  133. countQuery := model.DB.Model(&model.Channel{})
  134. if statusFilter == common.ChannelStatusEnabled {
  135. countQuery = countQuery.Where("status = ?", common.ChannelStatusEnabled)
  136. } else if statusFilter == 0 {
  137. countQuery = countQuery.Where("status != ?", common.ChannelStatusEnabled)
  138. }
  139. var results []struct {
  140. Type int64
  141. Count int64
  142. }
  143. _ = countQuery.Select("type, count(*) as count").Group("type").Find(&results).Error
  144. typeCounts := make(map[int64]int64)
  145. for _, r := range results {
  146. typeCounts[r.Type] = r.Count
  147. }
  148. common.ApiSuccess(c, gin.H{
  149. "items": channelData,
  150. "total": total,
  151. "page": pageInfo.GetPage(),
  152. "page_size": pageInfo.GetPageSize(),
  153. "type_counts": typeCounts,
  154. })
  155. return
  156. }
  157. func buildFetchModelsHeaders(channel *model.Channel, key string) (http.Header, error) {
  158. var headers http.Header
  159. switch channel.Type {
  160. case constant.ChannelTypeAnthropic:
  161. headers = GetClaudeAuthHeader(key)
  162. default:
  163. headers = GetAuthHeader(key)
  164. }
  165. headerOverride := channel.GetHeaderOverride()
  166. for k, v := range headerOverride {
  167. str, ok := v.(string)
  168. if !ok {
  169. return nil, fmt.Errorf("invalid header override for key %s", k)
  170. }
  171. if strings.Contains(str, "{api_key}") {
  172. str = strings.ReplaceAll(str, "{api_key}", key)
  173. }
  174. headers.Set(k, str)
  175. }
  176. return headers, nil
  177. }
  178. func FetchUpstreamModels(c *gin.Context) {
  179. id, err := strconv.Atoi(c.Param("id"))
  180. if err != nil {
  181. common.ApiError(c, err)
  182. return
  183. }
  184. channel, err := model.GetChannelById(id, true)
  185. if err != nil {
  186. common.ApiError(c, err)
  187. return
  188. }
  189. baseURL := constant.ChannelBaseURLs[channel.Type]
  190. if channel.GetBaseURL() != "" {
  191. baseURL = channel.GetBaseURL()
  192. }
  193. // 对于 Ollama 渠道,使用特殊处理
  194. if channel.Type == constant.ChannelTypeOllama {
  195. key := strings.Split(channel.Key, "\n")[0]
  196. models, err := ollama.FetchOllamaModels(baseURL, key)
  197. if err != nil {
  198. c.JSON(http.StatusOK, gin.H{
  199. "success": false,
  200. "message": fmt.Sprintf("获取Ollama模型失败: %s", err.Error()),
  201. })
  202. return
  203. }
  204. result := OpenAIModelsResponse{
  205. Data: make([]OpenAIModel, 0, len(models)),
  206. }
  207. for _, modelInfo := range models {
  208. metadata := map[string]any{}
  209. if modelInfo.Size > 0 {
  210. metadata["size"] = modelInfo.Size
  211. }
  212. if modelInfo.Digest != "" {
  213. metadata["digest"] = modelInfo.Digest
  214. }
  215. if modelInfo.ModifiedAt != "" {
  216. metadata["modified_at"] = modelInfo.ModifiedAt
  217. }
  218. details := modelInfo.Details
  219. if details.ParentModel != "" || details.Format != "" || details.Family != "" || len(details.Families) > 0 || details.ParameterSize != "" || details.QuantizationLevel != "" {
  220. metadata["details"] = modelInfo.Details
  221. }
  222. if len(metadata) == 0 {
  223. metadata = nil
  224. }
  225. result.Data = append(result.Data, OpenAIModel{
  226. ID: modelInfo.Name,
  227. Object: "model",
  228. Created: 0,
  229. OwnedBy: "ollama",
  230. Metadata: metadata,
  231. })
  232. }
  233. c.JSON(http.StatusOK, gin.H{
  234. "success": true,
  235. "data": result.Data,
  236. })
  237. return
  238. }
  239. // 对于 Gemini 渠道,使用特殊处理
  240. if channel.Type == constant.ChannelTypeGemini {
  241. // 获取用于请求的可用密钥(多密钥渠道优先使用启用状态的密钥)
  242. key, _, apiErr := channel.GetNextEnabledKey()
  243. if apiErr != nil {
  244. c.JSON(http.StatusOK, gin.H{
  245. "success": false,
  246. "message": fmt.Sprintf("获取渠道密钥失败: %s", apiErr.Error()),
  247. })
  248. return
  249. }
  250. key = strings.TrimSpace(key)
  251. models, err := gemini.FetchGeminiModels(baseURL, key, channel.GetSetting().Proxy)
  252. if err != nil {
  253. c.JSON(http.StatusOK, gin.H{
  254. "success": false,
  255. "message": fmt.Sprintf("获取Gemini模型失败: %s", err.Error()),
  256. })
  257. return
  258. }
  259. c.JSON(http.StatusOK, gin.H{
  260. "success": true,
  261. "message": "",
  262. "data": models,
  263. })
  264. return
  265. }
  266. var url string
  267. switch channel.Type {
  268. case constant.ChannelTypeAli:
  269. url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
  270. case constant.ChannelTypeZhipu_v4:
  271. if plan, ok := constant.ChannelSpecialBases[baseURL]; ok && plan.OpenAIBaseURL != "" {
  272. url = fmt.Sprintf("%s/models", plan.OpenAIBaseURL)
  273. } else {
  274. url = fmt.Sprintf("%s/api/paas/v4/models", baseURL)
  275. }
  276. case constant.ChannelTypeVolcEngine:
  277. if plan, ok := constant.ChannelSpecialBases[baseURL]; ok && plan.OpenAIBaseURL != "" {
  278. url = fmt.Sprintf("%s/v1/models", plan.OpenAIBaseURL)
  279. } else {
  280. url = fmt.Sprintf("%s/v1/models", baseURL)
  281. }
  282. case constant.ChannelTypeMoonshot:
  283. if plan, ok := constant.ChannelSpecialBases[baseURL]; ok && plan.OpenAIBaseURL != "" {
  284. url = fmt.Sprintf("%s/models", plan.OpenAIBaseURL)
  285. } else {
  286. url = fmt.Sprintf("%s/v1/models", baseURL)
  287. }
  288. default:
  289. url = fmt.Sprintf("%s/v1/models", baseURL)
  290. }
  291. // 获取用于请求的可用密钥(多密钥渠道优先使用启用状态的密钥)
  292. key, _, apiErr := channel.GetNextEnabledKey()
  293. if apiErr != nil {
  294. c.JSON(http.StatusOK, gin.H{
  295. "success": false,
  296. "message": fmt.Sprintf("获取渠道密钥失败: %s", apiErr.Error()),
  297. })
  298. return
  299. }
  300. key = strings.TrimSpace(key)
  301. headers, err := buildFetchModelsHeaders(channel, key)
  302. if err != nil {
  303. common.ApiError(c, err)
  304. return
  305. }
  306. body, err := GetResponseBody("GET", url, channel, headers)
  307. if err != nil {
  308. common.ApiError(c, err)
  309. return
  310. }
  311. var result OpenAIModelsResponse
  312. if err = json.Unmarshal(body, &result); err != nil {
  313. c.JSON(http.StatusOK, gin.H{
  314. "success": false,
  315. "message": fmt.Sprintf("解析响应失败: %s", err.Error()),
  316. })
  317. return
  318. }
  319. var ids []string
  320. for _, model := range result.Data {
  321. id := model.ID
  322. if channel.Type == constant.ChannelTypeGemini {
  323. id = strings.TrimPrefix(id, "models/")
  324. }
  325. ids = append(ids, id)
  326. }
  327. c.JSON(http.StatusOK, gin.H{
  328. "success": true,
  329. "message": "",
  330. "data": ids,
  331. })
  332. }
  333. func FixChannelsAbilities(c *gin.Context) {
  334. success, fails, err := model.FixAbility()
  335. if err != nil {
  336. common.ApiError(c, err)
  337. return
  338. }
  339. c.JSON(http.StatusOK, gin.H{
  340. "success": true,
  341. "message": "",
  342. "data": gin.H{
  343. "success": success,
  344. "fails": fails,
  345. },
  346. })
  347. }
  348. func SearchChannels(c *gin.Context) {
  349. keyword := c.Query("keyword")
  350. group := c.Query("group")
  351. modelKeyword := c.Query("model")
  352. statusParam := c.Query("status")
  353. statusFilter := parseStatusFilter(statusParam)
  354. idSort, _ := strconv.ParseBool(c.Query("id_sort"))
  355. enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
  356. channelData := make([]*model.Channel, 0)
  357. if enableTagMode {
  358. tags, err := model.SearchTags(keyword, group, modelKeyword, idSort)
  359. if err != nil {
  360. c.JSON(http.StatusOK, gin.H{
  361. "success": false,
  362. "message": err.Error(),
  363. })
  364. return
  365. }
  366. for _, tag := range tags {
  367. if tag != nil && *tag != "" {
  368. tagChannel, err := model.GetChannelsByTag(*tag, idSort, false)
  369. if err == nil {
  370. channelData = append(channelData, tagChannel...)
  371. }
  372. }
  373. }
  374. } else {
  375. channels, err := model.SearchChannels(keyword, group, modelKeyword, idSort)
  376. if err != nil {
  377. c.JSON(http.StatusOK, gin.H{
  378. "success": false,
  379. "message": err.Error(),
  380. })
  381. return
  382. }
  383. channelData = channels
  384. }
  385. if statusFilter == common.ChannelStatusEnabled || statusFilter == 0 {
  386. filtered := make([]*model.Channel, 0, len(channelData))
  387. for _, ch := range channelData {
  388. if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
  389. continue
  390. }
  391. if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
  392. continue
  393. }
  394. filtered = append(filtered, ch)
  395. }
  396. channelData = filtered
  397. }
  398. // calculate type counts for search results
  399. typeCounts := make(map[int64]int64)
  400. for _, channel := range channelData {
  401. typeCounts[int64(channel.Type)]++
  402. }
  403. typeParam := c.Query("type")
  404. typeFilter := -1
  405. if typeParam != "" {
  406. if tp, err := strconv.Atoi(typeParam); err == nil {
  407. typeFilter = tp
  408. }
  409. }
  410. if typeFilter >= 0 {
  411. filtered := make([]*model.Channel, 0, len(channelData))
  412. for _, ch := range channelData {
  413. if ch.Type == typeFilter {
  414. filtered = append(filtered, ch)
  415. }
  416. }
  417. channelData = filtered
  418. }
  419. page, _ := strconv.Atoi(c.DefaultQuery("p", "1"))
  420. pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
  421. if page < 1 {
  422. page = 1
  423. }
  424. if pageSize <= 0 {
  425. pageSize = 20
  426. }
  427. total := len(channelData)
  428. startIdx := (page - 1) * pageSize
  429. if startIdx > total {
  430. startIdx = total
  431. }
  432. endIdx := startIdx + pageSize
  433. if endIdx > total {
  434. endIdx = total
  435. }
  436. pagedData := channelData[startIdx:endIdx]
  437. for _, datum := range pagedData {
  438. clearChannelInfo(datum)
  439. }
  440. c.JSON(http.StatusOK, gin.H{
  441. "success": true,
  442. "message": "",
  443. "data": gin.H{
  444. "items": pagedData,
  445. "total": total,
  446. "type_counts": typeCounts,
  447. },
  448. })
  449. return
  450. }
  451. func GetChannel(c *gin.Context) {
  452. id, err := strconv.Atoi(c.Param("id"))
  453. if err != nil {
  454. common.ApiError(c, err)
  455. return
  456. }
  457. channel, err := model.GetChannelById(id, false)
  458. if err != nil {
  459. common.ApiError(c, err)
  460. return
  461. }
  462. if channel != nil {
  463. clearChannelInfo(channel)
  464. }
  465. c.JSON(http.StatusOK, gin.H{
  466. "success": true,
  467. "message": "",
  468. "data": channel,
  469. })
  470. return
  471. }
  472. // GetChannelKey 获取渠道密钥(需要通过安全验证中间件)
  473. // 此函数依赖 SecureVerificationRequired 中间件,确保用户已通过安全验证
  474. func GetChannelKey(c *gin.Context) {
  475. userId := c.GetInt("id")
  476. channelId, err := strconv.Atoi(c.Param("id"))
  477. if err != nil {
  478. common.ApiError(c, fmt.Errorf("渠道ID格式错误: %v", err))
  479. return
  480. }
  481. // 获取渠道信息(包含密钥)
  482. channel, err := model.GetChannelById(channelId, true)
  483. if err != nil {
  484. common.ApiError(c, fmt.Errorf("获取渠道信息失败: %v", err))
  485. return
  486. }
  487. if channel == nil {
  488. common.ApiError(c, fmt.Errorf("渠道不存在"))
  489. return
  490. }
  491. // 记录操作日志
  492. model.RecordLog(userId, model.LogTypeSystem, fmt.Sprintf("查看渠道密钥信息 (渠道ID: %d)", channelId))
  493. // 返回渠道密钥
  494. c.JSON(http.StatusOK, gin.H{
  495. "success": true,
  496. "message": "获取成功",
  497. "data": map[string]interface{}{
  498. "key": channel.Key,
  499. },
  500. })
  501. }
  502. // validateTwoFactorAuth 统一的2FA验证函数
  503. func validateTwoFactorAuth(twoFA *model.TwoFA, code string) bool {
  504. // 尝试验证TOTP
  505. if cleanCode, err := common.ValidateNumericCode(code); err == nil {
  506. if isValid, _ := twoFA.ValidateTOTPAndUpdateUsage(cleanCode); isValid {
  507. return true
  508. }
  509. }
  510. // 尝试验证备用码
  511. if isValid, err := twoFA.ValidateBackupCodeAndUpdateUsage(code); err == nil && isValid {
  512. return true
  513. }
  514. return false
  515. }
  516. // validateChannel 通用的渠道校验函数
  517. func validateChannel(channel *model.Channel, isAdd bool) error {
  518. // 校验 channel settings
  519. if err := channel.ValidateSettings(); err != nil {
  520. return fmt.Errorf("渠道额外设置[channel setting] 格式错误:%s", err.Error())
  521. }
  522. // 如果是添加操作,检查 channel 和 key 是否为空
  523. if isAdd {
  524. if channel == nil || channel.Key == "" {
  525. return fmt.Errorf("channel cannot be empty")
  526. }
  527. // 检查模型名称长度是否超过 255
  528. for _, m := range channel.GetModels() {
  529. if len(m) > 255 {
  530. return fmt.Errorf("模型名称过长: %s", m)
  531. }
  532. }
  533. }
  534. // VertexAI 特殊校验
  535. if channel.Type == constant.ChannelTypeVertexAi {
  536. if channel.Other == "" {
  537. return fmt.Errorf("部署地区不能为空")
  538. }
  539. regionMap, err := common.StrToMap(channel.Other)
  540. if err != nil {
  541. return fmt.Errorf("部署地区必须是标准的Json格式,例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}")
  542. }
  543. if regionMap["default"] == nil {
  544. return fmt.Errorf("部署地区必须包含default字段")
  545. }
  546. }
  547. // Codex OAuth key validation (optional, only when JSON object is provided)
  548. if channel.Type == constant.ChannelTypeCodex {
  549. trimmedKey := strings.TrimSpace(channel.Key)
  550. if isAdd || trimmedKey != "" {
  551. if !strings.HasPrefix(trimmedKey, "{") {
  552. return fmt.Errorf("Codex key must be a valid JSON object")
  553. }
  554. var keyMap map[string]any
  555. if err := common.Unmarshal([]byte(trimmedKey), &keyMap); err != nil {
  556. return fmt.Errorf("Codex key must be a valid JSON object")
  557. }
  558. if v, ok := keyMap["access_token"]; !ok || v == nil || strings.TrimSpace(fmt.Sprintf("%v", v)) == "" {
  559. return fmt.Errorf("Codex key JSON must include access_token")
  560. }
  561. if v, ok := keyMap["account_id"]; !ok || v == nil || strings.TrimSpace(fmt.Sprintf("%v", v)) == "" {
  562. return fmt.Errorf("Codex key JSON must include account_id")
  563. }
  564. }
  565. }
  566. return nil
  567. }
  568. func RefreshCodexChannelCredential(c *gin.Context) {
  569. channelId, err := strconv.Atoi(c.Param("id"))
  570. if err != nil {
  571. common.ApiError(c, fmt.Errorf("invalid channel id: %w", err))
  572. return
  573. }
  574. ctx, cancel := context.WithTimeout(c.Request.Context(), 10*time.Second)
  575. defer cancel()
  576. oauthKey, ch, err := service.RefreshCodexChannelCredential(ctx, channelId, service.CodexCredentialRefreshOptions{ResetCaches: true})
  577. if err != nil {
  578. c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
  579. return
  580. }
  581. c.JSON(http.StatusOK, gin.H{
  582. "success": true,
  583. "message": "refreshed",
  584. "data": gin.H{
  585. "expires_at": oauthKey.Expired,
  586. "last_refresh": oauthKey.LastRefresh,
  587. "account_id": oauthKey.AccountID,
  588. "email": oauthKey.Email,
  589. "channel_id": ch.Id,
  590. "channel_type": ch.Type,
  591. "channel_name": ch.Name,
  592. },
  593. })
  594. }
  595. type AddChannelRequest struct {
  596. Mode string `json:"mode"`
  597. MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
  598. BatchAddSetKeyPrefix2Name bool `json:"batch_add_set_key_prefix_2_name"`
  599. Channel *model.Channel `json:"channel"`
  600. }
  601. func getVertexArrayKeys(keys string) ([]string, error) {
  602. if keys == "" {
  603. return nil, nil
  604. }
  605. var keyArray []interface{}
  606. err := common.Unmarshal([]byte(keys), &keyArray)
  607. if err != nil {
  608. return nil, fmt.Errorf("批量添加 Vertex AI 必须使用标准的JsonArray格式,例如[{key1}, {key2}...],请检查输入: %w", err)
  609. }
  610. cleanKeys := make([]string, 0, len(keyArray))
  611. for _, key := range keyArray {
  612. var keyStr string
  613. switch v := key.(type) {
  614. case string:
  615. keyStr = strings.TrimSpace(v)
  616. default:
  617. bytes, err := json.Marshal(v)
  618. if err != nil {
  619. return nil, fmt.Errorf("Vertex AI key JSON 编码失败: %w", err)
  620. }
  621. keyStr = string(bytes)
  622. }
  623. if keyStr != "" {
  624. cleanKeys = append(cleanKeys, keyStr)
  625. }
  626. }
  627. if len(cleanKeys) == 0 {
  628. return nil, fmt.Errorf("批量添加 Vertex AI 的 keys 不能为空")
  629. }
  630. return cleanKeys, nil
  631. }
  632. func AddChannel(c *gin.Context) {
  633. addChannelRequest := AddChannelRequest{}
  634. err := c.ShouldBindJSON(&addChannelRequest)
  635. if err != nil {
  636. common.ApiError(c, err)
  637. return
  638. }
  639. // 使用统一的校验函数
  640. if err := validateChannel(addChannelRequest.Channel, true); err != nil {
  641. c.JSON(http.StatusOK, gin.H{
  642. "success": false,
  643. "message": err.Error(),
  644. })
  645. return
  646. }
  647. addChannelRequest.Channel.CreatedTime = common.GetTimestamp()
  648. keys := make([]string, 0)
  649. switch addChannelRequest.Mode {
  650. case "multi_to_single":
  651. addChannelRequest.Channel.ChannelInfo.IsMultiKey = true
  652. addChannelRequest.Channel.ChannelInfo.MultiKeyMode = addChannelRequest.MultiKeyMode
  653. if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi && addChannelRequest.Channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
  654. array, err := getVertexArrayKeys(addChannelRequest.Channel.Key)
  655. if err != nil {
  656. c.JSON(http.StatusOK, gin.H{
  657. "success": false,
  658. "message": err.Error(),
  659. })
  660. return
  661. }
  662. addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(array)
  663. addChannelRequest.Channel.Key = strings.Join(array, "\n")
  664. } else {
  665. cleanKeys := make([]string, 0)
  666. for _, key := range strings.Split(addChannelRequest.Channel.Key, "\n") {
  667. if key == "" {
  668. continue
  669. }
  670. key = strings.TrimSpace(key)
  671. cleanKeys = append(cleanKeys, key)
  672. }
  673. addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(cleanKeys)
  674. addChannelRequest.Channel.Key = strings.Join(cleanKeys, "\n")
  675. }
  676. keys = []string{addChannelRequest.Channel.Key}
  677. case "batch":
  678. if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi && addChannelRequest.Channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
  679. // multi json
  680. keys, err = getVertexArrayKeys(addChannelRequest.Channel.Key)
  681. if err != nil {
  682. c.JSON(http.StatusOK, gin.H{
  683. "success": false,
  684. "message": err.Error(),
  685. })
  686. return
  687. }
  688. } else {
  689. keys = strings.Split(addChannelRequest.Channel.Key, "\n")
  690. }
  691. case "single":
  692. keys = []string{addChannelRequest.Channel.Key}
  693. default:
  694. c.JSON(http.StatusOK, gin.H{
  695. "success": false,
  696. "message": "不支持的添加模式",
  697. })
  698. return
  699. }
  700. channels := make([]model.Channel, 0, len(keys))
  701. for _, key := range keys {
  702. if key == "" {
  703. continue
  704. }
  705. localChannel := addChannelRequest.Channel
  706. localChannel.Key = key
  707. if addChannelRequest.BatchAddSetKeyPrefix2Name && len(keys) > 1 {
  708. keyPrefix := localChannel.Key
  709. if len(localChannel.Key) > 8 {
  710. keyPrefix = localChannel.Key[:8]
  711. }
  712. localChannel.Name = fmt.Sprintf("%s %s", localChannel.Name, keyPrefix)
  713. }
  714. channels = append(channels, *localChannel)
  715. }
  716. err = model.BatchInsertChannels(channels)
  717. if err != nil {
  718. common.ApiError(c, err)
  719. return
  720. }
  721. service.ResetProxyClientCache()
  722. c.JSON(http.StatusOK, gin.H{
  723. "success": true,
  724. "message": "",
  725. })
  726. return
  727. }
  728. func DeleteChannel(c *gin.Context) {
  729. id, _ := strconv.Atoi(c.Param("id"))
  730. channel := model.Channel{Id: id}
  731. err := channel.Delete()
  732. if err != nil {
  733. common.ApiError(c, err)
  734. return
  735. }
  736. model.InitChannelCache()
  737. c.JSON(http.StatusOK, gin.H{
  738. "success": true,
  739. "message": "",
  740. })
  741. return
  742. }
  743. func DeleteDisabledChannel(c *gin.Context) {
  744. rows, err := model.DeleteDisabledChannel()
  745. if err != nil {
  746. common.ApiError(c, err)
  747. return
  748. }
  749. model.InitChannelCache()
  750. c.JSON(http.StatusOK, gin.H{
  751. "success": true,
  752. "message": "",
  753. "data": rows,
  754. })
  755. return
  756. }
  757. type ChannelTag struct {
  758. Tag string `json:"tag"`
  759. NewTag *string `json:"new_tag"`
  760. Priority *int64 `json:"priority"`
  761. Weight *uint `json:"weight"`
  762. ModelMapping *string `json:"model_mapping"`
  763. Models *string `json:"models"`
  764. Groups *string `json:"groups"`
  765. ParamOverride *string `json:"param_override"`
  766. HeaderOverride *string `json:"header_override"`
  767. }
  768. func DisableTagChannels(c *gin.Context) {
  769. channelTag := ChannelTag{}
  770. err := c.ShouldBindJSON(&channelTag)
  771. if err != nil || channelTag.Tag == "" {
  772. c.JSON(http.StatusOK, gin.H{
  773. "success": false,
  774. "message": "参数错误",
  775. })
  776. return
  777. }
  778. err = model.DisableChannelByTag(channelTag.Tag)
  779. if err != nil {
  780. common.ApiError(c, err)
  781. return
  782. }
  783. model.InitChannelCache()
  784. c.JSON(http.StatusOK, gin.H{
  785. "success": true,
  786. "message": "",
  787. })
  788. return
  789. }
  790. func EnableTagChannels(c *gin.Context) {
  791. channelTag := ChannelTag{}
  792. err := c.ShouldBindJSON(&channelTag)
  793. if err != nil || channelTag.Tag == "" {
  794. c.JSON(http.StatusOK, gin.H{
  795. "success": false,
  796. "message": "参数错误",
  797. })
  798. return
  799. }
  800. err = model.EnableChannelByTag(channelTag.Tag)
  801. if err != nil {
  802. common.ApiError(c, err)
  803. return
  804. }
  805. model.InitChannelCache()
  806. c.JSON(http.StatusOK, gin.H{
  807. "success": true,
  808. "message": "",
  809. })
  810. return
  811. }
  812. func EditTagChannels(c *gin.Context) {
  813. channelTag := ChannelTag{}
  814. err := c.ShouldBindJSON(&channelTag)
  815. if err != nil {
  816. c.JSON(http.StatusOK, gin.H{
  817. "success": false,
  818. "message": "参数错误",
  819. })
  820. return
  821. }
  822. if channelTag.Tag == "" {
  823. c.JSON(http.StatusOK, gin.H{
  824. "success": false,
  825. "message": "tag不能为空",
  826. })
  827. return
  828. }
  829. if channelTag.ParamOverride != nil {
  830. trimmed := strings.TrimSpace(*channelTag.ParamOverride)
  831. if trimmed != "" && !json.Valid([]byte(trimmed)) {
  832. c.JSON(http.StatusOK, gin.H{
  833. "success": false,
  834. "message": "参数覆盖必须是合法的 JSON 格式",
  835. })
  836. return
  837. }
  838. channelTag.ParamOverride = common.GetPointer[string](trimmed)
  839. }
  840. if channelTag.HeaderOverride != nil {
  841. trimmed := strings.TrimSpace(*channelTag.HeaderOverride)
  842. if trimmed != "" && !json.Valid([]byte(trimmed)) {
  843. c.JSON(http.StatusOK, gin.H{
  844. "success": false,
  845. "message": "请求头覆盖必须是合法的 JSON 格式",
  846. })
  847. return
  848. }
  849. channelTag.HeaderOverride = common.GetPointer[string](trimmed)
  850. }
  851. err = model.EditChannelByTag(channelTag.Tag, channelTag.NewTag, channelTag.ModelMapping, channelTag.Models, channelTag.Groups, channelTag.Priority, channelTag.Weight, channelTag.ParamOverride, channelTag.HeaderOverride)
  852. if err != nil {
  853. common.ApiError(c, err)
  854. return
  855. }
  856. model.InitChannelCache()
  857. c.JSON(http.StatusOK, gin.H{
  858. "success": true,
  859. "message": "",
  860. })
  861. return
  862. }
  863. type ChannelBatch struct {
  864. Ids []int `json:"ids"`
  865. Tag *string `json:"tag"`
  866. }
  867. func DeleteChannelBatch(c *gin.Context) {
  868. channelBatch := ChannelBatch{}
  869. err := c.ShouldBindJSON(&channelBatch)
  870. if err != nil || len(channelBatch.Ids) == 0 {
  871. c.JSON(http.StatusOK, gin.H{
  872. "success": false,
  873. "message": "参数错误",
  874. })
  875. return
  876. }
  877. err = model.BatchDeleteChannels(channelBatch.Ids)
  878. if err != nil {
  879. common.ApiError(c, err)
  880. return
  881. }
  882. model.InitChannelCache()
  883. c.JSON(http.StatusOK, gin.H{
  884. "success": true,
  885. "message": "",
  886. "data": len(channelBatch.Ids),
  887. })
  888. return
  889. }
  890. type PatchChannel struct {
  891. model.Channel
  892. MultiKeyMode *string `json:"multi_key_mode"`
  893. KeyMode *string `json:"key_mode"` // 多key模式下密钥覆盖或者追加
  894. }
  895. func UpdateChannel(c *gin.Context) {
  896. channel := PatchChannel{}
  897. err := c.ShouldBindJSON(&channel)
  898. if err != nil {
  899. common.ApiError(c, err)
  900. return
  901. }
  902. // 使用统一的校验函数
  903. if err := validateChannel(&channel.Channel, false); err != nil {
  904. c.JSON(http.StatusOK, gin.H{
  905. "success": false,
  906. "message": err.Error(),
  907. })
  908. return
  909. }
  910. // Preserve existing ChannelInfo to ensure multi-key channels keep correct state even if the client does not send ChannelInfo in the request.
  911. originChannel, err := model.GetChannelById(channel.Id, true)
  912. if err != nil {
  913. c.JSON(http.StatusOK, gin.H{
  914. "success": false,
  915. "message": err.Error(),
  916. })
  917. return
  918. }
  919. // Always copy the original ChannelInfo so that fields like IsMultiKey and MultiKeySize are retained.
  920. channel.ChannelInfo = originChannel.ChannelInfo
  921. // If the request explicitly specifies a new MultiKeyMode, apply it on top of the original info.
  922. if channel.MultiKeyMode != nil && *channel.MultiKeyMode != "" {
  923. channel.ChannelInfo.MultiKeyMode = constant.MultiKeyMode(*channel.MultiKeyMode)
  924. }
  925. // 处理多key模式下的密钥追加/覆盖逻辑
  926. if channel.KeyMode != nil && channel.ChannelInfo.IsMultiKey {
  927. switch *channel.KeyMode {
  928. case "append":
  929. // 追加模式:将新密钥添加到现有密钥列表
  930. if originChannel.Key != "" {
  931. var newKeys []string
  932. var existingKeys []string
  933. // 解析现有密钥
  934. if strings.HasPrefix(strings.TrimSpace(originChannel.Key), "[") {
  935. // JSON数组格式
  936. var arr []json.RawMessage
  937. if err := json.Unmarshal([]byte(strings.TrimSpace(originChannel.Key)), &arr); err == nil {
  938. existingKeys = make([]string, len(arr))
  939. for i, v := range arr {
  940. existingKeys[i] = string(v)
  941. }
  942. }
  943. } else {
  944. // 换行分隔格式
  945. existingKeys = strings.Split(strings.Trim(originChannel.Key, "\n"), "\n")
  946. }
  947. // 处理 Vertex AI 的特殊情况
  948. if channel.Type == constant.ChannelTypeVertexAi && channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
  949. // 尝试解析新密钥为JSON数组
  950. if strings.HasPrefix(strings.TrimSpace(channel.Key), "[") {
  951. array, err := getVertexArrayKeys(channel.Key)
  952. if err != nil {
  953. c.JSON(http.StatusOK, gin.H{
  954. "success": false,
  955. "message": "追加密钥解析失败: " + err.Error(),
  956. })
  957. return
  958. }
  959. newKeys = array
  960. } else {
  961. // 单个JSON密钥
  962. newKeys = []string{channel.Key}
  963. }
  964. } else {
  965. // 普通渠道的处理
  966. inputKeys := strings.Split(channel.Key, "\n")
  967. for _, key := range inputKeys {
  968. key = strings.TrimSpace(key)
  969. if key != "" {
  970. newKeys = append(newKeys, key)
  971. }
  972. }
  973. }
  974. seen := make(map[string]struct{}, len(existingKeys)+len(newKeys))
  975. for _, key := range existingKeys {
  976. normalized := strings.TrimSpace(key)
  977. if normalized == "" {
  978. continue
  979. }
  980. seen[normalized] = struct{}{}
  981. }
  982. dedupedNewKeys := make([]string, 0, len(newKeys))
  983. for _, key := range newKeys {
  984. normalized := strings.TrimSpace(key)
  985. if normalized == "" {
  986. continue
  987. }
  988. if _, ok := seen[normalized]; ok {
  989. continue
  990. }
  991. seen[normalized] = struct{}{}
  992. dedupedNewKeys = append(dedupedNewKeys, normalized)
  993. }
  994. allKeys := append(existingKeys, dedupedNewKeys...)
  995. channel.Key = strings.Join(allKeys, "\n")
  996. }
  997. case "replace":
  998. // 覆盖模式:直接使用新密钥(默认行为,不需要特殊处理)
  999. }
  1000. }
  1001. err = channel.Update()
  1002. if err != nil {
  1003. common.ApiError(c, err)
  1004. return
  1005. }
  1006. model.InitChannelCache()
  1007. service.ResetProxyClientCache()
  1008. channel.Key = ""
  1009. clearChannelInfo(&channel.Channel)
  1010. c.JSON(http.StatusOK, gin.H{
  1011. "success": true,
  1012. "message": "",
  1013. "data": channel,
  1014. })
  1015. return
  1016. }
  1017. func FetchModels(c *gin.Context) {
  1018. var req struct {
  1019. BaseURL string `json:"base_url"`
  1020. Type int `json:"type"`
  1021. Key string `json:"key"`
  1022. }
  1023. if err := c.ShouldBindJSON(&req); err != nil {
  1024. c.JSON(http.StatusBadRequest, gin.H{
  1025. "success": false,
  1026. "message": "Invalid request",
  1027. })
  1028. return
  1029. }
  1030. baseURL := req.BaseURL
  1031. if baseURL == "" {
  1032. baseURL = constant.ChannelBaseURLs[req.Type]
  1033. }
  1034. // remove line breaks and extra spaces.
  1035. key := strings.TrimSpace(req.Key)
  1036. key = strings.Split(key, "\n")[0]
  1037. if req.Type == constant.ChannelTypeOllama {
  1038. models, err := ollama.FetchOllamaModels(baseURL, key)
  1039. if err != nil {
  1040. c.JSON(http.StatusOK, gin.H{
  1041. "success": false,
  1042. "message": fmt.Sprintf("获取Ollama模型失败: %s", err.Error()),
  1043. })
  1044. return
  1045. }
  1046. names := make([]string, 0, len(models))
  1047. for _, modelInfo := range models {
  1048. names = append(names, modelInfo.Name)
  1049. }
  1050. c.JSON(http.StatusOK, gin.H{
  1051. "success": true,
  1052. "data": names,
  1053. })
  1054. return
  1055. }
  1056. if req.Type == constant.ChannelTypeGemini {
  1057. models, err := gemini.FetchGeminiModels(baseURL, key, "")
  1058. if err != nil {
  1059. c.JSON(http.StatusOK, gin.H{
  1060. "success": false,
  1061. "message": fmt.Sprintf("获取Gemini模型失败: %s", err.Error()),
  1062. })
  1063. return
  1064. }
  1065. c.JSON(http.StatusOK, gin.H{
  1066. "success": true,
  1067. "data": models,
  1068. })
  1069. return
  1070. }
  1071. client := &http.Client{}
  1072. url := fmt.Sprintf("%s/v1/models", baseURL)
  1073. request, err := http.NewRequest("GET", url, nil)
  1074. if err != nil {
  1075. c.JSON(http.StatusInternalServerError, gin.H{
  1076. "success": false,
  1077. "message": err.Error(),
  1078. })
  1079. return
  1080. }
  1081. request.Header.Set("Authorization", "Bearer "+key)
  1082. response, err := client.Do(request)
  1083. if err != nil {
  1084. c.JSON(http.StatusInternalServerError, gin.H{
  1085. "success": false,
  1086. "message": err.Error(),
  1087. })
  1088. return
  1089. }
  1090. //check status code
  1091. if response.StatusCode != http.StatusOK {
  1092. c.JSON(http.StatusInternalServerError, gin.H{
  1093. "success": false,
  1094. "message": "Failed to fetch models",
  1095. })
  1096. return
  1097. }
  1098. defer response.Body.Close()
  1099. var result struct {
  1100. Data []struct {
  1101. ID string `json:"id"`
  1102. } `json:"data"`
  1103. }
  1104. if err := json.NewDecoder(response.Body).Decode(&result); err != nil {
  1105. c.JSON(http.StatusInternalServerError, gin.H{
  1106. "success": false,
  1107. "message": err.Error(),
  1108. })
  1109. return
  1110. }
  1111. var models []string
  1112. for _, model := range result.Data {
  1113. models = append(models, model.ID)
  1114. }
  1115. c.JSON(http.StatusOK, gin.H{
  1116. "success": true,
  1117. "data": models,
  1118. })
  1119. }
  1120. func BatchSetChannelTag(c *gin.Context) {
  1121. channelBatch := ChannelBatch{}
  1122. err := c.ShouldBindJSON(&channelBatch)
  1123. if err != nil || len(channelBatch.Ids) == 0 {
  1124. c.JSON(http.StatusOK, gin.H{
  1125. "success": false,
  1126. "message": "参数错误",
  1127. })
  1128. return
  1129. }
  1130. err = model.BatchSetChannelTag(channelBatch.Ids, channelBatch.Tag)
  1131. if err != nil {
  1132. common.ApiError(c, err)
  1133. return
  1134. }
  1135. model.InitChannelCache()
  1136. c.JSON(http.StatusOK, gin.H{
  1137. "success": true,
  1138. "message": "",
  1139. "data": len(channelBatch.Ids),
  1140. })
  1141. return
  1142. }
  1143. func GetTagModels(c *gin.Context) {
  1144. tag := c.Query("tag")
  1145. if tag == "" {
  1146. c.JSON(http.StatusBadRequest, gin.H{
  1147. "success": false,
  1148. "message": "tag不能为空",
  1149. })
  1150. return
  1151. }
  1152. channels, err := model.GetChannelsByTag(tag, false, false) // idSort=false, selectAll=false
  1153. if err != nil {
  1154. c.JSON(http.StatusInternalServerError, gin.H{
  1155. "success": false,
  1156. "message": err.Error(),
  1157. })
  1158. return
  1159. }
  1160. var longestModels string
  1161. maxLength := 0
  1162. // Find the longest models string among all channels with the given tag
  1163. for _, channel := range channels {
  1164. if channel.Models != "" {
  1165. currentModels := strings.Split(channel.Models, ",")
  1166. if len(currentModels) > maxLength {
  1167. maxLength = len(currentModels)
  1168. longestModels = channel.Models
  1169. }
  1170. }
  1171. }
  1172. c.JSON(http.StatusOK, gin.H{
  1173. "success": true,
  1174. "message": "",
  1175. "data": longestModels,
  1176. })
  1177. return
  1178. }
  1179. // CopyChannel handles cloning an existing channel with its key.
  1180. // POST /api/channel/copy/:id
  1181. // Optional query params:
  1182. //
  1183. // suffix - string appended to the original name (default "_复制")
  1184. // reset_balance - bool, when true will reset balance & used_quota to 0 (default true)
  1185. func CopyChannel(c *gin.Context) {
  1186. id, err := strconv.Atoi(c.Param("id"))
  1187. if err != nil {
  1188. c.JSON(http.StatusOK, gin.H{"success": false, "message": "invalid id"})
  1189. return
  1190. }
  1191. suffix := c.DefaultQuery("suffix", "_复制")
  1192. resetBalance := true
  1193. if rbStr := c.DefaultQuery("reset_balance", "true"); rbStr != "" {
  1194. if v, err := strconv.ParseBool(rbStr); err == nil {
  1195. resetBalance = v
  1196. }
  1197. }
  1198. // fetch original channel with key
  1199. origin, err := model.GetChannelById(id, true)
  1200. if err != nil {
  1201. c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
  1202. return
  1203. }
  1204. // clone channel
  1205. clone := *origin // shallow copy is sufficient as we will overwrite primitives
  1206. clone.Id = 0 // let DB auto-generate
  1207. clone.CreatedTime = common.GetTimestamp()
  1208. clone.Name = origin.Name + suffix
  1209. clone.TestTime = 0
  1210. clone.ResponseTime = 0
  1211. if resetBalance {
  1212. clone.Balance = 0
  1213. clone.UsedQuota = 0
  1214. }
  1215. // insert
  1216. if err := model.BatchInsertChannels([]model.Channel{clone}); err != nil {
  1217. c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
  1218. return
  1219. }
  1220. model.InitChannelCache()
  1221. // success
  1222. c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": gin.H{"id": clone.Id}})
  1223. }
  1224. // MultiKeyManageRequest represents the request for multi-key management operations
  1225. type MultiKeyManageRequest struct {
  1226. ChannelId int `json:"channel_id"`
  1227. Action string `json:"action"` // "disable_key", "enable_key", "delete_key", "delete_disabled_keys", "get_key_status"
  1228. KeyIndex *int `json:"key_index,omitempty"` // for disable_key, enable_key, and delete_key actions
  1229. Page int `json:"page,omitempty"` // for get_key_status pagination
  1230. PageSize int `json:"page_size,omitempty"` // for get_key_status pagination
  1231. Status *int `json:"status,omitempty"` // for get_key_status filtering: 1=enabled, 2=manual_disabled, 3=auto_disabled, nil=all
  1232. }
  1233. // MultiKeyStatusResponse represents the response for key status query
  1234. type MultiKeyStatusResponse struct {
  1235. Keys []KeyStatus `json:"keys"`
  1236. Total int `json:"total"`
  1237. Page int `json:"page"`
  1238. PageSize int `json:"page_size"`
  1239. TotalPages int `json:"total_pages"`
  1240. // Statistics
  1241. EnabledCount int `json:"enabled_count"`
  1242. ManualDisabledCount int `json:"manual_disabled_count"`
  1243. AutoDisabledCount int `json:"auto_disabled_count"`
  1244. }
  1245. type KeyStatus struct {
  1246. Index int `json:"index"`
  1247. Status int `json:"status"` // 1: enabled, 2: disabled
  1248. DisabledTime int64 `json:"disabled_time,omitempty"`
  1249. Reason string `json:"reason,omitempty"`
  1250. KeyPreview string `json:"key_preview"` // first 10 chars of key for identification
  1251. }
  1252. // ManageMultiKeys handles multi-key management operations
  1253. func ManageMultiKeys(c *gin.Context) {
  1254. request := MultiKeyManageRequest{}
  1255. err := c.ShouldBindJSON(&request)
  1256. if err != nil {
  1257. common.ApiError(c, err)
  1258. return
  1259. }
  1260. channel, err := model.GetChannelById(request.ChannelId, true)
  1261. if err != nil {
  1262. c.JSON(http.StatusOK, gin.H{
  1263. "success": false,
  1264. "message": "渠道不存在",
  1265. })
  1266. return
  1267. }
  1268. if !channel.ChannelInfo.IsMultiKey {
  1269. c.JSON(http.StatusOK, gin.H{
  1270. "success": false,
  1271. "message": "该渠道不是多密钥模式",
  1272. })
  1273. return
  1274. }
  1275. lock := model.GetChannelPollingLock(channel.Id)
  1276. lock.Lock()
  1277. defer lock.Unlock()
  1278. switch request.Action {
  1279. case "get_key_status":
  1280. keys := channel.GetKeys()
  1281. // Default pagination parameters
  1282. page := request.Page
  1283. pageSize := request.PageSize
  1284. if page <= 0 {
  1285. page = 1
  1286. }
  1287. if pageSize <= 0 {
  1288. pageSize = 50 // Default page size
  1289. }
  1290. // Statistics for all keys (unchanged by filtering)
  1291. var enabledCount, manualDisabledCount, autoDisabledCount int
  1292. // Build all key status data first
  1293. var allKeyStatusList []KeyStatus
  1294. for i, key := range keys {
  1295. status := 1 // default enabled
  1296. var disabledTime int64
  1297. var reason string
  1298. if channel.ChannelInfo.MultiKeyStatusList != nil {
  1299. if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
  1300. status = s
  1301. }
  1302. }
  1303. // Count for statistics (all keys)
  1304. switch status {
  1305. case 1:
  1306. enabledCount++
  1307. case 2:
  1308. manualDisabledCount++
  1309. case 3:
  1310. autoDisabledCount++
  1311. }
  1312. if status != 1 {
  1313. if channel.ChannelInfo.MultiKeyDisabledTime != nil {
  1314. disabledTime = channel.ChannelInfo.MultiKeyDisabledTime[i]
  1315. }
  1316. if channel.ChannelInfo.MultiKeyDisabledReason != nil {
  1317. reason = channel.ChannelInfo.MultiKeyDisabledReason[i]
  1318. }
  1319. }
  1320. // Create key preview (first 10 chars)
  1321. keyPreview := key
  1322. if len(key) > 10 {
  1323. keyPreview = key[:10] + "..."
  1324. }
  1325. allKeyStatusList = append(allKeyStatusList, KeyStatus{
  1326. Index: i,
  1327. Status: status,
  1328. DisabledTime: disabledTime,
  1329. Reason: reason,
  1330. KeyPreview: keyPreview,
  1331. })
  1332. }
  1333. // Apply status filter if specified
  1334. var filteredKeyStatusList []KeyStatus
  1335. if request.Status != nil {
  1336. for _, keyStatus := range allKeyStatusList {
  1337. if keyStatus.Status == *request.Status {
  1338. filteredKeyStatusList = append(filteredKeyStatusList, keyStatus)
  1339. }
  1340. }
  1341. } else {
  1342. filteredKeyStatusList = allKeyStatusList
  1343. }
  1344. // Calculate pagination based on filtered results
  1345. filteredTotal := len(filteredKeyStatusList)
  1346. totalPages := (filteredTotal + pageSize - 1) / pageSize
  1347. if totalPages == 0 {
  1348. totalPages = 1
  1349. }
  1350. if page > totalPages {
  1351. page = totalPages
  1352. }
  1353. // Calculate range for current page
  1354. start := (page - 1) * pageSize
  1355. end := start + pageSize
  1356. if end > filteredTotal {
  1357. end = filteredTotal
  1358. }
  1359. // Get the page data
  1360. var pageKeyStatusList []KeyStatus
  1361. if start < filteredTotal {
  1362. pageKeyStatusList = filteredKeyStatusList[start:end]
  1363. }
  1364. c.JSON(http.StatusOK, gin.H{
  1365. "success": true,
  1366. "message": "",
  1367. "data": MultiKeyStatusResponse{
  1368. Keys: pageKeyStatusList,
  1369. Total: filteredTotal, // Total of filtered results
  1370. Page: page,
  1371. PageSize: pageSize,
  1372. TotalPages: totalPages,
  1373. EnabledCount: enabledCount, // Overall statistics
  1374. ManualDisabledCount: manualDisabledCount, // Overall statistics
  1375. AutoDisabledCount: autoDisabledCount, // Overall statistics
  1376. },
  1377. })
  1378. return
  1379. case "disable_key":
  1380. if request.KeyIndex == nil {
  1381. c.JSON(http.StatusOK, gin.H{
  1382. "success": false,
  1383. "message": "未指定要禁用的密钥索引",
  1384. })
  1385. return
  1386. }
  1387. keyIndex := *request.KeyIndex
  1388. if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
  1389. c.JSON(http.StatusOK, gin.H{
  1390. "success": false,
  1391. "message": "密钥索引超出范围",
  1392. })
  1393. return
  1394. }
  1395. if channel.ChannelInfo.MultiKeyStatusList == nil {
  1396. channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
  1397. }
  1398. if channel.ChannelInfo.MultiKeyDisabledTime == nil {
  1399. channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
  1400. }
  1401. if channel.ChannelInfo.MultiKeyDisabledReason == nil {
  1402. channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
  1403. }
  1404. channel.ChannelInfo.MultiKeyStatusList[keyIndex] = 2 // disabled
  1405. err = channel.Update()
  1406. if err != nil {
  1407. common.ApiError(c, err)
  1408. return
  1409. }
  1410. model.InitChannelCache()
  1411. c.JSON(http.StatusOK, gin.H{
  1412. "success": true,
  1413. "message": "密钥已禁用",
  1414. })
  1415. return
  1416. case "enable_key":
  1417. if request.KeyIndex == nil {
  1418. c.JSON(http.StatusOK, gin.H{
  1419. "success": false,
  1420. "message": "未指定要启用的密钥索引",
  1421. })
  1422. return
  1423. }
  1424. keyIndex := *request.KeyIndex
  1425. if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
  1426. c.JSON(http.StatusOK, gin.H{
  1427. "success": false,
  1428. "message": "密钥索引超出范围",
  1429. })
  1430. return
  1431. }
  1432. // 从状态列表中删除该密钥的记录,使其回到默认启用状态
  1433. if channel.ChannelInfo.MultiKeyStatusList != nil {
  1434. delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex)
  1435. }
  1436. if channel.ChannelInfo.MultiKeyDisabledTime != nil {
  1437. delete(channel.ChannelInfo.MultiKeyDisabledTime, keyIndex)
  1438. }
  1439. if channel.ChannelInfo.MultiKeyDisabledReason != nil {
  1440. delete(channel.ChannelInfo.MultiKeyDisabledReason, keyIndex)
  1441. }
  1442. err = channel.Update()
  1443. if err != nil {
  1444. common.ApiError(c, err)
  1445. return
  1446. }
  1447. model.InitChannelCache()
  1448. c.JSON(http.StatusOK, gin.H{
  1449. "success": true,
  1450. "message": "密钥已启用",
  1451. })
  1452. return
  1453. case "enable_all_keys":
  1454. // 清空所有禁用状态,使所有密钥回到默认启用状态
  1455. var enabledCount int
  1456. if channel.ChannelInfo.MultiKeyStatusList != nil {
  1457. enabledCount = len(channel.ChannelInfo.MultiKeyStatusList)
  1458. }
  1459. channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
  1460. channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
  1461. channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
  1462. err = channel.Update()
  1463. if err != nil {
  1464. common.ApiError(c, err)
  1465. return
  1466. }
  1467. model.InitChannelCache()
  1468. c.JSON(http.StatusOK, gin.H{
  1469. "success": true,
  1470. "message": fmt.Sprintf("已启用 %d 个密钥", enabledCount),
  1471. })
  1472. return
  1473. case "disable_all_keys":
  1474. // 禁用所有启用的密钥
  1475. if channel.ChannelInfo.MultiKeyStatusList == nil {
  1476. channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
  1477. }
  1478. if channel.ChannelInfo.MultiKeyDisabledTime == nil {
  1479. channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
  1480. }
  1481. if channel.ChannelInfo.MultiKeyDisabledReason == nil {
  1482. channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
  1483. }
  1484. var disabledCount int
  1485. for i := 0; i < channel.ChannelInfo.MultiKeySize; i++ {
  1486. status := 1 // default enabled
  1487. if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
  1488. status = s
  1489. }
  1490. // 只禁用当前启用的密钥
  1491. if status == 1 {
  1492. channel.ChannelInfo.MultiKeyStatusList[i] = 2 // disabled
  1493. disabledCount++
  1494. }
  1495. }
  1496. if disabledCount == 0 {
  1497. c.JSON(http.StatusOK, gin.H{
  1498. "success": false,
  1499. "message": "没有可禁用的密钥",
  1500. })
  1501. return
  1502. }
  1503. err = channel.Update()
  1504. if err != nil {
  1505. common.ApiError(c, err)
  1506. return
  1507. }
  1508. model.InitChannelCache()
  1509. c.JSON(http.StatusOK, gin.H{
  1510. "success": true,
  1511. "message": fmt.Sprintf("已禁用 %d 个密钥", disabledCount),
  1512. })
  1513. return
  1514. case "delete_key":
  1515. if request.KeyIndex == nil {
  1516. c.JSON(http.StatusOK, gin.H{
  1517. "success": false,
  1518. "message": "未指定要删除的密钥索引",
  1519. })
  1520. return
  1521. }
  1522. keyIndex := *request.KeyIndex
  1523. if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
  1524. c.JSON(http.StatusOK, gin.H{
  1525. "success": false,
  1526. "message": "密钥索引超出范围",
  1527. })
  1528. return
  1529. }
  1530. keys := channel.GetKeys()
  1531. var remainingKeys []string
  1532. var newStatusList = make(map[int]int)
  1533. var newDisabledTime = make(map[int]int64)
  1534. var newDisabledReason = make(map[int]string)
  1535. newIndex := 0
  1536. for i, key := range keys {
  1537. // 跳过要删除的密钥
  1538. if i == keyIndex {
  1539. continue
  1540. }
  1541. remainingKeys = append(remainingKeys, key)
  1542. // 保留其他密钥的状态信息,重新索引
  1543. if channel.ChannelInfo.MultiKeyStatusList != nil {
  1544. if status, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists && status != 1 {
  1545. newStatusList[newIndex] = status
  1546. }
  1547. }
  1548. if channel.ChannelInfo.MultiKeyDisabledTime != nil {
  1549. if t, exists := channel.ChannelInfo.MultiKeyDisabledTime[i]; exists {
  1550. newDisabledTime[newIndex] = t
  1551. }
  1552. }
  1553. if channel.ChannelInfo.MultiKeyDisabledReason != nil {
  1554. if r, exists := channel.ChannelInfo.MultiKeyDisabledReason[i]; exists {
  1555. newDisabledReason[newIndex] = r
  1556. }
  1557. }
  1558. newIndex++
  1559. }
  1560. if len(remainingKeys) == 0 {
  1561. c.JSON(http.StatusOK, gin.H{
  1562. "success": false,
  1563. "message": "不能删除最后一个密钥",
  1564. })
  1565. return
  1566. }
  1567. // Update channel with remaining keys
  1568. channel.Key = strings.Join(remainingKeys, "\n")
  1569. channel.ChannelInfo.MultiKeySize = len(remainingKeys)
  1570. channel.ChannelInfo.MultiKeyStatusList = newStatusList
  1571. channel.ChannelInfo.MultiKeyDisabledTime = newDisabledTime
  1572. channel.ChannelInfo.MultiKeyDisabledReason = newDisabledReason
  1573. err = channel.Update()
  1574. if err != nil {
  1575. common.ApiError(c, err)
  1576. return
  1577. }
  1578. model.InitChannelCache()
  1579. c.JSON(http.StatusOK, gin.H{
  1580. "success": true,
  1581. "message": "密钥已删除",
  1582. })
  1583. return
  1584. case "delete_disabled_keys":
  1585. keys := channel.GetKeys()
  1586. var remainingKeys []string
  1587. var deletedCount int
  1588. var newStatusList = make(map[int]int)
  1589. var newDisabledTime = make(map[int]int64)
  1590. var newDisabledReason = make(map[int]string)
  1591. newIndex := 0
  1592. for i, key := range keys {
  1593. status := 1 // default enabled
  1594. if channel.ChannelInfo.MultiKeyStatusList != nil {
  1595. if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
  1596. status = s
  1597. }
  1598. }
  1599. // 只删除自动禁用(status == 3)的密钥,保留启用(status == 1)和手动禁用(status == 2)的密钥
  1600. if status == 3 {
  1601. deletedCount++
  1602. } else {
  1603. remainingKeys = append(remainingKeys, key)
  1604. // 保留非自动禁用密钥的状态信息,重新索引
  1605. if status != 1 {
  1606. newStatusList[newIndex] = status
  1607. if channel.ChannelInfo.MultiKeyDisabledTime != nil {
  1608. if t, exists := channel.ChannelInfo.MultiKeyDisabledTime[i]; exists {
  1609. newDisabledTime[newIndex] = t
  1610. }
  1611. }
  1612. if channel.ChannelInfo.MultiKeyDisabledReason != nil {
  1613. if r, exists := channel.ChannelInfo.MultiKeyDisabledReason[i]; exists {
  1614. newDisabledReason[newIndex] = r
  1615. }
  1616. }
  1617. }
  1618. newIndex++
  1619. }
  1620. }
  1621. if deletedCount == 0 {
  1622. c.JSON(http.StatusOK, gin.H{
  1623. "success": false,
  1624. "message": "没有需要删除的自动禁用密钥",
  1625. })
  1626. return
  1627. }
  1628. // Update channel with remaining keys
  1629. channel.Key = strings.Join(remainingKeys, "\n")
  1630. channel.ChannelInfo.MultiKeySize = len(remainingKeys)
  1631. channel.ChannelInfo.MultiKeyStatusList = newStatusList
  1632. channel.ChannelInfo.MultiKeyDisabledTime = newDisabledTime
  1633. channel.ChannelInfo.MultiKeyDisabledReason = newDisabledReason
  1634. err = channel.Update()
  1635. if err != nil {
  1636. common.ApiError(c, err)
  1637. return
  1638. }
  1639. model.InitChannelCache()
  1640. c.JSON(http.StatusOK, gin.H{
  1641. "success": true,
  1642. "message": fmt.Sprintf("已删除 %d 个自动禁用的密钥", deletedCount),
  1643. "data": deletedCount,
  1644. })
  1645. return
  1646. default:
  1647. c.JSON(http.StatusOK, gin.H{
  1648. "success": false,
  1649. "message": "不支持的操作",
  1650. })
  1651. return
  1652. }
  1653. }
  1654. // OllamaPullModel 拉取 Ollama 模型
  1655. func OllamaPullModel(c *gin.Context) {
  1656. var req struct {
  1657. ChannelID int `json:"channel_id"`
  1658. ModelName string `json:"model_name"`
  1659. }
  1660. if err := c.ShouldBindJSON(&req); err != nil {
  1661. c.JSON(http.StatusBadRequest, gin.H{
  1662. "success": false,
  1663. "message": "Invalid request parameters",
  1664. })
  1665. return
  1666. }
  1667. if req.ChannelID == 0 || req.ModelName == "" {
  1668. c.JSON(http.StatusBadRequest, gin.H{
  1669. "success": false,
  1670. "message": "Channel ID and model name are required",
  1671. })
  1672. return
  1673. }
  1674. // 获取渠道信息
  1675. channel, err := model.GetChannelById(req.ChannelID, true)
  1676. if err != nil {
  1677. c.JSON(http.StatusNotFound, gin.H{
  1678. "success": false,
  1679. "message": "Channel not found",
  1680. })
  1681. return
  1682. }
  1683. // 检查是否是 Ollama 渠道
  1684. if channel.Type != constant.ChannelTypeOllama {
  1685. c.JSON(http.StatusBadRequest, gin.H{
  1686. "success": false,
  1687. "message": "This operation is only supported for Ollama channels",
  1688. })
  1689. return
  1690. }
  1691. baseURL := constant.ChannelBaseURLs[channel.Type]
  1692. if channel.GetBaseURL() != "" {
  1693. baseURL = channel.GetBaseURL()
  1694. }
  1695. key := strings.Split(channel.Key, "\n")[0]
  1696. err = ollama.PullOllamaModel(baseURL, key, req.ModelName)
  1697. if err != nil {
  1698. c.JSON(http.StatusInternalServerError, gin.H{
  1699. "success": false,
  1700. "message": fmt.Sprintf("Failed to pull model: %s", err.Error()),
  1701. })
  1702. return
  1703. }
  1704. c.JSON(http.StatusOK, gin.H{
  1705. "success": true,
  1706. "message": fmt.Sprintf("Model %s pulled successfully", req.ModelName),
  1707. })
  1708. }
  1709. // OllamaPullModelStream 流式拉取 Ollama 模型
  1710. func OllamaPullModelStream(c *gin.Context) {
  1711. var req struct {
  1712. ChannelID int `json:"channel_id"`
  1713. ModelName string `json:"model_name"`
  1714. }
  1715. if err := c.ShouldBindJSON(&req); err != nil {
  1716. c.JSON(http.StatusBadRequest, gin.H{
  1717. "success": false,
  1718. "message": "Invalid request parameters",
  1719. })
  1720. return
  1721. }
  1722. if req.ChannelID == 0 || req.ModelName == "" {
  1723. c.JSON(http.StatusBadRequest, gin.H{
  1724. "success": false,
  1725. "message": "Channel ID and model name are required",
  1726. })
  1727. return
  1728. }
  1729. // 获取渠道信息
  1730. channel, err := model.GetChannelById(req.ChannelID, true)
  1731. if err != nil {
  1732. c.JSON(http.StatusNotFound, gin.H{
  1733. "success": false,
  1734. "message": "Channel not found",
  1735. })
  1736. return
  1737. }
  1738. // 检查是否是 Ollama 渠道
  1739. if channel.Type != constant.ChannelTypeOllama {
  1740. c.JSON(http.StatusBadRequest, gin.H{
  1741. "success": false,
  1742. "message": "This operation is only supported for Ollama channels",
  1743. })
  1744. return
  1745. }
  1746. baseURL := constant.ChannelBaseURLs[channel.Type]
  1747. if channel.GetBaseURL() != "" {
  1748. baseURL = channel.GetBaseURL()
  1749. }
  1750. // 设置 SSE 头部
  1751. c.Header("Content-Type", "text/event-stream")
  1752. c.Header("Cache-Control", "no-cache")
  1753. c.Header("Connection", "keep-alive")
  1754. c.Header("Access-Control-Allow-Origin", "*")
  1755. key := strings.Split(channel.Key, "\n")[0]
  1756. // 创建进度回调函数
  1757. progressCallback := func(progress ollama.OllamaPullResponse) {
  1758. data, _ := json.Marshal(progress)
  1759. fmt.Fprintf(c.Writer, "data: %s\n\n", string(data))
  1760. c.Writer.Flush()
  1761. }
  1762. // 执行拉取
  1763. err = ollama.PullOllamaModelStream(baseURL, key, req.ModelName, progressCallback)
  1764. if err != nil {
  1765. errorData, _ := json.Marshal(gin.H{
  1766. "error": err.Error(),
  1767. })
  1768. fmt.Fprintf(c.Writer, "data: %s\n\n", string(errorData))
  1769. } else {
  1770. successData, _ := json.Marshal(gin.H{
  1771. "message": fmt.Sprintf("Model %s pulled successfully", req.ModelName),
  1772. })
  1773. fmt.Fprintf(c.Writer, "data: %s\n\n", string(successData))
  1774. }
  1775. // 发送结束标志
  1776. fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
  1777. c.Writer.Flush()
  1778. }
  1779. // OllamaDeleteModel 删除 Ollama 模型
  1780. func OllamaDeleteModel(c *gin.Context) {
  1781. var req struct {
  1782. ChannelID int `json:"channel_id"`
  1783. ModelName string `json:"model_name"`
  1784. }
  1785. if err := c.ShouldBindJSON(&req); err != nil {
  1786. c.JSON(http.StatusBadRequest, gin.H{
  1787. "success": false,
  1788. "message": "Invalid request parameters",
  1789. })
  1790. return
  1791. }
  1792. if req.ChannelID == 0 || req.ModelName == "" {
  1793. c.JSON(http.StatusBadRequest, gin.H{
  1794. "success": false,
  1795. "message": "Channel ID and model name are required",
  1796. })
  1797. return
  1798. }
  1799. // 获取渠道信息
  1800. channel, err := model.GetChannelById(req.ChannelID, true)
  1801. if err != nil {
  1802. c.JSON(http.StatusNotFound, gin.H{
  1803. "success": false,
  1804. "message": "Channel not found",
  1805. })
  1806. return
  1807. }
  1808. // 检查是否是 Ollama 渠道
  1809. if channel.Type != constant.ChannelTypeOllama {
  1810. c.JSON(http.StatusBadRequest, gin.H{
  1811. "success": false,
  1812. "message": "This operation is only supported for Ollama channels",
  1813. })
  1814. return
  1815. }
  1816. baseURL := constant.ChannelBaseURLs[channel.Type]
  1817. if channel.GetBaseURL() != "" {
  1818. baseURL = channel.GetBaseURL()
  1819. }
  1820. key := strings.Split(channel.Key, "\n")[0]
  1821. err = ollama.DeleteOllamaModel(baseURL, key, req.ModelName)
  1822. if err != nil {
  1823. c.JSON(http.StatusInternalServerError, gin.H{
  1824. "success": false,
  1825. "message": fmt.Sprintf("Failed to delete model: %s", err.Error()),
  1826. })
  1827. return
  1828. }
  1829. c.JSON(http.StatusOK, gin.H{
  1830. "success": true,
  1831. "message": fmt.Sprintf("Model %s deleted successfully", req.ModelName),
  1832. })
  1833. }
  1834. // OllamaVersion 获取 Ollama 服务版本信息
  1835. func OllamaVersion(c *gin.Context) {
  1836. id, err := strconv.Atoi(c.Param("id"))
  1837. if err != nil {
  1838. c.JSON(http.StatusBadRequest, gin.H{
  1839. "success": false,
  1840. "message": "Invalid channel id",
  1841. })
  1842. return
  1843. }
  1844. channel, err := model.GetChannelById(id, true)
  1845. if err != nil {
  1846. c.JSON(http.StatusNotFound, gin.H{
  1847. "success": false,
  1848. "message": "Channel not found",
  1849. })
  1850. return
  1851. }
  1852. if channel.Type != constant.ChannelTypeOllama {
  1853. c.JSON(http.StatusBadRequest, gin.H{
  1854. "success": false,
  1855. "message": "This operation is only supported for Ollama channels",
  1856. })
  1857. return
  1858. }
  1859. baseURL := constant.ChannelBaseURLs[channel.Type]
  1860. if channel.GetBaseURL() != "" {
  1861. baseURL = channel.GetBaseURL()
  1862. }
  1863. key := strings.Split(channel.Key, "\n")[0]
  1864. version, err := ollama.FetchOllamaVersion(baseURL, key)
  1865. if err != nil {
  1866. c.JSON(http.StatusOK, gin.H{
  1867. "success": false,
  1868. "message": fmt.Sprintf("获取Ollama版本失败: %s", err.Error()),
  1869. })
  1870. return
  1871. }
  1872. c.JSON(http.StatusOK, gin.H{
  1873. "success": true,
  1874. "data": gin.H{
  1875. "version": version,
  1876. },
  1877. })
  1878. }