Browse Source

Merge pull request #2398 from seefs001/fix/video-proxy

fix: Use channel proxy settings for task query scenarios
Calcium-Ion 1 month ago
parent
commit
e346f0bf16

+ 2 - 1
controller/task.go

@@ -116,9 +116,10 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
 	if adaptor == nil {
 		return errors.New("adaptor not found")
 	}
+	proxy := channel.GetSetting().Proxy
 	resp, err := adaptor.FetchTask(*channel.BaseURL, channel.Key, map[string]any{
 		"ids": taskIds,
-	})
+	}, proxy)
 	if err != nil {
 		common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err))
 		return err

+ 2 - 1
controller/task_video.go

@@ -67,6 +67,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
 	if channel.GetBaseURL() != "" {
 		baseURL = channel.GetBaseURL()
 	}
+	proxy := channel.GetSetting().Proxy
 
 	task := taskM[taskId]
 	if task == nil {
@@ -76,7 +77,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
 	resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
 		"task_id": taskId,
 		"action":  task.Action,
-	})
+	}, proxy)
 	if err != nil {
 		return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
 	}

+ 16 - 3
controller/video_proxy.go

@@ -1,6 +1,7 @@
 package controller
 
 import (
+	"context"
 	"fmt"
 	"io"
 	"net/http"
@@ -10,6 +11,7 @@ import (
 	"github.com/QuantumNous/new-api/constant"
 	"github.com/QuantumNous/new-api/logger"
 	"github.com/QuantumNous/new-api/model"
+	"github.com/QuantumNous/new-api/service"
 
 	"github.com/gin-gonic/gin"
 )
@@ -75,11 +77,22 @@ func VideoProxy(c *gin.Context) {
 	}
 
 	var videoURL string
-	client := &http.Client{
-		Timeout: 60 * time.Second,
+	proxy := channel.GetSetting().Proxy
+	client, err := service.GetHttpClientWithProxy(proxy)
+	if err != nil {
+		logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create proxy client for task %s: %s", taskID, err.Error()))
+		c.JSON(http.StatusInternalServerError, gin.H{
+			"error": gin.H{
+				"message": "Failed to create proxy client",
+				"type":    "server_error",
+			},
+		})
+		return
 	}
 
-	req, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, "", nil)
+	ctx, cancel := context.WithTimeout(c.Request.Context(), 60*time.Second)
+	defer cancel()
+	req, err := http.NewRequestWithContext(ctx, http.MethodGet, "", nil)
 	if err != nil {
 		logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create request: %s", err.Error()))
 		c.JSON(http.StatusInternalServerError, gin.H{

+ 2 - 1
controller/video_proxy_gemini.go

@@ -35,10 +35,11 @@ func getGeminiVideoURL(channel *model.Channel, task *model.Task, apiKey string)
 		return "", fmt.Errorf("api key not available for task")
 	}
 
+	proxy := channel.GetSetting().Proxy
 	resp, err := adaptor.FetchTask(baseURL, apiKey, map[string]any{
 		"task_id": task.TaskID,
 		"action":  task.Action,
-	})
+	}, proxy)
 	if err != nil {
 		return "", fmt.Errorf("fetch task failed: %w", err)
 	}

+ 1 - 1
relay/channel/adapter.go

@@ -47,7 +47,7 @@ type TaskAdaptor interface {
 	GetChannelName() string
 
 	// FetchTask
-	FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error)
+	FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error)
 
 	ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error)
 }

+ 6 - 2
relay/channel/task/ali/adaptor.go

@@ -393,7 +393,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 }
 
 // FetchTask 查询任务状态
-func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
+func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
 	taskID, ok := body["task_id"].(string)
 	if !ok {
 		return nil, fmt.Errorf("invalid task_id")
@@ -408,7 +408,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
 
 	req.Header.Set("Authorization", "Bearer "+key)
 
-	return service.GetHttpClient().Do(req)
+	client, err := service.GetHttpClientWithProxy(proxy)
+	if err != nil {
+		return nil, fmt.Errorf("new proxy http client failed: %w", err)
+	}
+	return client.Do(req)
 }
 
 func (a *TaskAdaptor) GetModelList() []string {

+ 6 - 2
relay/channel/task/doubao/adaptor.go

@@ -146,7 +146,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 }
 
 // FetchTask fetch task status
