Browse Source

feat: enhance HTTP client with custom redirect handling and SSRF protection

CaIon 3 months ago
parent
commit
7b732ec4b7
2 changed files with 22 additions and 4 deletions
  1. 2 2
      service/download.go
  2. 20 2
      service/http_client.go

+ 2 - 2
service/download.go

@@ -45,7 +45,7 @@ func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) {
 		return nil, fmt.Errorf("failed to marshal worker payload: %v", err)
 	}
 
-	return http.Post(workerUrl, "application/json", bytes.NewBuffer(workerPayload))
+	return GetHttpClient().Post(workerUrl, "application/json", bytes.NewBuffer(workerPayload))
 }
 
 func DoDownloadRequest(originUrl string, reason ...string) (resp *http.Response, err error) {
@@ -64,6 +64,6 @@ func DoDownloadRequest(originUrl string, reason ...string) (resp *http.Response,
 		}
 
 		common.SysLog(fmt.Sprintf("downloading from origin: %s, reason: %s", common.MaskSensitiveInfo(originUrl), strings.Join(reason, ", ")))
-		return http.Get(originUrl)
+		return GetHttpClient().Get(originUrl)
 	}
 }

+ 20 - 2
service/http_client.go

@@ -7,6 +7,7 @@ import (
 	"net/http"
 	"net/url"
 	"one-api/common"
+	"one-api/setting/system_setting"
 	"sync"
 	"time"
 
@@ -19,12 +20,27 @@ var (
 	proxyClients    = make(map[string]*http.Client)
 )
 
+func checkRedirect(req *http.Request, via []*http.Request) error {
+	fetchSetting := system_setting.GetFetchSetting()
+	urlStr := req.URL.String()
+	if err := common.ValidateURLWithFetchSetting(urlStr, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil {
+		return fmt.Errorf("redirect to %s blocked: %v", urlStr, err)
+	}
+	if len(via) >= 10 {
+		return fmt.Errorf("stopped after 10 redirects")
+	}
+	return nil
+}
+
 func InitHttpClient() {
 	if common.RelayTimeout == 0 {
-		httpClient = &http.Client{}
+		httpClient = &http.Client{
+			CheckRedirect: checkRedirect,
+		}
 	} else {
 		httpClient = &http.Client{
-			Timeout: time.Duration(common.RelayTimeout) * time.Second,
+			Timeout:       time.Duration(common.RelayTimeout) * time.Second,
+			CheckRedirect: checkRedirect,
 		}
 	}
 }
@@ -69,6 +85,7 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
 			Transport: &http.Transport{
 				Proxy: http.ProxyURL(parsedURL),
 			},
+			CheckRedirect: checkRedirect,
 		}
 		client.Timeout = time.Duration(common.RelayTimeout) * time.Second
 		proxyClientLock.Lock()
@@ -102,6 +119,7 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
 					return dialer.Dial(network, addr)
 				},
 			},
+			CheckRedirect: checkRedirect,
 		}
 		client.Timeout = time.Duration(common.RelayTimeout) * time.Second
 		proxyClientLock.Lock()