Przeglądaj źródła

Merge branch 'songquanpeng:main' into main

Calcium-Ion 2 lat temu
rodzic
commit
26f9d25860

+ 6 - 6
controller/channel-billing.go

@@ -111,7 +111,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
 }
 
 func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) {
-	url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.BaseURL)
+	url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.GetBaseURL())
 	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
 
 	if err != nil {
@@ -201,18 +201,18 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
 
 func updateChannelBalance(channel *model.Channel) (float64, error) {
 	baseURL := common.ChannelBaseURLs[channel.Type]
-	if channel.BaseURL == "" {
-		channel.BaseURL = baseURL
+	if channel.GetBaseURL() == "" {
+		channel.BaseURL = &baseURL
 	}
 	switch channel.Type {
 	case common.ChannelTypeOpenAI:
-		if channel.BaseURL != "" {
-			baseURL = channel.BaseURL
+		if channel.GetBaseURL() != "" {
+			baseURL = channel.GetBaseURL()
 		}
 	case common.ChannelTypeAzure:
 		return 0, errors.New("尚未实现")
 	case common.ChannelTypeCustom:
-		baseURL = channel.BaseURL
+		baseURL = channel.GetBaseURL()
 	case common.ChannelTypeCloseAI:
 		return updateChannelCloseAIBalance(channel)
 	case common.ChannelTypeOpenAISB:

+ 3 - 3
controller/channel-test.go

@@ -42,10 +42,10 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
 	}
 	requestURL := common.ChannelBaseURLs[channel.Type]
 	if channel.Type == common.ChannelTypeAzure {
-		requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model)
+		requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.GetBaseURL(), request.Model)
 	} else {
-		if channel.BaseURL != "" {
-			requestURL = channel.BaseURL
+		if channel.GetBaseURL() != "" {
+			requestURL = channel.GetBaseURL()
 		}
 		requestURL += "/v1/chat/completions"
 	}

+ 2 - 2
middleware/distributor.go

@@ -94,9 +94,9 @@ func Distribute() func(c *gin.Context) {
 		c.Set("channel", channel.Type)
 		c.Set("channel_id", channel.Id)
 		c.Set("channel_name", channel.Name)
-		c.Set("model_mapping", channel.ModelMapping)
+		c.Set("model_mapping", channel.GetModelMapping())
 		c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
-		c.Set("base_url", channel.BaseURL)
+		c.Set("base_url", channel.GetBaseURL())
 		switch channel.Type {
 		case common.ChannelTypeAzure:
 			c.Set("api_version", channel.Other)

+ 5 - 3
model/ability.go

@@ -10,16 +10,18 @@ type Ability struct {
 	Model     string `json:"model" gorm:"primaryKey;autoIncrement:false"`
 	ChannelId int    `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
 	Enabled   bool   `json:"enabled"`
-	Priority  int64  `json:"priority" gorm:"bigint;default:0"`
+	Priority  *int64 `json:"priority" gorm:"bigint;default:0;index"`
 }
 
 func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
 	ability := Ability{}
 	var err error = nil
 	if common.UsingSQLite {
-		err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("CASE WHEN priority <> 0 THEN priority ELSE RANDOM() END DESC ").Limit(1).First(&ability).Error
+		maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where("`group` = ? and model = ? and enabled = 1", group, model)
+		err = DB.Where("`group` = ? and model = ? and enabled = 1 and priority = (?)", group, model, maxPrioritySubQuery).Order("RANDOM()").Limit(1).First(&ability).Error
 	} else {
-		err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("CASE WHEN priority <> 0 THEN priority ELSE RAND() END DESC").Limit(1).First(&ability).Error
+		maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where("group = ? and model = ? and enabled = 1", group, model)
+		err = DB.Where("`group` = ? and model = ? and enabled = 1 and priority = (?)", group, model, maxPrioritySubQuery).Order("RAND()").Limit(1).First(&ability).Error
 	}
 	if err != nil {
 		return nil, err

+ 10 - 4
model/cache.go

@@ -165,7 +165,7 @@ func InitChannelCache() {
 	for group, model2channels := range newGroup2model2channels {
 		for model, channels := range model2channels {
 			sort.Slice(channels, func(i, j int) bool {
-				return channels[i].Priority > channels[j].Priority
+				return channels[i].GetPriority() > channels[j].GetPriority()
 			})
 			newGroup2model2channels[group][model] = channels
 		}
@@ -195,11 +195,17 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
 	if len(channels) == 0 {
 		return nil, errors.New("channel not found")
 	}
+	endIdx := len(channels)
 	// choose by priority
 	firstChannel := channels[0]
-	if firstChannel.Priority > 0 {
-		return firstChannel, nil
+	if firstChannel.GetPriority() > 0 {
+		for i := range channels {
+			if channels[i].GetPriority() != firstChannel.GetPriority() {
+				endIdx = i
+				break
+			}
+		}
 	}
-	idx := rand.Intn(len(channels))
+	idx := rand.Intn(endIdx)
 	return channels[idx], nil
 }

+ 24 - 3
model/channel.go

@@ -16,15 +16,15 @@ type Channel struct {
 	CreatedTime        int64   `json:"created_time" gorm:"bigint"`
 	TestTime           int64   `json:"test_time" gorm:"bigint"`
 	ResponseTime       int     `json:"response_time"` // in milliseconds
-	BaseURL            string  `json:"base_url" gorm:"column:base_url"`
+	BaseURL            *string `json:"base_url" gorm:"column:base_url;default:''"`
 	Other              string  `json:"other"`
 	Balance            float64 `json:"balance"` // in USD
 	BalanceUpdatedTime int64   `json:"balance_updated_time" gorm:"bigint"`
 	Models             string  `json:"models"`
 	Group              string  `json:"group" gorm:"type:varchar(32);default:'default'"`
 	UsedQuota          int64   `json:"used_quota" gorm:"bigint;default:0"`
-	ModelMapping       string  `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
-	Priority           int64   `json:"priority" gorm:"bigint;default:0"`
+	ModelMapping       *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
+	Priority           *int64  `json:"priority" gorm:"bigint;default:0"`
 }
 
 func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
@@ -80,6 +80,27 @@ func BatchInsertChannels(channels []Channel) error {
 	return nil
 }
 
+func (channel *Channel) GetPriority() int64 {
+	if channel.Priority == nil {
+		return 0
+	}
+	return *channel.Priority
+}
+
+func (channel *Channel) GetBaseURL() string {
+	if channel.BaseURL == nil {
+		return ""
+	}
+	return *channel.BaseURL
+}
+
+func (channel *Channel) GetModelMapping() string {
+	if channel.ModelMapping == nil {
+		return ""
+	}
+	return *channel.ModelMapping
+}
+
 func (channel *Channel) Insert() error {
 	var err error
 	err = DB.Create(channel).Error

+ 1 - 4
web/src/pages/Channel/EditChannel.js

@@ -175,7 +175,7 @@ const EditChannel = () => {
       return;
     }
     let localInputs = inputs;
-    if (localInputs.base_url.endsWith('/')) {
+    if (localInputs.base_url && localInputs.base_url.endsWith('/')) {
       localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1);
     }
     if (localInputs.type === 3 && localInputs.other === '') {
@@ -184,9 +184,6 @@ const EditChannel = () => {
     if (localInputs.type === 18 && localInputs.other === '') {
       localInputs.other = 'v2.1';
     }
-    if (localInputs.model_mapping === '') {
-      localInputs.model_mapping = '{}';
-    }
     let res;
     localInputs.models = localInputs.models.join(',');
     localInputs.group = localInputs.groups.join(',');