-func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
+func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
 	taskID, ok := body["task_id"].(string)
 	if !ok {
 		return nil, fmt.Errorf("invalid task_id")
@@ -163,7 +163,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
 	req.Header.Set("Content-Type", "application/json")
 	req.Header.Set("Authorization", "Bearer "+key)
 
-	return service.GetHttpClient().Do(req)
+	client, err := service.GetHttpClientWithProxy(proxy)
+	if err != nil {
+		return nil, fmt.Errorf("new proxy http client failed: %w", err)
+	}
+	return client.Do(req)
 }
 
 func (a *TaskAdaptor) GetModelList() []string {

+ 6 - 2
relay/channel/task/gemini/adaptor.go

@@ -200,7 +200,7 @@ func (a *TaskAdaptor) GetChannelName() string {
 }
 
 // FetchTask fetch task status
-func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
+func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
 	taskID, ok := body["task_id"].(string)
 	if !ok {
 		return nil, fmt.Errorf("invalid task_id")
@@ -223,7 +223,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
 	req.Header.Set("Accept", "application/json")
 	req.Header.Set("x-goog-api-key", key)
 
-	return service.GetHttpClient().Do(req)
+	client, err := service.GetHttpClientWithProxy(proxy)
+	if err != nil {
+		return nil, fmt.Errorf("new proxy http client failed: %w", err)
+	}
+	return client.Do(req)
 }
 
 func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {

+ 6 - 2
relay/channel/task/hailuo/adaptor.go

@@ -110,7 +110,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 	return hResp.TaskID, responseBody, nil
 }
 
-func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
+func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
 	taskID, ok := body["task_id"].(string)
 	if !ok {
 		return nil, fmt.Errorf("invalid task_id")
@@ -126,7 +126,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
 	req.Header.Set("Accept", "application/json")
 	req.Header.Set("Authorization", "Bearer "+key)
 
-	return service.GetHttpClient().Do(req)
+	client, err := service.GetHttpClientWithProxy(proxy)
+	if err != nil {
+		return nil, fmt.Errorf("new proxy http client failed: %w", err)
+	}
+	return client.Do(req)
 }
 
 func (a *TaskAdaptor) GetModelList() []string {

+ 6 - 2
relay/channel/task/jimeng/adaptor.go

@@ -210,7 +210,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 }
 
 // FetchTask fetch task status
-func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
+func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
 	taskID, ok := body["task_id"].(string)
 	if !ok {
 		return nil, fmt.Errorf("invalid task_id")
@@ -251,7 +251,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
 			return nil, errors.Wrap(err, "sign request failed")
 		}
 	}
-	return service.GetHttpClient().Do(req)
+	client, err := service.GetHttpClientWithProxy(proxy)
+	if err != nil {
+		return nil, fmt.Errorf("new proxy http client failed: %w", err)
+	}
+	return client.Do(req)
 }
 
 func (a *TaskAdaptor) GetModelList() []string {

+ 6 - 2
relay/channel/task/kling/adaptor.go

@@ -199,7 +199,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 }
 
 // FetchTask fetch task status
-func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
+func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
 	taskID, ok := body["task_id"].(string)
 	if !ok {
 		return nil, fmt.Errorf("invalid task_id")
@@ -228,7 +228,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
 	req.Header.Set("Authorization", "Bearer "+token)
 	req.Header.Set("User-Agent", "kling-sdk/1.0")
 
-	return service.GetHttpClient().Do(req)
+	client, err := service.GetHttpClientWithProxy(proxy)
+	if err != nil {
+		return nil, fmt.Errorf("new proxy http client failed: %w", err)
+	}
+	return client.Do(req)
 }
 
 func (a *TaskAdaptor) GetModelList() []string {

+ 6 - 2
relay/channel/task/sora/adaptor.go

@@ -125,7 +125,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relayco
 }
 
 // FetchTask fetch task status
