Просмотр исходного кода

feat: able to test all enabled channels (#59)

JustSong 2 лет назад
Родитель
Сommit
d267211ee7
5 измененных файлов с 116 добавлено и 13 удалено
  1. 84 11
      controller/channel.go
  2. 13 2
      model/channel.go
  3. 5 0
      model/user.go
  4. 1 0
      router/api-router.go
  5. 13 0
      web/src/components/ChannelsTable.js

+ 84 - 11
controller/channel.go

@@ -11,6 +11,7 @@ import (
 	"one-api/model"
 	"strconv"
 	"strings"
+	"sync"
 	"time"
 )
 
@@ -19,7 +20,7 @@ func GetAllChannels(c *gin.Context) {
 	if p < 0 {
 		p = 0
 	}
-	channels, err := model.GetAllChannels(p*common.ItemsPerPage, common.ItemsPerPage)
+	channels, err := model.GetAllChannels(p*common.ItemsPerPage, common.ItemsPerPage, false)
 	if err != nil {
 		c.JSON(http.StatusOK, gin.H{
 			"success": false,
@@ -206,6 +207,19 @@ func testChannel(channel *model.Channel, request *ChatRequest) error {
 	return nil
 }
 
+func buildTestRequest(c *gin.Context) *ChatRequest {
+	model_ := c.Query("model")
+	testRequest := &ChatRequest{
+		Model: model_,
+	}
+	testMessage := Message{
+		Role:    "user",
+		Content: "echo hi",
+	}
+	testRequest.Messages = append(testRequest.Messages, testMessage)
+	return testRequest
+}
+
 func TestChannel(c *gin.Context) {
 	id, err := strconv.Atoi(c.Param("id"))
 	if err != nil {
@@ -223,17 +237,9 @@ func TestChannel(c *gin.Context) {
 		})
 		return
 	}
-	model_ := c.Query("model")
-	chatRequest := &ChatRequest{
-		Model: model_,
-	}
-	testMessage := Message{
-		Role:    "user",
-		Content: "echo hi",
-	}
-	chatRequest.Messages = append(chatRequest.Messages, testMessage)
+	testRequest := buildTestRequest(c)
 	tik := time.Now()
-	err = testChannel(channel, chatRequest)
+	err = testChannel(channel, testRequest)
 	tok := time.Now()
 	milliseconds := tok.Sub(tik).Milliseconds()
 	go channel.UpdateResponseTime(milliseconds)
@@ -253,3 +259,70 @@ func TestChannel(c *gin.Context) {
 	})
 	return
 }
+
+var testAllChannelsLock sync.Mutex
+
+func testAllChannels(c *gin.Context) error {
+	ok := testAllChannelsLock.TryLock()
+	if !ok {
+		return errors.New("测试已在运行")
+	}
+	defer testAllChannelsLock.Unlock()
+	channels, err := model.GetAllChannels(0, 0, true)
+	if err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": err.Error(),
+		})
+		return err
+	}
+	testRequest := buildTestRequest(c)
+	var disableThreshold int64 = 5000 // TODO: make it configurable
+	email := model.GetRootUserEmail()
+	go func() {
+		for _, channel := range channels {
+			if channel.Status != common.ChannelStatusEnabled {
+				continue
+			}
+			tik := time.Now()
+			err := testChannel(channel, testRequest)
+			tok := time.Now()
+			milliseconds := tok.Sub(tik).Milliseconds()
+			if err != nil || milliseconds > disableThreshold {
+				if milliseconds > disableThreshold {
+					err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
+				}
+				// disable & notify
+				channel.UpdateStatus(common.ChannelStatusDisabled)
+				subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channel.Name, channel.Id)
+				content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channel.Name, channel.Id, err.Error())
+				err = common.SendEmail(subject, email, content)
+				if err != nil {
+					common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
+				}
+			}
+			channel.UpdateResponseTime(milliseconds)
+		}
+		err := common.SendEmail("通道测试完成", email, "通道测试完成")
+		if err != nil {
+			common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
+		}
+	}()
+	return nil
+}
+
+func TestAllChannels(c *gin.Context) {
+	err := testAllChannels(c)
+	if err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": err.Error(),
+		})
+		return
+	}
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+	})
+	return
+}

+ 13 - 2
model/channel.go

@@ -19,10 +19,14 @@ type Channel struct {
 	Other        string `json:"other"`
 }
 
-func GetAllChannels(startIdx int, num int) ([]*Channel, error) {
+func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
 	var channels []*Channel
 	var err error
-	err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
+	if selectAll {
+		err = DB.Order("id desc").Find(&channels).Error
+	} else {
+		err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
+	}
 	return channels, err
 }
 
@@ -82,6 +86,13 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) {
 	}
 }
 
+func (channel *Channel) UpdateStatus(status int) {
+	err := DB.Model(channel).Update("status", status).Error
+	if err != nil {
+		common.SysError("failed to update response time: " + err.Error())
+	}
+}
+
 func (channel *Channel) Delete() error {
 	var err error
 	err = DB.Delete(channel).Error

+ 5 - 0
model/user.go

@@ -234,3 +234,8 @@ func DecreaseUserQuota(id int, quota int) (err error) {
 	err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
 	return err
 }
+
+func GetRootUserEmail() (email string) {
+	DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
+	return email
+}

+ 1 - 0
router/api-router.go

@@ -63,6 +63,7 @@ func SetApiRouter(router *gin.Engine) {
 			channelRoute.GET("/", controller.GetAllChannels)
 			channelRoute.GET("/search", controller.SearchChannels)
 			channelRoute.GET("/:id", controller.GetChannel)
+			channelRoute.GET("/test", controller.TestAllChannels)
 			channelRoute.GET("/test/:id", controller.TestChannel)
 			channelRoute.POST("/", controller.AddChannel)
 			channelRoute.PUT("/", controller.UpdateChannel)

+ 13 - 0
web/src/components/ChannelsTable.js

@@ -170,6 +170,16 @@ const ChannelsTable = () => {
     }
   };
 
+  const testAllChannels = async () => {
+    const res = await API.get(`/api/channel/test`);
+    const { success, message } = res.data;
+    if (success) {
+      showSuccess("已成功开始测试所有已启用通道,请刷新页面查看结果。");
+    } else {
+      showError(message);
+    }
+  }
+
   const handleKeywordChange = async (e, { value }) => {
     setSearchKeyword(value.trim());
   };
@@ -335,6 +345,9 @@ const ChannelsTable = () => {
               <Button size='small' as={Link} to='/channel/add' loading={loading}>
                 添加新的渠道
               </Button>
+              <Button size='small' loading={loading} onClick={testAllChannels}>
+                测试所有已启用通道
+              </Button>
               <Pagination
                 floated='right'
                 activePage={activePage}