Browse Source

feat: channel rpm tpm record (#197)

* feat: channel rpm tpm record

* refactor: request record

* feat: split group rpm and token rpm

* fix: ci lint

* fix: ci lint

* fix: group tpm missing reset to log

* fix: hide request record when group is internal
zijiren 7 months ago
parent
commit
fdce696c37

+ 8 - 0
core/common/consume/consume.go

@@ -34,6 +34,8 @@ func AsyncConsume(
 	downstreamResult bool,
 	user string,
 	metadata map[string]string,
+	channelRate model.RequestRate,
+	groupRate model.RequestRate,
 ) {
 	consumeWaitGroup.Add(1)
 	defer func() {
@@ -58,6 +60,8 @@ func AsyncConsume(
 		downstreamResult,
 		user,
 		metadata,
+		channelRate,
+		groupRate,
 	)
 }
 
@@ -76,6 +80,8 @@ func Consume(
 	downstreamResult bool,
 	user string,
 	metadata map[string]string,
+	channelRate model.RequestRate,
+	groupRate model.RequestRate,
 ) {
 	amount := CalculateAmount(code, usage, modelPrice)
 	amount = consumeAmount(ctx, amount, postGroupConsumer, meta)
@@ -94,6 +100,8 @@ func Consume(
 		downstreamResult,
 		user,
 		metadata,
+		channelRate,
+		groupRate,
 	)
 	if err != nil {
 		log.Error("error batch record consume: " + err.Error())

+ 4 - 0
core/common/consume/record.go

@@ -21,6 +21,8 @@ func recordConsume(
 	downstreamResult bool,
 	user string,
 	metadata map[string]string,
+	channelModelRate model.RequestRate,
+	groupModelTokenRate model.RequestRate,
 ) error {
 	return model.BatchRecordLogs(
 		meta.RequestID,
@@ -45,5 +47,7 @@ func recordConsume(
 		amount,
 		user,
 		metadata,
+		channelModelRate,
+		groupModelTokenRate,
 	)
 }

+ 201 - 0
core/common/reqlimit/main.go

@@ -0,0 +1,201 @@
+package reqlimit
+
+import (
+	"context"
+	"time"
+
+	"github.com/labring/aiproxy/core/common"
+	log "github.com/sirupsen/logrus"
+)
+
+var (
+	memoryGroupModelLimiter = NewInMemoryRecord()
+	redisGroupModelLimiter  = newRedisGroupModelRecord()
+)
+
+func PushGroupModelRequest(ctx context.Context, group, model string, overed int64) (int64, int64, int64) {
+	if common.RedisEnabled {
+		count, overLimitCount, secondCount, err := redisGroupModelLimiter.PushRequest(ctx, overed, time.Minute, 1, group, model)
+		if err == nil {
+			return count, overLimitCount, secondCount
+		}
+		log.Error("redis push request error: " + err.Error())
+	}
+	return memoryGroupModelLimiter.PushRequest(overed, time.Minute, 1, group, model)
+}
+
+func GetGroupModelRequest(ctx context.Context, group, model string) (int64, int64) {
+	if model == "" {
+		model = "*"
+	}
+	if common.RedisEnabled {
+		totalCount, secondCount, err := redisGroupModelLimiter.GetRequest(ctx, time.Minute, group, model)
+		if err == nil {
+			return totalCount, secondCount
+		}
+		log.Error("redis get request error: " + err.Error())
+	}
+	return memoryGroupModelLimiter.GetRequest(time.Minute, group, model)
+}
+
+var (
+	memoryGroupModelTokennameLimiter = NewInMemoryRecord()
+	redisGroupModelTokennameLimiter  = newRedisGroupModelTokennameRecord()
+)
+
+func PushGroupModelTokennameRequest(ctx context.Context, group, model, tokenname string) (int64, int64, int64) {
+	if common.RedisEnabled {
+		count, overLimitCount, secondCount, err := redisGroupModelTokennameLimiter.PushRequest(ctx, 0, time.Minute, 1, group, model, tokenname)
+		if err == nil {
+			return count, overLimitCount, secondCount
+		}
+		log.Error("redis push request error: " + err.Error())
+	}
+	return memoryGroupModelTokennameLimiter.PushRequest(0, time.Minute, 1, group, model, tokenname)
+}
+
+func GetGroupModelTokennameRequest(ctx context.Context, group, model, tokenname string) (int64, int64) {
+	if model == "" {
+		model = "*"
+	}
+	if tokenname == "" {
+		tokenname = "*"
+	}
+	if common.RedisEnabled {
+		totalCount, secondCount, err := redisGroupModelTokennameLimiter.GetRequest(ctx, time.Minute, group, model)
+		if err == nil {
+			return totalCount, secondCount
+		}
+		log.Error("redis get request error: " + err.Error())
+	}
+	return memoryGroupModelTokennameLimiter.GetRequest(time.Minute, group, model, tokenname)
+}
+
+var (
+	memoryChannelModelRecord = NewInMemoryRecord()
+	redisChannelModelRecord  = newRedisChannelModelRecord()
+)
+
+func PushChannelModelRequest(ctx context.Context, channel, model string) (int64, int64, int64) {
+	if common.RedisEnabled {
+		count, overLimitCount, secondCount, err := redisChannelModelRecord.PushRequest(ctx, 0, time.Minute, 1, channel, model)
+		if err == nil {
+			return count, overLimitCount, secondCount
+		}
+		log.Error("redis push request error: " + err.Error())
+	}
+	return memoryChannelModelRecord.PushRequest(0, time.Minute, 1, channel, model)
+}
+
+func GetChannelModelRequest(ctx context.Context, channel, model string) (int64, int64) {
+	if channel == "" {
+		channel = "*"
+	}
+	if model == "" {
+		model = "*"
+	}
+	if common.RedisEnabled {
+		totalCount, secondCount, err := redisChannelModelRecord.GetRequest(ctx, time.Minute, channel, model)
+		if err == nil {
+			return totalCount, secondCount
+		}
+		log.Error("redis get request error: " + err.Error())
+	}
+	return memoryChannelModelRecord.GetRequest(time.Minute, channel, model)
+}
+
+var (
+	memoryGroupModelTokensLimiter = NewInMemoryRecord()
+	redisGroupModelTokensLimiter  = newRedisGroupModelTokensRecord()
+)
+
+func PushGroupModelTokensRequest(ctx context.Context, group, model string, maxTokens int64, tokens int64) (int64, int64, int64) {
+	if common.RedisEnabled {
+		count, overLimitCount, secondCount, err := redisGroupModelTokensLimiter.PushRequest(ctx, maxTokens, time.Minute, tokens, group, model)
+		if err == nil {
+			return count, overLimitCount, secondCount
+		}
+		log.Error("redis push request error: " + err.Error())
+	}
+	return memoryGroupModelTokensLimiter.PushRequest(maxTokens, time.Minute, tokens, group, model)
+}
+
+func GetGroupModelTokensRequest(ctx context.Context, group, model string) (int64, int64) {
+	if model == "" {
+		model = "*"
+	}
+	if common.RedisEnabled {
+		totalCount, secondCount, err := redisGroupModelTokensLimiter.GetRequest(ctx, time.Minute, group, model)
+		if err == nil {
+			return totalCount, secondCount
+		}
+		log.Error("redis get request error: " + err.Error())
+	}
+	return memoryGroupModelTokensLimiter.GetRequest(time.Minute, group, model)
+}
+
+var (
+	memoryGroupModelTokennameTokensLimiter = NewInMemoryRecord()
+	redisGroupModelTokennameTokensLimiter  = newRedisGroupModelTokennameTokensRecord()
+)
+
+func PushGroupModelTokennameTokensRequest(ctx context.Context, group, model, tokenname string, tokens int64) (int64, int64, int64) {
+	if common.RedisEnabled {
+		count, overLimitCount, secondCount, err := redisGroupModelTokennameTokensLimiter.PushRequest(ctx, 0, time.Minute, tokens, group, model, tokenname)
+		if err == nil {
+			return count, overLimitCount, secondCount
+		}
+		log.Error("redis push request error: " + err.Error())
+	}
+	return memoryGroupModelTokennameTokensLimiter.PushRequest(0, time.Minute, tokens, group, model, tokenname)
+}
+
+func GetGroupModelTokennameTokensRequest(ctx context.Context, group, model, tokenname string) (int64, int64) {
+	if model == "" {
+		model = "*"
+	}
+	if tokenname == "" {
+		tokenname = "*"
+	}
+	if common.RedisEnabled {
+		totalCount, secondCount, err := redisGroupModelTokennameTokensLimiter.GetRequest(ctx, time.Minute, group, model, tokenname)
+		if err == nil {
+			return totalCount, secondCount
+		}
+		log.Error("redis get request error: " + err.Error())
+	}
+	return memoryGroupModelTokennameTokensLimiter.GetRequest(time.Minute, group, model, tokenname)
+}
+
+var (
+	memoryChannelModelTokensRecord = NewInMemoryRecord()
+	redisChannelModelTokensRecord  = newRedisChannelModelTokensRecord()
+)
+
+func PushChannelModelTokensRequest(ctx context.Context, channel, model string, tokens int64) (int64, int64, int64) {
+	if common.RedisEnabled {
+		count, overLimitCount, secondCount, err := redisChannelModelTokensRecord.PushRequest(ctx, 0, time.Minute, tokens, channel, model)
+		if err == nil {
+			return count, overLimitCount, secondCount
+		}
+		log.Error("redis push request error: " + err.Error())
+	}
+	return memoryChannelModelTokensRecord.PushRequest(0, time.Minute, tokens, channel, model)
+}
+
+func GetChannelModelTokensRequest(ctx context.Context, channel, model string) (int64, int64) {
+	if channel == "" {
+		channel = "*"
+	}
+	if model == "" {
+		model = "*"
+	}
+	if common.RedisEnabled {
+		totalCount, secondCount, err := redisChannelModelTokensRecord.GetRequest(ctx, time.Minute, channel, model)
+		if err == nil {
+			return totalCount, secondCount
+		}
+		log.Error("redis get request error: " + err.Error())
+	}
+	return memoryChannelModelTokensRecord.GetRequest(time.Minute, channel, model)
+}

+ 149 - 0
core/common/reqlimit/mem.go

@@ -0,0 +1,149 @@
+package reqlimit
+
+import (
+	"strings"
+	"sync"
+	"sync/atomic"
+	"time"
+)
+
+type windowCounts struct {
+	normal int64
+	over   int64
+}
+
+type entry struct {
+	sync.Mutex
+	windows    map[int64]*windowCounts
+	lastAccess atomic.Value
+}
+
+type InMemoryRecord struct {
+	entries sync.Map
+}
+
+func NewInMemoryRecord() *InMemoryRecord {
+	rl := &InMemoryRecord{
+		entries: sync.Map{},
+	}
+	go rl.cleanupInactiveEntries(2*time.Minute, 1*time.Minute)
+	return rl
+}
+
+func (m *InMemoryRecord) getEntry(keys []string) *entry {
+	key := strings.Join(keys, ":")
+	actual, _ := m.entries.LoadOrStore(key, &entry{
+		windows: make(map[int64]*windowCounts),
+	})
+	e, _ := actual.(*entry)
+	if e.lastAccess.Load() == nil {
+		e.lastAccess.CompareAndSwap(nil, time.Now())
+	}
+	return e
+}
+
+func (m *InMemoryRecord) cleanupAndCount(e *entry, cutoff int64) (int64, int64) {
+	normalCount := int64(0)
+	overCount := int64(0)
+	for ts, wc := range e.windows {
+		if ts < cutoff {
+			delete(e.windows, ts)
+		} else {
+			normalCount += wc.normal
+			overCount += wc.over
+		}
+	}
+	return normalCount, overCount
+}
+
+func (m *InMemoryRecord) PushRequest(overed int64, duration time.Duration, n int64, keys ...string) (normalCount int64, overCount int64, secondCount int64) {
+	e := m.getEntry(keys)
+
+	e.Lock()
+	defer e.Unlock()
+
+	now := time.Now()
+	e.lastAccess.Store(now)
+
+	windowStart := now.Unix()
+	cutoff := windowStart - int64(duration.Seconds())
+
+	normalCount, overCount = m.cleanupAndCount(e, cutoff)
+
+	wc, exists := e.windows[windowStart]
+	if !exists {
+		wc = &windowCounts{}
+		e.windows[windowStart] = wc
+	}
+
+	if overed == 0 || normalCount <= overed {
+		wc.normal += n
+		normalCount += n
+	} else {
+		wc.over += n
+		overCount += n
+	}
+
+	return normalCount, overCount, wc.normal + wc.over
+}
+
+func (m *InMemoryRecord) GetRequest(duration time.Duration, keys ...string) (totalCount int64, secondCount int64) {
+	nowSecond := time.Now().Unix()
+	cutoff := nowSecond - int64(duration.Seconds())
+
+	m.entries.Range(func(key, value any) bool {
+		k, _ := key.(string)
+		currentKeys := parseKeys(k)
+
+		if matchKeys(keys, currentKeys) {
+			e, _ := value.(*entry)
+			e.Lock()
+			normalCount, overCount := m.cleanupAndCount(e, cutoff)
+			nowWindow := e.windows[nowSecond]
+			e.Unlock()
+			totalCount += normalCount + overCount
+			if nowWindow != nil {
+				secondCount += nowWindow.normal + nowWindow.over
+			}
+		}
+		return true
+	})
+
+	return totalCount, secondCount
+}
+
+func (m *InMemoryRecord) cleanupInactiveEntries(interval time.Duration, maxInactivity time.Duration) {
+	ticker := time.NewTicker(interval)
+	defer ticker.Stop()
+	for range ticker.C {
+		m.entries.Range(func(key, value any) bool {
+			e, _ := value.(*entry)
+			la := e.lastAccess.Load()
+			if la == nil {
+				return true
+			}
+			lastAccess, _ := la.(time.Time)
+			if time.Since(lastAccess) > maxInactivity {
+				m.entries.CompareAndDelete(key, e)
+			}
+			return true
+		})
+	}
+}
+
+func parseKeys(key string) []string {
+	return strings.Split(key, ":")
+}
+
+func matchKeys(pattern []string, keys []string) bool {
+	if len(pattern) != len(keys) {
+		return false
+	}
+
+	for i, p := range pattern {
+		if p != "*" && p != keys[i] {
+			return false
+		}
+	}
+	return true
+}

+ 272 - 0
core/common/reqlimit/mem_test.go

@@ -0,0 +1,272 @@
+package reqlimit_test
+
+import (
+	"fmt"
+	"sync"
+	"testing"
+	"time"
+
+	"github.com/labring/aiproxy/core/common/reqlimit"
+)
+
+func TestNewInMemoryRateLimiter(t *testing.T) {
+	rl := reqlimit.NewInMemoryRecord()
+	if rl == nil {
+		t.Fatal("NewInMemoryRateLimiter should return a non-nil instance")
+	}
+}
+
+func TestPushRequestBasic(t *testing.T) {
+	rl := reqlimit.NewInMemoryRecord()
+
+	normalCount, overCount, secondCount := rl.PushRequest(10, 60*time.Second, 1, "group1", "model1")
+
+	if normalCount != 1 {
+		t.Errorf("Expected normalCount to be 1, got %d", normalCount)
+	}
+	if overCount != 0 {
+		t.Errorf("Expected overCount to be 0, got %d", overCount)
+	}
+	if secondCount != 1 {
+		t.Errorf("Expected secondCount to be 1, got %d", secondCount)
+	}
+}
+
+func TestPushRequestRateLimit(t *testing.T) {
+	rl := reqlimit.NewInMemoryRecord()
+
+	maxReq := int64(2)
+	duration := 60 * time.Second
+
+	for i := range 4 {
+		normalCount, overCount, _ := rl.PushRequest(maxReq, duration, 1, "group1", "model1")
+
+		switch {
+		case i < 2:
+			if normalCount != int64(i+1) {
+				t.Errorf("Request %d: expected normalCount %d, got %d", i+1, i+1, normalCount)
+			}
+			if overCount != 0 {
+				t.Errorf("Request %d: expected overCount 0, got %d", i+1, overCount)
+			}
+		case i == 2:
+			if normalCount != 3 {
+				t.Errorf("Request %d: expected normalCount 3, got %d", i+1, normalCount)
+			}
+			if overCount != 0 {
+				t.Errorf("Request %d: expected overCount 0, got %d", i+1, overCount)
+			}
+		case i == 3:
+			if normalCount != 3 {
+				t.Errorf("Request %d: expected normalCount 3, got %d", i+1, normalCount)
+			}
+			if overCount != 1 {
+				t.Errorf("Request %d: expected overCount 1, got %d", i+1, overCount)
+			}
+		}
+	}
+}
+
+func TestPushRequestUnlimited(t *testing.T) {
+	rl := reqlimit.NewInMemoryRecord()
+
+	for i := range 5 {
+		normalCount, overCount, _ := rl.PushRequest(0, 60*time.Second, 1, "group1", "model1")
+
+		if normalCount != int64(i+1) {
+			t.Errorf("Request %d: expected normalCount %d, got %d", i+1, i+1, normalCount)
+		}
+		if overCount != 0 {
+			t.Errorf("Request %d: expected overCount 0, got %d", i+1, overCount)
+		}
+	}
+}
+
+func TestGetRequest(t *testing.T) {
+	rl := reqlimit.NewInMemoryRecord()
+
+	rl.PushRequest(10, 60*time.Second, 1, "group1", "model1")
+	rl.PushRequest(10, 60*time.Second, 1, "group1", "model2")
+	rl.PushRequest(10, 60*time.Second, 1, "group2", "model1")
+
+	totalCount, secondCount := rl.GetRequest(60*time.Second, "group1", "model1")
+	if totalCount != 1 {
+		t.Errorf("Expected totalCount 1, got %d", totalCount)
+	}
+	if secondCount != 1 {
+		t.Errorf("Expected secondCount 1, got %d", secondCount)
+	}
+
+	totalCount, _ = rl.GetRequest(60*time.Second, "*", "*")
+	if totalCount != 3 {
+		t.Errorf("Expected totalCount 3 for wildcard query, got %d", totalCount)
+	}
+
+	totalCount, _ = rl.GetRequest(60*time.Second, "group1", "*")
+	if totalCount != 2 {
+		t.Errorf("Expected totalCount 2 for group1 wildcard, got %d", totalCount)
+	}
+}
+
+func TestMultipleGroupsAndModels(t *testing.T) {
+	rl := reqlimit.NewInMemoryRecord()
+
+	groups := []string{"group1", "group2", "group3"}
+	models := []string{"model1", "model2"}
+
+	for _, group := range groups {
+		for _, model := range models {
+			rl.PushRequest(10, 60*time.Second, 1, group, model)
+		}
+	}
+
+	totalCount, _ := rl.GetRequest(60*time.Second, "*", "*")
+	expected := len(groups) * len(models)
+	if totalCount != int64(expected) {
+		t.Errorf("Expected totalCount %d, got %d", expected, totalCount)
+	}
+}
+
+func TestTimeWindowCleanup(t *testing.T) {
+	rl := reqlimit.NewInMemoryRecord()
+
+	rl.PushRequest(10, 2*time.Second, 1, "group1", "model1")
+
+	totalCount, _ := rl.GetRequest(2*time.Second, "group1", "model1")
+	if totalCount != 1 {
+		t.Errorf("Expected totalCount 1, got %d", totalCount)
+	}
+
+	time.Sleep(3 * time.Second)
+
+	totalCount, _ = rl.GetRequest(2*time.Second, "group1", "model1")
+	if totalCount != 0 {
+		t.Errorf("Expected totalCount 0 after cleanup, got %d", totalCount)
+	}
+}
+
+func TestConcurrentAccess(t *testing.T) {
+	rl := reqlimit.NewInMemoryRecord()
+
+	const numGoroutines = 100
+	const requestsPerGoroutine = 10
+
+	var wg sync.WaitGroup
+	wg.Add(numGoroutines)
+
+	for i := range numGoroutines {
+		go func(_ int) {
+			defer wg.Done()
+			for range requestsPerGoroutine {
+				rl.PushRequest(0, 60*time.Second, 1, "group1", "model1")
+			}
+		}(i)
+	}
+
+	wg.Wait()
+
+	totalCount, _ := rl.GetRequest(60*time.Second, "group1", "model1")
+	expected := int64(numGoroutines * requestsPerGoroutine)
+	if totalCount != expected {
+		t.Errorf("Expected totalCount %d, got %d", expected, totalCount)
+	}
+}
+
+func TestConcurrentDifferentKeys(t *testing.T) {
+	rl := reqlimit.NewInMemoryRecord()
+
+	const numGoroutines = 50
+	var wg sync.WaitGroup
+	wg.Add(numGoroutines)
+
+	for i := range numGoroutines {
+		go func(id int) {
+			defer wg.Done()
+			group := fmt.Sprintf("group%d", id%5)
+			model := fmt.Sprintf("model%d", id%3)
+			rl.PushRequest(10, 60*time.Second, 1, group, model)
+		}(i)
+	}
+
+	wg.Wait()
+
+	// 验证总数
+	totalCount, _ := rl.GetRequest(60*time.Second, "*", "*")
+	if totalCount != int64(numGoroutines) {
+		t.Errorf("Expected totalCount %d, got %d", numGoroutines, totalCount)
+	}
+}
+
+func TestRateLimitWithOverflow(t *testing.T) {
+	rl := reqlimit.NewInMemoryRecord()
+
+	maxReq := 5
+	duration := 60 * time.Second
+
+	for i := range 10 {
+		normalCount, overCount, _ := rl.PushRequest(int64(maxReq), duration, 1, "group1", "model1")
+
+		if i < maxReq {
+			if normalCount != int64(i+1) || overCount != 0 {
+				t.Errorf("Request %d: expected normal=%d, over=0, got normal=%d, over=%d",
+					i+1, i+1, normalCount, overCount)
+			}
+		} else {
+			expectedOver := int64(i - maxReq)
+			if normalCount != int64(maxReq+1) || overCount != expectedOver {
+				t.Errorf("Request %d: expected normal=5, over=%d, got normal=%d, over=%d",
+					i+1, expectedOver, normalCount, overCount)
+			}
+		}
+	}
+}
+
+func TestEmptyQueries(t *testing.T) {
+	rl := reqlimit.NewInMemoryRecord()
+
+	totalCount, secondCount := rl.GetRequest(60*time.Second, "*", "*")
+	if totalCount != 0 || secondCount != 0 {
+		t.Errorf("Expected empty results, got total=%d, second=%d", totalCount, secondCount)
+	}
+
+	totalCount, secondCount = rl.GetRequest(60*time.Second, "nonexistent", "model")
+	if totalCount != 0 || secondCount != 0 {
+		t.Errorf("Expected empty results for nonexistent key, got total=%d, second=%d", totalCount, secondCount)
+	}
+}
+
+func BenchmarkPushRequest(b *testing.B) {
+	rl := reqlimit.NewInMemoryRecord()
+
+	b.ResetTimer()
+	b.RunParallel(func(pb *testing.PB) {
+		i := 0
+		for pb.Next() {
+			group := fmt.Sprintf("group%d", i%10)
+			model := fmt.Sprintf("model%d", i%5)
+			rl.PushRequest(100, 60*time.Second, 1, group, model)
+			i++
+		}
+	})
+}
+
+func BenchmarkGetRequest(b *testing.B) {
+	rl := reqlimit.NewInMemoryRecord()
+
+	for i := range 100 {
+		group := fmt.Sprintf("group%d", i%10)
+		model := fmt.Sprintf("model%d", i%5)
+		rl.PushRequest(100, 60*time.Second, 1, group, model)
+	}
+
+	b.ResetTimer()
+	b.RunParallel(func(pb *testing.PB) {
+		i := 0
+		for pb.Next() {
+			group := fmt.Sprintf("group%d", i%10)
+			model := fmt.Sprintf("model%d", i%5)
+			rl.GetRequest(60*time.Second, group, model)
+			i++
+		}
+	})
+}

+ 224 - 0
core/common/reqlimit/redis.go

@@ -0,0 +1,224 @@
+package reqlimit
+
+import (
+	"context"
+	"errors"
+	"strconv"
+	"strings"
+	"time"
+
+	"github.com/labring/aiproxy/core/common"
+	"github.com/redis/go-redis/v9"
+)
+
+type redisRateRecord struct {
+	prefix string
+}
+
+func newRedisGroupModelRecord() *redisRateRecord {
+	return &redisRateRecord{
+		prefix: "group-model-record",
+	}
+}
+
+func newRedisGroupModelTokennameRecord() *redisRateRecord {
+	return &redisRateRecord{
+		prefix: "group-model-tokenname-record",
+	}
+}
+
+func newRedisChannelModelRecord() *redisRateRecord {
+	return &redisRateRecord{
+		prefix: "channel-model-record",
+	}
+}
+
+func newRedisGroupModelTokensRecord() *redisRateRecord {
+	return &redisRateRecord{
+		prefix: "group-model-tokens-record",
+	}
+}
+
+func newRedisGroupModelTokennameTokensRecord() *redisRateRecord {
+	return &redisRateRecord{
+		prefix: "group-model-tokenname-tokens-record",
+	}
+}
+
+func newRedisChannelModelTokensRecord() *redisRateRecord {
+	return &redisRateRecord{
+		prefix: "channel-model-tokens-record",
+	}
+}
+
+const pushRequestLuaScript = `
+local key = KEYS[1]
+local window_seconds = tonumber(ARGV[1])
+local current_time = tonumber(ARGV[2])
+local max_requests = tonumber(ARGV[3])
+local n = tonumber(ARGV[4])
+local cutoff_slice = current_time - window_seconds
+
+local function parse_count(value)
+    if not value then return 0, 0 end
+    local r, e = value:match("^(%d+):(%d+)$")
+    return tonumber(r) or 0, tonumber(e) or 0
+end
+
+local count = 0
+local over_count = 0
+
+local all_fields = redis.call('HGETALL', key)
+for i = 1, #all_fields, 2 do
+    local field_slice = tonumber(all_fields[i])
+    if field_slice < cutoff_slice then
+        redis.call('HDEL', key, all_fields[i])
+	else
+		local c, oc = parse_count(all_fields[i+1])
+		count = count + c
+		over_count = over_count + oc
+	end
+end
+
+local current_value = redis.call('HGET', key, tostring(current_time))
+local current_c, current_oc = parse_count(current_value)
+
+if max_requests == 0 or count <= max_requests then
+	current_c = current_c + n
+    count = count + n
+else
+	current_oc = current_oc + n
+	over_count = over_count + n
+end
+redis.call('HSET', key, current_time, current_c .. ":" .. current_oc)
+
+redis.call('EXPIRE', key, window_seconds)
+local current_second_count = current_c + current_oc
+return string.format("%d:%d:%d", count, over_count, current_second_count)
+`
+
+const getRequestCountLuaScript = `
+local pattern = KEYS[1]
+local window_seconds = tonumber(ARGV[1])
+local current_time = tonumber(ARGV[2])
+local cutoff_slice = current_time - window_seconds
+
+local function parse_count(value)
+    if not value then return 0, 0 end
+    local r, e = value:match("^(%d+):(%d+)$")
+    return tonumber(r) or 0, tonumber(e) or 0
+end
+
+local total = 0
+local current_second_count = 0
+
+local keys = redis.call('KEYS', pattern)
+for _, key in ipairs(keys) do
+    local count = 0
+    local over = 0
+
+    local all_fields = redis.call('HGETALL', key)
+    for i=1, #all_fields, 2 do
+        local field_slice = tonumber(all_fields[i])
+        if field_slice < cutoff_slice then
+			redis.call('HDEL', key, all_fields[i])
+		else
+			local c, oc = parse_count(all_fields[i+1])
+			count = count + c
+			over = over + oc
+            
+            if field_slice == current_time then
+                current_second_count = current_second_count + c + oc
+            end
+		end
+    end
+
+    total = total + count + over
+end
+
+return string.format("%d:%d", total, current_second_count)
+`
+
+var (
+	pushRequestScript     = redis.NewScript(pushRequestLuaScript)
+	getRequestCountScript = redis.NewScript(getRequestCountLuaScript)
+)
+
+func (r *redisRateRecord) buildKey(keys ...string) string {
+	return r.prefix + ":" + strings.Join(keys, ":")
+}
+
+func (r *redisRateRecord) GetRequest(ctx context.Context, duration time.Duration, keys ...string) (totalCount int64, secondCount int64, err error) {
+	if !common.RedisEnabled {
+		return 0, 0, nil
+	}
+
+	pattern := r.buildKey(keys...)
+
+	result, err := getRequestCountScript.Run(
+		ctx,
+		common.RDB,
+		[]string{pattern},
+		duration.Seconds(),
+		time.Now().Unix(),
+	).Text()
+	if err != nil {
+		return 0, 0, err
+	}
+
+	parts := strings.Split(result, ":")
+	if len(parts) != 2 {
+		return 0, 0, errors.New("invalid result format")
+	}
+
+	totalCountInt, err := strconv.ParseInt(parts[0], 10, 64)
+	if err != nil {
+		return 0, 0, err
+	}
+
+	secondCountInt, err := strconv.ParseInt(parts[1], 10, 64)
+	if err != nil {
+		return 0, 0, err
+	}
+
+	return totalCountInt, secondCountInt, nil
+}
+
+func (r *redisRateRecord) PushRequest(ctx context.Context, overed int64, duration time.Duration, n int64, keys ...string) (normalCount int64, overCount int64, secondCount int64, err error) {
+	key := r.buildKey(keys...)
+
+	result, err := pushRequestScript.Run(
+		ctx,
+		common.RDB,
+		[]string{key},
+		duration.Seconds(),
+		time.Now().Unix(),
+		overed,
+		n,
+	).Text()
+	if err != nil {
+		return 0, 0, 0, err
+	}
+
+	parts := strings.Split(result, ":")
+	if len(parts) != 3 {
+		return 0, 0, 0, errors.New("invalid result")
+	}
+
+	countInt, err := strconv.ParseInt(parts[0], 10, 64)
+	if err != nil {
+		return 0, 0, 0, err
+	}
+
+	overLimitCountInt, err := strconv.ParseInt(parts[1], 10, 64)
+	if err != nil {
+		return 0, 0, 0, err
+	}
+
+	secondCountInt, err := strconv.ParseInt(parts[2], 10, 64)
+	if err != nil {
+		return 0, 0, 0, err
+	}
+
+	return countInt, overLimitCountInt, secondCountInt, nil
+}

+ 0 - 45
core/common/rpmlimit/main.go

@@ -1,45 +0,0 @@
-package rpmlimit
-
-import (
-	"context"
-	"time"
-
-	"github.com/labring/aiproxy/core/common"
-	log "github.com/sirupsen/logrus"
-)
-
-func PushRequest(ctx context.Context, group, model string, maxRequestNum int64, duration time.Duration) (int64, int64, error) {
-	if common.RedisEnabled {
-		return redisPushRequest(ctx, group, model, maxRequestNum, duration)
-	}
-	count, overLimitCount := MemoryPushRequest(group, model, maxRequestNum, duration)
-	return count, overLimitCount, nil
-}
-
-func PushRequestAnyWay(ctx context.Context, group, model string, maxRequestNum int64, duration time.Duration) (int64, int64) {
-	if common.RedisEnabled {
-		count, overLimitCount, err := redisPushRequest(ctx, group, model, maxRequestNum, duration)
-		if err == nil {
-			return count, overLimitCount
-		}
-		log.Error("redis push request error: " + err.Error())
-	}
-	return MemoryPushRequest(group, model, maxRequestNum, duration)
-}
-
-func RateLimit(ctx context.Context, group, model string, maxRequestNum int64, duration time.Duration) (bool, error) {
-	if maxRequestNum == 0 {
-		return true, nil
-	}
-	if common.RedisEnabled {
-		return redisRateLimitRequest(ctx, group, model, maxRequestNum, duration)
-	}
-	return MemoryRateLimit(group, model, maxRequestNum, duration), nil
-}
-
-func GetRPM(ctx context.Context, group, model string) (int64, error) {
-	if common.RedisEnabled {
-		return redisGetRPM(ctx, group, model)
-	}
-	return GetMemoryRPM(group, model)
-}

+ 0 - 154
core/common/rpmlimit/mem.go

@@ -1,154 +0,0 @@
-package rpmlimit
-
-import (
-	"fmt"
-	"strings"
-	"sync"
-	"sync/atomic"
-	"time"
-)
-
-type windowCounts struct {
-	normal int64
-	over   int64
-}
-
-type entry struct {
-	sync.Mutex
-	windows    map[int64]*windowCounts
-	lastAccess atomic.Value
-}
-
-type InMemoryRateLimiter struct {
-	entries sync.Map
-}
-
-func newInMemoryRateLimiter() *InMemoryRateLimiter {
-	rl := &InMemoryRateLimiter{
-		entries: sync.Map{},
-	}
-	go rl.cleanupInactiveEntries(2*time.Minute, 1*time.Minute)
-	return rl
-}
-
-var memoryRateLimiter = newInMemoryRateLimiter()
-
-func (m *InMemoryRateLimiter) getEntry(group, model string) *entry {
-	key := fmt.Sprintf("%s:%s", group, model)
-	actual, _ := m.entries.LoadOrStore(key, &entry{
-		windows: make(map[int64]*windowCounts),
-	})
-	e, _ := actual.(*entry)
-	if e.lastAccess.Load() == nil {
-		e.lastAccess.CompareAndSwap(nil, time.Now())
-	}
-	return e
-}
-
-func (m *InMemoryRateLimiter) cleanupAndCount(e *entry, cutoff int64) (int64, int64) {
-	normalCount := int64(0)
-	overCount := int64(0)
-	for ts, wc := range e.windows {
-		if ts < cutoff {
-			delete(e.windows, ts)
-		} else {
-			normalCount += wc.normal
-			overCount += wc.over
-		}
-	}
-	return normalCount, overCount
-}
-
-func (m *InMemoryRateLimiter) pushRequest(group, model string, maxReq int64, duration time.Duration) (int64, int64) {
-	e := m.getEntry(group, model)
-
-	e.Lock()
-	defer e.Unlock()
-
-	now := time.Now()
-
-	e.lastAccess.Store(now)
-
-	windowStart := now.Unix()
-	cutoff := windowStart - int64(duration.Seconds())
-
-	normalCount, overCount := m.cleanupAndCount(e, cutoff)
-
-	wc, exists := e.windows[windowStart]
-	if !exists {
-		wc = &windowCounts{}
-		e.windows[windowStart] = wc
-	}
-
-	if maxReq == 0 || normalCount <= maxReq {
-		wc.normal++
-		normalCount++
-	} else {
-		wc.over++
-		overCount++
-	}
-
-	return normalCount, overCount
-}
-
-func (m *InMemoryRateLimiter) getRPM(group, model string, duration time.Duration) int {
-	total := 0
-	cutoff := time.Now().Unix() - int64(duration.Seconds())
-
-	m.entries.Range(func(key, value any) bool {
-		k, _ := key.(string)
-		currentGroup, currentModel := parseKey(k)
-
-		if (group == "*" || group == currentGroup) &&
-			(model == "" || model == "*" || model == currentModel) {
-			e, _ := value.(*entry)
-			e.Lock()
-			normalCount, overCount := m.cleanupAndCount(e, cutoff)
-			e.Unlock()
-			total += int(normalCount + overCount)
-		}
-		return true
-	})
-
-	return total
-}
-
-func (m *InMemoryRateLimiter) cleanupInactiveEntries(interval time.Duration, maxInactivity time.Duration) {
-	ticker := time.NewTicker(interval)
-	defer ticker.Stop()
-	for range ticker.C {
-		m.entries.Range(func(key, value any) bool {
-			e, _ := value.(*entry)
-			la := e.lastAccess.Load()
-			if la == nil {
-				return true
-			}
-			lastAccess, _ := la.(time.Time)
-			if time.Since(lastAccess) > maxInactivity {
-				m.entries.CompareAndDelete(key, e)
-			}
-			return true
-		})
-	}
-}
-
-func parseKey(key string) (group, model string) {
-	parts := strings.SplitN(key, ":", 2)
-	if len(parts) != 2 {
-		return "", ""
-	}
-	return parts[0], parts[1]
-}
-
-func MemoryPushRequest(group, model string, maxReq int64, duration time.Duration) (int64, int64) {
-	return memoryRateLimiter.pushRequest(group, model, maxReq, duration)
-}
-
-func MemoryRateLimit(group, model string, maxReq int64, duration time.Duration) bool {
-	current, _ := memoryRateLimiter.pushRequest(group, model, maxReq, duration)
-	return current <= maxReq
-}
-
-func GetMemoryRPM(group, model string) (int64, error) {
-	return int64(memoryRateLimiter.getRPM(group, model, time.Minute)), nil
-}

+ 0 - 165
core/common/rpmlimit/redis.go

@@ -1,165 +0,0 @@
-package rpmlimit
-
-import (
-	"context"
-	"errors"
-	"fmt"
-	"strconv"
-	"strings"
-	"time"
-
-	"github.com/labring/aiproxy/core/common"
-	"github.com/redis/go-redis/v9"
-)
-
-const (
-	groupModelRPMHashKey = "group_model_rpm_hash:%s:%s"
-)
-
-const pushRequestLuaScript = `
-local key = KEYS[1]
-local window_seconds = tonumber(ARGV[1])
-local current_time = tonumber(ARGV[2])
-local max_requests = tonumber(ARGV[3])
-local cutoff_slice = current_time - window_seconds
-
-local function parse_count(value)
-    if not value then return 0, 0 end
-    local r, e = value:match("^(%d+):(%d+)$")
-    return tonumber(r) or 0, tonumber(e) or 0
-end
-
-local count = 0
-local over_count = 0
-
-local all_fields = redis.call('HGETALL', key)
-for i = 1, #all_fields, 2 do
-    local field_slice = tonumber(all_fields[i])
-    if field_slice < cutoff_slice then
-        redis.call('HDEL', key, all_fields[i])
-	else
-		local c, oc = parse_count(all_fields[i+1])
-		count = count + c
-		over_count = over_count + oc
-	end
-end
-
-if max_requests == 0 or count <= max_requests then
-    local current_value = redis.call('HGET', key, tostring(current_time))
-    local c, oc = parse_count(current_value)
-    redis.call('HSET', key, current_time, (c+1) .. ":" .. oc)
-    count = count + 1
-else
-    local current_value = redis.call('HGET', key, tostring(current_time))
-    local c, oc = parse_count(current_value)
-    redis.call('HSET', key, current_time, c .. ":" .. (oc+1))
-    over_count = over_count + 1
-end
-
-redis.call('EXPIRE', key, window_seconds)
-return string.format("%d:%d", count, over_count)
-`
-
-const getRequestCountLuaScript = `
-local pattern = KEYS[1]
-local window_seconds = tonumber(ARGV[1])
-local current_time = tonumber(ARGV[2])
-local cutoff_slice = current_time - window_seconds
-
-local function parse_count(value)
-    if not value then return 0, 0 end
-    local r, e = value:match("^(%d+):(%d+)$")
-    return tonumber(r) or 0, tonumber(e) or 0
-end
-
-local total = 0
-
-local keys = redis.call('KEYS', pattern)
-for _, key in ipairs(keys) do
-    local count = 0
-    local over = 0
-
-    local all_fields = redis.call('HGETALL', key)
-    for i=1, #all_fields, 2 do
-        local field_slice = tonumber(all_fields[i])
-        if field_slice < cutoff_slice then
-			redis.call('HDEL', key, all_fields[i])
-		else
-			local c, oc = parse_count(all_fields[i+1])
-			count = count + c
-			over = over + oc
-		end
-    end
-
-    total = total + count + over
-end
-
-return total
-`
-
-var (
-	pushRequestScript     = redis.NewScript(pushRequestLuaScript)
-	getRequestCountScript = redis.NewScript(getRequestCountLuaScript)
-)
-
-func redisGetRPM(ctx context.Context, group, model string) (int64, error) {
-	if !common.RedisEnabled {
-		return 0, nil
-	}
-
-	var pattern string
-	switch {
-	case model == "":
-		model = "*"
-		fallthrough
-	default:
-		pattern = fmt.Sprintf("group_model_rpm_hash:%s:%s", group, model)
-	}
-
-	result, err := getRequestCountScript.Run(
-		ctx,
-		common.RDB,
-		[]string{pattern},
-		time.Minute.Seconds(),
-		time.Now().Unix(),
-	).Int64()
-	if err != nil {
-		return 0, err
-	}
-	return result, nil
-}
-
-func redisPushRequest(ctx context.Context, group, model string, maxRequestNum int64, duration time.Duration) (int64, int64, error) {
-	result, err := pushRequestScript.Run(
-		ctx,
-		common.RDB,
-		[]string{fmt.Sprintf(groupModelRPMHashKey, group, model)},
-		duration.Seconds(),
-		time.Now().Unix(),
-		maxRequestNum,
-	).Text()
-	if err != nil {
-		return 0, 0, err
-	}
-	count, overLimitCount, ok := strings.Cut(result, ":")
-	if !ok {
-		return 0, 0, errors.New("invalid result")
-	}
-	countInt, err := strconv.ParseInt(count, 10, 64)
-	if err != nil {
-		return 0, 0, err
-	}
-	overLimitCountInt, err := strconv.ParseInt(overLimitCount, 10, 64)
-	if err != nil {
-		return 0, 0, err
-	}
-	return countInt, overLimitCountInt, nil
-}
-
-func redisRateLimitRequest(ctx context.Context, group, model string, maxRequestNum int64, duration time.Duration) (bool, error) {
-	count, _, err := PushRequest(ctx, group, model, maxRequestNum, duration)
-	if err != nil {
-		return false, err
-	}
-	return count <= maxRequestNum, nil
-}

+ 1 - 1
core/controller/channel-test.go

@@ -101,7 +101,7 @@ func testSingleModel(mc *model.ModelCaches, channel *model.Channel, modelName st
 		modelConfig,
 		meta.WithRequestID(channelTestRequestID),
 	)
-	result := relayHandler(meta, newc)
+	result := relayHandler(newc, meta)
 	success := result.Error == nil
 	var respStr string
 	var code int

+ 17 - 31
core/controller/dashboard.go

@@ -9,7 +9,7 @@ import (
 	"time"
 
 	"github.com/gin-gonic/gin"
-	"github.com/labring/aiproxy/core/common/rpmlimit"
+	"github.com/labring/aiproxy/core/common/reqlimit"
 	"github.com/labring/aiproxy/core/middleware"
 	"github.com/labring/aiproxy/core/model"
 	"gorm.io/gorm"
@@ -144,7 +144,6 @@ func fillGaps(data []*model.ChartData, start, end time.Time, t model.TimeSpanTyp
 //	@Tags			dashboard
 //	@Produce		json
 //	@Security		ApiKeyAuth
-//	@Param			group			query		string	false	"Group or *"
 //	@Param			channel			query		int		false	"Channel ID"
 //	@Param			type			query		string	false	"Type of time span (day, week, month, two_week)"
 //	@Param			model			query		string	false	"Model name"
@@ -154,19 +153,15 @@ func fillGaps(data []*model.ChartData, start, end time.Time, t model.TimeSpanTyp
 //	@Success		200				{object}	middleware.APIResponse{data=model.DashboardResponse}
 //	@Router			/api/dashboard/ [get]
 func GetDashboard(c *gin.Context) {
-	log := middleware.GetLogger(c)
-
-	group := c.Query("group")
 	startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
 	endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
 	timezoneLocation, _ := time.LoadLocation(c.DefaultQuery("timezone", "Local"))
 	start, end, timeSpan := getDashboardTime(c.Query("type"), startTimestamp, endTimestamp, timezoneLocation)
 	modelName := c.Query("model")
-	channelID, _ := strconv.Atoi(c.Query("channel"))
-
-	needRPM := channelID != 0
+	channelStr := c.Query("channel")
+	channelID, _ := strconv.Atoi(channelStr)
 
-	dashboards, err := model.GetDashboardData(group, start, end, modelName, channelID, timeSpan, needRPM, timezoneLocation)
+	dashboards, err := model.GetDashboardData(start, end, modelName, channelID, timeSpan, timezoneLocation)
 	if err != nil {
 		middleware.ErrorResponse(c, http.StatusOK, err.Error())
 		return
@@ -174,14 +169,13 @@ func GetDashboard(c *gin.Context) {
 
 	dashboards.ChartData = fillGaps(dashboards.ChartData, start, end, timeSpan)
 
-	if !needRPM {
-		rpm, err := rpmlimit.GetRPM(c.Request.Context(), group, modelName)
-		if err != nil {
-			log.Errorf("failed to get rpm: %v", err)
-		} else {
-			dashboards.RPM = rpm
-		}
+	if channelID == 0 {
+		channelStr = "*"
 	}
+	rpm, _ := reqlimit.GetChannelModelRequest(c.Request.Context(), channelStr, modelName)
+	dashboards.RPM = rpm
+	tpm, _ := reqlimit.GetChannelModelTokensRequest(c.Request.Context(), channelStr, modelName)
+	dashboards.TPM = tpm
 
 	middleware.SuccessResponse(c, dashboards)
 }
@@ -203,8 +197,6 @@ func GetDashboard(c *gin.Context) {
 //	@Success		200				{object}	middleware.APIResponse{data=model.GroupDashboardResponse}
 //	@Router			/api/dashboard/{group} [get]
 func GetGroupDashboard(c *gin.Context) {
-	log := middleware.GetLogger(c)
-
 	group := c.Param("group")
 	if group == "" || group == "*" {
 		middleware.ErrorResponse(c, http.StatusOK, "invalid group parameter")
@@ -218,9 +210,7 @@ func GetGroupDashboard(c *gin.Context) {
 	tokenName := c.Query("token_name")
 	modelName := c.Query("model")
 
-	needRPM := tokenName != ""
-
-	dashboards, err := model.GetGroupDashboardData(group, start, end, tokenName, modelName, timeSpan, needRPM, timezoneLocation)
+	dashboards, err := model.GetGroupDashboardData(group, start, end, tokenName, modelName, timeSpan, timezoneLocation)
 	if err != nil {
 		middleware.ErrorResponse(c, http.StatusOK, "failed to get statistics")
 		return
@@ -228,14 +218,10 @@ func GetGroupDashboard(c *gin.Context) {
 
 	dashboards.ChartData = fillGaps(dashboards.ChartData, start, end, timeSpan)
 
-	if !needRPM {
-		rpm, err := rpmlimit.GetRPM(c.Request.Context(), group, modelName)
-		if err != nil {
-			log.Errorf("failed to get rpm: %v", err)
-		} else {
-			dashboards.RPM = rpm
-		}
-	}
+	rpm, _ := reqlimit.GetGroupModelTokennameRequest(c.Request.Context(), group, modelName, tokenName)
+	dashboards.RPM = rpm
+	tpm, _ := reqlimit.GetGroupModelTokennameTokensRequest(c.Request.Context(), group, modelName, tokenName)
+	dashboards.TPM = tpm
 
 	middleware.SuccessResponse(c, dashboards)
 }
@@ -293,7 +279,7 @@ func GetGroupDashboardModels(c *gin.Context) {
 //	@Param			channel			query		int		false	"Channel ID"
 //	@Param			start_timestamp	query		int64	false	"Start timestamp"
 //	@Param			end_timestamp	query		int64	false	"End timestamp"
-//	@Success		200				{object}	middleware.APIResponse{data=[]model.ModelCostRank}
+//	@Success		200				{object}	middleware.APIResponse{data=[]model.CostRank}
 //	@Router			/api/model_cost_rank/ [get]
 func GetModelCostRank(c *gin.Context) {
 	group := c.Query("group")
@@ -317,7 +303,7 @@ func GetModelCostRank(c *gin.Context) {
 //	@Param			group			path		string	true	"Group"
 //	@Param			start_timestamp	query		int64	false	"Start timestamp"
 //	@Param			end_timestamp	query		int64	false	"End timestamp"
-//	@Success		200				{object}	middleware.APIResponse{data=[]model.ModelCostRank}
+//	@Success		200				{object}	middleware.APIResponse{data=[]model.CostRank}
 //	@Router			/api/model_cost_rank/{group} [get]
 func GetGroupModelCostRank(c *gin.Context) {
 	group := c.Param("group")

+ 144 - 23
core/controller/relay-controller.go

@@ -17,10 +17,14 @@ import (
 	"github.com/labring/aiproxy/core/common/config"
 	"github.com/labring/aiproxy/core/common/consume"
 	"github.com/labring/aiproxy/core/common/notify"
+	"github.com/labring/aiproxy/core/common/reqlimit"
 	"github.com/labring/aiproxy/core/common/trylock"
 	"github.com/labring/aiproxy/core/middleware"
 	"github.com/labring/aiproxy/core/model"
 	"github.com/labring/aiproxy/core/monitor"
+	"github.com/labring/aiproxy/core/relay/adaptor"
+	"github.com/labring/aiproxy/core/relay/adaptor/openai"
+	"github.com/labring/aiproxy/core/relay/adaptors"
 	"github.com/labring/aiproxy/core/relay/controller"
 	"github.com/labring/aiproxy/core/relay/meta"
 	"github.com/labring/aiproxy/core/relay/mode"
@@ -31,7 +35,7 @@ import (
 // https://platform.openai.com/docs/api-reference/chat
 
 type (
-	RelayHandler    func(*meta.Meta, *gin.Context) *controller.HandleResult
+	RelayHandler    func(*gin.Context, *meta.Meta) *controller.HandleResult
 	GetRequestUsage func(*gin.Context, *model.ModelConfig) (model.Usage, error)
 	GetRequestPrice func(*gin.Context, *model.ModelConfig) (model.Price, error)
 )
@@ -42,10 +46,137 @@ type RelayController struct {
 	Handler         RelayHandler
 }
 
-func relayHandler(meta *meta.Meta, c *gin.Context) *controller.HandleResult {
+var ErrInvalidChannelTypeCode = "invalid_channel_type"
+
+type warpAdaptor struct {
+	adaptor.Adaptor
+}
+
+const (
+	MetaChannelModelKeyRPM = "channel_model_rpm"
+	MetaChannelModelKeyRPS = "channel_model_rps"
+	MetaChannelModelKeyTPM = "channel_model_tpm"
+	MetaChannelModelKeyTPS = "channel_model_tps"
+
+	MetaGroupModelTokennameTPM = "group_model_tokenname_tpm"
+	MetaGroupModelTokennameTPS = "group_model_tokenname_tps"
+)
+
+func getChannelModelRequestRate(meta *meta.Meta) model.RequestRate {
+	rate := model.RequestRate{}
+
+	if rpm, ok := meta.Get(MetaChannelModelKeyRPM); ok {
+		rate.RPM, _ = rpm.(int64)
+		rate.RPS = meta.GetInt64(MetaChannelModelKeyRPS)
+	} else {
+		rpm, rps := reqlimit.GetChannelModelRequest(context.Background(), strconv.Itoa(meta.Channel.ID), meta.OriginModel)
+		rate.RPM = rpm
+		rate.RPS = rps
+	}
+
+	if tpm, ok := meta.Get(MetaChannelModelKeyTPM); ok {
+		rate.TPM, _ = tpm.(int64)
+		rate.TPS = meta.GetInt64(MetaChannelModelKeyTPS)
+	} else {
+		tpm, tps := reqlimit.GetChannelModelTokensRequest(context.Background(), strconv.Itoa(meta.Channel.ID), meta.OriginModel)
+		rate.TPM = tpm
+		rate.TPS = tps
+	}
+
+	return rate
+}
+
+func getGroupModelTokenRequestRate(c *gin.Context, meta *meta.Meta) model.RequestRate {
+	r := model.RequestRate{
+		RPM: middleware.GetGroupModelTokenRPM(c),
+		RPS: middleware.GetGroupModelTokenRPS(c),
+		TPM: middleware.GetGroupModelTokenTPM(c),
+		TPS: middleware.GetGroupModelTokenTPS(c),
+	}
+
+	if tpm, ok := meta.Get(MetaGroupModelTokennameTPM); ok {
+		r.TPM, _ = tpm.(int64)
+		r.TPS = meta.GetInt64(MetaGroupModelTokennameTPS)
+	}
+
+	return r
+}
+
+func (w *warpAdaptor) DoRequest(meta *meta.Meta, c *gin.Context, req *http.Request) (*http.Response, error) {
+	count, overLimitCount, secondCount := reqlimit.PushChannelModelRequest(
+		context.Background(),
+		strconv.Itoa(meta.Channel.ID),
+		meta.OriginModel,
+	)
+	log := middleware.GetLogger(c)
+	meta.Set(MetaChannelModelKeyRPM, count+overLimitCount)
+	meta.Set(MetaChannelModelKeyRPS, secondCount)
+	log.Data["ch_rpm"] = count + overLimitCount
+	log.Data["ch_rps"] = secondCount
+	return w.Adaptor.DoRequest(meta, c, req)
+}
+
+func (w *warpAdaptor) DoResponse(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model.Usage, *relaymodel.ErrorWithStatusCode) {
+	usage, relayErr := w.Adaptor.DoResponse(meta, c, resp)
+	if usage == nil {
+		return nil, relayErr
+	}
+
+	count, overLimitCount, secondCount := reqlimit.PushChannelModelTokensRequest(
+		context.Background(),
+		strconv.Itoa(meta.Channel.ID),
+		meta.OriginModel,
+		int64(usage.TotalTokens),
+	)
+	log := middleware.GetLogger(c)
+	meta.Set(MetaChannelModelKeyTPM, count+overLimitCount)
+	meta.Set(MetaChannelModelKeyTPS, secondCount)
+	log.Data["ch_tpm"] = count + overLimitCount
+	log.Data["ch_tps"] = secondCount
+
+	count, overLimitCount, secondCount = reqlimit.PushGroupModelTokensRequest(
+		context.Background(),
+		meta.Group.ID,
+		meta.OriginModel,
+		meta.ModelConfig.TPM,
+		int64(usage.TotalTokens),
+	)
+	if meta.Group.Status != model.GroupStatusInternal {
+		log.Data["group_tpm"] = count + overLimitCount
+		log.Data["group_tps"] = secondCount
+	}
+
+	count, overLimitCount, secondCount = reqlimit.PushGroupModelTokennameTokensRequest(
+		context.Background(),
+		meta.Group.ID,
+		meta.OriginModel,
+		meta.Token.Name,
+		int64(usage.TotalTokens),
+	)
+	meta.Set(MetaGroupModelTokennameTPM, count+overLimitCount)
+	meta.Set(MetaGroupModelTokennameTPS, secondCount)
+	// log.Data["tpm"] = count + overLimitCount
+	// log.Data["tps"] = secondCount
+
+	return usage, relayErr
+}
+
+func relayHandler(c *gin.Context, meta *meta.Meta) *controller.HandleResult {
 	log := middleware.GetLogger(c)
 	middleware.SetLogFieldsFromMeta(meta, log.Data)
-	return controller.Handle(meta, c)
+
+	adaptor, ok := adaptors.GetAdaptor(meta.Channel.Type)
+	if !ok {
+		return &controller.HandleResult{
+			Error: openai.ErrorWrapperWithMessage(
+				fmt.Sprintf("invalid channel type: %d", meta.Channel.Type),
+				ErrInvalidChannelTypeCode,
+				http.StatusInternalServerError,
+			),
+		}
+	}
+
+	return controller.Handle(&warpAdaptor{adaptor}, c, meta)
 }
 
 func relayController(m mode.Mode) RelayController {
@@ -87,8 +218,8 @@ func relayController(m mode.Mode) RelayController {
 	return c
 }
 
-func RelayHelper(meta *meta.Meta, c *gin.Context, handel RelayHandler) (*controller.HandleResult, bool) {
-	result := handel(meta, c)
+func RelayHelper(c *gin.Context, meta *meta.Meta, handel RelayHandler) (*controller.HandleResult, bool) {
+	result := handel(c, meta)
 	if result.Error == nil {
 		if _, _, err := monitor.AddRequest(
 			context.Background(),
@@ -165,20 +296,8 @@ func notifyChannelIssue(meta *meta.Meta, issueType string, titleSuffix string, e
 			notifyFunc = notify.Error
 		}
 
-		now := time.Now()
-		group := "*"
-		rpm, rpmErr := model.GetRPM(group, now, "", meta.OriginModel, meta.Channel.ID)
-		tpm, tpmErr := model.GetTPM(group, now, "", meta.OriginModel, meta.Channel.ID)
-		if rpmErr != nil {
-			message += fmt.Sprintf("\nrpm: %v", rpmErr)
-		} else {
-			message += fmt.Sprintf("\nrpm: %d", rpm)
-		}
-		if tpmErr != nil {
-			message += fmt.Sprintf("\ntpm: %v", tpmErr)
-		} else {
-			message += fmt.Sprintf("\ntpm: %d", tpm)
-		}
+		rate := getChannelModelRequestRate(meta)
+		message += fmt.Sprintf("\nrpm: %d\nrps: %d\ntpm: %d\ntps: %d", rate.RPM, rate.RPS, rate.TPM, rate.TPS)
 	}
 
 	notifyFunc(
@@ -348,7 +467,7 @@ func relay(c *gin.Context, mode mode.Mode, relayController RelayController) {
 	}
 
 	// First attempt
-	result, retry := RelayHelper(meta, c, relayController.Handler)
+	result, retry := RelayHelper(c, meta, relayController.Handler)
 
 	retryTimes := int(config.GetRetryTimes())
 	if mc.RetryTimes > 0 {
@@ -436,6 +555,8 @@ func recordResult(
 		downstreamResult,
 		user,
 		metadata,
+		getChannelModelRequestRate(meta),
+		getGroupModelTokenRequestRate(c, meta),
 	)
 }
 
@@ -603,7 +724,7 @@ func retryLoop(c *gin.Context, mode mode.Mode, state *retryState, relayControlle
 			meta.WithRetryAt(time.Now()),
 		)
 		var retry bool
-		state.result, retry = RelayHelper(state.meta, c, relayController)
+		state.result, retry = RelayHelper(c, state.meta, relayController)
 
 		done := handleRetryResult(c, retry, newChannel, state)
 		if done || i == state.retryTimes-1 {
@@ -693,7 +814,7 @@ var channelNoRetryStatusCodesMap = map[int]struct{}{
 
 // 仅当是channel错误时,才需要记录,用户请求参数错误时,不需要记录
 func shouldRetry(_ *gin.Context, relayErr relaymodel.ErrorWithStatusCode) bool {
-	if relayErr.Error.Code == controller.ErrInvalidChannelTypeCode {
+	if relayErr.Error.Code == ErrInvalidChannelTypeCode {
 		return false
 	}
 	_, ok := channelNoRetryStatusCodesMap[relayErr.StatusCode]
@@ -708,7 +829,7 @@ var channelNoPermissionStatusCodesMap = map[int]struct{}{
 }
 
 func channelHasPermission(relayErr relaymodel.ErrorWithStatusCode) bool {
-	if relayErr.Error.Code == controller.ErrInvalidChannelTypeCode {
+	if relayErr.Error.Code == ErrInvalidChannelTypeCode {
 		return false
 	}
 	_, ok := channelNoPermissionStatusCodesMap[relayErr.StatusCode]

+ 82 - 40
core/docs/docs.go

@@ -937,12 +937,6 @@ const docTemplate = `{
                 ],
                 "summary": "Get dashboard data",
                 "parameters": [
-                    {
-                        "type": "string",
-                        "description": "Group or *",
-                        "name": "group",
-                        "in": "query"
-                    },
                     {
                         "type": "integer",
                         "description": "Channel ID",
@@ -4438,7 +4432,7 @@ const docTemplate = `{
                                         "data": {
                                             "type": "array",
                                             "items": {
-                                                "$ref": "#/definitions/model.ModelCostRank"
+                                                "$ref": "#/definitions/model.CostRank"
                                             }
                                         }
                                     }
@@ -4499,7 +4493,7 @@ const docTemplate = `{
                                         "data": {
                                             "type": "array",
                                             "items": {
-                                                "$ref": "#/definitions/model.ModelCostRank"
+                                                "$ref": "#/definitions/model.CostRank"
                                             }
                                         }
                                     }
@@ -8512,6 +8506,18 @@ const docTemplate = `{
                 "input_tokens": {
                     "type": "integer"
                 },
+                "max_rpm": {
+                    "type": "integer"
+                },
+                "max_rps": {
+                    "type": "integer"
+                },
+                "max_tpm": {
+                    "type": "integer"
+                },
+                "max_tps": {
+                    "type": "integer"
+                },
                 "output_tokens": {
                     "type": "integer"
                 },
@@ -8549,6 +8555,50 @@ const docTemplate = `{
                 }
             }
         },
+        "model.CostRank": {
+            "type": "object",
+            "properties": {
+                "cache_creation_tokens": {
+                    "type": "integer"
+                },
+                "cached_tokens": {
+                    "type": "integer"
+                },
+                "input_tokens": {
+                    "type": "integer"
+                },
+                "max_rpm": {
+                    "type": "integer"
+                },
+                "max_rps": {
+                    "type": "integer"
+                },
+                "max_tpm": {
+                    "type": "integer"
+                },
+                "max_tps": {
+                    "type": "integer"
+                },
+                "model": {
+                    "type": "string"
+                },
+                "output_tokens": {
+                    "type": "integer"
+                },
+                "request_count": {
+                    "type": "integer"
+                },
+                "total_tokens": {
+                    "type": "integer"
+                },
+                "used_amount": {
+                    "type": "number"
+                },
+                "web_search_count": {
+                    "type": "integer"
+                }
+            }
+        },
         "model.DashboardResponse": {
             "type": "object",
             "properties": {
@@ -8576,6 +8626,18 @@ const docTemplate = `{
                 "input_tokens": {
                     "type": "integer"
                 },
+                "max_rpm": {
+                    "type": "integer"
+                },
+                "max_rps": {
+                    "type": "integer"
+                },
+                "max_tpm": {
+                    "type": "integer"
+                },
+                "max_tps": {
+                    "type": "integer"
+                },
                 "output_tokens": {
                     "type": "integer"
                 },
@@ -8923,6 +8985,18 @@ const docTemplate = `{
                 "input_tokens": {
                     "type": "integer"
                 },
+                "max_rpm": {
+                    "type": "integer"
+                },
+                "max_rps": {
+                    "type": "integer"
+                },
+                "max_tpm": {
+                    "type": "integer"
+                },
+                "max_tps": {
+                    "type": "integer"
+                },
                 "models": {
                     "type": "array",
                     "items": {
@@ -9433,38 +9507,6 @@ const docTemplate = `{
                 }
             }
         },
-        "model.ModelCostRank": {
-            "type": "object",
-            "properties": {
-                "cache_creation_tokens": {
-                    "type": "integer"
-                },
-                "cached_tokens": {
-                    "type": "integer"
-                },
-                "input_tokens": {
-                    "type": "integer"
-                },
-                "model": {
-                    "type": "string"
-                },
-                "output_tokens": {
-                    "type": "integer"
-                },
-                "request_count": {
-                    "type": "integer"
-                },
-                "total_tokens": {
-                    "type": "integer"
-                },
-                "used_amount": {
-                    "type": "number"
-                },
-                "web_search_count": {
-                    "type": "integer"
-                }
-            }
-        },
         "model.ModelOwner": {
             "type": "string",
             "enum": [

+ 82 - 40
core/docs/swagger.json

@@ -928,12 +928,6 @@
                 ],
                 "summary": "Get dashboard data",
                 "parameters": [
-                    {
-                        "type": "string",
-                        "description": "Group or *",
-                        "name": "group",
-                        "in": "query"
-                    },
                     {
                         "type": "integer",
                         "description": "Channel ID",
@@ -4429,7 +4423,7 @@
                                         "data": {
                                             "type": "array",
                                             "items": {
-                                                "$ref": "#/definitions/model.ModelCostRank"
+                                                "$ref": "#/definitions/model.CostRank"
                                             }
                                         }
                                     }
@@ -4490,7 +4484,7 @@
                                         "data": {
                                             "type": "array",
                                             "items": {
-                                                "$ref": "#/definitions/model.ModelCostRank"
+                                                "$ref": "#/definitions/model.CostRank"
                                             }
                                         }
                                     }
@@ -8503,6 +8497,18 @@
                 "input_tokens": {
                     "type": "integer"
                 },
+                "max_rpm": {
+                    "type": "integer"
+                },
+                "max_rps": {
+                    "type": "integer"
+                },
+                "max_tpm": {
+                    "type": "integer"
+                },
+                "max_tps": {
+                    "type": "integer"
+                },
                 "output_tokens": {
                     "type": "integer"
                 },
@@ -8540,6 +8546,50 @@
                 }
             }
         },
+        "model.CostRank": {
+            "type": "object",
+            "properties": {
+                "cache_creation_tokens": {
+                    "type": "integer"
+                },
+                "cached_tokens": {
+                    "type": "integer"
+                },
+                "input_tokens": {
+                    "type": "integer"
+                },
+                "max_rpm": {
+                    "type": "integer"
+                },
+                "max_rps": {
+                    "type": "integer"
+                },
+                "max_tpm": {
+                    "type": "integer"
+                },
+                "max_tps": {
+                    "type": "integer"
+                },
+                "model": {
+                    "type": "string"
+                },
+                "output_tokens": {
+                    "type": "integer"
+                },
+                "request_count": {
+                    "type": "integer"
+                },
+                "total_tokens": {
+                    "type": "integer"
+                },
+                "used_amount": {
+                    "type": "number"
+                },
+                "web_search_count": {
+                    "type": "integer"
+                }
+            }
+        },
         "model.DashboardResponse": {
             "type": "object",
             "properties": {
@@ -8567,6 +8617,18 @@
                 "input_tokens": {
                     "type": "integer"
                 },
+                "max_rpm": {
+                    "type": "integer"
+                },
+                "max_rps": {
+                    "type": "integer"
+                },
+                "max_tpm": {
+                    "type": "integer"
+                },
+                "max_tps": {
+                    "type": "integer"
+                },
                 "output_tokens": {
                     "type": "integer"
                 },
@@ -8914,6 +8976,18 @@
                 "input_tokens": {
                     "type": "integer"
                 },
+                "max_rpm": {
+                    "type": "integer"
+                },
+                "max_rps": {
+                    "type": "integer"
+                },
+                "max_tpm": {
+                    "type": "integer"
+                },
+                "max_tps": {
+                    "type": "integer"
+                },
                 "models": {
                     "type": "array",
                     "items": {
@@ -9424,38 +9498,6 @@
                 }
             }
         },
-        "model.ModelCostRank": {
-            "type": "object",
-            "properties": {
-                "cache_creation_tokens": {
-                    "type": "integer"
-                },
-                "cached_tokens": {
-                    "type": "integer"
-                },
-                "input_tokens": {
-                    "type": "integer"
-                },
-                "model": {
-                    "type": "string"
-                },
-                "output_tokens": {
-                    "type": "integer"
-                },
-                "request_count": {
-                    "type": "integer"
-                },
-                "total_tokens": {
-                    "type": "integer"
-                },
-                "used_amount": {
-                    "type": "number"
-                },
-                "web_search_count": {
-                    "type": "integer"
-                }
-            }
-        },
         "model.ModelOwner": {
             "type": "string",
             "enum": [

+ 55 - 27
core/docs/swagger.yaml

@@ -654,6 +654,14 @@ definitions:
         type: integer
       input_tokens:
         type: integer
+      max_rpm:
+        type: integer
+      max_rps:
+        type: integer
+      max_tpm:
+        type: integer
+      max_tps:
+        type: integer
       output_tokens:
         type: integer
       request_count:
@@ -678,6 +686,35 @@ definitions:
       rejected_prediction_tokens:
         type: integer
     type: object
+  model.CostRank:
+    properties:
+      cache_creation_tokens:
+        type: integer
+      cached_tokens:
+        type: integer
+      input_tokens:
+        type: integer
+      max_rpm:
+        type: integer
+      max_rps:
+        type: integer
+      max_tpm:
+        type: integer
+      max_tps:
+        type: integer
+      model:
+        type: string
+      output_tokens:
+        type: integer
+      request_count:
+        type: integer
+      total_tokens:
+        type: integer
+      used_amount:
+        type: number
+      web_search_count:
+        type: integer
+    type: object
   model.DashboardResponse:
     properties:
       cache_creation_tokens:
@@ -696,6 +733,14 @@ definitions:
         type: integer
       input_tokens:
         type: integer
+      max_rpm:
+        type: integer
+      max_rps:
+        type: integer
+      max_tpm:
+        type: integer
+      max_tps:
+        type: integer
       output_tokens:
         type: integer
       rpm:
@@ -930,6 +975,14 @@ definitions:
         type: integer
       input_tokens:
         type: integer
+      max_rpm:
+        type: integer
+      max_rps:
+        type: integer
+      max_tpm:
+        type: integer
+      max_tps:
+        type: integer
       models:
         items:
           type: string
@@ -1272,27 +1325,6 @@ definitions:
       updated_at:
         type: string
     type: object
-  model.ModelCostRank:
-    properties:
-      cache_creation_tokens:
-        type: integer
-      cached_tokens:
-        type: integer
-      input_tokens:
-        type: integer
-      model:
-        type: string
-      output_tokens:
-        type: integer
-      request_count:
-        type: integer
-      total_tokens:
-        type: integer
-      used_amount:
-        type: number
-      web_search_count:
-        type: integer
-    type: object
   model.ModelOwner:
     enum:
     - openai
@@ -2240,10 +2272,6 @@ paths:
       description: Returns the general dashboard data including usage statistics and
         metrics
       parameters:
-      - description: Group or *
-        in: query
-        name: group
-        type: string
       - description: Channel ID
         in: query
         name: channel
@@ -4368,7 +4396,7 @@ paths:
             - properties:
                 data:
                   items:
-                    $ref: '#/definitions/model.ModelCostRank'
+                    $ref: '#/definitions/model.CostRank'
                   type: array
               type: object
       security:
@@ -4404,7 +4432,7 @@ paths:
             - properties:
                 data:
                   items:
-                    $ref: '#/definitions/model.ModelCostRank'
+                    $ref: '#/definitions/model.CostRank'
                   type: array
               type: object
       security:

+ 15 - 11
core/middleware/ctxkey.go

@@ -1,15 +1,19 @@
 package middleware
 
 const (
-	Channel         = "channel"
-	Group           = "group"
-	Token           = "token"
-	GroupBalance    = "group_balance"
-	RequestModel    = "request_model"
-	RequestUser     = "request_user"
-	RequestMetadata = "request_metadata"
-	RequestAt       = "request_at"
-	RequestID       = "request_id"
-	ModelCaches     = "model_caches"
-	ModelConfig     = "model_config"
+	Channel            = "channel"
+	GroupModelTokenRPM = "group_model_token_rpm"
+	GroupModelTokenRPS = "group_model_token_rps"
+	GroupModelTokenTPM = "group_model_token_tpm"
+	GroupModelTokenTPS = "group_model_token_tps"
+	Group              = "group"
+	Token              = "token"
+	GroupBalance       = "group_balance"
+	RequestModel       = "request_model"
+	RequestUser        = "request_user"
+	RequestMetadata    = "request_metadata"
+	RequestAt          = "request_at"
+	RequestID          = "request_id"
+	ModelCaches        = "model_caches"
+	ModelConfig        = "model_config"
 )

+ 63 - 18
core/middleware/distributor.go

@@ -16,7 +16,7 @@ import (
 	"github.com/labring/aiproxy/core/common/config"
 	"github.com/labring/aiproxy/core/common/consume"
 	"github.com/labring/aiproxy/core/common/notify"
-	"github.com/labring/aiproxy/core/common/rpmlimit"
+	"github.com/labring/aiproxy/core/common/reqlimit"
 	"github.com/labring/aiproxy/core/model"
 	"github.com/labring/aiproxy/core/relay/meta"
 	"github.com/labring/aiproxy/core/relay/mode"
@@ -96,39 +96,55 @@ func setTpmHeaders(c *gin.Context, tpm int64, remainingRequests int64) {
 	c.Header(XRateLimitResetTokens, "1m0s")
 }
 
-func checkGroupModelRPMAndTPM(c *gin.Context, group *model.GroupCache, mc *model.ModelConfig) error {
+func checkGroupModelRPMAndTPM(c *gin.Context, group *model.GroupCache, mc *model.ModelConfig, tokenName string) error {
 	log := GetLogger(c)
 
 	adjustedModelConfig := GetGroupAdjustedModelConfig(group, *mc)
 
-	count, overLimitCount := rpmlimit.PushRequestAnyWay(c.Request.Context(), group.ID, mc.Model, adjustedModelConfig.RPM, time.Minute)
-	log.Data["rpm"] = strconv.FormatInt(count+overLimitCount, 10)
+	groupModelCount, groupModelOverLimitCount, groupModelSecondCount := reqlimit.PushGroupModelRequest(c.Request.Context(), group.ID, mc.Model, adjustedModelConfig.RPM)
+	if group.Status != model.GroupStatusInternal {
+		log.Data["group_rpm"] = strconv.FormatInt(groupModelCount+groupModelOverLimitCount, 10)
+		log.Data["group_rps"] = strconv.FormatInt(groupModelSecondCount, 10)
+	}
+
+	groupModelTokenCount, groupModelTokenOverLimitCount, groupModelTokenSecondCount := reqlimit.PushGroupModelTokennameRequest(c.Request.Context(), group.ID, mc.Model, tokenName)
+	c.Set(GroupModelTokenRPM, groupModelTokenCount+groupModelTokenOverLimitCount)
+	c.Set(GroupModelTokenRPS, groupModelTokenSecondCount)
+	// log.Data["rpm"] = strconv.FormatInt(groupModelTokenCount+groupModelTokenOverLimitCount, 10)
+	// log.Data["rps"] = strconv.FormatInt(groupModelTokenSecondCount, 10)
+
 	if group.Status != model.GroupStatusInternal &&
 		adjustedModelConfig.RPM > 0 {
-		log.Data["rpm_limit"] = strconv.FormatInt(adjustedModelConfig.RPM, 10)
-		if count > adjustedModelConfig.RPM {
+		log.Data["group_rpm_limit"] = strconv.FormatInt(adjustedModelConfig.RPM, 10)
+		if groupModelCount > adjustedModelConfig.RPM {
 			setRpmHeaders(c, adjustedModelConfig.RPM, 0)
 			return ErrRequestRateLimitExceeded
 		}
-		setRpmHeaders(c, adjustedModelConfig.RPM, adjustedModelConfig.RPM-count)
+		setRpmHeaders(c, adjustedModelConfig.RPM, adjustedModelConfig.RPM-groupModelCount)
 	}
 
+	groupModelCountTPM, groupModelCountTPS := reqlimit.GetGroupModelTokensRequest(c.Request.Context(), group.ID, mc.Model)
+	if group.Status != model.GroupStatusInternal {
+		log.Data["group_tpm"] = strconv.FormatInt(groupModelCountTPM, 10)
+		log.Data["group_tps"] = strconv.FormatInt(groupModelCountTPS, 10)
+	}
+
+	groupModelTokenCountTPM, groupModelTokenCountTPS := reqlimit.GetGroupModelTokennameTokensRequest(c.Request.Context(), group.ID, mc.Model, tokenName)
+	c.Set(GroupModelTokenTPM, groupModelTokenCountTPM)
+	c.Set(GroupModelTokenTPS, groupModelTokenCountTPS)
+	// log.Data["tpm"] = strconv.FormatInt(groupModelTokenCountTPM, 10)
+	// log.Data["tps"] = strconv.FormatInt(groupModelTokenCountTPS, 10)
+
 	if group.Status != model.GroupStatusInternal &&
 		adjustedModelConfig.TPM > 0 {
-		tpm, err := model.CacheGetGroupModelTPM(group.ID, mc.Model)
-		if err != nil {
-			log.Errorf("get group model tpm (%s:%s) error: %s", group.ID, mc.Model, err.Error())
-			// ignore error
-			return nil
-		}
-		log.Data["tpm_limit"] = strconv.FormatInt(adjustedModelConfig.TPM, 10)
-		log.Data["tpm"] = strconv.FormatInt(tpm, 10)
-		if tpm >= adjustedModelConfig.TPM {
+		log.Data["group_tpm_limit"] = strconv.FormatInt(adjustedModelConfig.TPM, 10)
+		if groupModelCountTPM >= adjustedModelConfig.TPM {
 			setTpmHeaders(c, adjustedModelConfig.TPM, 0)
 			return ErrRequestTpmLimitExceeded
 		}
-		setTpmHeaders(c, adjustedModelConfig.TPM, adjustedModelConfig.TPM-tpm)
+		setTpmHeaders(c, adjustedModelConfig.TPM, adjustedModelConfig.TPM-groupModelCountTPM)
 	}
+
 	return nil
 }
 
@@ -378,7 +394,9 @@ func distribute(c *gin.Context, mode mode.Mode) {
 	}
 	c.Set(RequestMetadata, metadata)
 
-	if err := checkGroupModelRPMAndTPM(c, group, mc); err != nil {
+	token := GetToken(c)
+
+	if err := checkGroupModelRPMAndTPM(c, group, mc, token.Name); err != nil {
 		errMsg := err.Error()
 		consume.AsyncConsume(
 			nil,
@@ -394,6 +412,8 @@ func distribute(c *gin.Context, mode mode.Mode) {
 			true,
 			user,
 			metadata,
+			model.RequestRate{},
+			GetGroupModelTokenRequestRate(c),
 		)
 		AbortLogWithMessage(c, http.StatusTooManyRequests, errMsg, &ErrorField{
 			Type: "invalid_request_error",
@@ -405,6 +425,31 @@ func distribute(c *gin.Context, mode mode.Mode) {
 	c.Next()
 }
 
+func GetGroupModelTokenRequestRate(c *gin.Context) model.RequestRate {
+	return model.RequestRate{
+		RPM: GetGroupModelTokenRPM(c),
+		RPS: GetGroupModelTokenRPS(c),
+		TPM: GetGroupModelTokenTPM(c),
+		TPS: GetGroupModelTokenTPS(c),
+	}
+}
+
+func GetGroupModelTokenRPM(c *gin.Context) int64 {
+	return c.GetInt64(GroupModelTokenRPM)
+}
+
+func GetGroupModelTokenRPS(c *gin.Context) int64 {
+	return c.GetInt64(GroupModelTokenRPS)
+}
+
+func GetGroupModelTokenTPM(c *gin.Context) int64 {
+	return c.GetInt64(GroupModelTokenTPM)
+}
+
+func GetGroupModelTokenTPS(c *gin.Context) int64 {
+	return c.GetInt64(GroupModelTokenTPS)
+}
+
 func GetRequestModel(c *gin.Context) string {
 	return c.GetString(RequestModel)
 }

+ 41 - 4
core/model/batch.go

@@ -223,6 +223,13 @@ func processSummaryUpdates(wg *sync.WaitGroup) {
 	}
 }
 
+type RequestRate struct {
+	RPM int64
+	RPS int64
+	TPM int64
+	TPS int64
+}
+
 func BatchRecordLogs(
 	requestID string,
 	requestAt time.Time,
@@ -246,6 +253,8 @@ func BatchRecordLogs(
 	amount float64,
 	user string,
 	metadata map[string]string,
+	channelModelRate RequestRate,
+	groupModelTokenRate RequestRate,
 ) (err error) {
 	now := time.Now()
 
@@ -310,10 +319,10 @@ func BatchRecordLogs(
 	updateTokenData(tokenID, amount, amountDecimal)
 
 	if channelID != 0 {
-		updateSummaryData(channelID, modelName, now, code, amountDecimal, usage)
+		updateSummaryData(channelID, modelName, now, code, amountDecimal, usage, channelModelRate)
 	}
 
-	updateGroupSummaryData(group, tokenName, modelName, now, code, amountDecimal, usage)
+	updateGroupSummaryData(group, tokenName, modelName, now, code, amountDecimal, usage, groupModelTokenRate)
 
 	return err
 }
@@ -360,7 +369,7 @@ func updateTokenData(tokenID int, amount float64, amountDecimal decimal.Decimal)
 	}
 }
 
-func updateGroupSummaryData(group string, tokenName string, modelName string, createAt time.Time, code int, amountDecimal decimal.Decimal, usage Usage) {
+func updateGroupSummaryData(group string, tokenName string, modelName string, createAt time.Time, code int, amountDecimal decimal.Decimal, usage Usage, groupModelTokenRate RequestRate) {
 	groupUnique := GroupSummaryUnique{
 		GroupID:       group,
 		TokenName:     tokenName,
@@ -381,13 +390,27 @@ func updateGroupSummaryData(group string, tokenName string, modelName string, cr
 	groupSummary.UsedAmount = amountDecimal.
 		Add(decimal.NewFromFloat(groupSummary.UsedAmount)).
 		InexactFloat64()
+
+	if groupModelTokenRate.RPM > groupSummary.MaxRPM {
+		groupSummary.MaxRPM = groupModelTokenRate.RPM
+	}
+	if groupModelTokenRate.RPS > groupSummary.MaxRPS {
+		groupSummary.MaxRPS = groupModelTokenRate.RPS
+	}
+	if groupModelTokenRate.TPM > groupSummary.MaxTPM {
+		groupSummary.MaxTPM = groupModelTokenRate.TPM
+	}
+	if groupModelTokenRate.TPS > groupSummary.MaxTPS {
+		groupSummary.MaxTPS = groupModelTokenRate.TPS
+	}
+
 	groupSummary.Usage.Add(&usage)
 	if code != http.StatusOK {
 		groupSummary.ExceptionCount++
 	}
 }
 
-func updateSummaryData(channelID int, modelName string, createAt time.Time, code int, amountDecimal decimal.Decimal, usage Usage) {
+func updateSummaryData(channelID int, modelName string, createAt time.Time, code int, amountDecimal decimal.Decimal, usage Usage, channelModelRate RequestRate) {
 	summaryUnique := SummaryUnique{
 		ChannelID:     channelID,
 		Model:         modelName,
@@ -407,6 +430,20 @@ func updateSummaryData(channelID int, modelName string, createAt time.Time, code
 	summary.UsedAmount = amountDecimal.
 		Add(decimal.NewFromFloat(summary.UsedAmount)).
 		InexactFloat64()
+
+	if channelModelRate.RPM > summary.MaxRPM {
+		summary.MaxRPM = channelModelRate.RPM
+	}
+	if channelModelRate.RPS > summary.MaxRPS {
+		summary.MaxRPS = channelModelRate.RPS
+	}
+	if channelModelRate.TPM > summary.MaxTPM {
+		summary.MaxTPM = channelModelRate.TPM
+	}
+	if channelModelRate.TPS > summary.MaxTPS {
+		summary.MaxTPS = channelModelRate.TPS
+	}
+
 	summary.Usage.Add(&usage)
 	if code != http.StatusOK {
 		summary.ExceptionCount++

+ 0 - 31
core/model/cache.go

@@ -411,37 +411,6 @@ func CacheUpdateGroupUsedAmountOnlyIncrease(id string, amount float64) error {
 	return updateGroupUsedAmountOnlyIncreaseScript.Run(context.Background(), common.RDB, []string{fmt.Sprintf(GroupCacheKey, id)}, amount).Err()
 }
 
-//nolint:gosec
-func CacheGetGroupModelTPM(group string, model string) (int64, error) {
-	if !common.RedisEnabled {
-		return GetGroupModelTPM(group, model)
-	}
-
-	cacheKey := fmt.Sprintf(GroupModelTPMKey, group)
-	tpm, err := common.RDB.HGet(context.Background(), cacheKey, model).Int64()
-	if err == nil {
-		return tpm, nil
-	} else if !errors.Is(err, redis.Nil) {
-		log.Errorf("get group model tpm (%s:%s) from redis error: %s", group, model, err.Error())
-	}
-
-	tpm, err = GetGroupModelTPM(group, model)
-	if err != nil {
-		return 0, err
-	}
-
-	pipe := common.RDB.Pipeline()
-	pipe.HSet(context.Background(), cacheKey, model, tpm)
-	// 2-5 seconds
-	pipe.Expire(context.Background(), cacheKey, 2*time.Second+time.Duration(rand.Int64N(3))*time.Second)
-	_, err = pipe.Exec(context.Background())
-	if err != nil {
-		log.Errorf("set group model tpm (%s:%s) to redis error: %s", group, model, err.Error())
-	}
-
-	return tpm, nil
-}
-
 //nolint:revive
 type ModelConfigCache interface {
 	GetModelConfig(model string) (*ModelConfig, bool)

+ 38 - 117
core/model/log.go

@@ -1139,6 +1139,11 @@ type ChartData struct {
 	TotalTokens         int64   `json:"total_tokens,omitempty"`
 	ExceptionCount      int64   `json:"exception_count"`
 	WebSearchCount      int64   `json:"web_search_count,omitempty"`
+
+	MaxRPM int64 `json:"max_rpm"`
+	MaxTPM int64 `json:"max_tpm"`
+	MaxRPS int64 `json:"max_rps"`
+	MaxTPS int64 `json:"max_tps"`
 }
 
 type DashboardResponse struct {
@@ -1149,6 +1154,11 @@ type DashboardResponse struct {
 	RPM int64 `json:"rpm"`
 	TPM int64 `json:"tpm"`
 
+	MaxRPM int64 `json:"max_rpm"`
+	MaxTPM int64 `json:"max_tpm"`
+	MaxRPS int64 `json:"max_rps"`
+	MaxTPS int64 `json:"max_tps"`
+
 	UsedAmount          float64 `json:"used_amount"`
 	InputTokens         int64   `json:"input_tokens,omitempty"`
 	OutputTokens        int64   `json:"output_tokens,omitempty"`
@@ -1206,6 +1216,19 @@ func aggregateHourDataToDay(hourlyData []*ChartData, timezone *time.Location) []
 		day.CacheCreationTokens += data.CacheCreationTokens
 		day.TotalTokens += data.TotalTokens
 		day.WebSearchCount += data.WebSearchCount
+
+		if data.MaxRPM > day.MaxRPM {
+			day.MaxRPM = data.MaxRPM
+		}
+		if data.MaxTPM > day.MaxTPM {
+			day.MaxTPM = data.MaxTPM
+		}
+		if data.MaxRPS > day.MaxRPS {
+			day.MaxRPS = data.MaxRPS
+		}
+		if data.MaxTPS > day.MaxTPS {
+			day.MaxTPS = data.MaxTPS
+		}
 	}
 
 	result := make([]*ChartData, 0, len(dayData))
@@ -1279,70 +1302,30 @@ func sumDashboardResponse(chartData []*ChartData) DashboardResponse {
 		dashboardResponse.CachedTokens += data.CachedTokens
 		dashboardResponse.CacheCreationTokens += data.CacheCreationTokens
 		dashboardResponse.WebSearchCount += data.WebSearchCount
+
+		if data.MaxRPM > dashboardResponse.MaxRPM {
+			dashboardResponse.MaxRPM = data.MaxRPM
+		}
+		if data.MaxTPM > dashboardResponse.MaxTPM {
+			dashboardResponse.MaxTPM = data.MaxTPM
+		}
+		if data.MaxRPS > dashboardResponse.MaxRPS {
+			dashboardResponse.MaxRPS = data.MaxRPS
+		}
+		if data.MaxTPS > dashboardResponse.MaxTPS {
+			dashboardResponse.MaxTPS = data.MaxTPS
+		}
 	}
 	dashboardResponse.UsedAmount = usedAmount.InexactFloat64()
 	return dashboardResponse
 }
 
-func GetRPM(group string, end time.Time, tokenName, modelName string, channelID int) (int64, error) {
-	query := LogDB.Model(&Log{})
-
-	if group == "" {
-		query = query.Where("group_id = ''")
-	} else if group != "*" {
-		query = query.Where("group_id = ?", group)
-	}
-	if channelID != 0 {
-		query = query.Where("channel_id = ?", channelID)
-	}
-	if modelName != "" {
-		query = query.Where("model = ?", modelName)
-	}
-	if tokenName != "" {
-		query = query.Where("token_name = ?", tokenName)
-	}
-
-	var count int64
-	err := query.
-		Where("created_at BETWEEN ? AND ?", end.Add(-time.Minute), end).
-		Count(&count).Error
-	return count, err
-}
-
-func GetTPM(group string, end time.Time, tokenName, modelName string, channelID int) (int64, error) {
-	query := LogDB.Model(&Log{}).
-		Select("COALESCE(SUM(total_tokens), 0)")
-
-	if group == "" {
-		query = query.Where("group_id = ''")
-	} else if group != "*" {
-		query = query.Where("group_id = ?", group)
-	}
-	if channelID != 0 {
-		query = query.Where("channel_id = ?", channelID)
-	}
-	if modelName != "" {
-		query = query.Where("model = ?", modelName)
-	}
-	if tokenName != "" {
-		query = query.Where("token_name = ?", tokenName)
-	}
-
-	var tpm int64
-	err := query.
-		Where("created_at BETWEEN ? AND ?", end.Add(-time.Minute), end).
-		Scan(&tpm).Error
-	return tpm, err
-}
-
 func GetDashboardData(
-	group string,
 	start,
 	end time.Time,
 	modelName string,
 	channelID int,
 	timeSpan TimeSpanType,
-	needRPM bool,
 	timezone *time.Location,
 ) (*DashboardResponse, error) {
 	if end.IsZero() {
@@ -1353,8 +1336,6 @@ func GetDashboardData(
 
 	var (
 		chartData []*ChartData
-		rpm       int64
-		tpm       int64
 		channels  []int
 	)
 
@@ -1362,27 +1343,13 @@ func GetDashboardData(
 
 	g.Go(func() error {
 		var err error
-		chartData, err = getChartData(group, start, end, "", modelName, channelID, timeSpan, timezone)
-		return err
-	})
-
-	if needRPM {
-		g.Go(func() error {
-			var err error
-			rpm, err = GetRPM(group, end, "", modelName, channelID)
-			return err
-		})
-	}
-
-	g.Go(func() error {
-		var err error
-		tpm, err = GetTPM(group, end, "", modelName, channelID)
+		chartData, err = getChartData("*", start, end, "", modelName, channelID, timeSpan, timezone)
 		return err
 	})
 
 	g.Go(func() error {
 		var err error
-		channels, err = GetUsedChannels(group, start, end)
+		channels, err = GetUsedChannels("*", start, end)
 		return err
 	})
 
@@ -1392,8 +1359,6 @@ func GetDashboardData(
 
 	dashboardResponse := sumDashboardResponse(chartData)
 	dashboardResponse.Channels = channels
-	dashboardResponse.RPM = rpm
-	dashboardResponse.TPM = tpm
 
 	return &dashboardResponse, nil
 }
@@ -1404,7 +1369,6 @@ func GetGroupDashboardData(
 	tokenName string,
 	modelName string,
 	timeSpan TimeSpanType,
-	needRPM bool,
 	timezone *time.Location,
 ) (*GroupDashboardResponse, error) {
 	if group == "" || group == "*" {
@@ -1421,8 +1385,6 @@ func GetGroupDashboardData(
 		chartData  []*ChartData
 		tokenNames []string
 		models     []string
-		rpm        int64
-		tpm        int64
 	)
 
 	g := new(errgroup.Group)
@@ -1445,27 +1407,11 @@ func GetGroupDashboardData(
 		return err
 	})
 
-	if needRPM {
-		g.Go(func() error {
-			var err error
-			rpm, err = GetRPM(group, end, tokenName, modelName, 0)
-			return err
-		})
-	}
-
-	g.Go(func() error {
-		var err error
-		tpm, err = GetTPM(group, end, tokenName, modelName, 0)
-		return err
-	})
-
 	if err := g.Wait(); err != nil {
 		return nil, err
 	}
 
 	dashboardResponse := sumDashboardResponse(chartData)
-	dashboardResponse.RPM = rpm
-	dashboardResponse.TPM = tpm
 
 	return &GroupDashboardResponse{
 		DashboardResponse: dashboardResponse,
@@ -1474,31 +1420,6 @@ func GetGroupDashboardData(
 	}, nil
 }
 
-func GetGroupModelTPM(group string, model string) (int64, error) {
-	end := time.Now()
-	start := end.Add(-time.Minute)
-	var tpm int64
-	err := LogDB.
-		Model(&Log{}).
-		Where("group_id = ? AND created_at >= ? AND created_at <= ? AND model = ?", group, start, end, model).
-		Select("COALESCE(SUM(total_tokens), 0)").
-		Scan(&tpm).Error
-	return tpm, err
-}
-
-//nolint:revive
-type ModelCostRank struct {
-	Model               string  `json:"model"`
-	UsedAmount          float64 `json:"used_amount"`
-	InputTokens         int64   `json:"input_tokens"`
-	OutputTokens        int64   `json:"output_tokens"`
-	CachedTokens        int64   `json:"cached_tokens"`
-	CacheCreationTokens int64   `json:"cache_creation_tokens"`
-	TotalTokens         int64   `json:"total_tokens"`
-	RequestCount        int64   `json:"request_count"`
-	WebSearchCount      int64   `json:"web_search_count"`
-}
-
 func GetIPGroups(threshold int, start, end time.Time) (map[string][]string, error) {
 	if threshold < 1 {
 		threshold = 1

+ 45 - 8
core/model/summary.go

@@ -28,20 +28,40 @@ type SummaryData struct {
 	RequestCount   int64   `json:"request_count"`
 	UsedAmount     float64 `json:"used_amount"`
 	ExceptionCount int64   `json:"exception_count"`
+	MaxRPM         int64   `json:"max_rpm"`
+	MaxRPS         int64   `json:"max_rps"`
+	MaxTPM         int64   `json:"max_tpm"`
+	MaxTPS         int64   `json:"max_tps"`
 	Usage          Usage   `gorm:"embedded"        json:"usage,omitempty"`
 }
 
 func (d *SummaryData) buildUpdateData(tableName string) map[string]any {
 	data := map[string]any{}
 	if d.RequestCount > 0 {
-		data["request_count"] = gorm.Expr(fmt.Sprintf("COALESCE(%s.request_count, 0) + ?", tableName), d.RequestCount)
+		data["request_count"] = gorm.Expr(tableName+".request_count + ?", d.RequestCount)
 	}
 	if d.UsedAmount > 0 {
-		data["used_amount"] = gorm.Expr(fmt.Sprintf("COALESCE(%s.used_amount, 0) + ?", tableName), d.UsedAmount)
+		data["used_amount"] = gorm.Expr(tableName+".used_amount + ?", d.UsedAmount)
 	}
 	if d.ExceptionCount > 0 {
-		data["exception_count"] = gorm.Expr(fmt.Sprintf("COALESCE(%s.exception_count, 0) + ?", tableName), d.ExceptionCount)
+		data["exception_count"] = gorm.Expr(tableName+".exception_count + ?", d.ExceptionCount)
 	}
+
+	// max rpm tpm update
+	if d.MaxRPM > 0 {
+		data["max_rpm"] = gorm.Expr(fmt.Sprintf("CASE WHEN %s.max_rpm < ? THEN ? ELSE %s.max_rpm END", tableName, tableName), d.MaxRPM, d.MaxRPM)
+	}
+	if d.MaxRPS > 0 {
+		data["max_rps"] = gorm.Expr(fmt.Sprintf("CASE WHEN %s.max_rps < ? THEN ? ELSE %s.max_rps END", tableName, tableName), d.MaxRPS, d.MaxRPS)
+	}
+	if d.MaxTPM > 0 {
+		data["max_tpm"] = gorm.Expr(fmt.Sprintf("CASE WHEN %s.max_tpm < ? THEN ? ELSE %s.max_tpm END", tableName, tableName), d.MaxTPM, d.MaxTPM)
+	}
+	if d.MaxTPS > 0 {
+		data["max_tps"] = gorm.Expr(fmt.Sprintf("CASE WHEN %s.max_tps < ? THEN ? ELSE %s.max_tps END", tableName, tableName), d.MaxTPS, d.MaxTPS)
+	}
+
+	// usage update
 	if d.Usage.InputTokens > 0 {
 		data["input_tokens"] = gorm.Expr(fmt.Sprintf("COALESCE(%s.input_tokens, 0) + ?", tableName), d.Usage.InputTokens)
 	}
@@ -191,7 +211,7 @@ func getChartData(
 	}
 
 	query = query.
-		Select("hour_timestamp as timestamp, sum(request_count) as request_count, sum(used_amount) as used_amount, sum(exception_count) as exception_count, sum(input_tokens) as input_tokens, sum(output_tokens) as output_tokens, sum(cached_tokens) as cached_tokens, sum(cache_creation_tokens) as cache_creation_tokens, sum(total_tokens) as total_tokens, sum(web_search_count) as web_search_count").
+		Select("hour_timestamp as timestamp, sum(request_count) as request_count, sum(used_amount) as used_amount, sum(exception_count) as exception_count, sum(input_tokens) as input_tokens, sum(output_tokens) as output_tokens, sum(cached_tokens) as cached_tokens, sum(cache_creation_tokens) as cache_creation_tokens, sum(total_tokens) as total_tokens, sum(web_search_count) as web_search_count, max(max_rpm) as max_rpm, max(max_rps) as max_rps, max(max_tpm) as max_tpm, max(max_tps) as max_tps").
 		Group("timestamp").
 		Order("timestamp ASC")
 
@@ -281,8 +301,25 @@ func getLogGroupByValues[T cmp.Ordered](field string, group string, tokenName st
 	return values, nil
 }
 
-func GetModelCostRank(group string, channelID int, start, end time.Time) ([]*ModelCostRank, error) {
-	var ranks []*ModelCostRank
+type CostRank struct {
+	Model               string  `json:"model"`
+	UsedAmount          float64 `json:"used_amount"`
+	InputTokens         int64   `json:"input_tokens"`
+	OutputTokens        int64   `json:"output_tokens"`
+	CachedTokens        int64   `json:"cached_tokens"`
+	CacheCreationTokens int64   `json:"cache_creation_tokens"`
+	TotalTokens         int64   `json:"total_tokens"`
+	RequestCount        int64   `json:"request_count"`
+	WebSearchCount      int64   `json:"web_search_count"`
+
+	MaxRPM int64 `json:"max_rpm"`
+	MaxRPS int64 `json:"max_rps"`
+	MaxTPM int64 `json:"max_tpm"`
+	MaxTPS int64 `json:"max_tps"`
+}
+
+func GetModelCostRank(group string, channelID int, start, end time.Time) ([]*CostRank, error) {
+	var ranks []*CostRank
 
 	var query *gorm.DB
 	if group == "*" || channelID != 0 {
@@ -305,7 +342,7 @@ func GetModelCostRank(group string, channelID int, start, end time.Time) ([]*Mod
 	}
 
 	query = query.
-		Select("model, SUM(used_amount) as used_amount, SUM(request_count) as request_count, SUM(input_tokens) as input_tokens, SUM(output_tokens) as output_tokens, SUM(cached_tokens) as cached_tokens, SUM(cache_creation_tokens) as cache_creation_tokens, SUM(total_tokens) as total_tokens").
+		Select("model, SUM(used_amount) as used_amount, SUM(request_count) as request_count, SUM(input_tokens) as input_tokens, SUM(output_tokens) as output_tokens, SUM(cached_tokens) as cached_tokens, SUM(cache_creation_tokens) as cache_creation_tokens, SUM(total_tokens) as total_tokens, max(max_rpm) as max_rpm, max(max_rps) as max_rps, max(max_tpm) as max_tps, max(max_tps) as max_tps").
 		Group("model")
 
 	err := query.Scan(&ranks).Error
@@ -313,7 +350,7 @@ func GetModelCostRank(group string, channelID int, start, end time.Time) ([]*Mod
 		return nil, err
 	}
 
-	slices.SortFunc(ranks, func(a, b *ModelCostRank) int {
+	slices.SortFunc(ranks, func(a, b *CostRank) int {
 		if a.UsedAmount != b.UsedAmount {
 			return cmp.Compare(b.UsedAmount, a.UsedAmount)
 		}

+ 2 - 19
core/relay/controller/handle.go

@@ -1,15 +1,11 @@
 package controller
 
 import (
-	"fmt"
-	"net/http"
-
 	"github.com/gin-gonic/gin"
 	"github.com/labring/aiproxy/core/common/config"
 	"github.com/labring/aiproxy/core/middleware"
 	"github.com/labring/aiproxy/core/model"
-	"github.com/labring/aiproxy/core/relay/adaptor/openai"
-	"github.com/labring/aiproxy/core/relay/adaptors"
+	"github.com/labring/aiproxy/core/relay/adaptor"
 	"github.com/labring/aiproxy/core/relay/meta"
 	relaymodel "github.com/labring/aiproxy/core/relay/model"
 )
@@ -21,22 +17,9 @@ type HandleResult struct {
 	Detail *RequestDetail
 }
 
-var ErrInvalidChannelTypeCode = "invalid_channel_type"
-
-func Handle(meta *meta.Meta, c *gin.Context) *HandleResult {
+func Handle(adaptor adaptor.Adaptor, c *gin.Context, meta *meta.Meta) *HandleResult {
 	log := middleware.GetLogger(c)
 
-	adaptor, ok := adaptors.GetAdaptor(meta.Channel.Type)
-	if !ok {
-		return &HandleResult{
-			Error: openai.ErrorWrapperWithMessage(
-				fmt.Sprintf("invalid channel type: %d", meta.Channel.Type),
-				ErrInvalidChannelTypeCode,
-				http.StatusInternalServerError,
-			),
-		}
-	}
-
 	usage, detail, respErr := DoHelper(adaptor, c, meta)
 	if respErr != nil {
 		var logDetail *RequestDetail

+ 9 - 0
core/relay/meta/meta.go

@@ -160,6 +160,15 @@ func (m *Meta) GetBool(key string) bool {
 	return b
 }
 
+func (m *Meta) GetInt64(key string) int64 {
+	v, ok := m.Get(key)
+	if !ok {
+		return 0
+	}
+	i, _ := v.(int64)
+	return i
+}
+
 func GetMappedModelName(modelName string, mapping map[string]string) (string, bool) {
 	if len(modelName) == 0 {
 		return modelName, false