Browse Source

feat: support update AIProxy balance (#171)

* Add: support update AIProxy balance

* fix auth header

* chore: update balance renderer

---------

Co-authored-by: JustSong <[email protected]>
Joe 2 years ago
parent
commit
b7d71b4f0a
2 changed files with 60 additions and 12 deletions
  1. 47 7
      controller/channel-billing.go
  2. 13 5
      web/src/components/ChannelsTable.js

+ 47 - 7
controller/channel-billing.go

@@ -4,13 +4,14 @@ import (
 	"encoding/json"
 	"errors"
 	"fmt"
-	"github.com/gin-gonic/gin"
 	"io"
 	"net/http"
 	"one-api/common"
 	"one-api/model"
 	"strconv"
 	"time"
+
+	"github.com/gin-gonic/gin"
 )
 
 // https://github.com/songquanpeng/one-api/issues/79
@@ -44,14 +45,31 @@ type OpenAISBUsageResponse struct {
 	} `json:"data"`
 }
 
-func GetResponseBody(method, url string, channel *model.Channel) ([]byte, error) {
+type AIProxyUserOverviewResponse struct {
+	Success   bool   `json:"success"`
+	Message   string `json:"message"`
+	ErrorCode int    `json:"error_code"`
+	Data      struct {
+		TotalPoints float64 `json:"totalPoints"`
+	} `json:"data"`
+}
+
+// GetAuthHeader get auth header
+func GetAuthHeader(token string) http.Header {
+	h := http.Header{}
+	h.Add("Authorization", fmt.Sprintf("Bearer %s", token))
+	return h
+}
+
+func GetResponseBody(method, url string, channel *model.Channel, headers http.Header) ([]byte, error) {
 	client := &http.Client{}
 	req, err := http.NewRequest(method, url, nil)
 	if err != nil {
 		return nil, err
 	}
-	auth := fmt.Sprintf("Bearer %s", channel.Key)
-	req.Header.Add("Authorization", auth)
+	for k := range headers {
+		req.Header.Add(k, headers.Get(k))
+	}
 	res, err := client.Do(req)
 	if err != nil {
 		return nil, err
@@ -69,7 +87,7 @@ func GetResponseBody(method, url string, channel *model.Channel) ([]byte, error)
 
 func updateChannelOpenAISBBalance(channel *model.Channel) (float64, error) {
 	url := fmt.Sprintf("https://api.openai-sb.com/sb-api/user/status?api_key=%s", channel.Key)
-	body, err := GetResponseBody("GET", url, channel)
+	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
 	if err != nil {
 		return 0, err
 	}
@@ -89,6 +107,26 @@ func updateChannelOpenAISBBalance(channel *model.Channel) (float64, error) {
 	return balance, nil
 }
 
+func updateChannelAIProxyBalance(channel *model.Channel) (float64, error) {
+	url := "https://aiproxy.io/api/report/getUserOverview"
+	headers := http.Header{}
+	headers.Add("Api-Key", channel.Key)
+	body, err := GetResponseBody("GET", url, channel, headers)
+	if err != nil {
+		return 0, err
+	}
+	response := AIProxyUserOverviewResponse{}
+	err = json.Unmarshal(body, &response)
+	if err != nil {
+		return 0, err
+	}
+	if !response.Success {
+		return 0, fmt.Errorf("code: %d, message: %s", response.ErrorCode, response.Message)
+	}
+	channel.UpdateBalance(response.Data.TotalPoints)
+	return response.Data.TotalPoints, nil
+}
+
 func updateChannelBalance(channel *model.Channel) (float64, error) {
 	baseURL := common.ChannelBaseURLs[channel.Type]
 	switch channel.Type {
@@ -102,12 +140,14 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
 		baseURL = channel.BaseURL
 	case common.ChannelTypeOpenAISB:
 		return updateChannelOpenAISBBalance(channel)
+	case common.ChannelTypeAIProxy:
+		return updateChannelAIProxyBalance(channel)
 	default:
 		return 0, errors.New("尚未实现")
 	}
 	url := fmt.Sprintf("%s/v1/dashboard/billing/subscription", baseURL)
 
-	body, err := GetResponseBody("GET", url, channel)
+	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
 	if err != nil {
 		return 0, err
 	}
@@ -123,7 +163,7 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
 		startDate = now.AddDate(0, 0, -100).Format("2006-01-02")
 	}
 	url = fmt.Sprintf("%s/v1/dashboard/billing/usage?start_date=%s&end_date=%s", baseURL, startDate, endDate)
-	body, err = GetResponseBody("GET", url, channel)
+	body, err = GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
 	if err != nil {
 		return 0, err
 	}

+ 13 - 5
web/src/components/ChannelsTable.js

@@ -4,7 +4,7 @@ import { Link } from 'react-router-dom';
 import { API, showError, showInfo, showSuccess, timestamp2string } from '../helpers';
 
 import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants';
-import { renderGroup } from '../helpers/render';
+import { renderGroup, renderNumber } from '../helpers/render';
 
 function renderTimestamp(timestamp) {
   return (
@@ -28,10 +28,17 @@ function renderType(type) {
 }
 
 function renderBalance(type, balance) {
-  if (type === 5) {
-    return <span>¥{(balance / 10000).toFixed(2)}</span>
+  switch (type) {
+    case 1: // OpenAI
+    case 8: // 自定义
+      return <span>${balance.toFixed(2)}</span>;
+    case 5: // OpenAI-SB
+      return <span>¥{(balance / 10000).toFixed(2)}</span>;
+    case 10: // AI Proxy
+      return <span>{renderNumber(balance)}</span>;
+    default:
+      return <span>不支持</span>;
   }
-  return <span>${balance.toFixed(2)}</span>
 }
 
 const ChannelsTable = () => {
@@ -422,7 +429,8 @@ const ChannelsTable = () => {
               <Button size='small' loading={loading} onClick={testAllChannels}>
                 测试所有已启用通道
               </Button>
-              <Button size='small' onClick={updateAllChannelsBalance} loading={loading || updatingBalance}>更新所有已启用通道余额</Button>
+              <Button size='small' onClick={updateAllChannelsBalance}
+                      loading={loading || updatingBalance}>更新所有已启用通道余额</Button>
               <Pagination
                 floated='right'
                 activePage={activePage}