Jelajahi Sumber

feat: support channel remain quota query (close #79)

JustSong 2 tahun lalu
induk
melakukan
171b818504
4 mengubah file dengan 244 tambahan dan 19 penghapusan
  1. 158 0
      controller/channel-billing.go
  2. 23 11
      model/channel.go
  3. 2 0
      router/api-router.go
  4. 61 8
      web/src/components/ChannelsTable.js

+ 158 - 0
controller/channel-billing.go

@@ -0,0 +1,158 @@
+package controller
+
+import (
+	"encoding/json"
+	"errors"
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"io"
+	"net/http"
+	"one-api/common"
+	"one-api/model"
+	"strconv"
+	"time"
+)
+
+type OpenAISubscriptionResponse struct {
+	HasPaymentMethod bool    `json:"has_payment_method"`
+	HardLimitUSD     float64 `json:"hard_limit_usd"`
+}
+
+type OpenAIUsageResponse struct {
+	TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar
+}
+
+func updateChannelBalance(channel *model.Channel) (float64, error) {
+	baseURL := common.ChannelBaseURLs[channel.Type]
+	switch channel.Type {
+	case common.ChannelTypeAzure:
+		return 0, errors.New("尚未实现")
+	}
+	url := fmt.Sprintf("%s/v1/dashboard/billing/subscription", baseURL)
+
+	client := &http.Client{}
+	req, err := http.NewRequest("GET", url, nil)
+	if err != nil {
+		return 0, err
+	}
+	auth := fmt.Sprintf("Bearer %s", channel.Key)
+	req.Header.Add("Authorization", auth)
+	res, err := client.Do(req)
+	if err != nil {
+		return 0, err
+	}
+	body, err := io.ReadAll(res.Body)
+	if err != nil {
+		return 0, err
+	}
+	err = res.Body.Close()
+	if err != nil {
+		return 0, err
+	}
+	subscription := OpenAISubscriptionResponse{}
+	err = json.Unmarshal(body, &subscription)
+	if err != nil {
+		return 0, err
+	}
+	now := time.Now()
+	startDate := fmt.Sprintf("%s-01", now.Format("2006-01"))
+	//endDate := now.Format("2006-01-02")
+	url = fmt.Sprintf("%s/v1/dashboard/billing/usage?start_date=%s&end_date=%s", baseURL, startDate, "2023-06-01")
+	req, err = http.NewRequest("GET", url, nil)
+	if err != nil {
+		return 0, err
+	}
+	req.Header.Add("Authorization", auth)
+	res, err = client.Do(req)
+	if err != nil {
+		return 0, err
+	}
+	body, err = io.ReadAll(res.Body)
+	if err != nil {
+		return 0, err
+	}
+	err = res.Body.Close()
+	if err != nil {
+		return 0, err
+	}
+	usage := OpenAIUsageResponse{}
+	err = json.Unmarshal(body, &usage)
+	if err != nil {
+		return 0, err
+	}
+	balance := subscription.HardLimitUSD - usage.TotalUsage/100
+	channel.UpdateBalance(balance)
+	return balance, nil
+}
+
+func UpdateChannelBalance(c *gin.Context) {
+	id, err := strconv.Atoi(c.Param("id"))
+	if err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": err.Error(),
+		})
+		return
+	}
+	channel, err := model.GetChannelById(id, true)
+	if err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": err.Error(),
+		})
+		return
+	}
+	balance, err := updateChannelBalance(channel)
+	if err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": err.Error(),
+		})
+		return
+	}
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"balance": balance,
+	})
+	return
+}
+
+func updateAllChannelsBalance() error {
+	channels, err := model.GetAllChannels(0, 0, true)
+	if err != nil {
+		return err
+	}
+	for _, channel := range channels {
+		if channel.Status != common.ChannelStatusEnabled {
+			continue
+		}
+		balance, err := updateChannelBalance(channel)
+		if err != nil {
+			continue
+		} else {
+			// err is nil & balance <= 0 means quota is used up
+			if balance <= 0 {
+				disableChannel(channel.Id, channel.Name, "余额不足")
+			}
+		}
+	}
+	return nil
+}
+
+func UpdateAllChannelsBalance(c *gin.Context) {
+	// TODO: make it async
+	err := updateAllChannelsBalance()
+	if err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": err.Error(),
+		})
+		return
+	}
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+	})
+	return
+}

