Просмотр исходного кода

feat: add dashboard quota api (#386)

* feat: add dashboard quota api

* fix: ci lint

* fix: need check token quota limit
zijiren 2 месяцев назад
Родитель
Сommit
043da28754

+ 11 - 0
core/common/balance/balance.go

@@ -11,12 +11,19 @@ type GroupBalance interface {
 		ctx context.Context,
 		group model.GroupCache,
 	) (float64, PostGroupConsumer, error)
+
+	GetGroupQuota(ctx context.Context, group model.GroupCache) (*GroupQuota, error)
 }
 
 type PostGroupConsumer interface {
 	PostGroupConsume(ctx context.Context, tokenName string, usage float64) (float64, error)
 }
 
+type GroupQuota struct {
+	Total  float64 `json:"total"`
+	Remain float64 `json:"remain"`
+}
+
 var (
 	mock    GroupBalance = NewMockGroupBalance()
 	Default              = mock
@@ -35,3 +42,7 @@ func GetGroupRemainBalance(
 ) (float64, PostGroupConsumer, error) {
 	return Default.GetGroupRemainBalance(ctx, group)
 }
+
+func GetGroupQuota(ctx context.Context, group model.GroupCache) (*GroupQuota, error) {
+	return Default.GetGroupQuota(ctx, group)
+}

+ 10 - 0
core/common/balance/mock.go

@@ -25,6 +25,16 @@ func (q *MockGroupBalance) GetGroupRemainBalance(
 	return mockBalance, q, nil
 }
 
+func (q *MockGroupBalance) GetGroupQuota(
+	_ context.Context,
+	_ model.GroupCache,
+) (*GroupQuota, error) {
+	return &GroupQuota{
+		Total:  mockBalance,
+		Remain: mockBalance,
+	}, nil
+}
+
 func (q *MockGroupBalance) PostGroupConsume(
 	_ context.Context,
 	_ string,

+ 47 - 13
core/common/balance/sealos.go

@@ -89,9 +89,12 @@ func newSealosToken(key string) (string, error) {
 }
 
 type sealosGetGroupBalanceResp struct {
-	UserUID string `json:"userUID"`
-	Error   string `json:"error"`
-	Balance int64  `json:"balance"`
+	UserUID               string `json:"userUID"`
+	Error                 string `json:"error"`
+	Balance               int64  `json:"balance"`
+	WorkspaceSubscription bool   `json:"workspaceSubscription"`
+	TotalAIQuota          int64  `json:"totalAIQuota"`
+	RemainAIQuota         int64  `json:"remainAIQuota"`
 }
 
 type sealosPostGroupConsumeReq struct {
@@ -295,22 +298,28 @@ func (s *Sealos) getGroupRemainBalance(ctx context.Context, group string) (int64
 		log.Errorf("get group (%s) balance cache failed: %s", group, err)
 	}
 
-	balance, userUID, err := s.fetchBalanceFromAPI(ctx, group)
+	balance, err := s.fetchBalanceFromAPI(ctx, group)
 	if err != nil {
 		return 0, "", err
 	}
 
-	if err := cacheSetGroupBalance(ctx, group, balance, userUID); err != nil {
+	if err := cacheSetGroupBalance(ctx, group, balance.balance, balance.userUID); err != nil {
 		log.Errorf("set group (%s) balance cache failed: %s", group, err)
 	}
 
-	return balance, userUID, nil
+	return balance.balance, balance.userUID, nil
+}
+
+type sealosBalanceResoult struct {
+	quota   int64
+	balance int64
+	userUID string
 }
 
 func (s *Sealos) fetchBalanceFromAPI(
 	ctx context.Context,
 	group string,
-) (balance int64, userUID string, err error) {
+) (*sealosBalanceResoult, error) {
 	ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
 	defer cancel()
 
@@ -325,35 +334,60 @@ func (s *Sealos) fetchBalanceFromAPI(
 		nil,
 	)
 	if err != nil {
-		return 0, "", err
+		return nil, err
 	}
 
 	req.Header.Set("Authorization", "Bearer "+jwtToken)
 
 	resp, err := sealosHTTPClient.Do(req)
 	if err != nil {
-		return 0, "", err
+		return nil, err
 	}
 	defer resp.Body.Close()
 
 	var sealosResp sealosGetGroupBalanceResp
 	if err := sonic.ConfigDefault.NewDecoder(resp.Body).Decode(&sealosResp); err != nil {
-		return 0, "", err
+		return nil, err
 	}
 
 	if sealosResp.Error != "" {
-		return 0, "", errors.New(sealosResp.Error)
+		return nil, errors.New(sealosResp.Error)
 	}
 
 	if resp.StatusCode != http.StatusOK {
-		return 0, "", fmt.Errorf(
+		return nil, fmt.Errorf(
 			"get group (%s) balance failed with status code %d",
 			group,
 			resp.StatusCode,
 		)
 	}
 
-	return sealosResp.Balance, sealosResp.UserUID, nil
+	if sealosResp.WorkspaceSubscription {
+		return &sealosBalanceResoult{
+			quota:   sealosResp.TotalAIQuota,
+			balance: sealosResp.RemainAIQuota,
+			userUID: sealosResp.UserUID,
+		}, nil
+	}
+
+	return &sealosBalanceResoult{
+		quota:   sealosResp.Balance,
+		balance: sealosResp.Balance,
+		userUID: sealosResp.UserUID,
+	}, nil
+}
+
+// GetGroupQuota implements GroupBalance.
+func (s *Sealos) GetGroupQuota(ctx context.Context, group model.GroupCache) (*GroupQuota, error) {
+	balance, err := s.fetchBalanceFromAPI(ctx, group.ID)
+	if err != nil {
+		return nil, err
+	}
+
+	return &GroupQuota{
+		Total:  decimal.NewFromInt(balance.quota).Div(decimalBalancePrecision).InexactFloat64(),
+		Remain: decimal.NewFromInt(balance.balance).Div(decimalBalancePrecision).InexactFloat64(),
+	}, nil
 }
 
 type SealosPostGroupConsumer struct {

+ 55 - 6
core/controller/relay-dashboard.go

@@ -9,9 +9,45 @@ import (
 	"github.com/labring/aiproxy/core/common/balance"
 	"github.com/labring/aiproxy/core/middleware"
 	"github.com/labring/aiproxy/core/relay/adaptor/openai"
+	"github.com/shopspring/decimal"
 	log "github.com/sirupsen/logrus"
 )
 
+// GetQuota godoc
+//
+//	@Summary		Get quota
+//	@Description	Get quota
+//	@Tags			relay
+//	@Produce		json
+//	@Security		ApiKeyAuth
+//	@Success		200	{object}	balance.GroupQuota
+//	@Router			/v1/dashboard/billing/quota [get]
+func GetQuota(c *gin.Context) {
+	group := middleware.GetGroup(c)
+
+	groupQuota, err := balance.GetGroupQuota(c.Request.Context(), group)
+	if err != nil {
+		log.Errorf("get group (%s) balance failed: %s", group.ID, err)
+		middleware.ErrorResponse(
+			c,
+			http.StatusInternalServerError,
+			fmt.Sprintf("get group (%s) balance failed", group.ID),
+		)
+
+		return
+	}
+
+	token := middleware.GetToken(c)
+	if token.Quota > 0 {
+		groupQuota.Total = min(groupQuota.Total, token.Quota)
+		groupQuota.Remain = min(groupQuota.Remain, decimal.NewFromFloat(token.Quota).
+			Sub(decimal.NewFromFloat(token.UsedAmount)).
+			InexactFloat64())
+	}
+
+	c.JSON(http.StatusOK, groupQuota)
+}
+
 // GetSubscription godoc
 //
 //	@Summary		Get subscription
@@ -24,7 +60,7 @@ import (
 func GetSubscription(c *gin.Context) {
 	group := middleware.GetGroup(c)
 
-	b, _, err := balance.GetGroupRemainBalance(c.Request.Context(), group)
+	groupQuota, err := balance.GetGroupQuota(c.Request.Context(), group)
 	if err != nil {
 		if errors.Is(err, balance.ErrNoRealNameUsedAmountLimit) {
 			middleware.ErrorResponse(c, http.StatusForbidden, err.Error())
@@ -45,13 +81,19 @@ func GetSubscription(c *gin.Context) {
 
 	quota := token.Quota
 	if quota <= 0 {
-		quota = b
+		quota = groupQuota.Total
+	} else {
+		quota = min(quota, groupQuota.Total)
 	}
 
+	hlimit := decimal.NewFromFloat(quota).
+		Add(decimal.NewFromFloat(token.UsedAmount)).
+		InexactFloat64()
+
 	c.JSON(http.StatusOK, openai.SubscriptionResponse{
-		HardLimitUSD:       quota + token.UsedAmount,
-		SoftLimitUSD:       b,
-		SystemHardLimitUSD: quota + token.UsedAmount,
+		HardLimitUSD:       hlimit,
+		SoftLimitUSD:       groupQuota.Remain,
+		SystemHardLimitUSD: hlimit,
 	})
 }
 
@@ -66,5 +108,12 @@ func GetSubscription(c *gin.Context) {
 //	@Router			/v1/dashboard/billing/usage [get]
 func GetUsage(c *gin.Context) {
 	token := middleware.GetToken(c)
-	c.JSON(http.StatusOK, openai.UsageResponse{TotalUsage: token.UsedAmount * 100})
+	c.JSON(
+		http.StatusOK,
+		openai.UsageResponse{
+			TotalUsage: decimal.NewFromFloat(token.UsedAmount).
+				Mul(decimal.NewFromFloat(100)).
+				InexactFloat64(),
+		},
+	)
 }

+ 36 - 0
core/docs/docs.go

@@ -7397,6 +7397,31 @@ const docTemplate = `{
                 }
             }
         },
+        "/v1/dashboard/billing/quota": {
+            "get": {
+                "security": [
+                    {
+                        "ApiKeyAuth": []
+                    }
+                ],
+                "description": "Get quota",
+                "produces": [
+                    "application/json"
+                ],
+                "tags": [
+                    "relay"
+                ],
+                "summary": "Get quota",
+                "responses": {
+                    "200": {
+                        "description": "OK",
+                        "schema": {
+                            "$ref": "#/definitions/balance.GroupQuota"
+                        }
+                    }
+                }
+            }
+        },
         "/v1/dashboard/billing/subscription": {
             "get": {
                 "security": [
@@ -8424,6 +8449,17 @@ const docTemplate = `{
                 }
             }
         },
+        "balance.GroupQuota": {
+            "type": "object",
+            "properties": {
+                "remain": {
+                    "type": "number"
+                },
+                "total": {
+                    "type": "number"
+                }
+            }
+        },
         "controller.AddChannelRequest": {
             "type": "object",
             "properties": {

+ 36 - 0
core/docs/swagger.json

@@ -7388,6 +7388,31 @@
                 }
             }
         },
+        "/v1/dashboard/billing/quota": {
+            "get": {
+                "security": [
+                    {
+                        "ApiKeyAuth": []
+                    }
+                ],
+                "description": "Get quota",
+                "produces": [
+                    "application/json"
+                ],
+                "tags": [
+                    "relay"
+                ],
+                "summary": "Get quota",
+                "responses": {
+                    "200": {
+                        "description": "OK",
+                        "schema": {
+                            "$ref": "#/definitions/balance.GroupQuota"
+                        }
+                    }
+                }
+            }
+        },
         "/v1/dashboard/billing/subscription": {
             "get": {
                 "security": [
@@ -8415,6 +8440,17 @@
                 }
             }
         },
+        "balance.GroupQuota": {
+            "type": "object",
+            "properties": {
+                "remain": {
+                    "type": "number"
+                },
+                "total": {
+                    "type": "number"
+                }
+            }
+        },
         "controller.AddChannelRequest": {
             "type": "object",
             "properties": {

+ 28 - 6
core/docs/swagger.yaml

@@ -42,6 +42,13 @@ definitions:
       name:
         type: string
     type: object
+  balance.GroupQuota:
+    properties:
+      remain:
+        type: number
+      total:
+        type: number
+    type: object
   controller.AddChannelRequest:
     properties:
       base_url:
@@ -1466,6 +1473,8 @@ definitions:
         type: integer
       rpm:
         type: integer
+      status_5xx_count:
+        type: integer
       status_400_count:
         type: integer
       status_429_count:
@@ -1474,8 +1483,6 @@ definitions:
         type: integer
       status_500_count:
         type: integer
-      status_5xx_count:
-        type: integer
       token_names:
         items:
           type: string
@@ -2412,16 +2419,16 @@ definitions:
         type: integer
       retry_count:
         type: integer
+      status_400_count:
+        type: integer
+      status_429_count:
+        type: integer
       status_4xx_count:
         type: integer
       status_500_count:
         type: integer
       status_5xx_count:
         type: integer
-      status_400_count:
-        type: integer
-      status_429_count:
-        type: integer
       timestamp:
         type: integer
       token_name:
@@ -7176,6 +7183,21 @@ paths:
       summary: Completions
       tags:
       - relay
+  /v1/dashboard/billing/quota:
+    get:
+      description: Get quota
+      produces:
+      - application/json
+      responses:
+        "200":
+          description: OK
+          schema:
+            $ref: '#/definitions/balance.GroupQuota'
+      security:
+      - ApiKeyAuth: []
+      summary: Get quota
+      tags:
+      - relay
   /v1/dashboard/billing/subscription:
     get:
       description: Get subscription

+ 1 - 0
core/router/relay.go

@@ -21,6 +21,7 @@ func SetRelayRouter(router *gin.Engine) {
 	{
 		dashboardRouter.GET("/billing/subscription", controller.GetSubscription)
 		dashboardRouter.GET("/billing/usage", controller.GetUsage)
+		dashboardRouter.GET("/billing/quota", controller.GetQuota)
 	}
 
 	relayRouter := v1Router.Group("")