Selaa lähdekoodia

Merge branch 'main-upstream' into fix/volcengine_default_baseurl

# Conflicts:
#	main.go
Seefs 3 kuukautta sitten
vanhempi
sitoutus
9a1ef8b957

+ 32 - 1
common/sys_log.go

@@ -2,9 +2,10 @@ package common
 
 import (
 	"fmt"
-	"github.com/gin-gonic/gin"
 	"os"
 	"time"
+
+	"github.com/gin-gonic/gin"
 )
 
 func SysLog(s string) {
@@ -22,3 +23,33 @@ func FatalLog(v ...any) {
 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
 	os.Exit(1)
 }
+
+func LogStartupSuccess(startTime time.Time, port string) {
+
+	duration := time.Since(startTime)
+	durationMs := duration.Milliseconds()
+
+	// Get network IPs
+	networkIps := GetNetworkIps()
+
+	// Print blank line for spacing
+	fmt.Fprintf(gin.DefaultWriter, "\n")
+
+	// Print the main success message
+	fmt.Fprintf(gin.DefaultWriter, "  \033[32m%s %s\033[0m  ready in %d ms\n", SystemName, Version, durationMs)
+	fmt.Fprintf(gin.DefaultWriter, "\n")
+
+	// Skip fancy startup message in container environments
+	if !IsRunningInContainer() {
+		// Print local URL
+		fmt.Fprintf(gin.DefaultWriter, "  ➜  \033[1mLocal:\033[0m   http://localhost:%s/\n", port)
+	}
+
+	// Print network URLs
+	for _, ip := range networkIps {
+		fmt.Fprintf(gin.DefaultWriter, "  ➜  \033[1mNetwork:\033[0m http://%s:%s/\n", ip, port)
+	}
+
+	// Print blank line for spacing
+	fmt.Fprintf(gin.DefaultWriter, "\n")
+}

+ 72 - 0
common/utils.go

@@ -68,6 +68,78 @@ func GetIp() (ip string) {
 	return
 }
 
+func GetNetworkIps() []string {
+	var networkIps []string
+	ips, err := net.InterfaceAddrs()
+	if err != nil {
+		log.Println(err)
+		return networkIps
+	}
+
+	for _, a := range ips {
+		if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
+			if ipNet.IP.To4() != nil {
+				ip := ipNet.IP.String()
+				// Include common private network ranges
+				if strings.HasPrefix(ip, "10.") ||
+					strings.HasPrefix(ip, "172.") ||
+					strings.HasPrefix(ip, "192.168.") {
+					networkIps = append(networkIps, ip)
+				}
+			}
+		}
+	}
+	return networkIps
+}
+
+// IsRunningInContainer detects if the application is running inside a container
+func IsRunningInContainer() bool {
+	// Method 1: Check for .dockerenv file (Docker containers)
+	if _, err := os.Stat("/.dockerenv"); err == nil {
+		return true
+	}
+
+	// Method 2: Check cgroup for container indicators
+	if data, err := os.ReadFile("/proc/1/cgroup"); err == nil {
+		content := string(data)
+		if strings.Contains(content, "docker") ||
+			strings.Contains(content, "containerd") ||
+			strings.Contains(content, "kubepods") ||
+			strings.Contains(content, "/lxc/") {
+			return true
+		}
+	}
+
+	// Method 3: Check environment variables commonly set by container runtimes
+	containerEnvVars := []string{
+		"KUBERNETES_SERVICE_HOST",
+		"DOCKER_CONTAINER",
+		"container",
+	}
+
+	for _, envVar := range containerEnvVars {
+		if os.Getenv(envVar) != "" {
+			return true
+		}
+	}
+
+	// Method 4: Check if init process is not the traditional init
+	if data, err := os.ReadFile("/proc/1/comm"); err == nil {
+		comm := strings.TrimSpace(string(data))
+		// In containers, process 1 is often not "init" or "systemd"
+		if comm != "init" && comm != "systemd" {
+			// Additional check: if it's a common container entrypoint
+			if strings.Contains(comm, "docker") ||
+				strings.Contains(comm, "containerd") ||
+				strings.Contains(comm, "runc") {
+				return true
+			}
+		}
+	}
+
+	return false
+}
+
 var sizeKB = 1024
 var sizeMB = sizeKB * 1024
 var sizeGB = sizeMB * 1024

+ 8 - 0
dto/channel_settings.go

@@ -19,4 +19,12 @@ const (
 type ChannelOtherSettings struct {
 	AzureResponsesVersion string        `json:"azure_responses_version,omitempty"`
 	VertexKeyType         VertexKeyType `json:"vertex_key_type,omitempty"` // "json" or "api_key"
+	OpenRouterEnterprise  *bool         `json:"openrouter_enterprise,omitempty"`
+}
+
+func (s *ChannelOtherSettings) IsOpenRouterEnterprise() bool {
+	if s == nil || s.OpenRouterEnterprise == nil {
+		return false
+	}
+	return *s.OpenRouterEnterprise
 }

+ 7 - 1
main.go

@@ -18,6 +18,7 @@ import (
 	"os"
 	"strconv"
 	"strings"
+	"time"
 
 	"github.com/bytedance/gopkg/util/gopool"
 	"github.com/gin-contrib/sessions"
@@ -35,6 +36,7 @@ var buildFS embed.FS
 var indexPage []byte
 
 func main() {
+	startTime := time.Now()
 
 	err := InitResources()
 	if err != nil {
@@ -168,6 +170,10 @@ func main() {
 	if port == "" {
 		port = strconv.Itoa(*common.Port)
 	}
+
+	// Log startup success message
+	common.LogStartupSuccess(startTime, port)
+
 	err = server.Run(":" + port)
 	if err != nil {
 		common.FatalLog("failed to start HTTP server: " + err.Error())
@@ -222,4 +228,4 @@ func InitResources() error {
 		return err
 	}
 	return nil
-}
+}

+ 1 - 0
relay/channel/api_request.go

@@ -265,6 +265,7 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http
 
 	resp, err := client.Do(req)
 	if err != nil {
+		logger.LogError(c, "do request failed: "+err.Error())
 		return nil, types.NewError(err, types.ErrorCodeDoRequestFailed, types.ErrOptionWithHideErrMsg("upstream error: do request failed"))
 	}
 	if resp == nil {

+ 18 - 34
relay/channel/ollama/adaptor.go

@@ -10,6 +10,7 @@ import (
 	relaycommon "one-api/relay/common"
 	relayconstant "one-api/relay/constant"
 	"one-api/types"
+	"strings"
 
 	"github.com/gin-gonic/gin"
 )
@@ -17,10 +18,7 @@ import (
 type Adaptor struct {
 }
 
-func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
-	//TODO implement me
-	return nil, errors.New("not implemented")
-}
+func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { return nil, errors.New("not implemented") }
 
 func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
 	openaiAdaptor := openai.Adaptor{}
@@ -31,32 +29,21 @@ func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayIn
 	openaiRequest.(*dto.GeneralOpenAIRequest).StreamOptions = &dto.StreamOptions{
 		IncludeUsage: true,
 	}
-	return requestOpenAI2Ollama(c, openaiRequest.(*dto.GeneralOpenAIRequest))
+	// map to ollama chat request (Claude -> OpenAI -> Ollama chat)
+	return openAIChatToOllamaChat(c, openaiRequest.(*dto.GeneralOpenAIRequest))
 }
 
-func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
-	//TODO implement me
-	return nil, errors.New("not implemented")
-}
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { return nil, errors.New("not implemented") }
 
-func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
-	//TODO implement me
-	return nil, errors.New("not implemented")
-}
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { return nil, errors.New("not implemented") }
 
 func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 }
 
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
-	if info.RelayFormat == types.RelayFormatClaude {
-		return info.ChannelBaseUrl + "/v1/chat/completions", nil
-	}
-	switch info.RelayMode {
-	case relayconstant.RelayModeEmbeddings:
-		return info.ChannelBaseUrl + "/api/embed", nil
-	default:
-		return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil
-	}
+    if info.RelayMode == relayconstant.RelayModeEmbeddings { return info.ChannelBaseUrl + "/api/embed", nil }
+    if strings.Contains(info.RequestURLPath, "/v1/completions") || info.RelayMode == relayconstant.RelayModeCompletions { return info.ChannelBaseUrl + "/api/generate", nil }
+    return info.ChannelBaseUrl + "/api/chat", nil
 }
 
 func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
@@ -66,10 +53,12 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
 }
 
 func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