-func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
+func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
 	taskID, ok := body["task_id"].(string)
 	if !ok {
 		return nil, fmt.Errorf("invalid task_id")
@@ -140,7 +140,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
 
 	req.Header.Set("Authorization", "Bearer "+key)
 
-	return service.GetHttpClient().Do(req)
+	client, err := service.GetHttpClientWithProxy(proxy)
+	if err != nil {
+		return nil, fmt.Errorf("new proxy http client failed: %w", err)
+	}
+	return client.Do(req)
 }
 
 func (a *TaskAdaptor) GetModelList() []string {

+ 4 - 4
relay/channel/task/suno/adaptor.go

@@ -132,7 +132,7 @@ func (a *TaskAdaptor) GetChannelName() string {
 	return ChannelName
 }
 
-func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
+func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
 	requestUrl := fmt.Sprintf("%s/suno/fetch", baseUrl)
 	byteBody, err := json.Marshal(body)
 	if err != nil {
@@ -153,11 +153,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
 	req = req.WithContext(ctx)
 	req.Header.Set("Content-Type", "application/json")
 	req.Header.Set("Authorization", "Bearer "+key)
-	resp, err := service.GetHttpClient().Do(req)
+	client, err := service.GetHttpClientWithProxy(proxy)
 	if err != nil {
-		return nil, err
+		return nil, fmt.Errorf("new proxy http client failed: %w", err)
 	}
-	return resp, nil
+	return client.Do(req)
 }
 
 func actionValidate(c *gin.Context, sunoRequest *dto.SunoSubmitReq, action string) (err error) {

+ 12 - 4
relay/channel/task/vertex/adaptor.go

@@ -120,7 +120,11 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
 		return fmt.Errorf("failed to decode credentials: %w", err)
 	}
 
-	token, err := vertexcore.AcquireAccessToken(*adc, "")
+	proxy := ""
+	if info != nil {
+		proxy = info.ChannelSetting.Proxy
+	}
+	token, err := vertexcore.AcquireAccessToken(*adc, proxy)
 	if err != nil {
 		return fmt.Errorf("failed to acquire access token: %w", err)
 	}
@@ -216,7 +220,7 @@ func (a *TaskAdaptor) GetModelList() []string { return []string{"veo-3.0-generat
 func (a *TaskAdaptor) GetChannelName() string { return "vertex" }
 
 // FetchTask fetch task status
-func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
+func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
 	taskID, ok := body["task_id"].(string)
 	if !ok {
 		return nil, fmt.Errorf("invalid task_id")
@@ -249,7 +253,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
 	if err := json.Unmarshal([]byte(key), adc); err != nil {
 		return nil, fmt.Errorf("failed to decode credentials: %w", err)
 	}
-	token, err := vertexcore.AcquireAccessToken(*adc, "")
+	token, err := vertexcore.AcquireAccessToken(*adc, proxy)
 	if err != nil {
 		return nil, fmt.Errorf("failed to acquire access token: %w", err)
 	}
@@ -261,7 +265,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
 	req.Header.Set("Accept", "application/json")
 	req.Header.Set("Authorization", "Bearer "+token)
 	req.Header.Set("x-goog-user-project", adc.ProjectID)
-	return service.GetHttpClient().Do(req)
+	client, err := service.GetHttpClientWithProxy(proxy)
+	if err != nil {
+		return nil, fmt.Errorf("new proxy http client failed: %w", err)
+	}
+	return client.Do(req)
 }
 
 func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {

+ 6 - 2
relay/channel/task/vidu/adaptor.go

@@ -188,7 +188,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 	return vResp.TaskId, responseBody, nil
 }
 
-func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
+func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
 	taskID, ok := body["task_id"].(string)
 	if !ok {
 		return nil, fmt.Errorf("invalid task_id")
@@ -204,7 +204,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
 	req.Header.Set("Accept", "application/json")
 	req.Header.Set("Authorization", "Token "+key)
 
-	return service.GetHttpClient().Do(req)
+	client, err := service.GetHttpClientWithProxy(proxy)
+	if err != nil {
+		return nil, fmt.Errorf("new proxy http client failed: %w", err)
+	}
+	return client.Do(req)
 }
 
 func (a *TaskAdaptor) GetModelList() []string {

+ 2 - 1
relay/relay_task.go

@@ -326,6 +326,7 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
 		if channelModel.GetBaseURL() != "" {
 			baseURL = channelModel.GetBaseURL()
 		}
+		proxy := channelModel.GetSetting().Proxy
 		adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
 		if adaptor == nil {
 			return
@@ -333,7 +334,7 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
 		resp, err2 := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
 			"task_id": originTask.TaskID,
 			"action":  originTask.Action,
-		})
+		}, proxy)
 		if err2 != nil || resp == nil {
 			return
 		}

+ 18 - 10
service/http_client.go

@@ -35,9 +35,9 @@ func checkRedirect(req *http.Request, via []*http.Request) error {
 
 func InitHttpClient() {
 	transport := &http.Transport{
-		MaxIdleConns:          common.RelayMaxIdleConns,
-		MaxIdleConnsPerHost:   common.RelayMaxIdleConnsPerHost,
-		ForceAttemptHTTP2:     true,
+		MaxIdleConns:        common.RelayMaxIdleConns,
+		MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost,
+		ForceAttemptHTTP2:   true,
 	}
 
 	if common.RelayTimeout == 0 {
@@ -58,6 +58,14 @@ func GetHttpClient() *http.Client {
 	return httpClient
 }
 
+// GetHttpClientWithProxy returns the default client or a proxy-enabled one when proxyURL is provided.
+func GetHttpClientWithProxy(proxyURL string) (*http.Client, error) {
+	if proxyURL == "" {
+		return GetHttpClient(), nil
+	}
+	return NewProxyHttpClient(proxyURL)
+}
+
 // ResetProxyClientCache 清空代理客户端缓存,确保下次使用时重新初始化
 func ResetProxyClientCache() {
 	proxyClientLock.Lock()
@@ -92,10 +100,10 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
 	case "http", "https":
 		client := &http.Client{
 			Transport: &http.Transport{
-				MaxIdleConns:          common.RelayMaxIdleConns,
-				MaxIdleConnsPerHost:   common.RelayMaxIdleConnsPerHost,
-				ForceAttemptHTTP2:     true,
-				Proxy: http.ProxyURL(parsedURL),
+				MaxIdleConns:        common.RelayMaxIdleConns,
+				MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost,
+				ForceAttemptHTTP2:   true,
+				Proxy:               http.ProxyURL(parsedURL),
 			},
 			CheckRedirect: checkRedirect,
 		}
@@ -127,9 +135,9 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
 
 		client := &http.Client{
 			Transport: &http.Transport{
-				MaxIdleConns:          common.RelayMaxIdleConns,
-				MaxIdleConnsPerHost:   common.RelayMaxIdleConnsPerHost,
-				ForceAttemptHTTP2:     true,
+				MaxIdleConns:        common.RelayMaxIdleConns,
+				MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost,
+				ForceAttemptHTTP2:   true,
 				DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
 					return dialer.Dial(network, addr)
 				},