소스 검색

chore: update group model reqeuest record function (#202)

zijiren 7 달 전
부모
커밋
f6c0dedfd3
2개의 변경된 파일89개의 추가작업 그리고 78개의 파일을 삭제
  1. 49 62
      core/controller/relay-controller.go
  2. 40 16
      core/middleware/distributor.go

+ 49 - 62
core/controller/relay-controller.go

@@ -57,12 +57,9 @@ const (
 	MetaChannelModelKeyRPS = "channel_model_rps"
 	MetaChannelModelKeyTPM = "channel_model_tpm"
 	MetaChannelModelKeyTPS = "channel_model_tps"
-
-	MetaGroupModelTokennameTPM = "group_model_tokenname_tpm"
-	MetaGroupModelTokennameTPS = "group_model_tokenname_tps"
 )
 
-func getChannelModelRequestRate(meta *meta.Meta) model.RequestRate {
+func getChannelModelRequestRate(c *gin.Context, meta *meta.Meta) model.RequestRate {
 	rate := model.RequestRate{}
 
 	if rpm, ok := meta.Get(MetaChannelModelKeyRPM); ok {
@@ -72,6 +69,7 @@ func getChannelModelRequestRate(meta *meta.Meta) model.RequestRate {
 		rpm, rps := reqlimit.GetChannelModelRequest(context.Background(), strconv.Itoa(meta.Channel.ID), meta.OriginModel)
 		rate.RPM = rpm
 		rate.RPS = rps
+		updateChannelModelRequestRate(c, meta, rpm, rps)
 	}
 
 	if tpm, ok := meta.Get(MetaChannelModelKeyTPM); ok {
@@ -81,25 +79,26 @@ func getChannelModelRequestRate(meta *meta.Meta) model.RequestRate {
 		tpm, tps := reqlimit.GetChannelModelTokensRequest(context.Background(), strconv.Itoa(meta.Channel.ID), meta.OriginModel)
 		rate.TPM = tpm
 		rate.TPS = tps
+		updateChannelModelTokensRequestRate(c, meta, tpm, tps)
 	}
 
 	return rate
 }
 
-func getGroupModelTokenRequestRate(c *gin.Context, meta *meta.Meta) model.RequestRate {
-	r := model.RequestRate{
-		RPM: middleware.GetGroupModelTokenRPM(c),
-		RPS: middleware.GetGroupModelTokenRPS(c),
-		TPM: middleware.GetGroupModelTokenTPM(c),
-		TPS: middleware.GetGroupModelTokenTPS(c),
-	}
-
-	if tpm, ok := meta.Get(MetaGroupModelTokennameTPM); ok {
-		r.TPM, _ = tpm.(int64)
-		r.TPS = meta.GetInt64(MetaGroupModelTokennameTPS)
-	}
+func updateChannelModelRequestRate(c *gin.Context, meta *meta.Meta, rpm, rps int64) {
+	meta.Set(MetaChannelModelKeyRPM, rpm)
+	meta.Set(MetaChannelModelKeyRPS, rps)
+	log := middleware.GetLogger(c)
+	log.Data["ch_rpm"] = rpm
+	log.Data["ch_rps"] = rps
+}
 
-	return r
+func updateChannelModelTokensRequestRate(c *gin.Context, meta *meta.Meta, tpm, tps int64) {
+	meta.Set(MetaChannelModelKeyTPM, tpm)
+	meta.Set(MetaChannelModelKeyTPS, tps)
+	log := middleware.GetLogger(c)
+	log.Data["ch_tpm"] = tpm
+	log.Data["ch_tps"] = tps
 }
 
 func (w *warpAdaptor) DoRequest(meta *meta.Meta, c *gin.Context, req *http.Request) (*http.Response, error) {
@@ -108,11 +107,7 @@ func (w *warpAdaptor) DoRequest(meta *meta.Meta, c *gin.Context, req *http.Reque
 		strconv.Itoa(meta.Channel.ID),
 		meta.OriginModel,
 	)
-	log := middleware.GetLogger(c)
-	meta.Set(MetaChannelModelKeyRPM, count+overLimitCount)
-	meta.Set(MetaChannelModelKeyRPS, secondCount)
-	log.Data["ch_rpm"] = count + overLimitCount
-	log.Data["ch_rps"] = secondCount
+	updateChannelModelRequestRate(c, meta, count+overLimitCount, secondCount)
 	return w.Adaptor.DoRequest(meta, c, req)
 }
 
@@ -122,41 +117,33 @@ func (w *warpAdaptor) DoResponse(meta *meta.Meta, c *gin.Context, resp *http.Res
 		return nil, relayErr
 	}
 
-	count, overLimitCount, secondCount := reqlimit.PushChannelModelTokensRequest(
-		context.Background(),
-		strconv.Itoa(meta.Channel.ID),
-		meta.OriginModel,
-		int64(usage.TotalTokens),
-	)
-	log := middleware.GetLogger(c)
-	meta.Set(MetaChannelModelKeyTPM, count+overLimitCount)
-	meta.Set(MetaChannelModelKeyTPS, secondCount)
-	log.Data["ch_tpm"] = count + overLimitCount
-	log.Data["ch_tps"] = secondCount
+	if usage.TotalTokens > 0 {
+		count, overLimitCount, secondCount := reqlimit.PushChannelModelTokensRequest(
+			context.Background(),
+			strconv.Itoa(meta.Channel.ID),
+			meta.OriginModel,
+			int64(usage.TotalTokens),
+		)
+		updateChannelModelTokensRequestRate(c, meta, count+overLimitCount, secondCount)
 
-	count, overLimitCount, secondCount = reqlimit.PushGroupModelTokensRequest(
-		context.Background(),
-		meta.Group.ID,
-		meta.OriginModel,
-		meta.ModelConfig.TPM,
-		int64(usage.TotalTokens),
-	)
-	if meta.Group.Status != model.GroupStatusInternal {
-		log.Data["group_tpm"] = count + overLimitCount
-		log.Data["group_tps"] = secondCount
-	}
+		count, overLimitCount, secondCount = reqlimit.PushGroupModelTokensRequest(
+			context.Background(),
+			meta.Group.ID,
+			meta.OriginModel,
+			meta.ModelConfig.TPM,
+			int64(usage.TotalTokens),
+		)
+		middleware.UpdateGroupModelTokensRequest(c, meta.Group, count+overLimitCount, secondCount)
 
-	count, overLimitCount, secondCount = reqlimit.PushGroupModelTokennameTokensRequest(
-		context.Background(),
-		meta.Group.ID,
-		meta.OriginModel,
-		meta.Token.Name,
-		int64(usage.TotalTokens),
-	)
-	meta.Set(MetaGroupModelTokennameTPM, count+overLimitCount)
-	meta.Set(MetaGroupModelTokennameTPS, secondCount)
-	// log.Data["tpm"] = count + overLimitCount
-	// log.Data["tps"] = secondCount
+		count, overLimitCount, secondCount = reqlimit.PushGroupModelTokennameTokensRequest(
+			context.Background(),
+			meta.Group.ID,
+			meta.OriginModel,
+			meta.Token.Name,
+			int64(usage.TotalTokens),
+		)
+		middleware.UpdateGroupModelTokennameTokensRequest(c, count+overLimitCount, secondCount)
+	}
 
 	return usage, relayErr
 }
@@ -247,17 +234,17 @@ func RelayHelper(c *gin.Context, meta *meta.Meta, handel RelayHandler) (*control
 		}
 		switch {
 		case banExecution:
-			notifyChannelIssue(meta, "autoBanned", "Auto Banned", *result.Error)
+			notifyChannelIssue(c, meta, "autoBanned", "Auto Banned", *result.Error)
 		case beyondThreshold:
-			notifyChannelIssue(meta, "beyondThreshold", "Error Rate Beyond Threshold", *result.Error)
+			notifyChannelIssue(c, meta, "beyondThreshold", "Error Rate Beyond Threshold", *result.Error)
 		case !hasPermission:
-			notifyChannelIssue(meta, "channelHasPermission", "No Permission", *result.Error)
+			notifyChannelIssue(c, meta, "channelHasPermission", "No Permission", *result.Error)
 		}
 	}
 	return result, shouldRetry
 }
 
-func notifyChannelIssue(meta *meta.Meta, issueType string, titleSuffix string, err relaymodel.ErrorWithStatusCode) {
+func notifyChannelIssue(c *gin.Context, meta *meta.Meta, issueType string, titleSuffix string, err relaymodel.ErrorWithStatusCode) {
 	var notifyFunc func(title string, message string)
 
 	lockKey := fmt.Sprintf("%s:%d:%s", issueType, meta.Channel.ID, meta.OriginModel)
@@ -296,7 +283,7 @@ func notifyChannelIssue(meta *meta.Meta, issueType string, titleSuffix string, e
 			notifyFunc = notify.Error
 		}
 
-		rate := getChannelModelRequestRate(meta)
+		rate := getChannelModelRequestRate(c, meta)
 		message += fmt.Sprintf("\nrpm: %d\nrps: %d\ntpm: %d\ntps: %d", rate.RPM, rate.RPS, rate.TPM, rate.TPS)
 	}
 
@@ -555,8 +542,8 @@ func recordResult(
 		downstreamResult,
 		user,
 		metadata,
-		getChannelModelRequestRate(meta),
-		getGroupModelTokenRequestRate(c, meta),
+		getChannelModelRequestRate(c, meta),
+		middleware.GetGroupModelTokenRequestRate(c),
 	)
 }
 

+ 40 - 16
core/middleware/distributor.go

@@ -96,22 +96,52 @@ func setTpmHeaders(c *gin.Context, tpm int64, remainingRequests int64) {
 	c.Header(XRateLimitResetTokens, "1m0s")
 }
 
+func UpdateGroupModelRequest(c *gin.Context, group *model.GroupCache, rpm, rps int64) {
+	if group.Status == model.GroupStatusInternal {
+		return
+	}
+
+	log := GetLogger(c)
+	log.Data["group_rpm"] = strconv.FormatInt(rpm, 10)
+	log.Data["group_rps"] = strconv.FormatInt(rps, 10)
+}
+
+func UpdateGroupModelTokensRequest(c *gin.Context, group *model.GroupCache, tpm, tps int64) {
+	if group.Status == model.GroupStatusInternal {
+		return
+	}
+
+	log := GetLogger(c)
+	log.Data["group_tpm"] = strconv.FormatInt(tpm, 10)
+	log.Data["group_tps"] = strconv.FormatInt(tps, 10)
+}
+
+func UpdateGroupModelTokennameRequest(c *gin.Context, rpm, rps int64) {
+	c.Set(GroupModelTokenRPM, rpm)
+	c.Set(GroupModelTokenRPS, rps)
+	// log := GetLogger(c)
+	// log.Data["rpm"] = strconv.FormatInt(rpm, 10)
+	// log.Data["rps"] = strconv.FormatInt(rps, 10)
+}
+
+func UpdateGroupModelTokennameTokensRequest(c *gin.Context, tpm, tps int64) {
+	c.Set(GroupModelTokenTPM, tpm)
+	c.Set(GroupModelTokenTPS, tps)
+	// log := GetLogger(c)
+	// log.Data["tpm"] = strconv.FormatInt(tpm, 10)
+	// log.Data["tps"] = strconv.FormatInt(tps, 10)
+}
+
 func checkGroupModelRPMAndTPM(c *gin.Context, group *model.GroupCache, mc *model.ModelConfig, tokenName string) error {
 	log := GetLogger(c)
 
 	adjustedModelConfig := GetGroupAdjustedModelConfig(group, *mc)
 
 	groupModelCount, groupModelOverLimitCount, groupModelSecondCount := reqlimit.PushGroupModelRequest(c.Request.Context(), group.ID, mc.Model, adjustedModelConfig.RPM)
-	if group.Status != model.GroupStatusInternal {
-		log.Data["group_rpm"] = strconv.FormatInt(groupModelCount+groupModelOverLimitCount, 10)
-		log.Data["group_rps"] = strconv.FormatInt(groupModelSecondCount, 10)
-	}
+	UpdateGroupModelRequest(c, group, groupModelCount+groupModelOverLimitCount, groupModelSecondCount)
 
 	groupModelTokenCount, groupModelTokenOverLimitCount, groupModelTokenSecondCount := reqlimit.PushGroupModelTokennameRequest(c.Request.Context(), group.ID, mc.Model, tokenName)
-	c.Set(GroupModelTokenRPM, groupModelTokenCount+groupModelTokenOverLimitCount)
-	c.Set(GroupModelTokenRPS, groupModelTokenSecondCount)
-	// log.Data["rpm"] = strconv.FormatInt(groupModelTokenCount+groupModelTokenOverLimitCount, 10)
-	// log.Data["rps"] = strconv.FormatInt(groupModelTokenSecondCount, 10)
+	UpdateGroupModelTokennameRequest(c, groupModelTokenCount+groupModelTokenOverLimitCount, groupModelTokenSecondCount)
 
 	if group.Status != model.GroupStatusInternal &&
 		adjustedModelConfig.RPM > 0 {
@@ -124,16 +154,10 @@ func checkGroupModelRPMAndTPM(c *gin.Context, group *model.GroupCache, mc *model
 	}
 
 	groupModelCountTPM, groupModelCountTPS := reqlimit.GetGroupModelTokensRequest(c.Request.Context(), group.ID, mc.Model)
-	if group.Status != model.GroupStatusInternal {
-		log.Data["group_tpm"] = strconv.FormatInt(groupModelCountTPM, 10)
-		log.Data["group_tps"] = strconv.FormatInt(groupModelCountTPS, 10)
-	}
+	UpdateGroupModelTokensRequest(c, group, groupModelCountTPM, groupModelCountTPS)
 
 	groupModelTokenCountTPM, groupModelTokenCountTPS := reqlimit.GetGroupModelTokennameTokensRequest(c.Request.Context(), group.ID, mc.Model, tokenName)
-	c.Set(GroupModelTokenTPM, groupModelTokenCountTPM)
-	c.Set(GroupModelTokenTPS, groupModelTokenCountTPS)
-	// log.Data["tpm"] = strconv.FormatInt(groupModelTokenCountTPM, 10)
-	// log.Data["tps"] = strconv.FormatInt(groupModelTokenCountTPS, 10)
+	UpdateGroupModelTokennameTokensRequest(c, groupModelTokenCountTPM, groupModelTokenCountTPS)
 
 	if group.Status != model.GroupStatusInternal &&
 		adjustedModelConfig.TPM > 0 {