Browse Source

refactor: improve request type validation and enhance sensitive information masking

CaIon 6 months ago
parent
commit
5fe1ce89ec

+ 46 - 43
common/str.go

@@ -117,6 +117,48 @@ func MaskEmail(email string) string {
 	return "***@" + email[atIndex+1:]
 }
 
+// maskHostTail returns the tail parts of a domain/host that should be preserved.
+// It keeps 2 parts for likely country-code TLDs (e.g., co.uk, com.cn), otherwise keeps only the TLD.
+func maskHostTail(parts []string) []string {
+	if len(parts) < 2 {
+		return parts
+	}
+	lastPart := parts[len(parts)-1]
+	secondLastPart := parts[len(parts)-2]
+	if len(lastPart) == 2 && len(secondLastPart) <= 3 {
+		// Likely country code TLD like co.uk, com.cn
+		return []string{secondLastPart, lastPart}
+	}
+	return []string{lastPart}
+}
+
+// maskHostForURL collapses subdomains and keeps only masked prefix + preserved tail.
+// Example: api.openai.com -> ***.com, sub.domain.co.uk -> ***.co.uk
+func maskHostForURL(host string) string {
+	parts := strings.Split(host, ".")
+	if len(parts) < 2 {
+		return "***"
+	}
+	tail := maskHostTail(parts)
+	return "***." + strings.Join(tail, ".")
+}
+
+// maskHostForPlainDomain masks a plain domain and reflects subdomain depth with multiple ***.
+// Example: openai.com -> ***.com, api.openai.com -> ***.***.com, sub.domain.co.uk -> ***.***.co.uk
+func maskHostForPlainDomain(domain string) string {
+	parts := strings.Split(domain, ".")
+	if len(parts) < 2 {
+		return domain
+	}
+	tail := maskHostTail(parts)
+	numStars := len(parts) - len(tail)
+	if numStars < 1 {
+		numStars = 1
+	}
+	stars := strings.TrimSuffix(strings.Repeat("***.", numStars), ".")
+	return stars + "." + strings.Join(tail, ".")
+}
+
 // MaskSensitiveInfo masks sensitive information like URLs, IPs, and domain names in a string
 // Example:
 // http://example.com -> http://***.com
@@ -140,32 +182,8 @@ func MaskSensitiveInfo(str string) string {
 			return urlStr
 		}
 
-		// Split host by dots
-		parts := strings.Split(host, ".")
-		if len(parts) < 2 {
-			// If less than 2 parts, just mask the whole host
-			return u.Scheme + "://***" + u.Path
-		}
-
-		// Keep the TLD (Top Level Domain) and mask the rest
-		var maskedHost string
-		if len(parts) == 2 {
-			// example.com -> ***.com
-			maskedHost = "***." + parts[len(parts)-1]
-		} else {
-			// Handle cases like sub.domain.co.uk or api.example.com
-			// Keep last 2 parts if they look like country code TLD (co.uk, com.cn, etc.)
-			lastPart := parts[len(parts)-1]
-			secondLastPart := parts[len(parts)-2]
-
-			if len(lastPart) == 2 && len(secondLastPart) <= 3 {
-				// Likely country code TLD like co.uk, com.cn
-				maskedHost = "***." + secondLastPart + "." + lastPart
-			} else {
-				// Regular TLD like .com, .org
-				maskedHost = "***." + lastPart
-			}
-		}
+		// Mask host with unified logic
+		maskedHost := maskHostForURL(host)
 
 		result := u.Scheme + "://" + maskedHost
 