+ 23 - 11
model/channel.go

@@ -6,17 +6,19 @@ import (
 )
 
 type Channel struct {
-	Id           int    `json:"id"`
-	Type         int    `json:"type" gorm:"default:0"`
-	Key          string `json:"key" gorm:"not null"`
-	Status       int    `json:"status" gorm:"default:1"`
-	Name         string `json:"name" gorm:"index"`
-	Weight       int    `json:"weight"`
-	CreatedTime  int64  `json:"created_time" gorm:"bigint"`
-	TestTime     int64  `json:"test_time" gorm:"bigint"`
-	ResponseTime int    `json:"response_time"` // in milliseconds
-	BaseURL      string `json:"base_url" gorm:"column:base_url"`
-	Other        string `json:"other"`
+	Id                 int     `json:"id"`
+	Type               int     `json:"type" gorm:"default:0"`
+	Key                string  `json:"key" gorm:"not null"`
+	Status             int     `json:"status" gorm:"default:1"`
+	Name               string  `json:"name" gorm:"index"`
+	Weight             int     `json:"weight"`
+	CreatedTime        int64   `json:"created_time" gorm:"bigint"`
+	TestTime           int64   `json:"test_time" gorm:"bigint"`
+	ResponseTime       int     `json:"response_time"` // in milliseconds
+	BaseURL            string  `json:"base_url" gorm:"column:base_url"`
+	Other              string  `json:"other"`
+	Balance            float64 `json:"balance"` // in USD
+	BalanceUpdatedTime int64   `json:"balance_updated_time" gorm:"bigint"`
 }
 
 func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
@@ -86,6 +88,16 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) {
 	}
 }
 
