|
|
@@ -3,8 +3,10 @@ package controller
|
|
|
import (
|
|
|
"encoding/json"
|
|
|
"strconv"
|
|
|
+ "strings"
|
|
|
|
|
|
"one-api/common"
|
|
|
+ "one-api/constant"
|
|
|
"one-api/model"
|
|
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
@@ -162,17 +164,105 @@ func DeleteModelMeta(c *gin.Context) {
|
|
|
|
|
|
// 辅助函数:填充 Endpoints 和 BoundChannels 和 EnableGroups
|
|
|
func fillModelExtra(m *model.Model) {
|
|
|
- if m.Endpoints == "" {
|
|
|
- eps := model.GetModelSupportEndpointTypes(m.ModelName)
|
|
|
+ // 若为精确匹配,保持原有逻辑
|
|
|
+ if m.NameRule == model.NameRuleExact {
|
|
|
+ if m.Endpoints == "" {
|
|
|
+ eps := model.GetModelSupportEndpointTypes(m.ModelName)
|
|
|
+ if b, err := json.Marshal(eps); err == nil {
|
|
|
+ m.Endpoints = string(b)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if channels, err := model.GetBoundChannels(m.ModelName); err == nil {
|
|
|
+ m.BoundChannels = channels
|
|
|
+ }
|
|
|
+ m.EnableGroups = model.GetModelEnableGroups(m.ModelName)
|
|
|
+ m.QuotaType = model.GetModelQuotaType(m.ModelName)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // 非精确匹配:计算并集
|
|
|
+ pricings := model.GetPricing()
|
|
|
+
|
|
|
+ // 端点去重集合
|
|
|
+ endpointSet := make(map[constant.EndpointType]struct{})
|
|
|
+ // 已绑定渠道去重集合
|
|
|
+ channelSet := make(map[string]model.BoundChannel)
|
|
|
+ // 分组去重集合
|
|
|
+ groupSet := make(map[string]struct{})
|
|
|
+ // 计费类型(若有任意模型为 1,则返回 1)
|
|
|
+ quotaTypeSet := make(map[int]struct{})
|
|
|
+
|
|
|
+ for _, p := range pricings {
|
|
|
+ var matched bool
|
|
|
+ switch m.NameRule {
|
|
|
+ case model.NameRulePrefix:
|
|
|
+ matched = strings.HasPrefix(p.ModelName, m.ModelName)
|
|
|
+ case model.NameRuleSuffix:
|
|
|
+ matched = strings.HasSuffix(p.ModelName, m.ModelName)
|
|
|
+ case model.NameRuleContains:
|
|
|
+ matched = strings.Contains(p.ModelName, m.ModelName)
|
|
|
+ }
|
|
|
+ if !matched {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ // 收集端点
|
|
|
+ for _, et := range p.SupportedEndpointTypes {
|
|
|
+ endpointSet[et] = struct{}{}
|
|
|
+ }
|
|
|
+
|
|
|
+ // 收集分组
|
|
|
+ for _, g := range p.EnableGroup {
|
|
|
+ groupSet[g] = struct{}{}
|
|
|
+ }
|
|
|
+
|
|
|
+ // 收集计费类型
|
|
|
+ quotaTypeSet[p.QuotaType] = struct{}{}
|
|
|
+
|
|
|
+ // 收集渠道
|
|
|
+ if channels, err := model.GetBoundChannels(p.ModelName); err == nil {
|
|
|
+ for _, ch := range channels {
|
|
|
+ key := ch.Name + "_" + strconv.Itoa(ch.Type)
|
|
|
+ channelSet[key] = ch
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // 序列化端点
|
|
|
+ if len(endpointSet) > 0 && m.Endpoints == "" {
|
|
|
+ eps := make([]constant.EndpointType, 0, len(endpointSet))
|
|
|
+ for et := range endpointSet {
|
|
|
+ eps = append(eps, et)
|
|
|
+ }
|
|
|
if b, err := json.Marshal(eps); err == nil {
|
|
|
m.Endpoints = string(b)
|
|
|
}
|
|
|
}
|
|
|
- if channels, err := model.GetBoundChannels(m.ModelName); err == nil {
|
|
|
+
|
|
|
+ // 序列化渠道
|
|
|
+ if len(channelSet) > 0 {
|
|
|
+ channels := make([]model.BoundChannel, 0, len(channelSet))
|
|
|
+ for _, ch := range channelSet {
|
|
|
+ channels = append(channels, ch)
|
|
|
+ }
|
|
|
m.BoundChannels = channels
|
|
|
}
|
|
|
- // 填充启用分组
|
|
|
- m.EnableGroups = model.GetModelEnableGroups(m.ModelName)
|
|
|
- // 填充计费类型
|
|
|
- m.QuotaType = model.GetModelQuotaType(m.ModelName)
|
|
|
+
|
|
|
+ // 序列化分组
|
|
|
+ if len(groupSet) > 0 {
|
|
|
+ groups := make([]string, 0, len(groupSet))
|
|
|
+ for g := range groupSet {
|
|
|
+ groups = append(groups, g)
|
|
|
+ }
|
|
|
+ m.EnableGroups = groups
|
|
|
+ }
|
|
|
+
|
|
|
+ // 确定计费类型:仅当所有匹配模型计费类型一致时才返回该类型,否则返回 -1 表示未知/不确定
|
|
|
+ if len(quotaTypeSet) == 1 {
|
|
|
+ for k := range quotaTypeSet {
|
|
|
+ m.QuotaType = k
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ m.QuotaType = -1
|
|
|
+ }
|
|
|
}
|