Răsfoiți Sursa

fix: fix unable to set zero value for base url & model mapping

JustSong 2 ani în urmă
părinte
comite
159b9e3369

+ 6 - 6
controller/channel-billing.go

@@ -111,7 +111,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
 }
 
 func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) {
-	url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.BaseURL)
+	url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.GetBaseURL())
 	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
 
 	if err != nil {
@@ -201,18 +201,18 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
 
 func updateChannelBalance(channel *model.Channel) (float64, error) {
 	baseURL := common.ChannelBaseURLs[channel.Type]
-	if channel.BaseURL == "" {
-		channel.BaseURL = baseURL
+	if channel.GetBaseURL() == "" {
+		channel.BaseURL = &baseURL
 	}
 	switch channel.Type {
 	case common.ChannelTypeOpenAI:
-		if channel.BaseURL != "" {
-			baseURL = channel.BaseURL
+		if channel.GetBaseURL() != "" {
+			baseURL = channel.GetBaseURL()
 		}
 	case common.ChannelTypeAzure:
 		return 0, errors.New("尚未实现")
 	case common.ChannelTypeCustom:
-		baseURL = channel.BaseURL
+		baseURL = channel.GetBaseURL()
 	case common.ChannelTypeCloseAI:
 		return updateChannelCloseAIBalance(channel)
 	case common.ChannelTypeOpenAISB:

+ 3 - 3
controller/channel-test.go

@@ -42,10 +42,10 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
 	}
 	requestURL := common.ChannelBaseURLs[channel.Type]
 	if channel.Type == common.ChannelTypeAzure {
-		requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model)
+		requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.GetBaseURL(), request.Model)
 	} else {
-		if channel.BaseURL != "" {
-			requestURL = channel.BaseURL
+		if channel.GetBaseURL() != "" {
+			requestURL = channel.GetBaseURL()
 		}
 		requestURL += "/v1/chat/completions"
 	}

+ 2 - 2
middleware/distributor.go

@@ -82,9 +82,9 @@ func Distribute() func(c *gin.Context) {
 		c.Set("channel", channel.Type)
 		c.Set("channel_id", channel.Id)
 		c.Set("channel_name", channel.Name)
-		c.Set("model_mapping", channel.ModelMapping)
+		c.Set("model_mapping", channel.GetModelMapping())
 		c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
-		c.Set("base_url", channel.BaseURL)
+		c.Set("base_url", channel.GetBaseURL())
 		switch channel.Type {
 		case common.ChannelTypeAzure:
 			c.Set("api_version", channel.Other)

+ 17 - 3
model/channel.go

@@ -15,14 +15,14 @@ type Channel struct {
 	CreatedTime        int64   `json:"created_time" gorm:"bigint"`
 	TestTime           int64   `json:"test_time" gorm:"bigint"`
 	ResponseTime       int     `json:"response_time"` // in milliseconds
-	BaseURL            string  `json:"base_url" gorm:"column:base_url"`
+	BaseURL            *string `json:"base_url" gorm:"column:base_url"`
 	Other              string  `json:"other"`
 	Balance            float64 `json:"balance"` // in USD
 	BalanceUpdatedTime int64   `json:"balance_updated_time" gorm:"bigint"`
 	Models             string  `json:"models"`
 	Group              string  `json:"group" gorm:"type:varchar(32);default:'default'"`
 	UsedQuota          int64   `json:"used_quota" gorm:"bigint;default:0"`
-	ModelMapping       string  `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
+	ModelMapping       *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
 	Priority           *int64  `json:"priority" gorm:"bigint;default:0"`
 }
 
@@ -80,12 +80,26 @@ func BatchInsertChannels(channels []Channel) error {
 }
 
 func (channel *Channel) GetPriority() int64 {
-	if channel == nil {
+	if channel.Priority == nil {
 		return 0
 	}
 	return *channel.Priority
 }
 
+func (channel *Channel) GetBaseURL() string {
+	if channel.BaseURL == nil {
+		return ""
+	}
+	return *channel.BaseURL
+}
+
+func (channel *Channel) GetModelMapping() string {
+	if channel.ModelMapping == nil {
+		return ""
+	}
+	return *channel.ModelMapping
+}
+
 func (channel *Channel) Insert() error {
 	var err error
 	err = DB.Create(channel).Error

+ 0 - 3
web/src/pages/Channel/EditChannel.js

@@ -183,9 +183,6 @@ const EditChannel = () => {
     if (localInputs.type === 18 && localInputs.other === '') {
       localInputs.other = 'v2.1';
     }
-    if (localInputs.model_mapping === '') {
-      localInputs.model_mapping = '{}';
-    }
     let res;
     localInputs.models = localInputs.models.join(',');
     localInputs.group = localInputs.groups.join(',');