-	if request == nil {
-		return nil, errors.New("request is nil")
+	if request == nil { return nil, errors.New("request is nil") }
+	// decide generate or chat
+	if strings.Contains(info.RequestURLPath, "/v1/completions") || info.RelayMode == relayconstant.RelayModeCompletions {
+		return openAIToGenerate(c, request)
 	}
-	return requestOpenAI2Ollama(c, request)
+	return openAIChatToOllamaChat(c, request)
 }
 
 func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
@@ -80,10 +69,7 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
 	return requestOpenAI2Embeddings(request), nil
 }
 
-func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
-	// TODO implement me
-	return nil, errors.New("not implemented")
-}
+func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { return nil, errors.New("not implemented") }
 
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
@@ -92,15 +78,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
 	switch info.RelayMode {
 	case relayconstant.RelayModeEmbeddings:
-		usage, err = ollamaEmbeddingHandler(c, info, resp)
+		return ollamaEmbeddingHandler(c, info, resp)
 	default:
 		if info.IsStream {
-			usage, err = openai.OaiStreamHandler(c, info, resp)
-		} else {
-			usage, err = openai.OpenaiHandler(c, info, resp)
+			return ollamaStreamHandler(c, info, resp)
 		}
+		return ollamaChatHandler(c, info, resp)
 	}
