| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216 |
- package controller
- import (
- "fmt"
- "net/http"
- "one-api/common"
- "one-api/model"
- "strconv"
- // "strings"
- "github.com/gin-gonic/gin"
- "github.com/xuri/excelize/v2"
- )
- func ImportChannels(c *gin.Context) {
- // 从请求中获取上传的文件
- file, err := c.FormFile("file")
- if err != nil {
- common.ApiError(c, fmt.Errorf("failed to get file: %w", err))
- return
- }
- // 打开文件
- src, err := file.Open()
- if err != nil {
- common.ApiError(c, fmt.Errorf("failed to open file: %w", err))
- return
- }
- defer src.Close()
- // 读取Excel文件
- f, err := excelize.OpenReader(src)
- if err != nil {
- common.ApiError(c, fmt.Errorf("failed to open Excel file: %w", err))
- return
- }
- defer func() {
- if err := f.Close(); err != nil {
- common.SysError("Error closing Excel file: " + err.Error())
- }
- }()
- // 获取第一个工作表
- sheetName := f.GetSheetName(0)
- if sheetName == "" {
- common.ApiError(c, fmt.Errorf("no sheets found in Excel file"))
- return
- }
- // 读取所有行
- rows, err := f.GetRows(sheetName)
- if err != nil {
- common.ApiError(c, fmt.Errorf("failed to read rows: %w", err))
- return
- }
- // 检查是否有数据
- if len(rows) <= 1 {
- common.ApiError(c, fmt.Errorf("no data found in Excel file"))
- return
- }
- // 解析表头
- headers := rows[0]
- headerMap := make(map[string]int)
- for i, header := range headers {
- headerMap[header] = i
- }
- // 解析数据行
- channels := make([]model.Channel, 0, len(rows)-1)
- for i := 1; i < len(rows); i++ {
- row := rows[i]
- if len(row) == 0 {
- continue
- }
- // 创建渠道对象
- channel := model.Channel{}
- // 根据表头映射填充数据
- for header, colIndex := range headerMap {
- if colIndex >= len(row) {
- continue
- }
- value := row[colIndex]
- switch header {
- case "ID":
- // ID由数据库自动生成,不需要设置
- case "名称":
- channel.Name = value
- case "类型":
- if v, err := strconv.Atoi(value); err == nil {
- channel.Type = v
- }
- case "状态":
- if v, err := strconv.Atoi(value); err == nil {
- channel.Status = v
- }
- case "密钥":
- channel.Key = value
- case "组织":
- if value != "" {
- channel.OpenAIOrganization = &value
- }
- case "测试模型":
- if value != "" {
- channel.TestModel = &value
- }
- case "权重":
- if value != "" {
- if v, err := strconv.ParseUint(value, 10, 32); err == nil {
- weight := uint(v)
- channel.Weight = &weight
- }
- }
- case "创建时间":
- if v, err := strconv.ParseInt(value, 10, 64); err == nil {
- channel.CreatedTime = v
- }
- case "测试时间":
- if v, err := strconv.ParseInt(value, 10, 64); err == nil {
- channel.TestTime = v
- }
- case "响应时间":
- if v, err := strconv.Atoi(value); err == nil {
- channel.ResponseTime = v
- }
- case "基础URL":
- if value != "" {
- channel.BaseURL = &value
- }
- case "其他":
- channel.Other = value
- case "余额":
- if v, err := strconv.ParseFloat(value, 64); err == nil {
- channel.Balance = v
- }
- case "余额更新时间":
- if v, err := strconv.ParseInt(value, 10, 64); err == nil {
- channel.BalanceUpdatedTime = v
- }
- case "模型":
- channel.Models = value
- case "分组":
- channel.Group = value
- case "已用配额":
- if v, err := strconv.ParseInt(value, 10, 64); err == nil {
- channel.UsedQuota = v
- }
- case "模型映射":
- if value != "" {
- channel.ModelMapping = &value
- }
- case "状态码映射":
- if value != "" {
- channel.StatusCodeMapping = &value
- }
- case "优先级":
- if value != "" {
- if v, err := strconv.ParseInt(value, 10, 64); err == nil {
- channel.Priority = &v
- }
- }
- case "自动禁用":
- if value != "" {
- if v, err := strconv.Atoi(value); err == nil {
- channel.AutoBan = &v
- }
- }
- case "标签":
- if value != "" {
- channel.Tag = &value
- }
- case "额外设置":
- if value != "" {
- channel.Setting = &value
- }
- case "参数覆盖":
- if value != "" {
- channel.ParamOverride = &value
- }
- }
- }
- // 设置默认值
- if channel.CreatedTime == 0 {
- channel.CreatedTime = common.GetTimestamp()
- }
- // 如果没有设置状态,默认为启用
- if channel.Status == 0 {
- channel.Status = 1
- }
- channels = append(channels, channel)
- }
- // 批量插入渠道
- err = model.BatchInsertChannels(channels)
- if err != nil {
- common.ApiError(c, fmt.Errorf("failed to insert channels: %w", err))
- return
- }
- // 初始化渠道缓存
- model.InitChannelCache()
- // 返回成功响应
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": fmt.Sprintf("成功导入 %d 个渠道", len(channels)),
- })
- }
|