Преглед на файлове

feat: glm coding api support (#397)

* feat glm coding api support

* fix: ci lint
zijiren преди 2 месеца
родител
ревизия
5fb2596ed7
променени са 6 файла, в които са добавени 233 реда и са изтрити 13 реда
  1. 20 9
      core/controller/token.go
  2. 2 0
      core/model/chtype.go
  3. 89 3
      core/model/token.go
  4. 15 1
      core/relay/adaptor/anthropic/main.go
  5. 105 0
      core/relay/adaptor/zhipucoding/adaptor.go
  6. 2 0
      core/relay/adaptors/register.go

+ 20 - 9
core/controller/token.go

@@ -38,10 +38,13 @@ func (t *TokenResponse) MarshalJSON() ([]byte, error) {
 
 type (
 	AddTokenRequest struct {
-		Name    string   `json:"name"`
-		Subnets []string `json:"subnets"`
-		Models  []string `json:"models"`
-		Quota   float64  `json:"quota"`
+		Name                 string   `json:"name"`
+		Subnets              []string `json:"subnets"`
+		Models               []string `json:"models"`
+		Quota                float64  `json:"quota"`
+		PeriodQuota          float64  `json:"period_quota"`
+		PeriodType           string   `json:"period_type"`
+		PeriodLastUpdateTime int64    `json:"period_last_update_time"`
 	}
 
 	UpdateTokenStatusRequest struct {
@@ -54,12 +57,20 @@ type (
 )
 
 func (at *AddTokenRequest) ToToken() *model.Token {
-	return &model.Token{
-		Name:    model.EmptyNullString(at.Name),
-		Subnets: at.Subnets,
-		Models:  at.Models,
-		Quota:   at.Quota,
+	token := &model.Token{
+		Name:        model.EmptyNullString(at.Name),
+		Subnets:     at.Subnets,
+		Models:      at.Models,
+		Quota:       at.Quota,
+		PeriodQuota: at.PeriodQuota,
+		PeriodType:  model.EmptyNullString(at.PeriodType),
 	}
+
+	if at.PeriodLastUpdateTime > 0 {
+		token.PeriodLastUpdateTime = time.UnixMilli(at.PeriodLastUpdateTime)
+	}
+
+	return token
 }
 
 func validateToken(token AddTokenRequest) error {

+ 2 - 0
core/model/chtype.go

@@ -51,6 +51,7 @@ const (
 	ChannelTypeQianfan                 ChannelType = 49
 	ChannelTypeSangforAICP             ChannelType = 50
 	ChannelTypeStreamlake              ChannelType = 51
+	ChannelTypeZhipuCoding             ChannelType = 52
 )
 
 var channelTypeNames = map[ChannelType]string{
@@ -93,6 +94,7 @@ var channelTypeNames = map[ChannelType]string{
 	ChannelTypeQianfan:                 "qianfan",
 	ChannelTypeSangforAICP:             "Sangfor AICP",
 	ChannelTypeStreamlake:              "Streamlake",
+	ChannelTypeZhipuCoding:             "zhipu coding",
 }
 
 func AllChannelTypes() []ChannelType {

+ 89 - 3
core/model/token.go

@@ -886,11 +886,87 @@ func UpdateTokenUsedAmount(id int, amount float64, requestCount int) (err error)
 	return HandleUpdateResult(result, ErrTokenNotFound)
 }
 
+// calculateNextPeriodStartTime calculates the next period start time based on the last update time and period type
+// This finds the most recent period boundary by incrementing from lastUpdateTime until we reach the current time
+// This maintains period continuity - e.g., if reset was on Jan 15, next periods are Feb 15, Mar 15, etc.
+func calculateNextPeriodStartTime(lastUpdateTime time.Time, periodType EmptyNullString) time.Time {
+	if lastUpdateTime.IsZero() {
+		// If never initialized, return current time
+		return time.Now()
+	}
+
+	now := time.Now()
+
+	// If we haven't passed the period yet, no reset needed
+	if !now.After(lastUpdateTime) {
+		return lastUpdateTime
+	}
+
+	switch periodType {
+	case "", PeriodTypeMonthly:
+		// Start from lastUpdateTime and keep adding months until we find the most recent period start
+		nextPeriod := lastUpdateTime
+		for {
+			// Calculate next month period
+			candidate := time.Date(
+				nextPeriod.Year(),
+				nextPeriod.Month()+1,
+				nextPeriod.Day(),
+				nextPeriod.Hour(),
+				nextPeriod.Minute(),
+				nextPeriod.Second(),
+				nextPeriod.Nanosecond(),
+				nextPeriod.Location(),
+			)
+
+			// If candidate is in the future, the current nextPeriod is the one we want
+			if candidate.After(now) {
+				return nextPeriod
+			}
+
+			nextPeriod = candidate
+		}
+
+	case PeriodTypeWeekly:
+		// Calculate how many complete weeks have passed since lastUpdateTime
+		daysSinceLastUpdate := now.Sub(lastUpdateTime).Hours() / 24
+		weeksPassed := int(daysSinceLastUpdate / 7)
+
+		if weeksPassed == 0 {
+			// Still in the same week period, no reset needed
+			return lastUpdateTime
+		}
+
+		// Return the start of the most recent week period
+		// This is lastUpdateTime + (weeksPassed * 7 days)
+		return lastUpdateTime.Add(time.Duration(weeksPassed*7*24) * time.Hour)
+
+	case PeriodTypeDaily:
+		// Calculate how many complete days have passed since lastUpdateTime
+		daysSinceLastUpdate := int(now.Sub(lastUpdateTime).Hours() / 24)
+
+		if daysSinceLastUpdate == 0 {
+			// Still in the same day period, no reset needed
+			return lastUpdateTime
+		}
+
+		// Return the start of the most recent day period
+		// This is lastUpdateTime + (daysPassed * 1 day)
+		return lastUpdateTime.Add(time.Duration(daysSinceLastUpdate*24) * time.Hour)
+
+	default:
+		// Fallback to current time for unknown period types
+		return now
+	}
+}
+
 // ResetTokenPeriodUsage resets the period usage for a token with concurrency safety
 // This updates PeriodLastUpdateTime and PeriodLastUpdateAmount to current values
 func ResetTokenPeriodUsage(id int) error {
 	token := &Token{}
 
+	var newPeriodStartTime time.Time
+
 	// Use database transaction with optimistic locking to prevent concurrent resets
 	err := DB.Transaction(func(tx *gorm.DB) error {
 		// First, read the current state with FOR UPDATE lock
@@ -911,6 +987,16 @@ func ResetTokenPeriodUsage(id int) error {
 			return nil
 		}
 
+		// Calculate the correct next period start time based on period type
+		newPeriodStartTime = calculateNextPeriodStartTime(
+			token.PeriodLastUpdateTime,
+			token.PeriodType,
+		)
+
+		if newPeriodStartTime.IsZero() {
+			return errors.New("next period start time is zero")
+		}
+
 		// Perform the reset with the lock held - update period last update time and amount
 		result := tx.
 			Model(token).
@@ -922,7 +1008,7 @@ func ResetTokenPeriodUsage(id int) error {
 			Where("id = ?", id).
 			Updates(
 				map[string]any{
-					"period_last_update_time": time.Now(),
+					"period_last_update_time": newPeriodStartTime,
 					"period_last_update_amount": gorm.Expr(
 						"used_amount",
 					), // Set to current total usage
@@ -933,8 +1019,8 @@ func ResetTokenPeriodUsage(id int) error {
 	})
 
 	// Update cache only if database update succeeded
-	if err == nil && token.Key != "" {
-		if cacheErr := CacheResetTokenPeriodUsage(token.Key, time.Now(), token.UsedAmount); cacheErr != nil {
+	if err == nil && token.Key != "" && !newPeriodStartTime.IsZero() {
+		if cacheErr := CacheResetTokenPeriodUsage(token.Key, newPeriodStartTime, token.UsedAmount); cacheErr != nil {
 			log.Error("reset token period usage in cache failed: " + cacheErr.Error())
 		}
 	}

+ 15 - 1
core/relay/adaptor/anthropic/main.go

@@ -24,7 +24,11 @@ import (
 	"golang.org/x/sync/semaphore"
 )
 
-func ConvertRequest(meta *meta.Meta, req *http.Request) (adaptor.ConvertResult, error) {
+func ConvertRequest(
+	meta *meta.Meta,
+	req *http.Request,
+	callbacks ...func(node *ast.Node) error,
+) (adaptor.ConvertResult, error) {
 	// Parse request body into AST node
 	node, err := common.UnmarshalRequest2NodeReusable(req)
 	if err != nil {
@@ -43,6 +47,16 @@ func ConvertRequest(meta *meta.Meta, req *http.Request) (adaptor.ConvertResult,
 		return adaptor.ConvertResult{}, err
 	}
 
+	for _, callback := range callbacks {
+		if callback == nil {
+			continue
+		}
+
+		if err := callback(&node); err != nil {
+			return adaptor.ConvertResult{}, err
+		}
+	}
+
 	// Serialize the modified node
 	newBody, err := node.MarshalJSON()
 	if err != nil {

+ 105 - 0
core/relay/adaptor/zhipucoding/adaptor.go

@@ -0,0 +1,105 @@
+package zhipucoding
+
+import (
+	"net/http"
+	"net/url"
+
+	"github.com/bytedance/sonic/ast"
+	"github.com/gin-gonic/gin"
+	"github.com/labring/aiproxy/core/model"
+	"github.com/labring/aiproxy/core/relay/adaptor"
+	"github.com/labring/aiproxy/core/relay/adaptor/anthropic"
+	"github.com/labring/aiproxy/core/relay/adaptor/openai"
+	"github.com/labring/aiproxy/core/relay/adaptor/zhipu"
+	"github.com/labring/aiproxy/core/relay/meta"
+	"github.com/labring/aiproxy/core/relay/mode"
+	"github.com/labring/aiproxy/core/relay/utils"
+)
+
+var _ adaptor.Adaptor = (*Adaptor)(nil)
+
+type Adaptor struct {
+	openai.Adaptor
+}
+
+const baseURL = "https://open.bigmodel.cn"
+
+func (a *Adaptor) DefaultBaseURL() string {
+	return baseURL
+}
+
+func (a *Adaptor) SupportMode(m mode.Mode) bool {
+	return m == mode.ChatCompletions ||
+		m == mode.Completions ||
+		m == mode.Anthropic
+}
+
+func (a *Adaptor) GetRequestURL(meta *meta.Meta, store adaptor.Store) (adaptor.RequestURL, error) {
+	u := meta.Channel.BaseURL
+
+	switch meta.Mode {
+	case mode.Anthropic:
+		url, err := url.JoinPath(u, "/api/anthropic/v1/messages")
+		if err != nil {
+			return adaptor.RequestURL{}, err
+		}
+
+		return adaptor.RequestURL{
+			Method: http.MethodPost,
+			URL:    url,
+		}, nil
+	default:
+		meta.Channel.BaseURL += "/api/coding/paas/v4"
+		defer func() {
+			meta.Channel.BaseURL = u
+		}()
+
+		return a.Adaptor.GetRequestURL(meta, store)
+	}
+}
+
+func (a *Adaptor) ConvertRequest(
+	meta *meta.Meta,
+	store adaptor.Store,
+	req *http.Request,
+) (adaptor.ConvertResult, error) {
+	switch meta.Mode {
+	case mode.Anthropic:
+		return anthropic.ConvertRequest(meta, req, func(node *ast.Node) error {
+			if !node.Get("max_tokens").Exists() {
+				_, err := node.Set("max_tokens", ast.NewNumber("4096"))
+				return err
+			}
+
+			return nil
+		})
+	default:
+		return a.Adaptor.ConvertRequest(meta, store, req)
+	}
+}
+
+func (a *Adaptor) DoResponse(
+	meta *meta.Meta,
+	store adaptor.Store,
+	c *gin.Context,
+	resp *http.Response,
+) (usage model.Usage, err adaptor.Error) {
+	switch meta.Mode {
+	case mode.Anthropic:
+		if utils.IsStreamResponse(resp) {
+			usage, err = anthropic.StreamHandler(meta, c, resp)
+		} else {
+			usage, err = anthropic.Handler(meta, c, resp)
+		}
+	default:
+		usage, err = a.Adaptor.DoResponse(meta, store, c, resp)
+	}
+
+	return usage, err
+}
+
+func (a *Adaptor) Metadata() adaptor.Metadata {
+	return adaptor.Metadata{
+		Models: zhipu.ModelList,
+	}
+}

+ 2 - 0
core/relay/adaptors/register.go

@@ -42,6 +42,7 @@ import (
 	"github.com/labring/aiproxy/core/relay/adaptor/xai"
 	"github.com/labring/aiproxy/core/relay/adaptor/xunfei"
 	"github.com/labring/aiproxy/core/relay/adaptor/zhipu"
+	"github.com/labring/aiproxy/core/relay/adaptor/zhipucoding"
 	log "github.com/sirupsen/logrus"
 )
 
@@ -85,6 +86,7 @@ var ChannelAdaptor = map[model.ChannelType]adaptor.Adaptor{
 	model.ChannelTypeQianfan:                 &qianfan.Adaptor{},
 	model.ChannelTypeSangforAICP:             &sangforaicp.Adaptor{},
 	model.ChannelTypeStreamlake:              &streamlake.Adaptor{},
+	model.ChannelTypeZhipuCoding:             &zhipucoding.Adaptor{},
 }
 
 func GetAdaptor(channelType model.ChannelType) (adaptor.Adaptor, bool) {