channel_import.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. package controller
  2. import (
  3. "fmt"
  4. "net/http"
  5. "one-api/common"
  6. "one-api/model"
  7. "strconv"
  8. // "strings"
  9. "github.com/gin-gonic/gin"
  10. "github.com/xuri/excelize/v2"
  11. )
  12. func ImportChannels(c *gin.Context) {
  13. // 从请求中获取上传的文件
  14. file, err := c.FormFile("file")
  15. if err != nil {
  16. common.ApiError(c, fmt.Errorf("failed to get file: %w", err))
  17. return
  18. }
  19. // 打开文件
  20. src, err := file.Open()
  21. if err != nil {
  22. common.ApiError(c, fmt.Errorf("failed to open file: %w", err))
  23. return
  24. }
  25. defer src.Close()
  26. // 读取Excel文件
  27. f, err := excelize.OpenReader(src)
  28. if err != nil {
  29. common.ApiError(c, fmt.Errorf("failed to open Excel file: %w", err))
  30. return
  31. }
  32. defer func() {
  33. if err := f.Close(); err != nil {
  34. common.SysError("Error closing Excel file: " + err.Error())
  35. }
  36. }()
  37. // 获取第一个工作表
  38. sheetName := f.GetSheetName(0)
  39. if sheetName == "" {
  40. common.ApiError(c, fmt.Errorf("no sheets found in Excel file"))
  41. return
  42. }
  43. // 读取所有行
  44. rows, err := f.GetRows(sheetName)
  45. if err != nil {
  46. common.ApiError(c, fmt.Errorf("failed to read rows: %w", err))
  47. return
  48. }
  49. // 检查是否有数据
  50. if len(rows) <= 1 {
  51. common.ApiError(c, fmt.Errorf("no data found in Excel file"))
  52. return
  53. }
  54. // 解析表头
  55. headers := rows[0]
  56. headerMap := make(map[string]int)
  57. for i, header := range headers {
  58. headerMap[header] = i
  59. }
  60. // 解析数据行
  61. channels := make([]model.Channel, 0, len(rows)-1)
  62. for i := 1; i < len(rows); i++ {
  63. row := rows[i]
  64. if len(row) == 0 {
  65. continue
  66. }
  67. // 创建渠道对象
  68. channel := model.Channel{}
  69. // 根据表头映射填充数据
  70. for header, colIndex := range headerMap {
  71. if colIndex >= len(row) {
  72. continue
  73. }
  74. value := row[colIndex]
  75. switch header {
  76. case "ID":
  77. // ID由数据库自动生成,不需要设置
  78. case "名称":
  79. channel.Name = value
  80. case "类型":
  81. if v, err := strconv.Atoi(value); err == nil {
  82. channel.Type = v
  83. }
  84. case "状态":
  85. if v, err := strconv.Atoi(value); err == nil {
  86. channel.Status = v
  87. }
  88. case "密钥":
  89. channel.Key = value
  90. case "组织":
  91. if value != "" {
  92. channel.OpenAIOrganization = &value
  93. }
  94. case "测试模型":
  95. if value != "" {
  96. channel.TestModel = &value
  97. }
  98. case "权重":
  99. if value != "" {
  100. if v, err := strconv.ParseUint(value, 10, 32); err == nil {
  101. weight := uint(v)
  102. channel.Weight = &weight
  103. }
  104. }
  105. case "创建时间":
  106. if v, err := strconv.ParseInt(value, 10, 64); err == nil {
  107. channel.CreatedTime = v
  108. }
  109. case "测试时间":
  110. if v, err := strconv.ParseInt(value, 10, 64); err == nil {
  111. channel.TestTime = v
  112. }
  113. case "响应时间":
  114. if v, err := strconv.Atoi(value); err == nil {
  115. channel.ResponseTime = v
  116. }
  117. case "基础URL":
  118. if value != "" {
  119. channel.BaseURL = &value
  120. }
  121. case "其他":
  122. channel.Other = value
  123. case "余额":
  124. if v, err := strconv.ParseFloat(value, 64); err == nil {
  125. channel.Balance = v
  126. }
  127. case "余额更新时间":
  128. if v, err := strconv.ParseInt(value, 10, 64); err == nil {
  129. channel.BalanceUpdatedTime = v
  130. }
  131. case "模型":
  132. channel.Models = value
  133. case "分组":
  134. channel.Group = value
  135. case "已用配额":
  136. if v, err := strconv.ParseInt(value, 10, 64); err == nil {
  137. channel.UsedQuota = v
  138. }
  139. case "模型映射":
  140. if value != "" {
  141. channel.ModelMapping = &value
  142. }
  143. case "状态码映射":
  144. if value != "" {
  145. channel.StatusCodeMapping = &value
  146. }
  147. case "优先级":
  148. if value != "" {
  149. if v, err := strconv.ParseInt(value, 10, 64); err == nil {
  150. channel.Priority = &v
  151. }
  152. }
  153. case "自动禁用":
  154. if value != "" {
  155. if v, err := strconv.Atoi(value); err == nil {
  156. channel.AutoBan = &v
  157. }
  158. }
  159. case "标签":
  160. if value != "" {
  161. channel.Tag = &value
  162. }
  163. case "额外设置":
  164. if value != "" {
  165. channel.Setting = &value
  166. }
  167. case "参数覆盖":
  168. if value != "" {
  169. channel.ParamOverride = &value
  170. }
  171. }
  172. }
  173. // 设置默认值
  174. if channel.CreatedTime == 0 {
  175. channel.CreatedTime = common.GetTimestamp()
  176. }
  177. // 如果没有设置状态,默认为启用
  178. if channel.Status == 0 {
  179. channel.Status = 1
  180. }
  181. channels = append(channels, channel)
  182. }
  183. // 批量插入渠道
  184. err = model.BatchInsertChannels(channels)
  185. if err != nil {
  186. common.ApiError(c, fmt.Errorf("failed to insert channels: %w", err))
  187. return
  188. }
  189. // 初始化渠道缓存
  190. model.InitChannelCache()
  191. // 返回成功响应
  192. c.JSON(http.StatusOK, gin.H{
  193. "success": true,
  194. "message": fmt.Sprintf("成功导入 %d 个渠道", len(channels)),
  195. })
  196. }