Browse Source

fix: lint

zijiren233 9 months ago
parent
commit
d22b0cb69e

+ 5 - 5
.github/workflows/release.yml

@@ -1,4 +1,4 @@
-name: release
+name: Release
 
 on:
   push:
@@ -87,8 +87,8 @@ jobs:
           files: |
             aiproxy-${{ matrix.targets.GOOS }}-${{ matrix.targets.GOARCH }}${{ matrix.targets.EXT }}
 
-  build-docker:
-    name: Release Docker
+  build-docker-images:
+    name: Build Docker Images
     strategy:
       matrix:
         include:
@@ -149,9 +149,9 @@ jobs:
           if-no-files-found: error
           retention-days: 1
 
-  release-docker:
+  release-docker-images:
     name: Push Docker Images
-    needs: build-docker
+    needs: build-docker-images
     runs-on: ubuntu-24.04
     if: ${{ github.event_name != 'pull_request' }}
     steps:

+ 6 - 4
common/rpmlimit/rate-limit.go

@@ -76,16 +76,18 @@ func GetRPM(ctx context.Context, group, model string) (int64, error) {
 
 	var pattern string
 	var overLimitPattern string
-	if group == "" && model == "" {
+
+	switch {
+	case group == "" && model == "":
 		pattern = "group_model_rpm:*:*"
 		overLimitPattern = "over_limit_rpm:*:*"
-	} else if group == "" {
+	case group == "":
 		pattern = "group_model_rpm:*:" + model
 		overLimitPattern = "over_limit_rpm:*:" + model
-	} else if model == "" {
+	case model == "":
 		pattern = fmt.Sprintf("group_model_rpm:%s:*", group)
 		overLimitPattern = fmt.Sprintf("over_limit_rpm:%s:*", group)
-	} else {
+	default:
 		pattern = fmt.Sprintf("group_model_rpm:%s:%s", group, model)
 		overLimitPattern = fmt.Sprintf("over_limit_rpm:%s:%s", group, model)
 	}

+ 1 - 2
common/splitter/splitter.go

@@ -89,8 +89,7 @@ func (s *Splitter) processSeekTail() ([]byte, []byte) {
 	tailLen := s.tailLen
 	kmpNext := s.kmpNext
 
-	var i int
-	for i = 0; i < len(data); i++ {
+	for i := range data {
 		for j > 0 && data[i] != tail[j] {
 			j = kmpNext[j-1]
 		}

+ 6 - 3
controller/import.go

@@ -172,14 +172,17 @@ func ImportChannelFromOneAPI(c *gin.Context) {
 
 	var db *gorm.DB
 	var err error
-	if strings.HasPrefix(req.DSN, "mysql") {
+
+	switch {
+	case strings.HasPrefix(req.DSN, "mysql"):
 		db, err = model.OpenMySQL(req.DSN)
-	} else if strings.HasPrefix(req.DSN, "postgres") {
+	case strings.HasPrefix(req.DSN, "postgres"):
 		db, err = model.OpenPostgreSQL(req.DSN)
-	} else {
+	default:
 		middleware.ErrorResponse(c, http.StatusBadRequest, "invalid dsn, only mysql and postgres are supported")
 		return
 	}
+
 	if err != nil {
 		middleware.ErrorResponse(c, http.StatusBadRequest, err.Error())
 		return

+ 5 - 4
controller/relay.go

@@ -76,7 +76,7 @@ func RelayHelper(meta *meta.Meta, c *gin.Context, relayController RelayControlle
 	}
 	if shouldErrorMonitor(relayErr.StatusCode) {
 		hasPermission := channelHasPermission(relayErr.StatusCode)
-		beyondThreshold, autoBanned, err := monitor.AddRequest(
+		beyondThreshold, banExecution, err := monitor.AddRequest(
 			context.Background(),
 			meta.OriginModel,
 			int64(meta.Channel.ID),
@@ -86,7 +86,8 @@ func RelayHelper(meta *meta.Meta, c *gin.Context, relayController RelayControlle
 		if err != nil {
 			log.Errorf("add request failed: %+v", err)
 		}
-		if autoBanned {
+		switch {
+		case banExecution:
 			notify.ErrorThrottle(
 				fmt.Sprintf("autoBanned:%d:%s", meta.Channel.ID, meta.OriginModel),
 				time.Minute,
@@ -94,7 +95,7 @@ func RelayHelper(meta *meta.Meta, c *gin.Context, relayController RelayControlle
 					meta.Channel.Type, meta.Channel.Name, meta.Channel.ID, meta.OriginModel),
 				relayErr.JSONOrEmpty(),
 			)
-		} else if beyondThreshold {
+		case beyondThreshold:
 			notify.WarnThrottle(
 				fmt.Sprintf("beyondThreshold:%d:%s", meta.Channel.ID, meta.OriginModel),
 				time.Minute,
@@ -102,7 +103,7 @@ func RelayHelper(meta *meta.Meta, c *gin.Context, relayController RelayControlle
 					meta.Channel.Type, meta.Channel.Name, meta.Channel.ID, meta.OriginModel),
 				relayErr.JSONOrEmpty(),
 			)
-		} else if !hasPermission {
+		case !hasPermission:
 			notify.ErrorThrottle(
 				fmt.Sprintf("channelHasPermission:%d:%s", meta.Channel.ID, meta.OriginModel),
 				time.Minute,

+ 3 - 1
model/cache.go

@@ -407,12 +407,14 @@ func CacheGetGroupModelTPM(id string, model string) (int64, error) {
 	return tpm, nil
 }
 
+//nolint:revive
 type ModelConfigCache interface {
 	GetModelConfig(model string) (*ModelConfig, bool)
 }
 
 // read-only cache
-
+//
+//nolint:revive
 type ModelCaches struct {
 	ModelConfig                     ModelConfigCache
 	EnabledModel2channels           map[string][]*Channel

+ 2 - 0
model/configkey.go

@@ -2,6 +2,7 @@ package model
 
 import "reflect"
 
+//nolint:revive
 type ModelConfigKey string
 
 const (
@@ -14,6 +15,7 @@ const (
 	ModelConfigSupportVoicesKey    ModelConfigKey = "support_voices"
 )
 
+//nolint:revive
 type ModelConfigOption func(config map[ModelConfigKey]any)
 
 func WithModelConfigMaxContextTokens(maxContextTokens int) ModelConfigOption {

+ 17 - 12
model/log.go

@@ -962,11 +962,12 @@ func getChartData(group string, start, end time.Time, tokenName, modelName strin
 		query = query.Where("group_id = ?", group)
 	}
 
-	if !start.IsZero() && !end.IsZero() {
+	switch {
+	case !start.IsZero() && !end.IsZero():
 		query = query.Where("request_at BETWEEN ? AND ?", start, end)
-	} else if !start.IsZero() {
+	case !start.IsZero():
 		query = query.Where("request_at >= ?", start)
-	} else if !end.IsZero() {
+	case !end.IsZero():
 		query = query.Where("request_at <= ?", end)
 	}
 
@@ -1003,11 +1004,12 @@ func getLogDistinctValues[T cmp.Ordered](field string, group string, start, end
 		query = query.Where("group_id = ?", group)
 	}
 
-	if !start.IsZero() && !end.IsZero() {
+	switch {
+	case !start.IsZero() && !end.IsZero():
 		query = query.Where("request_at BETWEEN ? AND ?", start, end)
-	} else if !start.IsZero() {
+	case !start.IsZero():
 		query = query.Where("request_at >= ?", start)
-	} else if !end.IsZero() {
+	case !end.IsZero():
 		query = query.Where("request_at <= ?", end)
 	}
 
@@ -1030,11 +1032,12 @@ func getLogGroupByValues[T cmp.Ordered](field string, group string, start, end t
 		query = query.Where("group_id = ?", group)
 	}
 
-	if !start.IsZero() && !end.IsZero() {
+	switch {
+	case !start.IsZero() && !end.IsZero():
 		query = query.Where("request_at BETWEEN ? AND ?", start, end)
-	} else if !start.IsZero() {
+	case !start.IsZero():
 		query = query.Where("request_at >= ?", start)
-	} else if !end.IsZero() {
+	case !end.IsZero():
 		query = query.Where("request_at <= ?", end)
 	}
 
@@ -1265,6 +1268,7 @@ func GetGroupModelTPM(group string, model string) (int64, error) {
 	return tpm, err
 }
 
+//nolint:revive
 type ModelCostRank struct {
 	Model      string  `json:"model"`
 	UsedAmount float64 `json:"used_amount"`
@@ -1283,11 +1287,12 @@ func GetModelCostRank(group string, start, end time.Time) ([]*ModelCostRank, err
 		query = query.Where("group_id = ?", group)
 	}
 
-	if !start.IsZero() && !end.IsZero() {
+	switch {
+	case !start.IsZero() && !end.IsZero():
 		query = query.Where("request_at BETWEEN ? AND ?", start, end)
-	} else if !start.IsZero() {
+	case !start.IsZero():
 		query = query.Where("request_at >= ?", start)
-	} else if !end.IsZero() {
+	case !end.IsZero():
 		query = query.Where("request_at <= ?", end)
 	}
 

+ 1 - 0
model/owner.go

@@ -1,5 +1,6 @@
 package model
 
+//nolint:revive
 type ModelOwner string
 
 const (

+ 3 - 3
model/utils.go

@@ -211,7 +211,7 @@ func BatchRecordConsume(
 				Add(decimal.NewFromFloat(batchData.Groups[group].Amount)).
 				InexactFloat64()
 		}
-		batchData.Groups[group].Count += 1
+		batchData.Groups[group].Count++
 	}
 
 	if tokenID > 0 {
@@ -224,7 +224,7 @@ func BatchRecordConsume(
 				Add(decimal.NewFromFloat(batchData.Tokens[tokenID].Amount)).
 				InexactFloat64()
 		}
-		batchData.Tokens[tokenID].Count += 1
+		batchData.Tokens[tokenID].Count++
 	}
 
 	if channelID > 0 {
@@ -237,7 +237,7 @@ func BatchRecordConsume(
 				Add(decimal.NewFromFloat(batchData.Channels[channelID].Amount)).
 				InexactFloat64()
 		}
-		batchData.Channels[channelID].Count += 1
+		batchData.Channels[channelID].Count++
 	}
 
 	return err

+ 12 - 12
monitor/memmodel.go

@@ -134,7 +134,7 @@ func (m *MemModelMonitor) checkAndBan(now time.Time, channel *ChannelStats, tryB
 		return false, true
 	}
 
-	req, err := channel.timeWindows.GetStats(maxSliceCount)
+	req, err := channel.timeWindows.GetStats()
 	if req < minRequestCount {
 		return false, false
 	}
@@ -150,14 +150,14 @@ func (m *MemModelMonitor) checkAndBan(now time.Time, channel *ChannelStats, tryB
 }
 
 func getErrorRateFromStats(stats *TimeWindowStats) float64 {
-	req, err := stats.GetStats(maxSliceCount)
+	req, err := stats.GetStats()
 	if req < minRequestCount {
 		return 0
 	}
 	return float64(err) / float64(req)
 }
 
-func (m *MemModelMonitor) GetModelsErrorRate(ctx context.Context) (map[string]float64, error) {
+func (m *MemModelMonitor) GetModelsErrorRate(_ context.Context) (map[string]float64, error) {
 	m.mu.RLock()
 	defer m.mu.RUnlock()
 
@@ -168,7 +168,7 @@ func (m *MemModelMonitor) GetModelsErrorRate(ctx context.Context) (map[string]fl
 	return result, nil
 }
 
-func (m *MemModelMonitor) GetModelChannelErrorRate(ctx context.Context, model string) (map[int64]float64, error) {
+func (m *MemModelMonitor) GetModelChannelErrorRate(_ context.Context, model string) (map[int64]float64, error) {
 	m.mu.RLock()
 	defer m.mu.RUnlock()
 
@@ -181,7 +181,7 @@ func (m *MemModelMonitor) GetModelChannelErrorRate(ctx context.Context, model st
 	return result, nil
 }
 
-func (m *MemModelMonitor) GetChannelModelErrorRates(ctx context.Context, channelID int64) (map[string]float64, error) {
+func (m *MemModelMonitor) GetChannelModelErrorRates(_ context.Context, channelID int64) (map[string]float64, error) {
 	m.mu.RLock()
 	defer m.mu.RUnlock()
 
@@ -194,7 +194,7 @@ func (m *MemModelMonitor) GetChannelModelErrorRates(ctx context.Context, channel
 	return result, nil
 }
 
-func (m *MemModelMonitor) GetAllChannelModelErrorRates(ctx context.Context) (map[int64]map[string]float64, error) {
+func (m *MemModelMonitor) GetAllChannelModelErrorRates(_ context.Context) (map[int64]map[string]float64, error) {
 	m.mu.RLock()
 	defer m.mu.RUnlock()
 
@@ -210,7 +210,7 @@ func (m *MemModelMonitor) GetAllChannelModelErrorRates(ctx context.Context) (map
 	return result, nil
 }
 
-func (m *MemModelMonitor) GetBannedChannelsWithModel(ctx context.Context, model string) ([]int64, error) {
+func (m *MemModelMonitor) GetBannedChannelsWithModel(_ context.Context, model string) ([]int64, error) {
 	m.mu.RLock()
 	defer m.mu.RUnlock()
 
@@ -228,7 +228,7 @@ func (m *MemModelMonitor) GetBannedChannelsWithModel(ctx context.Context, model
 	return banned, nil
 }
 
-func (m *MemModelMonitor) GetAllBannedModelChannels(ctx context.Context) (map[string][]int64, error) {
+func (m *MemModelMonitor) GetAllBannedModelChannels(_ context.Context) (map[string][]int64, error) {
 	m.mu.RLock()
 	defer m.mu.RUnlock()
 
@@ -250,7 +250,7 @@ func (m *MemModelMonitor) GetAllBannedModelChannels(ctx context.Context) (map[st
 	return result, nil
 }
 
-func (m *MemModelMonitor) ClearChannelModelErrors(ctx context.Context, model string, channelID int) error {
+func (m *MemModelMonitor) ClearChannelModelErrors(_ context.Context, model string, channelID int) error {
 	m.mu.Lock()
 	defer m.mu.Unlock()
 
@@ -260,7 +260,7 @@ func (m *MemModelMonitor) ClearChannelModelErrors(ctx context.Context, model str
 	return nil
 }
 
-func (m *MemModelMonitor) ClearChannelAllModelErrors(ctx context.Context, channelID int) error {
+func (m *MemModelMonitor) ClearChannelAllModelErrors(_ context.Context, channelID int) error {
 	m.mu.Lock()
 	defer m.mu.Unlock()
 
@@ -270,7 +270,7 @@ func (m *MemModelMonitor) ClearChannelAllModelErrors(ctx context.Context, channe
 	return nil
 }
 
-func (m *MemModelMonitor) ClearAllModelErrors(ctx context.Context) error {
+func (m *MemModelMonitor) ClearAllModelErrors(_ context.Context) error {
 	m.mu.Lock()
 	defer m.mu.Unlock()
 
@@ -317,7 +317,7 @@ func (t *TimeWindowStats) AddRequest(now time.Time, isError bool) {
 	}
 }
 
-func (t *TimeWindowStats) GetStats(maxSlice int) (totalReq, totalErr int) {
+func (t *TimeWindowStats) GetStats() (totalReq, totalErr int) {
 	t.mu.Lock()
 	defer t.mu.Unlock()
 

+ 1 - 1
relay/adaptor/doc2x/adaptor.go

@@ -56,7 +56,7 @@ func (a *Adaptor) DoResponse(meta *meta.Meta, c *gin.Context, resp *http.Respons
 	}
 }
 
-func (a *Adaptor) SetupRequestHeader(meta *meta.Meta, c *gin.Context, req *http.Request) error {
+func (a *Adaptor) SetupRequestHeader(meta *meta.Meta, _ *gin.Context, req *http.Request) error {
 	req.Header.Set("Authorization", "Bearer "+meta.Channel.Key)
 	return nil
 }

+ 3 - 0
relay/adaptor/doc2x/html2md_test.go

@@ -8,6 +8,8 @@ import (
 )
 
 func TestHTMLTable2Md(t *testing.T) {
+	t.Parallel()
+
 	tables := []struct {
 		name     string
 		html     string
@@ -53,6 +55,7 @@ func TestHTMLTable2Md(t *testing.T) {
 var htmlImage = `<img src="https://cdn.noedgeai.com/01956426-b164-730d-a1fe-8be8972145d6_0.jpg?x=258&y=694&w=1132&h=826"/>`
 
 func TestInlineMdImage(t *testing.T) {
+	t.Parallel()
 	result := doc2x.InlineMdImage(context.Background(), htmlImage)
 	t.Log(result)
 }

+ 4 - 10
relay/adaptor/doc2x/pdf.go

@@ -306,14 +306,14 @@ func inferMimeType(u string) string {
 	}
 }
 
-func handleConvertPdfToMd(ctx context.Context, str string) (string, error) {
+func handleConvertPdfToMd(ctx context.Context, str string) string {
 	result := InlineMdImage(ctx, str)
 	result = HTMLTable2Md(result)
 
 	result = mediaCommentRegex.ReplaceAllString(result, "")
 	result = footnoteCommentRegex.ReplaceAllString(result, "")
 
-	return result, nil
+	return result
 }
 
 func handleParsePdfResponse(meta *meta.Meta, c *gin.Context, response *StatusResponseDataResult) (*relaymodel.Usage, *relaymodel.ErrorWithStatusCode) {
@@ -328,10 +328,7 @@ func handleParsePdfResponse(meta *meta.Meta, c *gin.Context, response *StatusRes
 	switch meta.GetString("response_format") {
 	case "list":
 		for i, md := range mds {
-			result, err := handleConvertPdfToMd(c.Request.Context(), md)
-			if err != nil {
-				return nil, openai.ErrorWrapperWithMessage("convert pdf to md failed: "+err.Error(), "convert_pdf_to_md_failed", http.StatusInternalServerError)
-			}
+			result := handleConvertPdfToMd(c.Request.Context(), md)
 			mds[i] = result
 		}
 		c.JSON(http.StatusOK, relaymodel.ParsePdfListResponse{
@@ -343,10 +340,7 @@ func handleParsePdfResponse(meta *meta.Meta, c *gin.Context, response *StatusRes
 		for _, md := range mds {
 			builder.WriteString(md)
 		}
-		result, err := handleConvertPdfToMd(c.Request.Context(), builder.String())
-		if err != nil {
-			return nil, openai.ErrorWrapperWithMessage("convert pdf to md failed: "+err.Error(), "convert_pdf_to_md_failed", http.StatusInternalServerError)
-		}
+		result := handleConvertPdfToMd(c.Request.Context(), builder.String())
 		c.JSON(http.StatusOK, relaymodel.ParsePdfResponse{
 			Pages:    pages,
 			Markdown: result,

+ 1 - 1
relay/adaptor/doubao/main.go

@@ -86,6 +86,6 @@ func (a *Adaptor) GetChannelName() string {
 	return "doubao"
 }
 
-func (a *Adaptor) GetBalance(channel *model.Channel) (float64, error) {
+func (a *Adaptor) GetBalance(_ *model.Channel) (float64, error) {
 	return 0, adaptor.ErrGetBalanceNotImplemented
 }

+ 4 - 3
relay/adaptor/gemini/main.go

@@ -158,7 +158,8 @@ func buildContents(ctx context.Context, textRequest *model.GeneralOpenAIRequest)
 			Parts: make([]Part, 0),
 		}
 
-		if message.Role == "assistant" && len(message.ToolCalls) > 0 {
+		switch {
+		case message.Role == "assistant" && len(message.ToolCalls) > 0:
 			for _, toolCall := range message.ToolCalls {
 				var args map[string]any
 				if toolCall.Function.Arguments != "" {
@@ -175,7 +176,7 @@ func buildContents(ctx context.Context, textRequest *model.GeneralOpenAIRequest)
 					},
 				})
 			}
-		} else if message.Role == "tool" && message.ToolCallID != "" {
+		case message.Role == "tool" && message.ToolCallID != "":
 			var contentMap map[string]any
 			if message.Content != nil {
 				switch content := message.Content.(type) {
@@ -201,7 +202,7 @@ func buildContents(ctx context.Context, textRequest *model.GeneralOpenAIRequest)
 					},
 				},
 			})
-		} else {
+		default:
 			openaiContent := message.ParseContent()
 			for _, part := range openaiContent {
 				if part.Type == model.ContentTypeImageURL {

+ 1 - 1
relay/adaptor/lingyiwanwu/adaptor.go

@@ -24,6 +24,6 @@ func (a *Adaptor) GetChannelName() string {
 	return "lingyiwanwu"
 }
 
-func (a *Adaptor) GetBalance(channel *model.Channel) (float64, error) {
+func (a *Adaptor) GetBalance(_ *model.Channel) (float64, error) {
 	return 0, adaptor.ErrGetBalanceNotImplemented
 }

+ 1 - 1
relay/adaptor/minimax/adaptor.go

@@ -79,6 +79,6 @@ func (a *Adaptor) GetChannelName() string {
 	return "minimax"
 }
 
-func (a *Adaptor) GetBalance(channel *model.Channel) (float64, error) {
+func (a *Adaptor) GetBalance(_ *model.Channel) (float64, error) {
 	return 0, adaptor.ErrGetBalanceNotImplemented
 }

+ 1 - 1
relay/adaptor/stepfun/adaptor.go

@@ -38,6 +38,6 @@ func (a *Adaptor) GetChannelName() string {
 	return "stepfun"
 }
 
-func (a *Adaptor) GetBalance(channel *model.Channel) (float64, error) {
+func (a *Adaptor) GetBalance(_ *model.Channel) (float64, error) {
 	return 0, adaptor.ErrGetBalanceNotImplemented
 }

+ 1 - 1
relay/adaptor/tencent/adaptor.go

@@ -26,6 +26,6 @@ func (a *Adaptor) GetChannelName() string {
 	return "tencent"
 }
 
-func (a *Adaptor) GetBalance(channel *model.Channel) (float64, error) {
+func (a *Adaptor) GetBalance(_ *model.Channel) (float64, error) {
 	return 0, adaptor.ErrGetBalanceNotImplemented
 }

+ 1 - 1
relay/adaptor/xunfei/adaptor.go

@@ -45,6 +45,6 @@ func (a *Adaptor) GetChannelName() string {
 	return "xunfei"
 }
 
-func (a *Adaptor) GetBalance(channel *model.Channel) (float64, error) {
+func (a *Adaptor) GetBalance(_ *model.Channel) (float64, error) {
 	return 0, adaptor.ErrGetBalanceNotImplemented
 }

+ 1 - 1
relay/adaptor/zhipu/adaptor.go

@@ -40,6 +40,6 @@ func (a *Adaptor) GetChannelName() string {
 	return "zhipu"
 }
 
-func (a *Adaptor) GetBalance(channel *model.Channel) (float64, error) {
+func (a *Adaptor) GetBalance(_ *model.Channel) (float64, error) {
 	return 0, adaptor.ErrGetBalanceNotImplemented
 }