@@ -208,26 +226,11 @@ func MaskSensitiveInfo(str string) string {
 	// Mask domain names without protocol (like openai.com, www.openai.com)
 	domainPattern := regexp.MustCompile(`\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\b`)
 	str = domainPattern.ReplaceAllStringFunc(str, func(domain string) string {
-		// Skip if it's already been processed as part of a URL
+		// Skip if it's already part of a URL to avoid partial masking
 		if strings.Contains(str, "://"+domain) {
 			return domain
 		}
-
-		parts := strings.Split(domain, ".")
-		if len(parts) < 2 {
-			return domain
-		}
-
-		// Handle different domain patterns
-		if len(parts) == 2 {
-			// openai.com -> ***.com
-			return "***." + parts[1]
-		} else {
-			// www.openai.com -> ***.***.com
-			// api.openai.com -> ***.***.com
-			lastPart := parts[len(parts)-1]
-			return "***.***." + lastPart
-		}
+		return maskHostForPlainDomain(domain)
 	})
 
 	// Mask IP addresses

+ 4 - 3
controller/relay.go

@@ -113,8 +113,8 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
 	meta := request.GetTokenCountMeta()
 
 	if setting.ShouldCheckPromptSensitive() {
-		words, err := service.CheckSensitiveText(meta.CombineText)
-		if err != nil {
+		contains, words := service.CheckSensitiveText(meta.CombineText)
+		if contains {
 			logger.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
 			newAPIError = types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
 			return
@@ -139,7 +139,8 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
 	}
 
 	defer func() {
-		if newAPIError != nil {
+		// Only return quota if downstream failed and quota was actually pre-consumed
+		if newAPIError != nil && preConsumedQuota != 0 {
 			service.ReturnPreConsumedQuota(c, relayInfo, preConsumedQuota)
 		}
 	}()

+ 6 - 3
logger/logger.go

@@ -4,8 +4,6 @@ import (
 	"context"
 	"encoding/json"
 	"fmt"
-	"github.com/bytedance/gopkg/util/gopool"
-	"github.com/gin-gonic/gin"
 	"io"
 	"log"
 	"one-api/common"
@@ -13,6 +11,9 @@ import (
 	"path/filepath"
 	"sync"
 	"time"
+
+	"github.com/bytedance/gopkg/util/gopool"
+	"github.com/gin-gonic/gin"
 )
 
 const (
@@ -29,6 +30,9 @@ var setupLogLock sync.Mutex
 var setupLogWorking bool
 
 func SetupLogger() {
+	defer func() {
+		setupLogWorking = false
+	}()
 	if *common.LogDir != "" {
 		ok := setupLogLock.TryLock()
 		if !ok {
@@ -37,7 +41,6 @@ func SetupLogger() {
 		}
 		defer func() {
 			setupLogLock.Unlock()
-			setupLogWorking = false
 		}()
 		logPath := filepath.Join(*common.LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405")))
 		fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)

+ 5 - 1
relay/channel/xunfei/relay-xunfei.go

@@ -206,6 +206,11 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap
 	if err != nil || resp.StatusCode != 101 {
 		return nil, nil, err
 	}
+
+	defer func() {
+		conn.Close()
+	}()
+
 	data := requestOpenAI2Xunfei(textRequest, appId, domain)
 	err = conn.WriteJSON(data)
 	if err != nil {
@@ -229,7 +234,6 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap
 			}
 			dataChan <- response
 			if response.Payload.Choices.Status == 2 {
-				err := conn.Close()
 				if err != nil {
 					common.SysLog("error closing websocket connection: " + err.Error())
 				}

+ 1 - 1
relay/claude_handler.go

@@ -24,7 +24,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
 	textRequest, ok := info.Request.(*dto.ClaudeRequest)
 
 	if !ok {
-		common.FatalLog(fmt.Sprintf("invalid request type, expected dto.ClaudeRequest, got %T", info.Request))
+		common.FatalLog(fmt.Sprintf("invalid request type, expected *dto.ClaudeRequest, got %T", info.Request))
 	}
 
 	err := helper.ModelMappedHelper(c, info, textRequest)

+ 18 - 20
relay/common/relay_info.go

@@ -87,26 +87,24 @@ type RelayInfo struct {
 	UsePrice               bool
 	RelayMode              int
 	OriginModelName        string
-	//RecodeModelName      string
-	RequestURLPath string
-	PromptTokens   int
-	//SupportStreamOptions  bool
-	ShouldIncludeUsage    bool
-	DisablePing           bool // 是否禁止向下游发送自定义 Ping
-	ClientWs              *websocket.Conn
-	TargetWs              *websocket.Conn
-	InputAudioFormat      string
-	OutputAudioFormat     string
-	RealtimeTools         []dto.RealTimeTool
-	IsFirstRequest        bool
-	AudioUsage            bool
-	ReasoningEffort       string
-	UserSetting           dto.UserSetting
-	UserEmail             string
-	UserQuota             int
-	RelayFormat           types.RelayFormat
-	SendResponseCount     int
-	FinalPreConsumedQuota int // 最终预消耗的配额
+	RequestURLPath         string
+	PromptTokens           int
+	ShouldIncludeUsage     bool
+	DisablePing            bool // 是否禁止向下游发送自定义 Ping
+	ClientWs               *websocket.Conn
+	TargetWs               *websocket.Conn
+	InputAudioFormat       string
+	OutputAudioFormat      string
+	RealtimeTools          []dto.RealTimeTool
+	IsFirstRequest         bool
+	AudioUsage             bool
+	ReasoningEffort        string
+	UserSetting            dto.UserSetting
+	UserEmail              string
+	UserQuota              int
+	RelayFormat            types.RelayFormat
+	SendResponseCount      int
+	FinalPreConsumedQuota  int // 最终预消耗的配额
 
 	PriceData types.PriceData
 

+ 1 - 1
relay/embedding_handler.go

@@ -21,7 +21,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
 
 	embeddingRequest, ok := info.Request.(*dto.EmbeddingRequest)
 	if !ok {
-		common.FatalLog(fmt.Sprintf("invalid request type, expected dto.ClaudeRequest, got %T", info.Request))
+		common.FatalLog(fmt.Sprintf("invalid request type, expected *dto.EmbeddingRequest, got %T", info.Request))
 	}
 
 	err := helper.ModelMappedHelper(c, info, embeddingRequest)

+ 1 - 1
relay/gemini_handler.go

@@ -55,7 +55,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
 
 	request, ok := info.Request.(*dto.GeminiChatRequest)
 	if !ok {
-		common.FatalLog(fmt.Sprintf("invalid request type, expected dto.GeminiChatRequest, got %T", info.Request))
+		common.FatalLog(fmt.Sprintf("invalid request type, expected *dto.GeminiChatRequest, got %T", info.Request))
 	}
 
 	// model mapped 模型映射

+ 3 - 0
relay/helper/common.go

@@ -122,6 +122,9 @@ func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error {
 }
 
 func WssError(c *gin.Context, ws *websocket.Conn, openaiError types.OpenAIError) {
+	if ws == nil {
+		return
+	}
 	errorObj := &dto.RealtimeEvent{
 		Type:    "error",
 		EventId: GetLocalRealtimeID(c),

+ 2 - 20
service/sensitive.go

@@ -2,7 +2,6 @@ package service
 
 import (
 	"errors"
-	"fmt"
 	"one-api/dto"
 	"one-api/setting"
 	"strings"
@@ -32,25 +31,8 @@ func CheckSensitiveMessages(messages []dto.Message) ([]string, error) {
 	return nil, nil
 }
 
-func CheckSensitiveText(text string) ([]string, error) {
-	if ok, words := SensitiveWordContains(text); ok {
-		return words, errors.New("sensitive words detected")
-	}
-	return nil, nil
-}
-
-func CheckSensitiveInput(input any) ([]string, error) {
-	switch v := input.(type) {
-	case string:
-		return CheckSensitiveText(v)
-	case []string:
-		var builder strings.Builder
-		for _, s := range v {
-			builder.WriteString(s)
-		}
-		return CheckSensitiveText(builder.String())
-	}
-	return CheckSensitiveText(fmt.Sprintf("%v", input))
+func CheckSensitiveText(text string) (bool, []string) {
+	return SensitiveWordContains(text)
 }
 
 // SensitiveWordContains 是否包含敏感词,返回是否包含敏感词和敏感词列表

+ 0 - 9
types/error.go

@@ -121,15 +121,6 @@ func (e *NewAPIError) MaskSensitiveError() string {
 		return string(e.errorCode)
 	}
 	errStr := e.Err.Error()
-	if e.StatusCode == http.StatusServiceUnavailable {
-		if e.errorCode == ErrorCodeModelNotFound {
-			errStr = "上游分组模型服务不可用,请稍后再试"
-		} else {
-			if strings.Contains(errStr, "分组") || strings.Contains(errStr, "渠道") {
-				errStr = "上游分组模型服务不可用,请稍后再试"
-			}
-		}
-	}
 	return common.MaskSensitiveInfo(errStr)
 }