|
|
@@ -17,10 +17,14 @@ import (
|
|
|
"github.com/labring/aiproxy/core/common/config"
|
|
|
"github.com/labring/aiproxy/core/common/consume"
|
|
|
"github.com/labring/aiproxy/core/common/notify"
|
|
|
+ "github.com/labring/aiproxy/core/common/reqlimit"
|
|
|
"github.com/labring/aiproxy/core/common/trylock"
|
|
|
"github.com/labring/aiproxy/core/middleware"
|
|
|
"github.com/labring/aiproxy/core/model"
|
|
|
"github.com/labring/aiproxy/core/monitor"
|
|
|
+ "github.com/labring/aiproxy/core/relay/adaptor"
|
|
|
+ "github.com/labring/aiproxy/core/relay/adaptor/openai"
|
|
|
+ "github.com/labring/aiproxy/core/relay/adaptors"
|
|
|
"github.com/labring/aiproxy/core/relay/controller"
|
|
|
"github.com/labring/aiproxy/core/relay/meta"
|
|
|
"github.com/labring/aiproxy/core/relay/mode"
|
|
|
@@ -31,7 +35,7 @@ import (
|
|
|
// https://platform.openai.com/docs/api-reference/chat
|
|
|
|
|
|
type (
|
|
|
- RelayHandler func(*meta.Meta, *gin.Context) *controller.HandleResult
|
|
|
+ RelayHandler func(*gin.Context, *meta.Meta) *controller.HandleResult
|
|
|
GetRequestUsage func(*gin.Context, *model.ModelConfig) (model.Usage, error)
|
|
|
GetRequestPrice func(*gin.Context, *model.ModelConfig) (model.Price, error)
|
|
|
)
|
|
|
@@ -42,10 +46,137 @@ type RelayController struct {
|
|
|
Handler RelayHandler
|
|
|
}
|
|
|
|
|
|
-func relayHandler(meta *meta.Meta, c *gin.Context) *controller.HandleResult {
|
|
|
+var ErrInvalidChannelTypeCode = "invalid_channel_type"
|
|
|
+
|
|
|
+type warpAdaptor struct {
|
|
|
+ adaptor.Adaptor
|
|
|
+}
|
|
|
+
|
|
|
+const (
|
|
|
+ MetaChannelModelKeyRPM = "channel_model_rpm"
|
|
|
+ MetaChannelModelKeyRPS = "channel_model_rps"
|
|
|
+ MetaChannelModelKeyTPM = "channel_model_tpm"
|
|
|
+ MetaChannelModelKeyTPS = "channel_model_tps"
|
|
|
+
|
|
|
+ MetaGroupModelTokennameTPM = "group_model_tokenname_tpm"
|
|
|
+ MetaGroupModelTokennameTPS = "group_model_tokenname_tps"
|
|
|
+)
|
|
|
+
|
|
|
+func getChannelModelRequestRate(meta *meta.Meta) model.RequestRate {
|
|
|
+ rate := model.RequestRate{}
|
|
|
+
|
|
|
+ if rpm, ok := meta.Get(MetaChannelModelKeyRPM); ok {
|
|
|
+ rate.RPM, _ = rpm.(int64)
|
|
|
+ rate.RPS = meta.GetInt64(MetaChannelModelKeyRPS)
|
|
|
+ } else {
|
|
|
+ rpm, rps := reqlimit.GetChannelModelRequest(context.Background(), strconv.Itoa(meta.Channel.ID), meta.OriginModel)
|
|
|
+ rate.RPM = rpm
|
|
|
+ rate.RPS = rps
|
|
|
+ }
|
|
|
+
|
|
|
+ if tpm, ok := meta.Get(MetaChannelModelKeyTPM); ok {
|
|
|
+ rate.TPM, _ = tpm.(int64)
|
|
|
+ rate.TPS = meta.GetInt64(MetaChannelModelKeyTPS)
|
|
|
+ } else {
|
|
|
+ tpm, tps := reqlimit.GetChannelModelTokensRequest(context.Background(), strconv.Itoa(meta.Channel.ID), meta.OriginModel)
|
|
|
+ rate.TPM = tpm
|
|
|
+ rate.TPS = tps
|
|
|
+ }
|
|
|
+
|
|
|
+ return rate
|
|
|
+}
|
|
|
+
|
|
|
+func getGroupModelTokenRequestRate(c *gin.Context, meta *meta.Meta) model.RequestRate {
|
|
|
+ r := model.RequestRate{
|
|
|
+ RPM: middleware.GetGroupModelTokenRPM(c),
|
|
|
+ RPS: middleware.GetGroupModelTokenRPS(c),
|
|
|
+ TPM: middleware.GetGroupModelTokenTPM(c),
|
|
|
+ TPS: middleware.GetGroupModelTokenTPS(c),
|
|
|
+ }
|
|
|
+
|
|
|
+ if tpm, ok := meta.Get(MetaGroupModelTokennameTPM); ok {
|
|
|
+ r.TPM, _ = tpm.(int64)
|
|
|
+ r.TPS = meta.GetInt64(MetaGroupModelTokennameTPS)
|
|
|
+ }
|
|
|
+
|
|
|
+ return r
|
|
|
+}
|
|
|
+
|
|
|
+func (w *warpAdaptor) DoRequest(meta *meta.Meta, c *gin.Context, req *http.Request) (*http.Response, error) {
|
|
|
+ count, overLimitCount, secondCount := reqlimit.PushChannelModelRequest(
|
|
|
+ context.Background(),
|
|
|
+ strconv.Itoa(meta.Channel.ID),
|
|
|
+ meta.OriginModel,
|
|
|
+ )
|
|
|
+ log := middleware.GetLogger(c)
|
|
|
+ meta.Set(MetaChannelModelKeyRPM, count+overLimitCount)
|
|
|
+ meta.Set(MetaChannelModelKeyRPS, secondCount)
|
|
|
+ log.Data["ch_rpm"] = count + overLimitCount
|
|
|
+ log.Data["ch_rps"] = secondCount
|
|
|
+ return w.Adaptor.DoRequest(meta, c, req)
|
|
|
+}
|
|
|
+
|
|
|
+func (w *warpAdaptor) DoResponse(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model.Usage, *relaymodel.ErrorWithStatusCode) {
|
|
|
+ usage, relayErr := w.Adaptor.DoResponse(meta, c, resp)
|
|
|
+ if usage == nil {
|
|
|
+ return nil, relayErr
|
|
|
+ }
|
|
|
+
|
|
|
+ count, overLimitCount, secondCount := reqlimit.PushChannelModelTokensRequest(
|
|
|
+ context.Background(),
|
|
|
+ strconv.Itoa(meta.Channel.ID),
|
|
|
+ meta.OriginModel,
|
|
|
+ int64(usage.TotalTokens),
|
|
|
+ )
|
|
|
+ log := middleware.GetLogger(c)
|
|
|
+ meta.Set(MetaChannelModelKeyTPM, count+overLimitCount)
|
|
|
+ meta.Set(MetaChannelModelKeyTPS, secondCount)
|
|
|
+ log.Data["ch_tpm"] = count + overLimitCount
|
|
|
+ log.Data["ch_tps"] = secondCount
|
|
|
+
|
|
|
+ count, overLimitCount, secondCount = reqlimit.PushGroupModelTokensRequest(
|
|
|
+ context.Background(),
|
|
|
+ meta.Group.ID,
|
|
|
+ meta.OriginModel,
|
|
|
+ meta.ModelConfig.TPM,
|
|
|
+ int64(usage.TotalTokens),
|
|
|
+ )
|
|
|
+ if meta.Group.Status != model.GroupStatusInternal {
|
|
|
+ log.Data["group_tpm"] = count + overLimitCount
|
|
|
+ log.Data["group_tps"] = secondCount
|
|
|
+ }
|
|
|
+
|
|
|
+ count, overLimitCount, secondCount = reqlimit.PushGroupModelTokennameTokensRequest(
|
|
|
+ context.Background(),
|
|
|
+ meta.Group.ID,
|
|
|
+ meta.OriginModel,
|
|
|
+ meta.Token.Name,
|
|
|
+ int64(usage.TotalTokens),
|
|
|
+ )
|
|
|
+ meta.Set(MetaGroupModelTokennameTPM, count+overLimitCount)
|
|
|
+ meta.Set(MetaGroupModelTokennameTPS, secondCount)
|
|
|
+ // log.Data["tpm"] = count + overLimitCount
|
|
|
+ // log.Data["tps"] = secondCount
|
|
|
+
|
|
|
+ return usage, relayErr
|
|
|
+}
|
|
|
+
|
|
|
+func relayHandler(c *gin.Context, meta *meta.Meta) *controller.HandleResult {
|
|
|
log := middleware.GetLogger(c)
|
|
|
middleware.SetLogFieldsFromMeta(meta, log.Data)
|
|
|
- return controller.Handle(meta, c)
|
|
|
+
|
|
|
+ adaptor, ok := adaptors.GetAdaptor(meta.Channel.Type)
|
|
|
+ if !ok {
|
|
|
+ return &controller.HandleResult{
|
|
|
+ Error: openai.ErrorWrapperWithMessage(
|
|
|
+ fmt.Sprintf("invalid channel type: %d", meta.Channel.Type),
|
|
|
+ ErrInvalidChannelTypeCode,
|
|
|
+ http.StatusInternalServerError,
|
|
|
+ ),
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return controller.Handle(&warpAdaptor{adaptor}, c, meta)
|
|
|
}
|
|
|
|
|
|
func relayController(m mode.Mode) RelayController {
|
|
|
@@ -87,8 +218,8 @@ func relayController(m mode.Mode) RelayController {
|
|
|
return c
|
|
|
}
|
|
|
|
|
|
-func RelayHelper(meta *meta.Meta, c *gin.Context, handel RelayHandler) (*controller.HandleResult, bool) {
|
|
|
- result := handel(meta, c)
|
|
|
+func RelayHelper(c *gin.Context, meta *meta.Meta, handel RelayHandler) (*controller.HandleResult, bool) {
|
|
|
+ result := handel(c, meta)
|
|
|
if result.Error == nil {
|
|
|
if _, _, err := monitor.AddRequest(
|
|
|
context.Background(),
|
|
|
@@ -165,20 +296,8 @@ func notifyChannelIssue(meta *meta.Meta, issueType string, titleSuffix string, e
|
|
|
notifyFunc = notify.Error
|
|
|
}
|
|
|
|
|
|
- now := time.Now()
|
|
|
- group := "*"
|
|
|
- rpm, rpmErr := model.GetRPM(group, now, "", meta.OriginModel, meta.Channel.ID)
|
|
|
- tpm, tpmErr := model.GetTPM(group, now, "", meta.OriginModel, meta.Channel.ID)
|
|
|
- if rpmErr != nil {
|
|
|
- message += fmt.Sprintf("\nrpm: %v", rpmErr)
|
|
|
- } else {
|
|
|
- message += fmt.Sprintf("\nrpm: %d", rpm)
|
|
|
- }
|
|
|
- if tpmErr != nil {
|
|
|
- message += fmt.Sprintf("\ntpm: %v", tpmErr)
|
|
|
- } else {
|
|
|
- message += fmt.Sprintf("\ntpm: %d", tpm)
|
|
|
- }
|
|
|
+ rate := getChannelModelRequestRate(meta)
|
|
|
+ message += fmt.Sprintf("\nrpm: %d\nrps: %d\ntpm: %d\ntps: %d", rate.RPM, rate.RPS, rate.TPM, rate.TPS)
|
|
|
}
|
|
|
|
|
|
notifyFunc(
|
|
|
@@ -348,7 +467,7 @@ func relay(c *gin.Context, mode mode.Mode, relayController RelayController) {
|
|
|
}
|
|
|
|
|
|
// First attempt
|
|
|
- result, retry := RelayHelper(meta, c, relayController.Handler)
|
|
|
+ result, retry := RelayHelper(c, meta, relayController.Handler)
|
|
|
|
|
|
retryTimes := int(config.GetRetryTimes())
|
|
|
if mc.RetryTimes > 0 {
|
|
|
@@ -436,6 +555,8 @@ func recordResult(
|
|
|
downstreamResult,
|
|
|
user,
|
|
|
metadata,
|
|
|
+ getChannelModelRequestRate(meta),
|
|
|
+ getGroupModelTokenRequestRate(c, meta),
|
|
|
)
|
|
|
}
|
|
|
|
|
|
@@ -603,7 +724,7 @@ func retryLoop(c *gin.Context, mode mode.Mode, state *retryState, relayControlle
|
|
|
meta.WithRetryAt(time.Now()),
|
|
|
)
|
|
|
var retry bool
|
|
|
- state.result, retry = RelayHelper(state.meta, c, relayController)
|
|
|
+ state.result, retry = RelayHelper(c, state.meta, relayController)
|
|
|
|
|
|
done := handleRetryResult(c, retry, newChannel, state)
|
|
|
if done || i == state.retryTimes-1 {
|
|
|
@@ -693,7 +814,7 @@ var channelNoRetryStatusCodesMap = map[int]struct{}{
|
|
|
|
|
|
// 仅当是channel错误时,才需要记录,用户请求参数错误时,不需要记录
|
|
|
func shouldRetry(_ *gin.Context, relayErr relaymodel.ErrorWithStatusCode) bool {
|
|
|
- if relayErr.Error.Code == controller.ErrInvalidChannelTypeCode {
|
|
|
+ if relayErr.Error.Code == ErrInvalidChannelTypeCode {
|
|
|
return false
|
|
|
}
|
|
|
_, ok := channelNoRetryStatusCodesMap[relayErr.StatusCode]
|
|
|
@@ -708,7 +829,7 @@ var channelNoPermissionStatusCodesMap = map[int]struct{}{
|
|
|
}
|
|
|
|
|
|
func channelHasPermission(relayErr relaymodel.ErrorWithStatusCode) bool {
|
|
|
- if relayErr.Error.Code == controller.ErrInvalidChannelTypeCode {
|
|
|
+ if relayErr.Error.Code == ErrInvalidChannelTypeCode {
|
|
|
return false
|
|
|
}
|
|
|
_, ok := channelNoPermissionStatusCodesMap[relayErr.StatusCode]
|