-	return
 }
 
 func (a *Adaptor) GetModelList() []string {

+ 57 - 36
relay/channel/ollama/dto.go

@@ -2,48 +2,69 @@ package ollama
 
 import (
 	"encoding/json"
-	"one-api/dto"
 )
 
-type OllamaRequest struct {
-	Model            string                `json:"model,omitempty"`
-	Messages         []dto.Message         `json:"messages,omitempty"`
-	Stream           bool                  `json:"stream,omitempty"`
-	Temperature      *float64              `json:"temperature,omitempty"`
-	Seed             float64               `json:"seed,omitempty"`
-	Topp             float64               `json:"top_p,omitempty"`
-	TopK             int                   `json:"top_k,omitempty"`
-	Stop             any                   `json:"stop,omitempty"`
-	MaxTokens        uint                  `json:"max_tokens,omitempty"`
-	Tools            []dto.ToolCallRequest `json:"tools,omitempty"`
-	ResponseFormat   any                   `json:"response_format,omitempty"`
-	FrequencyPenalty float64               `json:"frequency_penalty,omitempty"`
-	PresencePenalty  float64               `json:"presence_penalty,omitempty"`
-	Suffix           any                   `json:"suffix,omitempty"`
-	StreamOptions    *dto.StreamOptions    `json:"stream_options,omitempty"`
-	Prompt           any                   `json:"prompt,omitempty"`
-	Think            json.RawMessage       `json:"think,omitempty"`
-}
-
-type Options struct {
-	Seed             int      `json:"seed,omitempty"`
-	Temperature      *float64 `json:"temperature,omitempty"`
-	TopK             int      `json:"top_k,omitempty"`
-	TopP             float64  `json:"top_p,omitempty"`
-	FrequencyPenalty float64  `json:"frequency_penalty,omitempty"`
-	PresencePenalty  float64  `json:"presence_penalty,omitempty"`
-	NumPredict       int      `json:"num_predict,omitempty"`
-	NumCtx           int      `json:"num_ctx,omitempty"`
+type OllamaChatMessage struct {
+	Role      string            `json:"role"`
+	Content   string            `json:"content,omitempty"`
+	Images    []string          `json:"images,omitempty"`
+	ToolCalls []OllamaToolCall  `json:"tool_calls,omitempty"`
+	ToolName  string            `json:"tool_name,omitempty"`
+	Thinking  json.RawMessage   `json:"thinking,omitempty"`
+}
+
+type OllamaToolFunction struct {
+	Name        string      `json:"name"`
+	Description string      `json:"description,omitempty"`
+	Parameters  interface{} `json:"parameters,omitempty"`
+}
+
+type OllamaTool struct {
+	Type     string            `json:"type"`
+	Function OllamaToolFunction `json:"function"`
+}
+
+type OllamaToolCall struct {
+	Function struct {
+		Name      string      `json:"name"`
+		Arguments interface{} `json:"arguments"`
+	} `json:"function"`
+}
+
+type OllamaChatRequest struct {
+	Model     string              `json:"model"`
+	Messages  []OllamaChatMessage `json:"messages"`
+	Tools     interface{}         `json:"tools,omitempty"`
+	Format    interface{}         `json:"format,omitempty"`
+	Stream    bool                `json:"stream,omitempty"`
+	Options   map[string]any      `json:"options,omitempty"`
+	KeepAlive interface{}         `json:"keep_alive,omitempty"`
+	Think     json.RawMessage     `json:"think,omitempty"`
+}
+
+type OllamaGenerateRequest struct {
+	Model     string         `json:"model"`
+	Prompt    string         `json:"prompt,omitempty"`
+	Suffix    string         `json:"suffix,omitempty"`
+	Images    []string       `json:"images,omitempty"`
+	Format    interface{}    `json:"format,omitempty"`
+	Stream    bool           `json:"stream,omitempty"`
+	Options   map[string]any `json:"options,omitempty"`
+	KeepAlive interface{}    `json:"keep_alive,omitempty"`
+	Think     json.RawMessage `json:"think,omitempty"`
 }
 
 type OllamaEmbeddingRequest struct {
-	Model   string   `json:"model,omitempty"`
-	Input   []string `json:"input"`
-	Options *Options `json:"options,omitempty"`
+	Model     string         `json:"model"`
+	Input     interface{}    `json:"input"`
+	Options   map[string]any `json:"options,omitempty"`
+	Dimensions int            `json:"dimensions,omitempty"`
 }
 
 type OllamaEmbeddingResponse struct {
-	Error     string      `json:"error,omitempty"`
-	Model     string      `json:"model"`
-	Embedding [][]float64 `json:"embeddings,omitempty"`
+	Error           string        `json:"error,omitempty"`
+	Model           string        `json:"model"`
+	Embeddings      [][]float64   `json:"embeddings"`
+	PromptEvalCount int           `json:"prompt_eval_count,omitempty"`
 }
+

+ 157 - 101
relay/channel/ollama/relay-ollama.go

@@ -1,6 +1,7 @@
 package ollama
 
 import (
+	"encoding/json"
 	"fmt"
 	"io"
 	"net/http"
@@ -14,121 +15,176 @@ import (
 	"github.com/gin-gonic/gin"
 )
 
-func requestOpenAI2Ollama(c *gin.Context, request *dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
-	messages := make([]dto.Message, 0, len(request.Messages))
-	for _, message := range request.Messages {
-		if !message.IsStringContent() {
-			mediaMessages := message.ParseContent()
-			for j, mediaMessage := range mediaMessages {
-				if mediaMessage.Type == dto.ContentTypeImageURL {
-					imageUrl := mediaMessage.GetImageMedia()
-					// check if not base64
-					if strings.HasPrefix(imageUrl.Url, "http") {
-						fileData, err := service.GetFileBase64FromUrl(c, imageUrl.Url, "formatting image for Ollama")
-						if err != nil {
-							return nil, err
+func openAIChatToOllamaChat(c *gin.Context, r *dto.GeneralOpenAIRequest) (*OllamaChatRequest, error) {
+	chatReq := &OllamaChatRequest{
+		Model:   r.Model,
+		Stream:  r.Stream,
+		Options: map[string]any{},
+		Think:   r.Think,
+	}
+	if r.ResponseFormat != nil {
+		if r.ResponseFormat.Type == "json" {
+			chatReq.Format = "json"
+		} else if r.ResponseFormat.Type == "json_schema" {
+			if len(r.ResponseFormat.JsonSchema) > 0 {
+				var schema any
+				_ = json.Unmarshal(r.ResponseFormat.JsonSchema, &schema)
+				chatReq.Format = schema
+			}
+		}
+	}
+
+	// options mapping
+	if r.Temperature != nil { chatReq.Options["temperature"] = r.Temperature }
+	if r.TopP != 0 { chatReq.Options["top_p"] = r.TopP }
+	if r.TopK != 0 { chatReq.Options["top_k"] = r.TopK }
+	if r.FrequencyPenalty != 0 { chatReq.Options["frequency_penalty"] = r.FrequencyPenalty }
+	if r.PresencePenalty != 0 { chatReq.Options["presence_penalty"] = r.PresencePenalty }
+	if r.Seed != 0 { chatReq.Options["seed"] = int(r.Seed) }
+	if mt := r.GetMaxTokens(); mt != 0 { chatReq.Options["num_predict"] = int(mt) }
+
+	if r.Stop != nil {
+		switch v := r.Stop.(type) {
+		case string:
+			chatReq.Options["stop"] = []string{v}
+		case []string:
+			chatReq.Options["stop"] = v
+		case []any:
+			arr := make([]string,0,len(v))
+			for _, i := range v { if s,ok:=i.(string); ok { arr = append(arr,s) } }
+			if len(arr)>0 { chatReq.Options["stop"] = arr }
+		}
+	}
+
+	if len(r.Tools) > 0 {
+		tools := make([]OllamaTool,0,len(r.Tools))
+		for _, t := range r.Tools {
+			tools = append(tools, OllamaTool{Type: "function", Function: OllamaToolFunction{Name: t.Function.Name, Description: t.Function.Description, Parameters: t.Function.Parameters}})
+		}
+		chatReq.Tools = tools
+	}
+
+	chatReq.Messages = make([]OllamaChatMessage,0,len(r.Messages))
+	for _, m := range r.Messages {
+		var textBuilder strings.Builder
+		var images []string
+		if m.IsStringContent() {
+			textBuilder.WriteString(m.StringContent())
+		} else {
+			parts := m.ParseContent()
+			for _, part := range parts {
+				if part.Type == dto.ContentTypeImageURL {
+					img := part.GetImageMedia()
+					if img != nil && img.Url != "" {
+						var base64Data string
+						if strings.HasPrefix(img.Url, "http") {
+							fileData, err := service.GetFileBase64FromUrl(c, img.Url, "fetch image for ollama chat")
+							if err != nil { return nil, err }
+							base64Data = fileData.Base64Data
+						} else if strings.HasPrefix(img.Url, "data:") {
+							if idx := strings.Index(img.Url, ","); idx != -1 && idx+1 < len(img.Url) { base64Data = img.Url[idx+1:] }
+						} else {
+							base64Data = img.Url
 						}
-						imageUrl.Url = fmt.Sprintf("data:%s;base64,%s", fileData.MimeType, fileData.Base64Data)
+						if base64Data != "" { images = append(images, base64Data) }
 					}
-					mediaMessage.ImageUrl = imageUrl
-					mediaMessages[j] = mediaMessage
+				} else if part.Type == dto.ContentTypeText {
+					textBuilder.WriteString(part.Text)
+				}
+			}
+		}
+		cm := OllamaChatMessage{Role: m.Role, Content: textBuilder.String()}
+		if len(images)>0 { cm.Images = images }
+		if m.Role == "tool" && m.Name != nil { cm.ToolName = *m.Name }
+		if m.ToolCalls != nil && len(m.ToolCalls) > 0 {
+			parsed := m.ParseToolCalls()
+			if len(parsed) > 0 {
+				calls := make([]OllamaToolCall,0,len(parsed))
+				for _, tc := range parsed {
+					var args interface{}
+					if tc.Function.Arguments != "" { _ = json.Unmarshal([]byte(tc.Function.Arguments), &args) }
+					if args==nil { args = map[string]any{} }
+					oc := OllamaToolCall{}
+					oc.Function.Name = tc.Function.Name
+					oc.Function.Arguments = args
+					calls = append(calls, oc)
 				}
+				cm.ToolCalls = calls
 			}
-			message.SetMediaContent(mediaMessages)
 		}
-		messages = append(messages, dto.Message{
-			Role:       message.Role,
-			Content:    message.Content,
-			ToolCalls:  message.ToolCalls,
-			ToolCallId: message.ToolCallId,
-		})
+		chatReq.Messages = append(chatReq.Messages, cm)
 	}
-	str, ok := request.Stop.(string)
-	var Stop []string
-	if ok {
-		Stop = []string{str}
-	} else {
-		Stop, _ = request.Stop.([]string)
+	return chatReq, nil
+}
+
+// openAIToGenerate converts OpenAI completions request to Ollama generate
+func openAIToGenerate(c *gin.Context, r *dto.GeneralOpenAIRequest) (*OllamaGenerateRequest, error) {
+	gen := &OllamaGenerateRequest{
+		Model:   r.Model,
+		Stream:  r.Stream,
+		Options: map[string]any{},
+		Think:   r.Think,
+	}
+	// Prompt may be in r.Prompt (string or []any)
+	if r.Prompt != nil {
+		switch v := r.Prompt.(type) {
+		case string:
+			gen.Prompt = v
+		case []any:
+			var sb strings.Builder
+			for _, it := range v { if s,ok:=it.(string); ok { sb.WriteString(s) } }
+			gen.Prompt = sb.String()
+		default:
+			gen.Prompt = fmt.Sprintf("%v", r.Prompt)
+		}
 	}
-	ollamaRequest := &OllamaRequest{
-		Model:            request.Model,
-		Messages:         messages,
-		Stream:           request.Stream,
-		Temperature:      request.Temperature,
-		Seed:             request.Seed,
-		Topp:             request.TopP,
-		TopK:             request.TopK,
-		Stop:             Stop,
-		Tools:            request.Tools,
-		MaxTokens:        request.GetMaxTokens(),
-		ResponseFormat:   request.ResponseFormat,
-		FrequencyPenalty: request.FrequencyPenalty,
-		PresencePenalty:  request.PresencePenalty,
-		Prompt:           request.Prompt,
-		StreamOptions:    request.StreamOptions,
-		Suffix:           request.Suffix,
+	if r.Suffix != nil { if s,ok:=r.Suffix.(string); ok { gen.Suffix = s } }
+	if r.ResponseFormat != nil {
+		if r.ResponseFormat.Type == "json" { gen.Format = "json" } else if r.ResponseFormat.Type == "json_schema" { var schema any; _ = json.Unmarshal(r.ResponseFormat.JsonSchema,&schema); gen.Format=schema }
 	}
-	ollamaRequest.Think = request.Think
-	return ollamaRequest, nil
+	if r.Temperature != nil { gen.Options["temperature"] = r.Temperature }
+	if r.TopP != 0 { gen.Options["top_p"] = r.TopP }
+	if r.TopK != 0 { gen.Options["top_k"] = r.TopK }
+	if r.FrequencyPenalty != 0 { gen.Options["frequency_penalty"] = r.FrequencyPenalty }
+	if r.PresencePenalty != 0 { gen.Options["presence_penalty"] = r.PresencePenalty }
+	if r.Seed != 0 { gen.Options["seed"] = int(r.Seed) }
+	if mt := r.GetMaxTokens(); mt != 0 { gen.Options["num_predict"] = int(mt) }
+	if r.Stop != nil {
+		switch v := r.Stop.(type) {
+		case string: gen.Options["stop"] = []string{v}
+		case []string: gen.Options["stop"] = v
+		case []any: arr:=make([]string,0,len(v)); for _,i:= range v { if s,ok:=i.(string); ok { arr=append(arr,s) } }; if len(arr)>0 { gen.Options["stop"]=arr }
+		}
+	}
+	return gen, nil
 }
 
-func requestOpenAI2Embeddings(request dto.EmbeddingRequest) *OllamaEmbeddingRequest {
-	return &OllamaEmbeddingRequest{
-		Model: request.Model,
-		Input: request.ParseInput(),
-		Options: &Options{
-			Seed:             int(request.Seed),
-			Temperature:      request.Temperature,
-			TopP:             request.TopP,
-			FrequencyPenalty: request.FrequencyPenalty,
-			PresencePenalty:  request.PresencePenalty,
-		},
-	}
+func requestOpenAI2Embeddings(r dto.EmbeddingRequest) *OllamaEmbeddingRequest {
+	opts := map[string]any{}
+	if r.Temperature != nil { opts["temperature"] = r.Temperature }
+	if r.TopP != 0 { opts["top_p"] = r.TopP }
+	if r.FrequencyPenalty != 0 { opts["frequency_penalty"] = r.FrequencyPenalty }
+	if r.PresencePenalty != 0 { opts["presence_penalty"] = r.PresencePenalty }
+	if r.Seed != 0 { opts["seed"] = int(r.Seed) }
+	if r.Dimensions != 0 { opts["dimensions"] = r.Dimensions }
+	input := r.ParseInput()
+	if len(input)==1 { return &OllamaEmbeddingRequest{Model:r.Model, Input: input[0], Options: opts, Dimensions:r.Dimensions} }
+	return &OllamaEmbeddingRequest{Model:r.Model, Input: input, Options: opts, Dimensions:r.Dimensions}
 }
 
 func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
-	var ollamaEmbeddingResponse OllamaEmbeddingResponse
-	responseBody, err := io.ReadAll(resp.Body)
-	if err != nil {
-		return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
-	}
+	var oResp OllamaEmbeddingResponse
+	body, err := io.ReadAll(resp.Body)
+	if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
 	service.CloseResponseBodyGracefully(resp)
-	err = common.Unmarshal(responseBody, &ollamaEmbeddingResponse)
-	if err != nil {
-		return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
-	}
-	if ollamaEmbeddingResponse.Error != "" {
-		return nil, types.NewOpenAIError(fmt.Errorf("ollama error: %s", ollamaEmbeddingResponse.Error), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
-	}
-	flattenedEmbeddings := flattenEmbeddings(ollamaEmbeddingResponse.Embedding)
-	data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
-	data = append(data, dto.OpenAIEmbeddingResponseItem{
-		Embedding: flattenedEmbeddings,
-		Object:    "embedding",
-	})
-	usage := &dto.Usage{
-		TotalTokens:      info.PromptTokens,
-		CompletionTokens: 0,
-		PromptTokens:     info.PromptTokens,
-	}
-	embeddingResponse := &dto.OpenAIEmbeddingResponse{
-		Object: "list",
-		Data:   data,
-		Model:  info.UpstreamModelName,
-		Usage:  *usage,
-	}
-	doResponseBody, err := common.Marshal(embeddingResponse)
-	if err != nil {
-		return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
-	}
-	service.IOCopyBytesGracefully(c, resp, doResponseBody)
+	if err = common.Unmarshal(body, &oResp); err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
+	if oResp.Error != "" { return nil, types.NewOpenAIError(fmt.Errorf("ollama error: %s", oResp.Error), types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
+	data := make([]dto.OpenAIEmbeddingResponseItem,0,len(oResp.Embeddings))
+	for i, emb := range oResp.Embeddings { data = append(data, dto.OpenAIEmbeddingResponseItem{Index:i,Object:"embedding",Embedding:emb}) }
+	usage := &dto.Usage{PromptTokens: oResp.PromptEvalCount, CompletionTokens:0, TotalTokens: oResp.PromptEvalCount}
+	embResp := &dto.OpenAIEmbeddingResponse{Object:"list", Data:data, Model: info.UpstreamModelName, Usage:*usage}
+	out, _ := common.Marshal(embResp)
+	service.IOCopyBytesGracefully(c, resp, out)
 	return usage, nil
 }
 
-func flattenEmbeddings(embeddings [][]float64) []float64 {
-	flattened := []float64{}
-	for _, row := range embeddings {
-		flattened = append(flattened, row...)
-	}
-	return flattened
-}

+ 210 - 0
relay/channel/ollama/stream.go

@@ -0,0 +1,210 @@
+package ollama
+
+import (
+    "bufio"
+    "encoding/json"
+    "fmt"
+    "io"
+    "net/http"
+    "one-api/common"
+    "one-api/dto"
+    "one-api/logger"
+    relaycommon "one-api/relay/common"
+    "one-api/relay/helper"
+    "one-api/service"
+    "one-api/types"
+    "strings"
+    "time"
+
+    "github.com/gin-gonic/gin"
+)
+
+type ollamaChatStreamChunk struct {
+    Model            string `json:"model"`
+    CreatedAt        string `json:"created_at"`
+    // chat
+    Message *struct {
+        Role      string `json:"role"`
+        Content   string `json:"content"`
+        Thinking  json.RawMessage `json:"thinking"`
+        ToolCalls []struct {
+            Function struct {
+                Name      string      `json:"name"`
+                Arguments interface{} `json:"arguments"`
+            } `json:"function"`
+        } `json:"tool_calls"`
+    } `json:"message"`
+    // generate
+    Response string `json:"response"`
+    Done         bool    `json:"done"`
+    DoneReason   string  `json:"done_reason"`
+    TotalDuration int64  `json:"total_duration"`
+    LoadDuration  int64  `json:"load_duration"`
+    PromptEvalCount int  `json:"prompt_eval_count"`
+    EvalCount       int  `json:"eval_count"`
+    PromptEvalDuration int64 `json:"prompt_eval_duration"`
+    EvalDuration       int64 `json:"eval_duration"`
+}
+
+func toUnix(ts string) int64 {
+    if ts == "" { return time.Now().Unix() }
+    // try time.RFC3339 or with nanoseconds
+    t, err := time.Parse(time.RFC3339Nano, ts)
+    if err != nil { t2, err2 := time.Parse(time.RFC3339, ts); if err2==nil { return t2.Unix() }; return time.Now().Unix() }
+    return t.Unix()
+}
+
+func ollamaStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+    if resp == nil || resp.Body == nil { return nil, types.NewOpenAIError(fmt.Errorf("empty response"), types.ErrorCodeBadResponse, http.StatusBadRequest) }
+    defer service.CloseResponseBodyGracefully(resp)
+
+    helper.SetEventStreamHeaders(c)
+    scanner := bufio.NewScanner(resp.Body)
+    usage := &dto.Usage{}
+    var model = info.UpstreamModelName
+    var responseId = common.GetUUID()
+    var created = time.Now().Unix()
+    var toolCallIndex int
+    start := helper.GenerateStartEmptyResponse(responseId, created, model, nil)
+    if data, err := common.Marshal(start); err == nil { _ = helper.StringData(c, string(data)) }
+
+    for scanner.Scan() {
+        line := scanner.Text()
+        line = strings.TrimSpace(line)
+        if line == "" { continue }
+        var chunk ollamaChatStreamChunk
+        if err := json.Unmarshal([]byte(line), &chunk); err != nil {
+            logger.LogError(c, "ollama stream json decode error: "+err.Error()+" line="+line)
+            return usage, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
+        }
+        if chunk.Model != "" { model = chunk.Model }
+        created = toUnix(chunk.CreatedAt)
+
+        if !chunk.Done {
+            // delta content
+            var content string
+            if chunk.Message != nil { content = chunk.Message.Content } else { content = chunk.Response }
+            delta := dto.ChatCompletionsStreamResponse{
+                Id:      responseId,
+                Object:  "chat.completion.chunk",
+                Created: created,
+                Model:   model,
+                Choices: []dto.ChatCompletionsStreamResponseChoice{ {
+                    Index: 0,
+                    Delta: dto.ChatCompletionsStreamResponseChoiceDelta{ Role: "assistant" },
+                } },
+            }
+            if content != "" { delta.Choices[0].Delta.SetContentString(content) }
+            if chunk.Message != nil && len(chunk.Message.Thinking) > 0 {
+                raw := strings.TrimSpace(string(chunk.Message.Thinking))
+                if raw != "" && raw != "null" { delta.Choices[0].Delta.SetReasoningContent(raw) }
+            }
+            // tool calls
+            if chunk.Message != nil && len(chunk.Message.ToolCalls) > 0 {
+                delta.Choices[0].Delta.ToolCalls = make([]dto.ToolCallResponse,0,len(chunk.Message.ToolCalls))
+                for _, tc := range chunk.Message.ToolCalls {
+                    // arguments -> string
+                    argBytes, _ := json.Marshal(tc.Function.Arguments)
+                    toolId := fmt.Sprintf("call_%d", toolCallIndex)
+                    tr := dto.ToolCallResponse{ID:toolId, Type:"function", Function: dto.FunctionResponse{Name: tc.Function.Name, Arguments: string(argBytes)}}
+                    tr.SetIndex(toolCallIndex)
+                    toolCallIndex++
+                    delta.Choices[0].Delta.ToolCalls = append(delta.Choices[0].Delta.ToolCalls, tr)
+                }
+            }
+            if data, err := common.Marshal(delta); err == nil { _ = helper.StringData(c, string(data)) }
+            continue
+        }
+        // done frame
+        // finalize once and break loop
+        usage.PromptTokens = chunk.PromptEvalCount
+        usage.CompletionTokens = chunk.EvalCount
+        usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
+    finishReason := chunk.DoneReason
+    if finishReason == "" { finishReason = "stop" }
+        // emit stop delta
+        if stop := helper.GenerateStopResponse(responseId, created, model, finishReason); stop != nil {
+            if data, err := common.Marshal(stop); err == nil { _ = helper.StringData(c, string(data)) }
+        }
+        // emit usage frame
+        if final := helper.GenerateFinalUsageResponse(responseId, created, model, *usage); final != nil {
+            if data, err := common.Marshal(final); err == nil { _ = helper.StringData(c, string(data)) }
+        }
+        // send [DONE]
+        helper.Done(c)
+        break
+    }
+    if err := scanner.Err(); err != nil && err != io.EOF { logger.LogError(c, "ollama stream scan error: "+err.Error()) }
+    return usage, nil
+}
+
+// non-stream handler for chat/generate
+func ollamaChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+    body, err := io.ReadAll(resp.Body)
+    if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) }
+    service.CloseResponseBodyGracefully(resp)
+    raw := string(body)
+    if common.DebugEnabled { println("ollama non-stream raw resp:", raw) }
+
+    lines := strings.Split(raw, "\n")
+    var (
+        aggContent strings.Builder
+        reasoningBuilder strings.Builder
+        lastChunk ollamaChatStreamChunk
+        parsedAny bool
+    )
+    for _, ln := range lines {
+        ln = strings.TrimSpace(ln)
+        if ln == "" { continue }
+        var ck ollamaChatStreamChunk
+        if err := json.Unmarshal([]byte(ln), &ck); err != nil {
+            if len(lines) == 1 { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
+            continue
+        }
+        parsedAny = true
+        lastChunk = ck
+        if ck.Message != nil && len(ck.Message.Thinking) > 0 {
+            raw := strings.TrimSpace(string(ck.Message.Thinking))
+            if raw != "" && raw != "null" { reasoningBuilder.WriteString(raw) }
+        }
+        if ck.Message != nil && ck.Message.Content != "" { aggContent.WriteString(ck.Message.Content) } else if ck.Response != "" { aggContent.WriteString(ck.Response) }
+    }
+
+    if !parsedAny {
+        var single ollamaChatStreamChunk
+        if err := json.Unmarshal(body, &single); err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
+        lastChunk = single
+        if single.Message != nil {
+            if len(single.Message.Thinking) > 0 { raw := strings.TrimSpace(string(single.Message.Thinking)); if raw != "" && raw != "null" { reasoningBuilder.WriteString(raw) } }
+            aggContent.WriteString(single.Message.Content)
+        } else { aggContent.WriteString(single.Response) }
+    }
+
+    model := lastChunk.Model
+    if model == "" { model = info.UpstreamModelName }
+    created := toUnix(lastChunk.CreatedAt)
+    usage := &dto.Usage{PromptTokens: lastChunk.PromptEvalCount, CompletionTokens: lastChunk.EvalCount, TotalTokens: lastChunk.PromptEvalCount + lastChunk.EvalCount}
+    content := aggContent.String()
+    finishReason := lastChunk.DoneReason
+    if finishReason == "" { finishReason = "stop" }
+
+    msg := dto.Message{Role: "assistant", Content: contentPtr(content)}
+    if rc := reasoningBuilder.String(); rc != "" { msg.ReasoningContent = rc }
+    full := dto.OpenAITextResponse{
+        Id:      common.GetUUID(),
+        Model:   model,
+        Object:  "chat.completion",
+        Created: created,
+        Choices: []dto.OpenAITextResponseChoice{ {
+            Index: 0,
+            Message: msg,
+            FinishReason: finishReason,
+        } },
+        Usage: *usage,
+    }
+    out, _ := common.Marshal(full)
+    service.IOCopyBytesGracefully(c, resp, out)
+    return usage, nil
+}
+
+func contentPtr(s string) *string { if s=="" { return nil }; return &s }

+ 18 - 0
relay/channel/openai/relay-openai.go

@@ -12,6 +12,7 @@ import (
 	"one-api/constant"
 	"one-api/dto"
 	"one-api/logger"
+	"one-api/relay/channel/openrouter"
 	relaycommon "one-api/relay/common"
 	"one-api/relay/helper"
 	"one-api/service"
@@ -185,10 +186,27 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
 	if common.DebugEnabled {
 		println("upstream response body:", string(responseBody))
 	}
+	// Unmarshal to simpleResponse
+	if info.ChannelType == constant.ChannelTypeOpenRouter && info.ChannelOtherSettings.IsOpenRouterEnterprise() {
+		// 尝试解析为 openrouter enterprise
+		var enterpriseResponse openrouter.OpenRouterEnterpriseResponse
+		err = common.Unmarshal(responseBody, &enterpriseResponse)
+		if err != nil {
+			return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
+		}
+		if enterpriseResponse.Success {
+			responseBody = enterpriseResponse.Data
+		} else {
+			logger.LogError(c, fmt.Sprintf("openrouter enterprise response success=false, data: %s", enterpriseResponse.Data))
+			return nil, types.NewOpenAIError(fmt.Errorf("openrouter response success=false"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
+		}
+	}
+
 	err = common.Unmarshal(responseBody, &simpleResponse)
 	if err != nil {
 		return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
 	}
+
 	if oaiError := simpleResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
 		return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
 	}

+ 7 - 0
relay/channel/openrouter/dto.go

@@ -1,5 +1,7 @@
 package openrouter
 
+import "encoding/json"
+
 type RequestReasoning struct {
 	// One of the following (not both):
 	Effort    string `json:"effort,omitempty"`     // Can be "high", "medium", or "low" (OpenAI-style)
@@ -7,3 +9,8 @@ type RequestReasoning struct {
 	// Optional: Default is false. All models support this.
 	Exclude bool `json:"exclude,omitempty"` // Set to true to exclude reasoning tokens from response
 }
+
+type OpenRouterEnterpriseResponse struct {
+	Data    json.RawMessage `json:"data"`
+	Success bool            `json:"success"`
+}

+ 13 - 6
relay/channel/volcengine/adaptor.go

@@ -9,6 +9,7 @@ import (
 	"mime/multipart"
 	"net/http"
 	"net/textproto"
+	channelconstant "one-api/constant"
 	"one-api/dto"
 	"one-api/relay/channel"
 	"one-api/relay/channel/openai"
@@ -188,20 +189,26 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 }
 
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+	// 支持自定义域名,如果未设置则使用默认域名
+	baseUrl := info.ChannelBaseUrl
+	if baseUrl == "" {
+		baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine]
+	}
+
 	switch info.RelayMode {
 	case constant.RelayModeChatCompletions:
 		if strings.HasPrefix(info.UpstreamModelName, "bot") {
-			return fmt.Sprintf("%s/api/v3/bots/chat/completions", info.ChannelBaseUrl), nil
+			return fmt.Sprintf("%s/api/v3/bots/chat/completions", baseUrl), nil
 		}
-		return fmt.Sprintf("%s/api/v3/chat/completions", info.ChannelBaseUrl), nil
+		return fmt.Sprintf("%s/api/v3/chat/completions", baseUrl), nil
 	case constant.RelayModeEmbeddings:
-		return fmt.Sprintf("%s/api/v3/embeddings", info.ChannelBaseUrl), nil
+		return fmt.Sprintf("%s/api/v3/embeddings", baseUrl), nil
 	case constant.RelayModeImagesGenerations:
-		return fmt.Sprintf("%s/api/v3/images/generations", info.ChannelBaseUrl), nil
+		return fmt.Sprintf("%s/api/v3/images/generations", baseUrl), nil
 	case constant.RelayModeImagesEdits:
-		return fmt.Sprintf("%s/api/v3/images/edits", info.ChannelBaseUrl), nil
+		return fmt.Sprintf("%s/api/v3/images/edits", baseUrl), nil
 	case constant.RelayModeRerank:
-		return fmt.Sprintf("%s/api/v3/rerank", info.ChannelBaseUrl), nil
+		return fmt.Sprintf("%s/api/v3/rerank", baseUrl), nil
 	default:
 	}
 	return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)

+ 5 - 0
relay/channel/volcengine/constants.go

@@ -9,6 +9,11 @@ var ModelList = []string{
 	"Doubao-lite-4k",
 	"Doubao-embedding",
 	"doubao-seedream-4-0-250828",
+	"seedream-4-0-250828",
+	"doubao-seedance-1-0-pro-250528",
+	"seedance-1-0-pro-250528",
+	"doubao-seed-1-6-thinking-250715",
+	"seed-1-6-thinking-250715",
 }
 
 var ChannelName = "volcengine"

+ 3 - 4
relay/channel/xunfei/relay-xunfei.go

@@ -207,10 +207,6 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap
 		return nil, nil, err
 	}
 
-	defer func() {
-		conn.Close()
-	}()
-
 	data := requestOpenAI2Xunfei(textRequest, appId, domain)
 	err = conn.WriteJSON(data)
 	if err != nil {
@@ -220,6 +216,9 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap
 	dataChan := make(chan XunfeiChatResponse)
 	stopChan := make(chan bool)
 	go func() {
+		defer func() {
+			conn.Close()
+		}()
 		for {
 			_, msg, err := conn.ReadMessage()
 			if err != nil {

+ 75 - 1
web/src/components/table/channels/modals/EditChannelModal.jsx

@@ -164,6 +164,8 @@ const EditChannelModal = (props) => {
     settings: '',
     // 仅 Vertex: 密钥格式(存入 settings.vertex_key_type)
     vertex_key_type: 'json',
+    // 企业账户设置
+    is_enterprise_account: false,
   };
   const [batch, setBatch] = useState(false);
   const [multiToSingle, setMultiToSingle] = useState(false);
@@ -189,6 +191,7 @@ const EditChannelModal = (props) => {
   const [channelSearchValue, setChannelSearchValue] = useState('');
   const [useManualInput, setUseManualInput] = useState(false); // 是否使用手动输入模式
   const [keyMode, setKeyMode] = useState('append'); // 密钥模式:replace(覆盖)或 append(追加)
+  const [isEnterpriseAccount, setIsEnterpriseAccount] = useState(false); // 是否为企业账户
 
   // 2FA验证查看密钥相关状态
   const [twoFAState, setTwoFAState] = useState({
@@ -235,7 +238,7 @@ const EditChannelModal = (props) => {
     pass_through_body_enabled: false,
     system_prompt: '',
   });
-  const showApiConfigCard = inputs.type !== 45; // 控制是否显示 API 配置卡片(仅当渠道类型不是 豆包 时显示)
+  const showApiConfigCard = true; // 控制是否显示 API 配置卡片
   const getInitValues = () => ({ ...originInputs });
 
   // 处理渠道额外设置的更新
@@ -342,6 +345,10 @@ const EditChannelModal = (props) => {
         case 36:
           localModels = ['suno_music', 'suno_lyrics'];
           break;
+        case 45:
+          localModels = getChannelModels(value);
+          setInputs((prevInputs) => ({ ...prevInputs, base_url: 'https://ark.cn-beijing.volces.com' }));
+          break;
         default:
           localModels = getChannelModels(value);
           break;
@@ -433,15 +440,19 @@ const EditChannelModal = (props) => {
             parsedSettings.azure_responses_version || '';
           // 读取 Vertex 密钥格式
           data.vertex_key_type = parsedSettings.vertex_key_type || 'json';
+          // 读取企业账户设置
+          data.is_enterprise_account = parsedSettings.openrouter_enterprise === true;
         } catch (error) {
           console.error('解析其他设置失败:', error);
           data.azure_responses_version = '';
           data.region = '';
           data.vertex_key_type = 'json';
+          data.is_enterprise_account = false;
         }
       } else {
         // 兼容历史数据:老渠道没有 settings 时,默认按 json 展示
         data.vertex_key_type = 'json';
+        data.is_enterprise_account = false;
       }
 
       setInputs(data);
@@ -453,6 +464,8 @@ const EditChannelModal = (props) => {
       } else {
         setAutoBan(true);
       }
+      // 同步企业账户状态
+      setIsEnterpriseAccount(data.is_enterprise_account || false);
       setBasicModels(getChannelModels(data.type));
       // 同步更新channelSettings状态显示
       setChannelSettings({
@@ -712,6 +725,8 @@ const EditChannelModal = (props) => {
     });
     // 重置密钥模式状态
     setKeyMode('append');
+    // 重置企业账户状态
+    setIsEnterpriseAccount(false);
     // 清空表单中的key_mode字段
     if (formApiRef.current) {
       formApiRef.current.setValue('key_mode', undefined);
@@ -844,6 +859,10 @@ const EditChannelModal = (props) => {
       showInfo(t('请至少选择一个模型!'));
       return;
     }
+    if (localInputs.type === 45 && (!localInputs.base_url || localInputs.base_url.trim() === '')) {
+      showInfo(t('请输入API地址!'));
+      return;
+    }
     if (
       localInputs.model_mapping &&
       localInputs.model_mapping !== '' &&
@@ -873,6 +892,21 @@ const EditChannelModal = (props) => {
     };
     localInputs.setting = JSON.stringify(channelExtraSettings);
 
+    // 处理type === 20的企业账户设置
+    if (localInputs.type === 20) {
+      let settings = {};
+      if (localInputs.settings) {
+        try {
+          settings = JSON.parse(localInputs.settings);
+        } catch (error) {
+          console.error('解析settings失败:', error);
+        }
+      }
+      // 设置企业账户标识,无论是true还是false都要传到后端
+      settings.openrouter_enterprise = localInputs.is_enterprise_account === true;
+      localInputs.settings = JSON.stringify(settings);
+    }
+
     // 清理不需要发送到后端的字段
     delete localInputs.force_format;
     delete localInputs.thinking_to_content;
@@ -880,6 +914,7 @@ const EditChannelModal = (props) => {
     delete localInputs.pass_through_body_enabled;
     delete localInputs.system_prompt;
     delete localInputs.system_prompt_override;
+    delete localInputs.is_enterprise_account;
     // 顶层的 vertex_key_type 不应发送给后端
     delete localInputs.vertex_key_type;
 
@@ -1264,6 +1299,21 @@ const EditChannelModal = (props) => {
                     onChange={(value) => handleInputChange('type', value)}
                   />
 
+                  {inputs.type === 20 && (
+                    <Form.Switch
+                      field='is_enterprise_account'
+                      label={t('是否为企业账户')}
+                      checkedText={t('是')}
+                      uncheckedText={t('否')}
+                      onChange={(value) => {
+                        setIsEnterpriseAccount(value);
+                        handleInputChange('is_enterprise_account', value);
+                      }}
+                      extraText={t('企业账户为特殊返回格式,需要特殊处理,如果非企业账户,请勿勾选')}
+                      initValue={inputs.is_enterprise_account}
+                    />
+                  )}
+
                   <Form.Input
                     field='name'
                     label={t('名称')}
@@ -1883,6 +1933,30 @@ const EditChannelModal = (props) => {
                         />
                       </div>
                     )}
+
+                    {inputs.type === 45 && (
+                        <div>
+                          <Form.Select
+                              field='base_url'
+                              label={t('API地址')}
+                              placeholder={t('请选择API地址')}
+                              onChange={(value) =>
+                                  handleInputChange('base_url', value)
+                              }
+                              optionList={[
+                                {
+                                  value: 'https://ark.cn-beijing.volces.com',
+                                  label: 'https://ark.cn-beijing.volces.com'
+                                },
+                                {
+                                  value: 'https://ark.ap-southeast.bytepluses.com',
+                                  label: 'https://ark.ap-southeast.bytepluses.com'
+                                }
+                              ]}
+                              defaultValue='https://ark.cn-beijing.volces.com'
+                          />
+                        </div>
+                    )}
                   </Card>
                 )}
 

+ 195 - 162
web/src/hooks/channels/useChannelsData.jsx

@@ -25,13 +25,9 @@ import {
   showInfo,
   showSuccess,
   loadChannelModels,
-  copy,
+  copy
 } from '../../helpers';
-import {
-  CHANNEL_OPTIONS,
-  ITEMS_PER_PAGE,
-  MODEL_TABLE_PAGE_SIZE,
-} from '../../constants';
+import { CHANNEL_OPTIONS, ITEMS_PER_PAGE, MODEL_TABLE_PAGE_SIZE } from '../../constants';
 import { useIsMobile } from '../common/useIsMobile';
 import { useTableCompactMode } from '../common/useTableCompactMode';
 import { Modal } from '@douyinfe/semi-ui';
@@ -68,7 +64,7 @@ export const useChannelsData = () => {
 
   // Status filter
   const [statusFilter, setStatusFilter] = useState(
-    localStorage.getItem('channel-status-filter') || 'all',
+    localStorage.getItem('channel-status-filter') || 'all'
   );
 
   // Type tabs states
@@ -83,9 +79,10 @@ export const useChannelsData = () => {
   const [testingModels, setTestingModels] = useState(new Set());
   const [selectedModelKeys, setSelectedModelKeys] = useState([]);
   const [isBatchTesting, setIsBatchTesting] = useState(false);
-  const [testQueue, setTestQueue] = useState([]);
-  const [isProcessingQueue, setIsProcessingQueue] = useState(false);
   const [modelTablePage, setModelTablePage] = useState(1);
+  
+  // 使用 ref 来避免闭包问题,类似旧版实现
+  const shouldStopBatchTestingRef = useRef(false);
 
   // Multi-key management states
   const [showMultiKeyManageModal, setShowMultiKeyManageModal] = useState(false);
@@ -119,12 +116,9 @@ export const useChannelsData = () => {
   // Initialize from localStorage
   useEffect(() => {
     const localIdSort = localStorage.getItem('id-sort') === 'true';
-    const localPageSize =
-      parseInt(localStorage.getItem('page-size')) || ITEMS_PER_PAGE;
-    const localEnableTagMode =
-      localStorage.getItem('enable-tag-mode') === 'true';
-    const localEnableBatchDelete =
-      localStorage.getItem('enable-batch-delete') === 'true';
+    const localPageSize = parseInt(localStorage.getItem('page-size')) || ITEMS_PER_PAGE;
+    const localEnableTagMode = localStorage.getItem('enable-tag-mode') === 'true';
+    const localEnableBatchDelete = localStorage.getItem('enable-batch-delete') === 'true';
 
     setIdSort(localIdSort);
     setPageSize(localPageSize);
@@ -182,10 +176,7 @@ export const useChannelsData = () => {
   // Save column preferences
   useEffect(() => {
     if (Object.keys(visibleColumns).length > 0) {
-      localStorage.setItem(
-        'channels-table-columns',
-        JSON.stringify(visibleColumns),
-      );
+      localStorage.setItem('channels-table-columns', JSON.stringify(visibleColumns));
     }
   }, [visibleColumns]);
 
@@ -299,21 +290,14 @@ export const useChannelsData = () => {
     const { searchKeyword, searchGroup, searchModel } = getFormValues();
     if (searchKeyword !== '' || searchGroup !== '' || searchModel !== '') {
       setLoading(true);
-      await searchChannels(
-        enableTagMode,
-        typeKey,
-        statusF,
-        page,
-        pageSize,
-        idSort,
-      );
+      await searchChannels(enableTagMode, typeKey, statusF, page, pageSize, idSort);
       setLoading(false);
       return;
     }
 
     const reqId = ++requestCounter.current;
     setLoading(true);
-    const typeParam = typeKey !== 'all' ? `&type=${typeKey}` : '';
+    const typeParam = (typeKey !== 'all') ? `&type=${typeKey}` : '';
     const statusParam = statusF !== 'all' ? `&status=${statusF}` : '';
     const res = await API.get(
       `/api/channel/?p=${page}&page_size=${pageSize}&id_sort=${idSort}&tag_mode=${enableTagMode}${typeParam}${statusParam}`,
@@ -327,10 +311,7 @@ export const useChannelsData = () => {
     if (success) {
       const { items, total, type_counts } = data;
       if (type_counts) {
-        const sumAll = Object.values(type_counts).reduce(
-          (acc, v) => acc + v,
-          0,
-        );
+        const sumAll = Object.values(type_counts).reduce((acc, v) => acc + v, 0);
         setTypeCounts({ ...type_counts, all: sumAll });
       }
       setChannelFormat(items, enableTagMode);
@@ -354,18 +335,11 @@ export const useChannelsData = () => {
     setSearching(true);
     try {
       if (searchKeyword === '' && searchGroup === '' && searchModel === '') {
-        await loadChannels(
-          page,
-          pageSz,
-          sortFlag,
-          enableTagMode,
-          typeKey,
-          statusF,
-        );
+        await loadChannels(page, pageSz, sortFlag, enableTagMode, typeKey, statusF);
         return;
       }
 
-      const typeParam = typeKey !== 'all' ? `&type=${typeKey}` : '';
+      const typeParam = (typeKey !== 'all') ? `&type=${typeKey}` : '';
       const statusParam = statusF !== 'all' ? `&status=${statusF}` : '';
       const res = await API.get(
         `/api/channel/search?keyword=${searchKeyword}&group=${searchGroup}&model=${searchModel}&id_sort=${sortFlag}&tag_mode=${enableTagMode}&p=${page}&page_size=${pageSz}${typeParam}${statusParam}`,
@@ -373,10 +347,7 @@ export const useChannelsData = () => {
       const { success, message, data } = res.data;
       if (success) {
         const { items = [], total = 0, type_counts = {} } = data;
-        const sumAll = Object.values(type_counts).reduce(
-          (acc, v) => acc + v,
-          0,
-        );
+        const sumAll = Object.values(type_counts).reduce((acc, v) => acc + v, 0);
         setTypeCounts({ ...type_counts, all: sumAll });
         setChannelFormat(items, enableTagMode);
         setChannelCount(total);
@@ -395,14 +366,7 @@ export const useChannelsData = () => {
     if (searchKeyword === '' && searchGroup === '' && searchModel === '') {
       await loadChannels(page, pageSize, idSort, enableTagMode);
     } else {
-      await searchChannels(
-        enableTagMode,
-        activeTypeKey,
-        statusFilter,
-        page,
-        pageSize,
-        idSort,
-      );
+      await searchChannels(enableTagMode, activeTypeKey, statusFilter, page, pageSize, idSort);
     }
   };
 
@@ -488,16 +452,9 @@ export const useChannelsData = () => {
     const { searchKeyword, searchGroup, searchModel } = getFormValues();
     setActivePage(page);
     if (searchKeyword === '' && searchGroup === '' && searchModel === '') {
-      loadChannels(page, pageSize, idSort, enableTagMode).then(() => {});
+      loadChannels(page, pageSize, idSort, enableTagMode).then(() => { });
     } else {
-      searchChannels(
-        enableTagMode,
-        activeTypeKey,
-        statusFilter,
-        page,
-        pageSize,
-        idSort,
-      );
+      searchChannels(enableTagMode, activeTypeKey, statusFilter, page, pageSize, idSort);
     }
   };
 
@@ -513,14 +470,7 @@ export const useChannelsData = () => {
           showError(reason);
         });
     } else {
-      searchChannels(
-        enableTagMode,
-        activeTypeKey,
-        statusFilter,
-        1,
-        size,
-        idSort,
-      );
+      searchChannels(enableTagMode, activeTypeKey, statusFilter, 1, size, idSort);
     }
   };
 
@@ -551,10 +501,7 @@ export const useChannelsData = () => {
         showError(res?.data?.message || t('渠道复制失败'));
       }
     } catch (error) {
-      showError(
-        t('渠道复制失败: ') +
-          (error?.response?.data?.message || error?.message || error),
-      );
+      showError(t('渠道复制失败: ') + (error?.response?.data?.message || error?.message || error));
     }
   };
 
@@ -593,11 +540,7 @@ export const useChannelsData = () => {
         data.priority = parseInt(data.priority);
         break;
       case 'weight':
-        if (
-          data.weight === undefined ||
-          data.weight < 0 ||
-          data.weight === ''
-        ) {
+        if (data.weight === undefined || data.weight < 0 || data.weight === '') {
           showInfo('权重必须是非负整数!');
           return;
         }
@@ -740,136 +683,226 @@ export const useChannelsData = () => {
     const res = await API.post(`/api/channel/fix`);
     const { success, message, data } = res.data;
     if (success) {
-      showSuccess(
-        t('已修复 ${success} 个通道,失败 ${fails} 个通道。')
-          .replace('${success}', data.success)
-          .replace('${fails}', data.fails),
-      );
+      showSuccess(t('已修复 ${success} 个通道,失败 ${fails} 个通道。').replace('${success}', data.success).replace('${fails}', data.fails));
       await refresh();
     } else {
       showError(message);
     }
   };
 
-  // Test channel
+  // Test channel - 单个模型测试,参考旧版实现
   const testChannel = async (record, model) => {
-    setTestQueue((prev) => [...prev, { channel: record, model }]);
-    if (!isProcessingQueue) {
-      setIsProcessingQueue(true);
+    const testKey = `${record.id}-${model}`;
+
+    // 检查是否应该停止批量测试
+    if (shouldStopBatchTestingRef.current && isBatchTesting) {
+      return Promise.resolve();
     }
-  };
 
-  // Process test queue
-  const processTestQueue = async () => {
-    if (!isProcessingQueue || testQueue.length === 0) return;
+    // 添加到正在测试的模型集合
+    setTestingModels(prev => new Set([...prev, model]));
 
-    const { channel, model, indexInFiltered } = testQueue[0];
+    try {
+      const res = await API.get(`/api/channel/test/${record.id}?model=${model}`);
 
-    if (currentTestChannel && currentTestChannel.id === channel.id) {
-      let pageNo;
-      if (indexInFiltered !== undefined) {
-        pageNo = Math.floor(indexInFiltered / MODEL_TABLE_PAGE_SIZE) + 1;
-      } else {
-        const filteredModelsList = currentTestChannel.models
-          .split(',')
-          .filter((m) =>
-            m.toLowerCase().includes(modelSearchKeyword.toLowerCase()),
-          );
-        const modelIdx = filteredModelsList.indexOf(model);
-        pageNo =
-          modelIdx !== -1
-            ? Math.floor(modelIdx / MODEL_TABLE_PAGE_SIZE) + 1
-            : 1;
+      // 检查是否在请求期间被停止
+      if (shouldStopBatchTestingRef.current && isBatchTesting) {
+        return Promise.resolve();
       }
-      setModelTablePage(pageNo);
-    }
 
-    try {
-      setTestingModels((prev) => new Set([...prev, model]));
-      const res = await API.get(
-        `/api/channel/test/${channel.id}?model=${model}`,
-      );
       const { success, message, time } = res.data;
 
-      setModelTestResults((prev) => ({
+      // 更新测试结果
+      setModelTestResults(prev => ({
         ...prev,
-        [`${channel.id}-${model}`]: { success, time },
+        [testKey]: {
+          success,
+          message,
+          time: time || 0,
+          timestamp: Date.now()
+        }
       }));
 
       if (success) {
-        updateChannelProperty(channel.id, (ch) => {
-          ch.response_time = time * 1000;
-          ch.test_time = Date.now() / 1000;
+        // 更新渠道响应时间
+        updateChannelProperty(record.id, (channel) => {
+          channel.response_time = time * 1000;
+          channel.test_time = Date.now() / 1000;
         });
-        if (!model) {
+
+        if (!model || model === '') {
           showInfo(
             t('通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。')
-              .replace('${name}', channel.name)
+              .replace('${name}', record.name)
+              .replace('${time.toFixed(2)}', time.toFixed(2)),
+          );
+        } else {
+          showInfo(
+            t('通道 ${name} 测试成功,模型 ${model} 耗时 ${time.toFixed(2)} 秒。')
+              .replace('${name}', record.name)
+              .replace('${model}', model)
               .replace('${time.toFixed(2)}', time.toFixed(2)),
           );
         }
       } else {
-        showError(message);
+        showError(`${t('模型')} ${model}: ${message}`);
       }
     } catch (error) {
-      showError(error.message);
+      // 处理网络错误
+      const testKey = `${record.id}-${model}`;
+      setModelTestResults(prev => ({
+        ...prev,
+        [testKey]: {
+          success: false,
+          message: error.message || t('网络错误'),
+          time: 0,
+          timestamp: Date.now()
+        }
+      }));
+      showError(`${t('模型')} ${model}: ${error.message || t('测试失败')}`);
     } finally {
-      setTestingModels((prev) => {
+      // 从正在测试的模型集合中移除
+      setTestingModels(prev => {
         const newSet = new Set(prev);
         newSet.delete(model);
         return newSet;
       });
     }
-
-    setTestQueue((prev) => prev.slice(1));
   };
 
-  // Monitor queue changes
-  useEffect(() => {
-    if (testQueue.length > 0 && isProcessingQueue) {
-      processTestQueue();
-    } else if (testQueue.length === 0 && isProcessingQueue) {
-      setIsProcessingQueue(false);
-      setIsBatchTesting(false);
+  // 批量测试单个渠道的所有模型,参考旧版实现
+  const batchTestModels = async () => {
+    if (!currentTestChannel || !currentTestChannel.models) {
+      showError(t('渠道模型信息不完整'));
+      return;
     }
-  }, [testQueue, isProcessingQueue]);
 
-  // Batch test models
-  const batchTestModels = async () => {
-    if (!currentTestChannel) return;
+    const models = currentTestChannel.models.split(',').filter(model =>
+      model.toLowerCase().includes(modelSearchKeyword.toLowerCase())
+    );
+
+    if (models.length === 0) {
+      showError(t('没有找到匹配的模型'));
+      return;
+    }
 
     setIsBatchTesting(true);
-    setModelTablePage(1);
+    shouldStopBatchTestingRef.current = false; // 重置停止标志
+
+    // 清空该渠道之前的测试结果
+    setModelTestResults(prev => {
+      const newResults = { ...prev };
+      models.forEach(model => {
+        const testKey = `${currentTestChannel.id}-${model}`;
+        delete newResults[testKey];
+      });
+      return newResults;
+    });
 
-    const filteredModels = currentTestChannel.models
-      .split(',')
-      .filter((model) =>
-        model.toLowerCase().includes(modelSearchKeyword.toLowerCase()),
-      );
+    try {
+      showInfo(t('开始批量测试 ${count} 个模型,已清空上次结果...').replace('${count}', models.length));
 
-    setTestQueue(
-      filteredModels.map((model, idx) => ({
-        channel: currentTestChannel,
-        model,
-        indexInFiltered: idx,
-      })),
-    );
-    setIsProcessingQueue(true);
+      // 提高并发数量以加快测试速度,参考旧版的并发限制
+      const concurrencyLimit = 5;
+      const results = [];
+
+      for (let i = 0; i < models.length; i += concurrencyLimit) {
+        // 检查是否应该停止
+        if (shouldStopBatchTestingRef.current) {
+          showInfo(t('批量测试已停止'));
+          break;
+        }
+
+        const batch = models.slice(i, i + concurrencyLimit);
+        showInfo(t('正在测试第 ${current} - ${end} 个模型 (共 ${total} 个)')
+          .replace('${current}', i + 1)
+          .replace('${end}', Math.min(i + concurrencyLimit, models.length))
+          .replace('${total}', models.length)
+        );
+
+        const batchPromises = batch.map(model => testChannel(currentTestChannel, model));
+        const batchResults = await Promise.allSettled(batchPromises);
+        results.push(...batchResults);
+
+        // 再次检查是否应该停止
+        if (shouldStopBatchTestingRef.current) {
+          showInfo(t('批量测试已停止'));
+          break;
+        }
+
+        // 短暂延迟避免过于频繁的请求
+        if (i + concurrencyLimit < models.length) {
+          await new Promise(resolve => setTimeout(resolve, 100));
+        }
+      }
+
+      if (!shouldStopBatchTestingRef.current) {
+        // 等待一小段时间确保所有结果都已更新
+        await new Promise(resolve => setTimeout(resolve, 300));
+
+        // 使用当前状态重新计算结果统计
+        setModelTestResults(currentResults => {
+          let successCount = 0;
+          let failCount = 0;
+
+          models.forEach(model => {
+            const testKey = `${currentTestChannel.id}-${model}`;
+            const result = currentResults[testKey];
+            if (result && result.success) {
+              successCount++;
+            } else {
+              failCount++;
+            }
+          });
+
+          // 显示完成消息
+          setTimeout(() => {
+            showSuccess(t('批量测试完成!成功: ${success}, 失败: ${fail}, 总计: ${total}')
+              .replace('${success}', successCount)
+              .replace('${fail}', failCount)
+              .replace('${total}', models.length)
+            );
+          }, 100);
+
+          return currentResults; // 不修改状态,只是为了获取最新值
+        });
+      }
+    } catch (error) {
+      showError(t('批量测试过程中发生错误: ') + error.message);
+    } finally {
+      setIsBatchTesting(false);
+    }
+  };
+
+  // 停止批量测试
+  const stopBatchTesting = () => {
+    shouldStopBatchTestingRef.current = true;
+    setIsBatchTesting(false);
+    setTestingModels(new Set());
+    showInfo(t('已停止批量测试'));
+  };
+
+  // 清空测试结果
+  const clearTestResults = () => {
+    setModelTestResults({});
+    showInfo(t('已清空测试结果'));
   };
 
   // Handle close modal
   const handleCloseModal = () => {
+    // 如果正在批量测试,先停止测试
     if (isBatchTesting) {
-      setTestQueue([]);
-      setIsProcessingQueue(false);
-      setIsBatchTesting(false);
-      showSuccess(t('已停止测试'));
-    } else {
-      setShowModelTestModal(false);
-      setModelSearchKeyword('');
-      setSelectedModelKeys([]);
-      setModelTablePage(1);
+      shouldStopBatchTestingRef.current = true;
+      showInfo(t('关闭弹窗,已停止批量测试'));
     }
+
+    setShowModelTestModal(false);
+    setModelSearchKeyword('');
+    setIsBatchTesting(false);
+    setTestingModels(new Set());
+    setSelectedModelKeys([]);
+    setModelTablePage(1);
+    // 可选择性保留测试结果,这里不清空以便用户查看
   };
 
   // Type counts
@@ -1012,4 +1045,4 @@ export const useChannelsData = () => {
     setCompactMode,
     setActivePage,
   };
-};
+};