+func (channel *Channel) UpdateBalance(balance float64) {
+	err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{
+		BalanceUpdatedTime: common.GetTimestamp(),
+		Balance:            balance,
+	}).Error
+	if err != nil {
+		common.SysError("failed to update balance: " + err.Error())
+	}
+}
+
 func (channel *Channel) Delete() error {
 	var err error
 	err = DB.Delete(channel).Error

+ 2 - 0
router/api-router.go

@@ -66,6 +66,8 @@ func SetApiRouter(router *gin.Engine) {
 			channelRoute.GET("/:id", controller.GetChannel)
 			channelRoute.GET("/test", controller.TestAllChannels)
 			channelRoute.GET("/test/:id", controller.TestChannel)
+			channelRoute.GET("/update_balance", controller.UpdateAllChannelsBalance)
+			channelRoute.GET("/update_balance/:id", controller.UpdateChannelBalance)
 			channelRoute.POST("/", controller.AddChannel)
 			channelRoute.PUT("/", controller.UpdateChannel)
 			channelRoute.DELETE("/:id", controller.DeleteChannel)

+ 61 - 8
web/src/components/ChannelsTable.js

@@ -32,6 +32,7 @@ const ChannelsTable = () => {
   const [activePage, setActivePage] = useState(1);
   const [searchKeyword, setSearchKeyword] = useState('');
   const [searching, setSearching] = useState(false);
+  const [updatingBalance, setUpdatingBalance] = useState(false);
 
   const loadChannels = async (startIdx) => {
     const res = await API.get(`/api/channel/?p=${startIdx}`);
@@ -63,7 +64,7 @@ const ChannelsTable = () => {
   const refresh = async () => {
     setLoading(true);
     await loadChannels(0);
-  }
+  };
 
   useEffect(() => {
     loadChannels(0)
@@ -127,7 +128,7 @@ const ChannelsTable = () => {
 
   const renderResponseTime = (responseTime) => {
     let time = responseTime / 1000;
-    time = time.toFixed(2) + " 秒";
+    time = time.toFixed(2) + ' 秒';
     if (responseTime === 0) {
       return <Label basic color='grey'>未测试</Label>;
     } else if (responseTime <= 1000) {
@@ -179,11 +180,38 @@ const ChannelsTable = () => {
     const res = await API.get(`/api/channel/test`);
     const { success, message } = res.data;
     if (success) {
-      showInfo("已成功开始测试所有已启用通道,请刷新页面查看结果。");
+      showInfo('已成功开始测试所有已启用通道,请刷新页面查看结果。');
     } else {
       showError(message);
     }
-  }
+  };
+
+  const updateChannelBalance = async (id, name, idx) => {
+    const res = await API.get(`/api/channel/update_balance/${id}/`);
+    const { success, message, balance } = res.data;
+    if (success) {
+      let newChannels = [...channels];
+      let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx;
+      newChannels[realIdx].balance = balance;
+      newChannels[realIdx].balance_updated_time = Date.now() / 1000;
+      setChannels(newChannels);
+      showInfo(`通道 ${name} 余额更新成功!`);
+    } else {
+      showError(message);
+    }
+  };
+
+  const updateAllChannelsBalance = async () => {
+    setUpdatingBalance(true);
+    const res = await API.get(`/api/channel/update_balance`);
+    const { success, message } = res.data;
+    if (success) {
+      showInfo('已更新完毕所有已启用通道余额!');
+    } else {
+      showError(message);
+    }
+    setUpdatingBalance(false);
+  };
 
   const handleKeywordChange = async (e, { value }) => {
     setSearchKeyword(value.trim());
@@ -263,10 +291,10 @@ const ChannelsTable = () => {
             <Table.HeaderCell
               style={{ cursor: 'pointer' }}
               onClick={() => {
-                sortChannel('test_time');
+                sortChannel('balance');
               }}
             >
-              测试时间
+              余额
             </Table.HeaderCell>
             <Table.HeaderCell>操作</Table.HeaderCell>
           </Table.Row>
@@ -286,8 +314,22 @@ const ChannelsTable = () => {
                   <Table.Cell>{channel.name ? channel.name : '无'}</Table.Cell>
                   <Table.Cell>{renderType(channel.type)}</Table.Cell>
                   <Table.Cell>{renderStatus(channel.status)}</Table.Cell>
-                  <Table.Cell>{renderResponseTime(channel.response_time)}</Table.Cell>
-                  <Table.Cell>{channel.test_time ? renderTimestamp(channel.test_time) : "未测试"}</Table.Cell>
+                  <Table.Cell>
+                    <Popup
+                      content={channel.test_time ? renderTimestamp(channel.test_time) : '未测试'}
+                      key={channel.id}
+                      trigger={renderResponseTime(channel.response_time)}
+                      basic
+                    />
+                  </Table.Cell>
+                  <Table.Cell>
+                    <Popup
+                      content={channel.balance_updated_time ? renderTimestamp(channel.balance_updated_time) : '未更新'}
+                      key={channel.id}
+                      trigger={<span>${channel.balance.toFixed(2)}</span>}
+                      basic
+                    />
+                  </Table.Cell>
                   <Table.Cell>
                     <div>
                       <Button
@@ -299,6 +341,16 @@ const ChannelsTable = () => {
                       >
                         测试
                       </Button>
+                      <Button
+                        size={'small'}
+                        positive
+                        loading={updatingBalance}
+                        onClick={() => {
+                          updateChannelBalance(channel.id, channel.name, idx);
+                        }}
+                      >
+                        更新余额
+                      </Button>
                       <Popup
                         trigger={
                           <Button size='small' negative>
@@ -353,6 +405,7 @@ const ChannelsTable = () => {
               <Button size='small' loading={loading} onClick={testAllChannels}>
                 测试所有已启用通道
               </Button>
+              <Button size='small' onClick={updateAllChannelsBalance} loading={updatingBalance}>更新所有已启用通道余额</Button>
               <Pagination
                 floated='right'
                 activePage={activePage}