Просмотр исходного кода

feat: add timeout plugin to control response header timeout (#338)

* feat: add timeout plugin to control response header timeout

* chore: remove channel monitor 429 notify

* fix: relay mode check error message

* chore: adjust the timeout period

* fix: warn notify preoid

* fix: ci lint

* fix: request at need set by logger middleware

* chore: log format

* chore: remove do request error default notify case
zijiren 4 месяцев назад
Родитель
Сommit
c2e492408f
38 измененных файлов с 451 добавлено и 159 удалено
  1. 10 5
      core/common/body.go
  2. 3 0
      core/common/consume/consume.go
  3. 2 0
      core/common/consume/record.go
  4. 37 2
      core/common/gin.go
  5. 2 0
      core/controller/relay-controller.go
  6. 28 18
      core/main.go
  7. 1 1
      core/mcpproxy/stateless-streamable.go
  8. 1 1
      core/middleware/distributor.go
  9. 14 8
      core/middleware/log.go
  10. 1 10
      core/middleware/reqid.go
  11. 4 1
      core/model/batch.go
  12. 10 24
      core/model/main.go
  13. 22 1
      core/model/modelconfig.go
  14. 21 5
      core/monitor/memmodel.go
  15. 28 9
      core/monitor/model.go
  16. 1 3
      core/relay/adaptor/ali/adaptor.go
  17. 1 1
      core/relay/adaptor/ali/error.go
  18. 2 2
      core/relay/adaptor/anthropic/adaptor.go
  19. 2 2
      core/relay/adaptor/baidu/adaptor.go
  20. 1 1
      core/relay/adaptor/baidu/token.go
  21. 2 2
      core/relay/adaptor/baiduv2/adaptor.go
  22. 2 2
      core/relay/adaptor/cohere/adaptor.go
  23. 2 2
      core/relay/adaptor/coze/adaptor.go
  24. 2 2
      core/relay/adaptor/doc2x/adaptor.go
  25. 2 2
      core/relay/adaptor/gemini/adaptor.go
  26. 2 2
      core/relay/adaptor/ollama/adaptor.go
  27. 2 2
      core/relay/adaptor/openai/adaptor.go
  28. 1 3
      core/relay/adaptor/openai/video.go
  29. 2 2
      core/relay/adaptor/text-embeddings-inference/adaptor.go
  30. 2 2
      core/relay/adaptor/vertexai/adaptor.go
  31. 19 22
      core/relay/controller/dohelper.go
  32. 2 0
      core/relay/meta/meta.go
  33. 26 16
      core/relay/plugin/monitor/monitor.go
  34. 1 1
      core/relay/plugin/streamfake/fake.go
  35. 125 0
      core/relay/plugin/timeout/timeout.go
  36. 2 1
      core/relay/render/claudeevent.go
  37. 64 3
      core/relay/utils/utils.go
  38. 2 1
      mcp-servers/server.go

+ 10 - 5
core/common/body.go

@@ -13,7 +13,7 @@ import (
 	"github.com/bytedance/sonic/ast"
 )
 
-type RequestBodyKey struct{}
+type requestBodyKey struct{}
 
 const (
 	MaxRequestBodySize  = 1024 * 1024 * 50 // 50MB
@@ -84,18 +84,23 @@ func GetRequestBody(req *http.Request) ([]byte, error) {
 
 func SetRequestBody(req *http.Request, body []byte) {
 	ctx := req.Context()
-	bufCtx := context.WithValue(ctx, RequestBodyKey{}, body)
+	bufCtx := context.WithValue(ctx, requestBodyKey{}, body)
 	*req = *req.WithContext(bufCtx)
 }
 
+func IsJSONContentType(ct string) bool {
+	return strings.HasSuffix(ct, "/json") ||
+		strings.Contains(ct, "/json;")
+}
+
 func GetRequestBodyReusable(req *http.Request) ([]byte, error) {
 	contentType := req.Header.Get("Content-Type")
-	if contentType == "application/x-www-form-urlencoded" ||
+	if strings.HasPrefix(contentType, "application/x-www-form-urlencoded") ||
 		strings.HasPrefix(contentType, "multipart/form-data") {
 		return nil, nil
 	}
 
-	requestBody := req.Context().Value(RequestBodyKey{})
+	requestBody := req.Context().Value(requestBodyKey{})
 	if requestBody != nil {
 		b, _ := requestBody.([]byte)
 		return b, nil
@@ -115,7 +120,7 @@ func GetRequestBodyReusable(req *http.Request) ([]byte, error) {
 	}()
 
 	if req.ContentLength <= 0 ||
-		strings.HasPrefix(contentType, "application/json") {
+		IsJSONContentType(contentType) {
 		buf, err = io.ReadAll(LimitReader(req.Body, MaxRequestBodySize))
 		if err != nil {
 			if errors.Is(err, ErrLimitedReaderExceeded) {

+ 3 - 0
core/common/consume/consume.go

@@ -51,6 +51,7 @@ func AsyncConsume(
 
 	go Consume(
 		context.Background(),
+		time.Now(),
 		postGroupConsumer,
 		firstByteAt,
 		code,
@@ -69,6 +70,7 @@ func AsyncConsume(
 
 func Consume(
 	ctx context.Context,
+	now time.Time,
 	postGroupConsumer balance.PostGroupConsumer,
 	firstByteAt time.Time,
 	code int,
@@ -94,6 +96,7 @@ func Consume(
 	selectedModelPrice.ConditionalPrices = nil
 
 	err := recordConsume(
+		now,
 		meta,
 		code,
 		firstByteAt,

+ 2 - 0
core/common/consume/record.go

@@ -8,6 +8,7 @@ import (
 )
 
 func recordConsume(
+	now time.Time,
 	meta *meta.Meta,
 	code int,
 	firstByteAt time.Time,
@@ -23,6 +24,7 @@ func recordConsume(
 	metadata map[string]string,
 ) error {
 	return model.BatchRecordLogs(
+		now,
 		meta.RequestID,
 		meta.RequestAt,
 		meta.RetryAt,

+ 37 - 2
core/common/gin.go

@@ -1,8 +1,11 @@
 package common
 
 import (
+	"context"
 	"fmt"
+	"net/http"
 	"sync"
+	"time"
 
 	"github.com/gin-gonic/gin"
 	"github.com/sirupsen/logrus"
@@ -29,7 +32,14 @@ func PutLogFields(fields logrus.Fields) {
 }
 
 func GetLogger(c *gin.Context) *logrus.Entry {
-	if log, ok := c.Get("log"); ok {
+	return GetLoggerFromReq(c.Request)
+}
+
+type ginLoggerKey struct{}
+
+func GetLoggerFromReq(req *http.Request) *logrus.Entry {
+	ctx := req.Context()
+	if log := ctx.Value(ginLoggerKey{}); log != nil {
 		v, ok := log.(*logrus.Entry)
 		if !ok {
 			panic(fmt.Sprintf("log type error: %T, %v", v, v))
@@ -39,14 +49,39 @@ func GetLogger(c *gin.Context) *logrus.Entry {
 	}
 
 	entry := NewLogger()
-	c.Set("log", entry)
+	SetLogger(req, entry)
 
 	return entry
 }
 
+func SetLogger(req *http.Request, entry *logrus.Entry) {
+	newCtx := context.WithValue(req.Context(), ginLoggerKey{}, entry)
+	*req = *req.WithContext(newCtx)
+}
+
 func NewLogger() *logrus.Entry {
 	return &logrus.Entry{
 		Logger: logrus.StandardLogger(),
 		Data:   GetLogFields(),
 	}
 }
+
+func TruncateDuration(d time.Duration) time.Duration {
+	if d > time.Hour {
+		return d.Truncate(time.Minute)
+	}
+
+	if d > time.Minute {
+		return d.Truncate(time.Second)
+	}
+
+	if d > time.Second {
+		return d.Truncate(time.Millisecond)
+	}
+
+	if d > time.Millisecond {
+		return d.Truncate(time.Microsecond)
+	}
+
+	return d
+}

+ 2 - 0
core/controller/relay-controller.go

@@ -31,6 +31,7 @@ import (
 	monitorplugin "github.com/labring/aiproxy/core/relay/plugin/monitor"
 	"github.com/labring/aiproxy/core/relay/plugin/streamfake"
 	"github.com/labring/aiproxy/core/relay/plugin/thinksplit"
+	"github.com/labring/aiproxy/core/relay/plugin/timeout"
 	websearch "github.com/labring/aiproxy/core/relay/plugin/web-search"
 )
 
@@ -86,6 +87,7 @@ func wrapPlugin(ctx context.Context, mc *model.ModelCaches, a adaptor.Adaptor) a
 		monitorplugin.NewGroupMonitorPlugin(),
 		cache.NewCachePlugin(common.RDB),
 		streamfake.NewStreamFakePlugin(),
+		timeout.NewTimeoutPlugin(),
 		websearch.NewWebSearchPlugin(func(modelName string) (*model.Channel, error) {
 			return getWebSearchChannel(ctx, mc, modelName)
 		}),

+ 28 - 18
core/main.go

@@ -112,10 +112,12 @@ func startSyncServices(ctx context.Context, wg *sync.WaitGroup) {
 func setupHTTPServer() (*http.Server, *gin.Engine) {
 	server := gin.New()
 
-	server.
-		Use(middleware.GinRecoveryHandler).
-		Use(middleware.NewLog(log.StandardLogger())).
-		Use(middleware.RequestIDMiddleware, middleware.CORS())
+	server.Use(
+		middleware.GinRecoveryHandler,
+		middleware.NewLog(log.StandardLogger()),
+		middleware.RequestIDMiddleware,
+		middleware.CORS(),
+	)
 	router.SetRouter(server)
 
 	listenEnv := os.Getenv("LISTEN")
@@ -131,8 +133,6 @@ func setupHTTPServer() (*http.Server, *gin.Engine) {
 }
 
 func autoTestBannedModels(ctx context.Context) {
-	log.Info("auto test banned models start")
-
 	ticker := time.NewTicker(time.Second * 30)
 	defer ticker.Stop()
 
@@ -147,8 +147,6 @@ func autoTestBannedModels(ctx context.Context) {
 }
 
 func detectIPGroupsTask(ctx context.Context) {
-	log.Info("detect IP groups start")
-
 	ticker := time.NewTicker(time.Minute)
 	defer ticker.Stop()
 
@@ -244,7 +242,6 @@ func detectIPGroups() {
 }
 
 func cleanLog(ctx context.Context) {
-	log.Info("clean log start")
 	// the interval should not be too large to avoid cleaning too much at once
 	ticker := time.NewTicker(time.Minute)
 	defer ticker.Stop()
@@ -310,6 +307,13 @@ func printLoadedEnvFiles() {
 	}
 }
 
+func listenAndServe(srv *http.Server) {
+	if err := srv.ListenAndServe(); err != nil &&
+		!errors.Is(err, http.ErrServerClosed) {
+		log.Fatal("failed to start HTTP server: " + err.Error())
+	}
+}
+
 // Swagger godoc
 //
 //	@title						AI Proxy Swagger API
@@ -346,19 +350,20 @@ func main() {
 
 	srv, _ := setupHTTPServer()
 
-	go func() {
-		log.Infof("server started on http://%s", srv.Addr)
-		log.Infof("swagger server started on http://%s/swagger/index.html", srv.Addr)
-
-		if err := srv.ListenAndServe(); err != nil &&
-			!errors.Is(err, http.ErrServerClosed) {
-			log.Fatal("failed to start HTTP server: " + err.Error())
-		}
-	}()
+	log.Info("auto test banned models task started")
 
 	go autoTestBannedModels(ctx)
+
+	log.Info("clean log task started")
+
 	go cleanLog(ctx)
+
+	log.Info("detect ip groups task started")
+
 	go detectIPGroupsTask(ctx)
+
+	log.Info("update channels balance task started")
+
 	go controller.UpdateChannelsBalance(time.Minute * 10)
 
 	batchProcessorCtx, batchProcessorCancel := context.WithCancel(context.Background())
@@ -367,6 +372,11 @@ func main() {
 
 	go model.StartBatchProcessorSummary(batchProcessorCtx, &wg)
 
+	log.Infof("server started on http://%s", srv.Addr)
+	log.Infof("swagger started on http://%s/swagger/index.html", srv.Addr)
+
+	go listenAndServe(srv)
+
 	<-ctx.Done()
 
 	shutdownSrvCtx, shutdownSrvCancel := context.WithTimeout(context.Background(), 600*time.Second)

+ 1 - 1
core/mcpproxy/stateless-streamable.go

@@ -52,7 +52,7 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request
 
 	// Check content type
 	contentType := r.Header.Get("Content-Type")
-	if contentType != "application/json" {
+	if !common.IsJSONContentType(contentType) {
 		http.Error(w, "Invalid content type: must be 'application/json'", http.StatusBadRequest)
 		return
 	}

+ 1 - 1
core/middleware/distributor.go

@@ -389,7 +389,7 @@ func distribute(c *gin.Context, mode mode.Mode) {
 			c,
 			http.StatusNotFound,
 			fmt.Sprintf(
-				"The model `%s` does not exist or you do not have access to it.",
+				"The model `%s` does not exist on this endpoint.",
 				requestModel,
 			),
 		)

+ 14 - 8
core/middleware/log.go

@@ -10,8 +10,19 @@ import (
 	"github.com/sirupsen/logrus"
 )
 
+func SetRequestAt(c *gin.Context, requestAt time.Time) {
+	c.Set(RequestAt, requestAt)
+}
+
+func GetRequestAt(c *gin.Context) time.Time {
+	return c.GetTime(RequestAt)
+}
+
 func NewLog(l *logrus.Logger) gin.HandlerFunc {
 	return func(c *gin.Context) {
+		start := time.Now()
+		SetRequestAt(c, start)
+
 		fields := common.GetLogFields()
 		defer func() {
 			common.PutLogFields(fields)
@@ -21,9 +32,8 @@ func NewLog(l *logrus.Logger) gin.HandlerFunc {
 			Logger: l,
 			Data:   fields,
 		}
-		c.Set("log", entry)
+		common.SetLogger(c.Request, entry)
 
-		start := time.Now()
 		path := c.Request.URL.Path
 		raw := c.Request.URL.RawQuery
 
@@ -74,13 +84,9 @@ func formatter(param gin.LogFormatterParams) string {
 		resetColor = param.ResetColor()
 	}
 
-	if param.Latency > time.Minute {
-		param.Latency = param.Latency.Truncate(time.Second)
-	}
-
-	return fmt.Sprintf("[GIN] |%s %3d %s| %13v | %15s |%s %-7s %s %#v\n%s",
+	return fmt.Sprintf("[GIN] |%s %3d %s| %10v | %15s |%s %-7s %s %#v\n%s",
 		statusColor, param.StatusCode, resetColor,
-		param.Latency,
+		common.TruncateDuration(param.Latency),
 		param.ClientIP,
 		methodColor, param.Method, resetColor,
 		param.Path,

+ 1 - 10
core/middleware/reqid.go

@@ -28,16 +28,7 @@ func GetRequestID(c *gin.Context) string {
 }
 
 func RequestIDMiddleware(c *gin.Context) {
-	now := time.Now()
+	now := GetRequestAt(c)
 	id := GenRequestID(now)
 	SetRequestID(c, id)
-	SetRequestAt(c, now)
-}
-
-func SetRequestAt(c *gin.Context, requestAt time.Time) {
-	c.Set(RequestAt, requestAt)
-}
-
-func GetRequestAt(c *gin.Context) time.Time {
-	return c.GetTime(RequestAt)
 }

+ 4 - 1
core/model/batch.go

@@ -295,6 +295,7 @@ func processSummaryMinuteUpdates(wg *sync.WaitGroup) {
 }
 
 func BatchRecordLogs(
+	now time.Time,
 	requestID string,
 	requestAt time.Time,
 	retryAt time.Time,
@@ -318,7 +319,9 @@ func BatchRecordLogs(
 	user string,
 	metadata map[string]string,
 ) (err error) {
-	now := time.Now()
+	if now.IsZero() {
+		now = time.Now()
+	}
 
 	if downstreamResult {
 		if config.GetLogStorageHours() >= 0 {

+ 10 - 24
core/model/main.go

@@ -165,47 +165,33 @@ func migrateDB() error {
 func InitLogDB() {
 	if os.Getenv("LOG_SQL_DSN") == "" {
 		LogDB = DB
+	} else {
+		log.Info("using log database for table logs")
 
-		if config.DisableAutoMigrateDB {
-			return
-		}
+		var err error
 
-		err := migrateLOGDB()
+		LogDB, err = chooseDB("LOG_SQL_DSN")
 		if err != nil {
-			log.Fatal("failed to migrate secondary database: " + err.Error())
+			log.Fatal("failed to initialize log database: " + err.Error())
 			return
 		}
 
-		log.Info("secondary database migrated")
-
-		return
+		setDBConns(LogDB)
 	}
 
-	log.Info("using secondary database for table logs")
-
-	var err error
-
-	LogDB, err = chooseDB("LOG_SQL_DSN")
-	if err != nil {
-		log.Fatal("failed to initialize secondary database: " + err.Error())
-		return
-	}
-
-	setDBConns(LogDB)
-
 	if config.DisableAutoMigrateDB {
 		return
 	}
 
-	log.Info("secondary database migration started")
+	log.Info("log database migration started")
 
-	err = migrateLOGDB()
+	err := migrateLOGDB()
 	if err != nil {
-		log.Fatal("failed to migrate secondary database: " + err.Error())
+		log.Fatal("failed to migrate log database: " + err.Error())
 		return
 	}
 
-	log.Info("secondary database migrated")
+	log.Info("log database migrated")
 }
 
 func migrateLOGDB() error {

+ 22 - 1
core/model/modelconfig.go

@@ -18,6 +18,11 @@ const (
 	PriceUnit = 1000
 )
 
+type TimeoutConfig struct {
+	RequestTimeout       int64 `json:"request_timeout,omitempty"`
+	StreamRequestTimeout int64 `json:"stream_request_timeout,omitempty"`
+}
+
 type ModelConfig struct {
 	CreatedAt        time.Time                  `gorm:"index;autoCreateTime"          json:"created_at"`
 	UpdatedAt        time.Time                  `gorm:"index;autoUpdateTime"          json:"updated_at"`
@@ -35,7 +40,8 @@ type ModelConfig struct {
 	ImagePrices     map[string]float64 `gorm:"serializer:fastjson;type:text" json:"image_prices,omitempty"`
 	Price           Price              `gorm:"embedded"                      json:"price,omitempty"`
 	RetryTimes      int64              `                                     json:"retry_times,omitempty"`
-	Timeout         int64              `                                     json:"timeout,omitempty"`
+	TimeoutConfig   TimeoutConfig      `gorm:"embedded"                      json:"timeout_config,omitempty"`
+	WarnErrorRate   float64            `                                     json:"warn_error_rate,omitempty"`
 	MaxErrorRate    float64            `                                     json:"max_error_rate,omitempty"`
 	ForceSaveDetail bool               `                                     json:"force_save_detail,omitempty"`
 }
@@ -58,6 +64,21 @@ func NewDefaultModelConfig(model string) ModelConfig {
 	}
 }
 
+func (c *ModelConfig) RequestTimeout() time.Duration {
+	return timeoutSecond(c.TimeoutConfig.RequestTimeout)
+}
+
+func (c *ModelConfig) StreamRequestTimeout() time.Duration {
+	return timeoutSecond(c.TimeoutConfig.StreamRequestTimeout)
+}
+
+func timeoutSecond(second int64) time.Duration {
+	if second == 0 {
+		return 0
+	}
+	return time.Duration(second) * time.Second
+}
+
 func (c *ModelConfig) LoadPluginConfig(pluginName string, config any) error {
 	if len(c.Plugin) == 0 {
 		return nil

+ 21 - 5
core/monitor/memmodel.go

@@ -96,8 +96,14 @@ func (m *MemModelMonitor) AddRequest(
 	model string,
 	channelID int64,
 	isError, tryBan bool,
+	warnErrorRate,
 	maxErrorRate float64,
 ) (beyondThreshold, banExecution bool) {
+	// Set default warning threshold if not specified
+	if warnErrorRate <= 0 {
+		warnErrorRate = DefaultWarnErrorRate
+	}
+
 	m.mu.Lock()
 	defer m.mu.Unlock()
 
@@ -127,16 +133,18 @@ func (m *MemModelMonitor) AddRequest(
 	modelData.totalStats.AddRequest(now, isError)
 	channel.timeWindows.AddRequest(now, isError)
 
-	return m.checkAndBan(now, channel, tryBan, maxErrorRate)
+	return m.checkAndBan(now, channel, tryBan, warnErrorRate, maxErrorRate)
 }
 
 func (m *MemModelMonitor) checkAndBan(
 	now time.Time,
 	channel *ChannelStats,
 	tryBan bool,
+	warnErrorRate,
 	maxErrorRate float64,
 ) (beyondThreshold, banExecution bool) {
 	canBan := maxErrorRate > 0
+
 	if tryBan && canBan {
 		if channel.bannedUntil.After(now) {
 			return false, false
@@ -152,14 +160,22 @@ func (m *MemModelMonitor) checkAndBan(
 		return false, false
 	}
 
-	if float64(err)/float64(req) >= maxErrorRate {
-		if !canBan || channel.bannedUntil.After(now) {
-			return true, false
+	errorRate := float64(err) / float64(req)
+
+	// Check if error rate exceeds warning threshold
+	exceedsWarning := errorRate >= warnErrorRate
+
+	// Check if we should ban (only if maxErrorRate is set and exceeded)
+	if canBan && errorRate >= maxErrorRate {
+		if channel.bannedUntil.After(now) {
+			return true, false // Already banned
 		}
 
 		channel.bannedUntil = now.Add(banDuration)
 
-		return false, true
+		return false, true // Ban executed
+	} else if exceedsWarning {
+		return true, false // Beyond warning threshold but not banning
 	}
 
 	return false, false

+ 28 - 9
core/monitor/model.go

@@ -17,6 +17,9 @@ const (
 	statsKeySuffix        = ":stats"
 	modelTotalStatsSuffix = ":total_stats"
 	channelKeyPart        = ":channel:"
+
+	// Default warning threshold
+	DefaultWarnErrorRate = 0.3 // 30%
 )
 
 func modelKeyPrefix() string {
@@ -69,19 +72,28 @@ func GetModelsErrorRate(ctx context.Context) (map[string]float64, error) {
 }
 
 // AddRequest adds a request record and checks if channel should be banned
+// warnErrorRate: threshold for warning (default 30%)
+// maxErrorRate: threshold for banning (0 means no banning)
 func AddRequest(
 	ctx context.Context,
 	model string,
 	channelID int64,
 	isError, tryBan bool,
+	warnErrorRate,
 	maxErrorRate float64,
 ) (beyondThreshold, banExecution bool, err error) {
+	// Set default warning threshold if not specified
+	if warnErrorRate <= 0 {
+		warnErrorRate = DefaultWarnErrorRate
+	}
+
 	if !common.RedisEnabled {
 		beyondThreshold, banExecution = memModelMonitor.AddRequest(
 			model,
 			channelID,
 			isError,
 			tryBan,
+			warnErrorRate,
 			maxErrorRate,
 		)
 
@@ -104,6 +116,7 @@ func AddRequest(
 		channelID,
 		errorFlag,
 		now,
+		warnErrorRate,
 		maxErrorRate,
 		maxErrorRate > 0,
 		tryBan,
@@ -376,9 +389,10 @@ local model = KEYS[2]
 local channel_id = ARGV[1]
 local is_error = tonumber(ARGV[2])
 local now_ts = tonumber(ARGV[3])
-local max_error_rate = tonumber(ARGV[4])
-local can_ban = tonumber(ARGV[5])
-local try_ban = tonumber(ARGV[6])
+local warn_error_rate = tonumber(ARGV[4])
+local max_error_rate = tonumber(ARGV[5])
+local can_ban = tonumber(ARGV[6])
+local try_ban = tonumber(ARGV[7])
 
 local banned_key = prefix .. ":model:" .. model .. ":channel:" .. channel_id .. ":banned"
 local stats_key = prefix .. ":model:" .. model .. ":channel:" .. channel_id .. ":stats"
@@ -440,15 +454,20 @@ local function check_channel_error()
 		return 0
 	end
 
-	if (total_err / total_req) < max_error_rate then
-		return 0
-	else
-		if can_ban == 0 or already_banned then
-			return 3
+	local error_rate = total_err / total_req
+	
+	-- Check if we should ban (only if max_error_rate is set and exceeded)
+	if can_ban == 1 and error_rate >= max_error_rate then
+		if already_banned then
+			return 3  -- Beyond threshold but already banned
 		end
 		redis.call("SET", banned_key, 1)
 		redis.call("PEXPIRE", banned_key, banExpiry)
-		return 1
+		return 1  -- Ban executed
+	elseif error_rate >= warn_error_rate then
+		return 3  -- Beyond warning threshold but not banning
+	else
+		return 0  -- All good
 	end
 end
 

+ 1 - 3
core/relay/adaptor/ali/adaptor.go

@@ -172,10 +172,8 @@ func (a *Adaptor) DoRequest(
 		return TTSDoRequest(meta, req)
 	case mode.AudioTranscription:
 		return STTDoRequest(meta, req)
-	case mode.ChatCompletions:
-		fallthrough
 	default:
-		return utils.DoRequest(req)
+		return utils.DoRequest(req, meta.RequestTimeout)
 	}
 }
 

+ 1 - 1
core/relay/adaptor/ali/error.go

@@ -19,7 +19,7 @@ func ErrorHanlder(resp *http.Response) adaptor.Error {
 		statusCode = http.StatusServiceUnavailable
 		openAIError.Type = relaymodel.ErrorTypeUpstream
 	case "RequestTimeOut":
-		statusCode = http.StatusGatewayTimeout
+		statusCode = http.StatusRequestTimeout
 		openAIError.Type = relaymodel.ErrorTypeUpstream
 	}
 

+ 2 - 2
core/relay/adaptor/anthropic/adaptor.go

@@ -116,12 +116,12 @@ func (a *Adaptor) ConvertRequest(
 }
 
 func (a *Adaptor) DoRequest(
-	_ *meta.Meta,
+	meta *meta.Meta,
 	_ adaptor.Store,
 	_ *gin.Context,
 	req *http.Request,
 ) (*http.Response, error) {
-	return utils.DoRequest(req)
+	return utils.DoRequest(req, meta.RequestTimeout)
 }
 
 func (a *Adaptor) DoResponse(

+ 2 - 2
core/relay/adaptor/baidu/adaptor.go

@@ -130,12 +130,12 @@ func (a *Adaptor) ConvertRequest(
 }
 
 func (a *Adaptor) DoRequest(
-	_ *meta.Meta,
+	meta *meta.Meta,
 	_ adaptor.Store,
 	_ *gin.Context,
 	req *http.Request,
 ) (*http.Response, error) {
-	return utils.DoRequest(req)
+	return utils.DoRequest(req, meta.RequestTimeout)
 }
 
 func (a *Adaptor) DoResponse(

+ 1 - 1
core/relay/adaptor/baidu/token.go

@@ -70,7 +70,7 @@ func getBaiduAccessTokenHelper(ctx context.Context, apiKey string) (*AccessToken
 	req.Header.Add("Content-Type", "application/json")
 	req.Header.Add("Accept", "application/json")
 
-	res, err := utils.DoRequest(req)
+	res, err := utils.DoRequest(req, 0)
 	if err != nil {
 		return nil, err
 	}

+ 2 - 2
core/relay/adaptor/baiduv2/adaptor.go

@@ -109,12 +109,12 @@ func (a *Adaptor) ConvertRequest(
 }
 
 func (a *Adaptor) DoRequest(
-	_ *meta.Meta,
+	meta *meta.Meta,
 	_ adaptor.Store,
 	_ *gin.Context,
 	req *http.Request,
 ) (*http.Response, error) {
-	return utils.DoRequest(req)
+	return utils.DoRequest(req, meta.RequestTimeout)
 }
 
 func (a *Adaptor) DoResponse(

+ 2 - 2
core/relay/adaptor/cohere/adaptor.go

@@ -82,12 +82,12 @@ func (a *Adaptor) ConvertRequest(
 }
 
 func (a *Adaptor) DoRequest(
-	_ *meta.Meta,
+	meta *meta.Meta,
 	_ adaptor.Store,
 	_ *gin.Context,
 	req *http.Request,
 ) (*http.Response, error) {
-	return utils.DoRequest(req)
+	return utils.DoRequest(req, meta.RequestTimeout)
 }
 
 func (a *Adaptor) DoResponse(

+ 2 - 2
core/relay/adaptor/coze/adaptor.go

@@ -112,12 +112,12 @@ func (a *Adaptor) ConvertRequest(
 }
 
 func (a *Adaptor) DoRequest(
-	_ *meta.Meta,
+	meta *meta.Meta,
 	_ adaptor.Store,
 	_ *gin.Context,
 	req *http.Request,
 ) (*http.Response, error) {
-	return utils.DoRequest(req)
+	return utils.DoRequest(req, meta.RequestTimeout)
 }
 
 func (a *Adaptor) DoResponse(

+ 2 - 2
core/relay/adaptor/doc2x/adaptor.go

@@ -59,12 +59,12 @@ func (a *Adaptor) ConvertRequest(
 }
 
 func (a *Adaptor) DoRequest(
-	_ *meta.Meta,
+	meta *meta.Meta,
 	_ adaptor.Store,
 	_ *gin.Context,
 	req *http.Request,
 ) (*http.Response, error) {
-	return utils.DoRequest(req)
+	return utils.DoRequest(req, meta.RequestTimeout)
 }
 
 func (a *Adaptor) DoResponse(

+ 2 - 2
core/relay/adaptor/gemini/adaptor.go

@@ -86,12 +86,12 @@ func (a *Adaptor) ConvertRequest(
 }
 
 func (a *Adaptor) DoRequest(
-	_ *meta.Meta,
+	meta *meta.Meta,
 	_ adaptor.Store,
 	_ *gin.Context,
 	req *http.Request,
 ) (*http.Response, error) {
-	return utils.DoRequest(req)
+	return utils.DoRequest(req, meta.RequestTimeout)
 }
 
 func (a *Adaptor) DoResponse(

+ 2 - 2
core/relay/adaptor/ollama/adaptor.go

@@ -96,12 +96,12 @@ func (a *Adaptor) ConvertRequest(
 }
 
 func (a *Adaptor) DoRequest(
-	_ *meta.Meta,
+	meta *meta.Meta,
 	_ adaptor.Store,
 	_ *gin.Context,
 	req *http.Request,
 ) (*http.Response, error) {
-	return utils.DoRequest(req)
+	return utils.DoRequest(req, meta.RequestTimeout)
 }
 
 func (a *Adaptor) DoResponse(

+ 2 - 2
core/relay/adaptor/openai/adaptor.go

@@ -367,12 +367,12 @@ func DoResponse(
 const MetaResponseFormat = "response_format"
 
 func (a *Adaptor) DoRequest(
-	_ *meta.Meta,
+	meta *meta.Meta,
 	_ adaptor.Store,
 	_ *gin.Context,
 	req *http.Request,
 ) (*http.Response, error) {
-	return utils.DoRequest(req)
+	return utils.DoRequest(req, meta.RequestTimeout)
 }
 
 func (a *Adaptor) DoResponse(

+ 1 - 3
core/relay/adaptor/openai/video.go

@@ -80,7 +80,7 @@ func VideoHandler(
 		)
 	}
 
-	node, err := sonic.Get(responseBody)
+	idNode, err := sonic.GetWithOptions(responseBody, ast.SearchOptions{}, "id")
 	if err != nil {
 		return model.Usage{}, relaymodel.WrapperOpenAIVideoError(
 			err,
@@ -88,8 +88,6 @@ func VideoHandler(
 		)
 	}
 
-	idNode := node.Get("id")
-
 	id, err := idNode.String()
 	if err != nil {
 		return model.Usage{}, relaymodel.WrapperOpenAIVideoError(

+ 2 - 2
core/relay/adaptor/text-embeddings-inference/adaptor.go

@@ -96,12 +96,12 @@ func (a *Adaptor) ConvertRequest(
 }
 
 func (a *Adaptor) DoRequest(
-	_ *meta.Meta,
+	meta *meta.Meta,
 	_ adaptor.Store,
 	_ *gin.Context,
 	req *http.Request,
 ) (*http.Response, error) {
-	return utils.DoRequest(req)
+	return utils.DoRequest(req, meta.RequestTimeout)
 }
 
 func (a *Adaptor) DoResponse(

+ 2 - 2
core/relay/adaptor/vertexai/adaptor.go

@@ -143,10 +143,10 @@ func (a *Adaptor) SetupRequestHeader(
 }
 
 func (a *Adaptor) DoRequest(
-	_ *meta.Meta,
+	meta *meta.Meta,
 	_ adaptor.Store,
 	_ *gin.Context,
 	req *http.Request,
 ) (*http.Response, error) {
-	return utils.DoRequest(req)
+	return utils.DoRequest(req, meta.RequestTimeout)
 }

+ 19 - 22
core/relay/controller/dohelper.go

@@ -94,31 +94,18 @@ func DoHelper(
 ) {
 	detail := RequestDetail{}
 
-	// 1. Get request body
-	if err := getRequestBody(meta, c, &detail); err != nil {
+	if err := storeRequestBody(meta, c, &detail); err != nil {
 		return model.Usage{}, nil, err
 	}
 
 	// donot use c.Request.Context() because it will be canceled by the client
 	ctx := context.Background()
 
-	timeout := meta.ModelConfig.Timeout
-	if timeout <= 0 {
-		timeout = defaultTimeout
-	}
-
-	var cancel context.CancelFunc
-
-	ctx, cancel = context.WithTimeout(ctx, time.Duration(timeout)*time.Second)
-	defer cancel()
-
-	// 2. Convert and prepare request
 	resp, err := prepareAndDoRequest(ctx, a, c, meta, store)
 	if err != nil {
 		return model.Usage{}, &detail, err
 	}
 
-	// 3. Handle error response
 	if resp == nil {
 		relayErr := relaymodel.WrapperErrorWithMessage(
 			meta.Mode,
@@ -135,25 +122,29 @@ func DoHelper(
 		defer resp.Body.Close()
 	}
 
-	// 4. Handle success response
 	usage, relayErr := handleResponse(a, c, meta, store, resp, &detail)
 	if relayErr != nil {
 		return model.Usage{}, &detail, relayErr
 	}
 
-	// 5. Update usage metrics
-	updateUsageMetrics(usage, common.GetLogger(c))
+	log := common.GetLogger(c)
+	updateUsageMetrics(usage, log)
+
+	if !detail.FirstByteAt.IsZero() {
+		ttfb := detail.FirstByteAt.Sub(meta.RequestAt)
+		log.Data["ttfb"] = common.TruncateDuration(ttfb).String()
+	}
 
 	return usage, &detail, nil
 }
 
-func getRequestBody(meta *meta.Meta, c *gin.Context, detail *RequestDetail) adaptor.Error {
+func storeRequestBody(meta *meta.Meta, c *gin.Context, detail *RequestDetail) adaptor.Error {
 	switch {
 	case meta.Mode == mode.AudioTranscription,
 		meta.Mode == mode.AudioTranslation,
 		meta.Mode == mode.ImagesEdits:
 		return nil
-	case !strings.Contains(c.GetHeader("Content-Type"), "/json"):
+	case !common.IsJSONContentType(c.GetHeader("Content-Type")):
 		return nil
 	default:
 		reqBody, err := common.GetRequestBodyReusable(c.Request)
@@ -171,8 +162,6 @@ func getRequestBody(meta *meta.Meta, c *gin.Context, detail *RequestDetail) adap
 	}
 }
 
-const defaultTimeout = 60 * 30 // 30 minutes
-
 func prepareAndDoRequest(
 	ctx context.Context,
 	a adaptor.Adaptor,
@@ -274,7 +263,7 @@ func doRequest(
 		if errors.Is(err, context.DeadlineExceeded) {
 			return nil, relaymodel.WrapperErrorWithMessage(
 				meta.Mode,
-				http.StatusGatewayTimeout,
+				http.StatusRequestTimeout,
 				"request timeout: "+err.Error(),
 			)
 		}
@@ -295,6 +284,14 @@ func doRequest(
 			)
 		}
 
+		if strings.Contains(err.Error(), "timeout awaiting response headers") {
+			return nil, relaymodel.WrapperErrorWithMessage(
+				meta.Mode,
+				http.StatusRequestTimeout,
+				"request timeout: "+err.Error(),
+			)
+		}
+
 		return nil, relaymodel.WrapperErrorWithMessage(
 			meta.Mode,
 			http.StatusInternalServerError,

+ 2 - 0
core/relay/meta/meta.go

@@ -33,6 +33,8 @@ type Meta struct {
 	ActualModel string
 	Mode        mode.Mode
 
+	RequestTimeout time.Duration
+
 	RequestUsage model.Usage
 
 	JobID        string

+ 26 - 16
core/relay/plugin/monitor/monitor.go

@@ -65,7 +65,7 @@ func getRequestDuration(meta *meta.Meta) time.Duration {
 		return 0
 	}
 
-	return time.Since(requestAtTime)
+	return common.TruncateDuration(time.Since(requestAtTime))
 }
 
 func (m *ChannelMonitor) DoRequest(
@@ -86,6 +86,11 @@ func (m *ChannelMonitor) DoRequest(
 	meta.Set("requestAt", requestAt)
 
 	resp, err := do.DoRequest(meta, store, c, req)
+
+	requestCost := common.TruncateDuration(time.Since(requestAt))
+	log := common.GetLogger(c)
+	log.Data["req_cost"] = requestCost.String()
+
 	if err == nil {
 		return resp, nil
 	}
@@ -96,20 +101,30 @@ func (m *ChannelMonitor) DoRequest(
 		int64(meta.Channel.ID),
 		true,
 		false,
+		meta.ModelConfig.WarnErrorRate,
 		meta.ModelConfig.MaxErrorRate,
 	)
 	if _err != nil {
-		common.GetLogger(c).
-			Errorf("add request failed: %+v", _err)
+		log.Errorf("add request failed: %+v", _err)
 	}
 
 	switch {
 	case banExecution:
-		notifyChannelRequestIssue(meta, "autoBanned", "Auto Banned", err)
+		notifyChannelRequestIssue(
+			meta,
+			"autoBanned",
+			"Auto Banned",
+			err,
+			requestCost,
+		)
 	case beyondThreshold:
-		notifyChannelRequestIssue(meta, "beyondThreshold", "Error Rate Beyond Threshold", err)
-	default:
-		notifyChannelRequestIssue(meta, "requestFailed", "Request Failed", err)
+		notifyChannelRequestIssue(
+			meta,
+			"beyondThreshold",
+			"Error Rate Beyond Threshold",
+			err,
+			requestCost,
+		)
 	}
 
 	return resp, err
@@ -119,6 +134,7 @@ func notifyChannelRequestIssue(
 	meta *meta.Meta,
 	issueType, titleSuffix string,
 	err error,
+	requestCost time.Duration,
 ) {
 	var notifyFunc func(title, message string)
 
@@ -150,7 +166,7 @@ func notifyChannelRequestIssue(
 		meta.Mode,
 		err.Error(),
 		meta.RequestID,
-		getRequestDuration(meta).String(),
+		requestCost.String(),
 	)
 
 	notifyFunc(
@@ -187,6 +203,7 @@ func (m *ChannelMonitor) DoResponse(
 			int64(meta.Channel.ID),
 			false,
 			false,
+			meta.ModelConfig.WarnErrorRate,
 			meta.ModelConfig.MaxErrorRate,
 		); err != nil {
 			log.Errorf("add request failed: %+v", err)
@@ -207,6 +224,7 @@ func (m *ChannelMonitor) DoResponse(
 		int64(meta.Channel.ID),
 		true,
 		!hasPermission,
+		meta.ModelConfig.WarnErrorRate,
 		meta.ModelConfig.MaxErrorRate,
 	)
 	if err != nil {
@@ -214,14 +232,6 @@ func (m *ChannelMonitor) DoResponse(
 	}
 
 	switch {
-	case relayErr.StatusCode() == http.StatusTooManyRequests:
-		notifyChannelResponseIssue(
-			c,
-			meta,
-			"requestRateLimitExceeded",
-			"Request Rate Limit Exceeded",
-			relayErr,
-		)
 	case banExecution:
 		notifyChannelResponseIssue(c, meta, "autoBanned", "Auto Banned", relayErr)
 	case beyondThreshold:

+ 1 - 1
core/relay/plugin/streamfake/fake.go

@@ -270,7 +270,7 @@ func (rw *fakeStreamResponseWriter) parseStreamingData(data []byte) error {
 			rw.finishReason = finishReason
 		}
 
-		logprobsContentNode := choiceNode.Get("logprobs").Get("content")
+		logprobsContentNode := choiceNode.GetByPath("logprobs", "content")
 		if err := logprobsContentNode.Check(); err == nil {
 			l, err := logprobsContentNode.Len()
 			if err != nil {

+ 125 - 0
core/relay/plugin/timeout/timeout.go

@@ -0,0 +1,125 @@
+package timeout
+
+import (
+	"net/http"
+	"time"
+
+	"github.com/bytedance/sonic"
+	"github.com/bytedance/sonic/ast"
+	"github.com/labring/aiproxy/core/common"
+	"github.com/labring/aiproxy/core/relay/adaptor"
+	"github.com/labring/aiproxy/core/relay/meta"
+	"github.com/labring/aiproxy/core/relay/mode"
+	"github.com/labring/aiproxy/core/relay/plugin"
+	"github.com/labring/aiproxy/core/relay/plugin/noop"
+)
+
+var _ plugin.Plugin = (*Timeout)(nil)
+
+type Timeout struct {
+	noop.Noop
+}
+
+func NewTimeoutPlugin() plugin.Plugin {
+	return &Timeout{}
+}
+
+func (t *Timeout) ConvertRequest(
+	meta *meta.Meta,
+	store adaptor.Store,
+	req *http.Request,
+	do adaptor.ConvertRequest,
+) (adaptor.ConvertResult, error) {
+	var stream bool
+	switch meta.Mode {
+	case mode.Embeddings:
+		meta.RequestTimeout = time.Second * 10
+	case mode.Moderations:
+		meta.RequestTimeout = time.Minute * 3
+	case mode.ImagesGenerations,
+		mode.ImagesEdits:
+		meta.RequestTimeout = time.Minute * 5
+	case mode.AudioTranscription,
+		mode.AudioTranslation:
+		meta.RequestTimeout = time.Minute * 3
+	case mode.Rerank:
+		meta.RequestTimeout = time.Second * 10
+	case mode.ParsePdf:
+		meta.RequestTimeout = time.Minute * 3
+	case mode.VideoGenerationsJobs,
+		mode.VideoGenerationsGetJobs,
+		mode.VideoGenerationsContent:
+		meta.RequestTimeout = time.Second * 30
+	case mode.ResponsesGet,
+		mode.ResponsesDelete,
+		mode.ResponsesCancel,
+		mode.ResponsesInputItems:
+		meta.RequestTimeout = time.Second * 30
+	case mode.ChatCompletions,
+		mode.Completions,
+		mode.Responses,
+		mode.Anthropic:
+		stream, _ = isStream(req)
+
+		inputTokens := meta.RequestUsage.InputTokens
+		if stream {
+			switch {
+			case inputTokens > 100*1024:
+				meta.RequestTimeout = time.Minute * 3
+			case inputTokens > 10*1024:
+				meta.RequestTimeout = time.Minute * 2
+			default:
+				meta.RequestTimeout = time.Minute
+			}
+		} else {
+			switch {
+			case inputTokens > 100*1024:
+				meta.RequestTimeout = time.Minute * 10
+			case inputTokens > 10*1024:
+				meta.RequestTimeout = time.Minute * 5
+			default:
+				meta.RequestTimeout = time.Minute * 3
+			}
+		}
+	default:
+		if common.IsJSONContentType(req.Header.Get("Content-Type")) {
+			stream, _ = isStream(req)
+			if stream {
+				meta.RequestTimeout = time.Minute
+			} else {
+				meta.RequestTimeout = time.Minute * 3
+			}
+		}
+	}
+
+	if stream {
+		if timeout := meta.ModelConfig.StreamRequestTimeout(); timeout != 0 {
+			meta.RequestTimeout = timeout
+		}
+	} else {
+		if timeout := meta.ModelConfig.RequestTimeout(); timeout != 0 {
+			meta.RequestTimeout = timeout
+		}
+	}
+
+	if meta.RequestTimeout != 0 {
+		log := common.GetLoggerFromReq(req)
+		log.Data["req_timeout"] = common.TruncateDuration(meta.RequestTimeout).String()
+	}
+
+	return do.ConvertRequest(meta, store, req)
+}
+
+func isStream(req *http.Request) (bool, error) {
+	body, err := common.GetRequestBodyReusable(req)
+	if err != nil {
+		return false, nil
+	}
+
+	node, err := sonic.GetWithOptions(body, ast.SearchOptions{}, "stream")
+	if err != nil {
+		return false, err
+	}
+
+	return node.Bool()
+}

+ 2 - 1
core/relay/render/claudeevent.go

@@ -4,6 +4,7 @@ import (
 	"net/http"
 
 	"github.com/bytedance/sonic"
+	"github.com/bytedance/sonic/ast"
 	"github.com/labring/aiproxy/core/common/conv"
 )
 
@@ -18,7 +19,7 @@ func (r *Anthropic) Render(w http.ResponseWriter) error {
 	event := r.Event
 
 	if event == "" {
-		eventNode, err := sonic.Get(r.Data, "type")
+		eventNode, err := sonic.GetWithOptions(r.Data, ast.SearchOptions{}, "type")
 		if err != nil {
 			return err
 		}

+ 64 - 3
core/relay/utils/utils.go

@@ -2,14 +2,18 @@ package utils
 
 import (
 	"fmt"
+	"net"
 	"net/http"
+	"strconv"
 	"strings"
 	"sync"
+	"time"
 
 	"github.com/bytedance/sonic"
 	"github.com/bytedance/sonic/ast"
 	"github.com/labring/aiproxy/core/common"
 	model "github.com/labring/aiproxy/core/relay/model"
+	"github.com/patrickmn/go-cache"
 )
 
 func UnmarshalGeneralThinking(req *http.Request) (model.GeneralOpenAIThinkingRequest, error) {
@@ -125,10 +129,67 @@ func UnmarshalMap(req *http.Request) (map[string]any, error) {
 	return request, nil
 }
 
-var defaultClient = &http.Client{}
+const (
+	defaultHeaderTimeout = time.Minute * 15
+	tlsHandshakeTimeout  = time.Second * 5
+)
+
+var (
+	defaultTransport *http.Transport
+	defaultClient    *http.Client
+	defaultDialer    = &net.Dialer{
+		Timeout:   10 * time.Second,
+		KeepAlive: 30 * time.Second,
+	}
+	clientCache = cache.New(time.Minute, time.Minute)
+)
+
+func init() {
+	defaultTransport, _ = http.DefaultTransport.(*http.Transport)
+	if defaultTransport == nil {
+		panic("http default transport is not http.Transport type")
+	}
+
+	defaultTransport = defaultTransport.Clone()
+	defaultTransport.DialContext = defaultDialer.DialContext
+	defaultTransport.ResponseHeaderTimeout = defaultHeaderTimeout
+	defaultTransport.TLSHandshakeTimeout = tlsHandshakeTimeout
+
+	defaultClient = &http.Client{
+		Transport: defaultTransport,
+	}
+}
+
+func loadHTTPClient(timeout time.Duration) *http.Client {
+	if timeout == 0 || timeout == defaultHeaderTimeout {
+		return defaultClient
+	}
+
+	key := strconv.Itoa(int(timeout))
+
+	clientI, ok := clientCache.Get(key)
+	if ok {
+		client, ok := clientI.(*http.Client)
+		if !ok {
+			panic("unknow http client type")
+		}
+
+		return client
+	}
+
+	transport := defaultTransport.Clone()
+	transport.ResponseHeaderTimeout = timeout
+
+	client := &http.Client{
+		Transport: transport,
+	}
+	clientCache.SetDefault(key, client)
+
+	return client
+}
 
-func DoRequest(req *http.Request) (*http.Response, error) {
-	resp, err := defaultClient.Do(req)
+func DoRequest(req *http.Request, timeout time.Duration) (*http.Response, error) {
+	resp, err := loadHTTPClient(timeout).Do(req)
 	if err != nil {
 		return nil, err
 	}

+ 2 - 1
mcp-servers/server.go

@@ -6,6 +6,7 @@ import (
 	"runtime"
 
 	"github.com/bytedance/sonic"
+	"github.com/bytedance/sonic/ast"
 	"github.com/mark3labs/mcp-go/client/transport"
 	"github.com/mark3labs/mcp-go/mcp"
 )
@@ -22,7 +23,7 @@ func (s *client2Server) HandleMessage(
 	ctx context.Context,
 	message json.RawMessage,
 ) mcp.JSONRPCMessage {
-	methodNode, err := sonic.Get(message, "method")
+	methodNode, err := sonic.GetWithOptions(message, ast.SearchOptions{}, "method")
 	if err != nil {
 		return CreateMCPErrorResponse(nil, mcp.PARSE_ERROR, err.Error())
 	}