Browse Source

feat: support automatic channel testing & balance updates (close #11, close #59)

JustSong 2 years ago
parent
commit
4463224f04
5 changed files with 57 additions and 15 deletions
  1. 6 0
      README.md
  2. 4 0
      common/constants.go
  3. 10 0
      controller/channel-billing.go
  4. 22 15
      controller/channel-test.go
  5. 15 0
      main.go

+ 6 - 0
README.md

@@ -250,6 +250,12 @@ graph LR
    + 例子:`SYNC_FREQUENCY=60`
 6. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。
    + 例子:`NODE_TYPE=slave`
+7. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。
+   + 例子:`CHANNEL_UPDATE_FREQUENCY=1440`
+8. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。
+   + 例子:`CHANNEL_TEST_FREQUENCY=1440`
+9. `REQUEST_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
+   + 例子:`POLLING_INTERVAL=5`
 
 ### 命令行参数
 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。

+ 4 - 0
common/constants.go

@@ -2,6 +2,7 @@ package common
 
 import (
 	"os"
+	"strconv"
 	"sync"
 	"time"
 
@@ -70,6 +71,9 @@ var RootUserEmail = ""
 
 var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
 
+var requestInterval, _ = strconv.Atoi(os.Getenv("REQUEST_INTERVAL"))
+var RequestInterval = time.Duration(requestInterval) * time.Second
+
 const (
 	RoleGuestUser  = 0
 	RoleCommonUser = 1

+ 10 - 0
controller/channel-billing.go

@@ -257,6 +257,7 @@ func updateAllChannelsBalance() error {
 				disableChannel(channel.Id, channel.Name, "余额不足")
 			}
 		}
+		time.Sleep(common.RequestInterval)
 	}
 	return nil
 }
@@ -277,3 +278,12 @@ func UpdateAllChannelsBalance(c *gin.Context) {
 	})
 	return
 }
+
+func AutomaticallyUpdateChannels(frequency int) {
+	for {
+		time.Sleep(time.Duration(frequency) * time.Minute)
+		common.SysLog("updating all channels")
+		_ = updateAllChannelsBalance()
+		common.SysLog("channels update done")
+	}
+}

+ 22 - 15
controller/channel-test.go

@@ -62,10 +62,9 @@ func testChannel(channel *model.Channel, request ChatRequest) error {
 	return nil
 }
 
-func buildTestRequest(c *gin.Context) *ChatRequest {
-	model_ := c.Query("model")
+func buildTestRequest() *ChatRequest {
 	testRequest := &ChatRequest{
-		Model:     model_,
+		Model:     "", // this will be set later
 		MaxTokens: 1,
 	}
 	testMessage := Message{
@@ -93,7 +92,7 @@ func TestChannel(c *gin.Context) {
 		})
 		return
 	}
-	testRequest := buildTestRequest(c)
+	testRequest := buildTestRequest()
 	tik := time.Now()
 	err = testChannel(channel, *testRequest)
 	tok := time.Now()
@@ -133,7 +132,7 @@ func disableChannel(channelId int, channelName string, reason string) {
 	}
 }
 
-func testAllChannels(c *gin.Context) error {
+func testAllChannels(notify bool) error {
 	if common.RootUserEmail == "" {
 		common.RootUserEmail = model.GetRootUserEmail()
 	}
@@ -146,13 +145,9 @@ func testAllChannels(c *gin.Context) error {
 	testAllChannelsLock.Unlock()
 	channels, err := model.GetAllChannels(0, 0, true)
 	if err != nil {
-		c.JSON(http.StatusOK, gin.H{
-			"success": false,
-			"message": err.Error(),
-		})
 		return err
 	}
-	testRequest := buildTestRequest(c)
+	testRequest := buildTestRequest()
 	var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
 	if disableThreshold == 0 {
 		disableThreshold = 10000000 // a impossible value
@@ -173,20 +168,23 @@ func testAllChannels(c *gin.Context) error {
 				disableChannel(channel.Id, channel.Name, err.Error())
 			}
 			channel.UpdateResponseTime(milliseconds)
-		}
-		err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
-		if err != nil {
-			common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
+			time.Sleep(common.RequestInterval)
 		}
 		testAllChannelsLock.Lock()
 		testAllChannelsRunning = false
 		testAllChannelsLock.Unlock()
+		if notify {
+			err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
+			if err != nil {
+				common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
+			}
+		}
 	}()
 	return nil
 }
 
 func TestAllChannels(c *gin.Context) {
-	err := testAllChannels(c)
+	err := testAllChannels(true)
 	if err != nil {
 		c.JSON(http.StatusOK, gin.H{
 			"success": false,
@@ -200,3 +198,12 @@ func TestAllChannels(c *gin.Context) {
 	})
 	return
 }
+
+func AutomaticallyTestChannels(frequency int) {
+	for {
+		time.Sleep(time.Duration(frequency) * time.Minute)
+		common.SysLog("testing all channels")
+		_ = testAllChannels(false)
+		common.SysLog("channel test finished")
+	}
+}

+ 15 - 0
main.go

@@ -7,6 +7,7 @@ import (
 	"github.com/gin-contrib/sessions/redis"
 	"github.com/gin-gonic/gin"
 	"one-api/common"
+	"one-api/controller"
 	"one-api/middleware"
 	"one-api/model"
 	"one-api/router"
@@ -59,6 +60,20 @@ func main() {
 			go model.SyncChannelCache(frequency)
 		}
 	}
+	if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" {
+		frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY"))
+		if err != nil {
+			common.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error())
+		}
+		go controller.AutomaticallyUpdateChannels(frequency)
+	}
+	if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" {
+		frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY"))
+		if err != nil {
+			common.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error())
+		}
+		go controller.AutomaticallyTestChannels(frequency)
+	}
 
 	// Initialize HTTP server
 	server := gin.Default()