|
|
@@ -11,6 +11,7 @@ import (
|
|
|
"one-api/model"
|
|
|
"strconv"
|
|
|
"strings"
|
|
|
+ "sync"
|
|
|
"time"
|
|
|
)
|
|
|
|
|
|
@@ -19,7 +20,7 @@ func GetAllChannels(c *gin.Context) {
|
|
|
if p < 0 {
|
|
|
p = 0
|
|
|
}
|
|
|
- channels, err := model.GetAllChannels(p*common.ItemsPerPage, common.ItemsPerPage)
|
|
|
+ channels, err := model.GetAllChannels(p*common.ItemsPerPage, common.ItemsPerPage, false)
|
|
|
if err != nil {
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
|
"success": false,
|
|
|
@@ -206,6 +207,19 @@ func testChannel(channel *model.Channel, request *ChatRequest) error {
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
+func buildTestRequest(c *gin.Context) *ChatRequest {
|
|
|
+ model_ := c.Query("model")
|
|
|
+ testRequest := &ChatRequest{
|
|
|
+ Model: model_,
|
|
|
+ }
|
|
|
+ testMessage := Message{
|
|
|
+ Role: "user",
|
|
|
+ Content: "echo hi",
|
|
|
+ }
|
|
|
+ testRequest.Messages = append(testRequest.Messages, testMessage)
|
|
|
+ return testRequest
|
|
|
+}
|
|
|
+
|
|
|
func TestChannel(c *gin.Context) {
|
|
|
id, err := strconv.Atoi(c.Param("id"))
|
|
|
if err != nil {
|
|
|
@@ -223,17 +237,9 @@ func TestChannel(c *gin.Context) {
|
|
|
})
|
|
|
return
|
|
|
}
|
|
|
- model_ := c.Query("model")
|
|
|
- chatRequest := &ChatRequest{
|
|
|
- Model: model_,
|
|
|
- }
|
|
|
- testMessage := Message{
|
|
|
- Role: "user",
|
|
|
- Content: "echo hi",
|
|
|
- }
|
|
|
- chatRequest.Messages = append(chatRequest.Messages, testMessage)
|
|
|
+ testRequest := buildTestRequest(c)
|
|
|
tik := time.Now()
|
|
|
- err = testChannel(channel, chatRequest)
|
|
|
+ err = testChannel(channel, testRequest)
|
|
|
tok := time.Now()
|
|
|
milliseconds := tok.Sub(tik).Milliseconds()
|
|
|
go channel.UpdateResponseTime(milliseconds)
|
|
|
@@ -253,3 +259,70 @@ func TestChannel(c *gin.Context) {
|
|
|
})
|
|
|
return
|
|
|
}
|
|
|
+
|
|
|
+var testAllChannelsLock sync.Mutex
|
|
|
+
|
|
|
+func testAllChannels(c *gin.Context) error {
|
|
|
+ ok := testAllChannelsLock.TryLock()
|
|
|
+ if !ok {
|
|
|
+ return errors.New("测试已在运行")
|
|
|
+ }
|
|
|
+ defer testAllChannelsLock.Unlock()
|
|
|
+ channels, err := model.GetAllChannels(0, 0, true)
|
|
|
+ if err != nil {
|
|
|
+ c.JSON(http.StatusOK, gin.H{
|
|
|
+ "success": false,
|
|
|
+ "message": err.Error(),
|
|
|
+ })
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ testRequest := buildTestRequest(c)
|
|
|
+ var disableThreshold int64 = 5000 // TODO: make it configurable
|
|
|
+ email := model.GetRootUserEmail()
|
|
|
+ go func() {
|
|
|
+ for _, channel := range channels {
|
|
|
+ if channel.Status != common.ChannelStatusEnabled {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ tik := time.Now()
|
|
|
+ err := testChannel(channel, testRequest)
|
|
|
+ tok := time.Now()
|
|
|
+ milliseconds := tok.Sub(tik).Milliseconds()
|
|
|
+ if err != nil || milliseconds > disableThreshold {
|
|
|
+ if milliseconds > disableThreshold {
|
|
|
+ err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
|
|
+ }
|
|
|
+ // disable & notify
|
|
|
+ channel.UpdateStatus(common.ChannelStatusDisabled)
|
|
|
+ subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channel.Name, channel.Id)
|
|
|
+ content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channel.Name, channel.Id, err.Error())
|
|
|
+ err = common.SendEmail(subject, email, content)
|
|
|
+ if err != nil {
|
|
|
+ common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
|
|
|
+ }
|
|
|
+ }
|
|
|
+ channel.UpdateResponseTime(milliseconds)
|
|
|
+ }
|
|
|
+ err := common.SendEmail("通道测试完成", email, "通道测试完成")
|
|
|
+ if err != nil {
|
|
|
+ common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
|
|
|
+ }
|
|
|
+ }()
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func TestAllChannels(c *gin.Context) {
|
|
|
+ err := testAllChannels(c)
|
|
|
+ if err != nil {
|
|
|
+ c.JSON(http.StatusOK, gin.H{
|
|
|
+ "success": false,
|
|
|
+ "message": err.Error(),
|
|
|
+ })
|
|
|
+ return
|
|
|
+ }
|
|
|
+ c.JSON(http.StatusOK, gin.H{
|
|
|
+ "success": true,
|
|
|
+ "message": "",
|
|
|
+ })
|
|
|
+ return
|
|
|
+}
|