Răsfoiți Sursa

feat: ssrf支持域名和ip黑白名单过滤模式

creamlike1024 3 luni în urmă
părinte
comite
168ebb1cd4
4 a modificat fișierele cu 74 adăugiri și 133 ștergeri
  1. 70 129
      common/ssrf_protection.go
  2. 2 2
      service/download.go
  3. 1 1
      service/user_notify.go
  4. 1 1
      service/webhook.go

+ 70 - 129
common/ssrf_protection.go

@@ -11,16 +11,20 @@ import (
 // SSRFProtection SSRF防护配置
 type SSRFProtection struct {
 	AllowPrivateIp   bool
-	WhitelistDomains []string // domain format, e.g. example.com, *.example.com
-	WhitelistIps     []string // CIDR format
+	DomainFilterMode bool     // true: 白名单, false: 黑名单
+	DomainList       []string // domain format, e.g. example.com, *.example.com
+	IpFilterMode     bool     // true: 白名单, false: 黑名单
+	IpList           []string // CIDR or single IP
 	AllowedPorts     []int    // 允许的端口范围
 }
 
 // DefaultSSRFProtection 默认SSRF防护配置
 var DefaultSSRFProtection = &SSRFProtection{
 	AllowPrivateIp:   false,
-	WhitelistDomains: []string{},
-	WhitelistIps:     []string{},
+	DomainFilterMode: true,
+	DomainList:       []string{},
+	IpFilterMode:     true,
+	IpList:           []string{},
 	AllowedPorts:     []int{},
 }
 
@@ -138,44 +142,25 @@ func (p *SSRFProtection) isAllowedPort(port int) bool {
 	return false
 }
 
-// isAllowedPortFromRanges 从端口范围字符串检查端口是否被允许
-func isAllowedPortFromRanges(port int, portRanges []string) bool {
-	if len(portRanges) == 0 {
-		return true // 如果没有配置端口限制,则允许所有端口
-	}
-
-	allowedPorts, err := parsePortRanges(portRanges)
-	if err != nil {
-		// 如果解析失败,为安全起见拒绝访问
-		return false
-	}
-
-	for _, allowedPort := range allowedPorts {
-		if port == allowedPort {
-			return true
-		}
-	}
-	return false
-}
-
 // isDomainWhitelisted 检查域名是否在白名单中
-func (p *SSRFProtection) isDomainWhitelisted(domain string) bool {
-	if len(p.WhitelistDomains) == 0 {
+func isDomainListed(domain string, list []string) bool {
+	if len(list) == 0 {
 		return false
 	}
 
 	domain = strings.ToLower(domain)
-	for _, whitelistDomain := range p.WhitelistDomains {
-		whitelistDomain = strings.ToLower(whitelistDomain)
-
+	for _, item := range list {
+		item = strings.ToLower(strings.TrimSpace(item))
+		if item == "" {
+			continue
+		}
 		// 精确匹配
-		if domain == whitelistDomain {
+		if domain == item {
 			return true
 		}
-
 		// 通配符匹配 (*.example.com)
-		if strings.HasPrefix(whitelistDomain, "*.") {
-			suffix := strings.TrimPrefix(whitelistDomain, "*.")
+		if strings.HasPrefix(item, "*.") {
+			suffix := strings.TrimPrefix(item, "*.")
 			if strings.HasSuffix(domain, "."+suffix) || domain == suffix {
 				return true
 			}
@@ -184,13 +169,23 @@ func (p *SSRFProtection) isDomainWhitelisted(domain string) bool {
 	return false
 }
 
+func (p *SSRFProtection) isDomainAllowed(domain string) bool {
+	listed := isDomainListed(domain, p.DomainList)
+	if p.DomainFilterMode { // 白名单
+		return listed
+	}
+	// 黑名单
+	return !listed
+}
+
 // isIPWhitelisted 检查IP是否在白名单中
-func (p *SSRFProtection) isIPWhitelisted(ip net.IP) bool {
-	if len(p.WhitelistIps) == 0 {
+
+func isIPListed(ip net.IP, list []string) bool {
+	if len(list) == 0 {
 		return false
 	}
 
-	for _, whitelistCIDR := range p.WhitelistIps {
+	for _, whitelistCIDR := range list {
 		_, network, err := net.ParseCIDR(whitelistCIDR)
 		if err != nil {
 			// 尝试作为单个IP处理
@@ -211,22 +206,17 @@ func (p *SSRFProtection) isIPWhitelisted(ip net.IP) bool {
 
 // IsIPAccessAllowed 检查IP是否允许访问
 func (p *SSRFProtection) IsIPAccessAllowed(ip net.IP) bool {
-	// 如果IP在白名单中,直接允许访问(绕过私有IP检查)
-	if p.isIPWhitelisted(ip) {
-		return true
+	// 私有IP限制
+	if isPrivateIP(ip) && !p.AllowPrivateIp {
+		return false
 	}
 
-	// 如果IP白名单为空,允许所有IP(但仍需通过私有IP检查)
-	if len(p.WhitelistIps) == 0 {
-		// 检查私有IP限制
-		if isPrivateIP(ip) && !p.AllowPrivateIp {
-			return false
-		}
-		return true
+	listed := isIPListed(ip, p.IpList)
+	if p.IpFilterMode { // 白名单
+		return listed
 	}
-
-	// 如果IP白名单不为空且IP不在白名单中,拒绝访问
-	return false
+	// 黑名单
+	return !listed
 }
 
 // ValidateURL 验证URL是否安全
@@ -264,28 +254,44 @@ func (p *SSRFProtection) ValidateURL(urlStr string) error {
 		return fmt.Errorf("port %d is not allowed", port)
 	}
 
-	// 检查域名白名单
-	if p.isDomainWhitelisted(host) {
-		return nil // 白名单域名直接通过
+	// 如果 host 是 IP,则跳过域名检查
+	if ip := net.ParseIP(host); ip != nil {
+		if !p.IsIPAccessAllowed(ip) {
+			if isPrivateIP(ip) {
+				return fmt.Errorf("private IP address not allowed: %s", ip.String())
+			}
+			if p.IpFilterMode {
+				return fmt.Errorf("ip not in whitelist: %s", ip.String())
+			}
+			return fmt.Errorf("ip in blacklist: %s", ip.String())
+		}
+		return nil
 	}
 
-	// DNS解析获取IP地址
+	// 先进行域名过滤
+	if !p.isDomainAllowed(host) {
+		if p.DomainFilterMode {
+			return fmt.Errorf("domain not in whitelist: %s", host)
+		}
+		return fmt.Errorf("domain in blacklist: %s", host)
+	}
+
+	// 解析域名对应IP并检查
 	ips, err := net.LookupIP(host)
 	if err != nil {
 		return fmt.Errorf("DNS resolution failed for %s: %v", host, err)
 	}
-
-	// 检查所有解析的IP地址
 	for _, ip := range ips {
 		if !p.IsIPAccessAllowed(ip) {
-			if isPrivateIP(ip) {
+			if isPrivateIP(ip) && !p.AllowPrivateIp {
 				return fmt.Errorf("private IP address not allowed: %s resolves to %s", host, ip.String())
-			} else {
-				return fmt.Errorf("IP address not in whitelist: %s resolves to %s", host, ip.String())
 			}
+			if p.IpFilterMode {
+				return fmt.Errorf("ip not in whitelist: %s resolves to %s", host, ip.String())
+			}
+			return fmt.Errorf("ip in blacklist: %s resolves to %s", host, ip.String())
 		}
 	}
-
 	return nil
 }
 
@@ -295,7 +301,7 @@ func ValidateURLWithDefaults(urlStr string) error {
 }
 
 // ValidateURLWithFetchSetting 使用FetchSetting配置验证URL
-func ValidateURLWithFetchSetting(urlStr string, enableSSRFProtection, allowPrivateIp bool, whitelistDomains, whitelistIps, allowedPorts []string) error {
+func ValidateURLWithFetchSetting(urlStr string, enableSSRFProtection, allowPrivateIp bool, domainFilterMode bool, ipFilterMode bool, domainList, ipList, allowedPorts []string) error {
 	// 如果SSRF防护被禁用,直接返回成功
 	if !enableSSRFProtection {
 		return nil
@@ -309,76 +315,11 @@ func ValidateURLWithFetchSetting(urlStr string, enableSSRFProtection, allowPriva
 
 	protection := &SSRFProtection{
 		AllowPrivateIp:   allowPrivateIp,
-		WhitelistDomains: whitelistDomains,
-		WhitelistIps:     whitelistIps,
+		DomainFilterMode: domainFilterMode,
+		DomainList:       domainList,
+		IpFilterMode:     ipFilterMode,
+		IpList:           ipList,
 		AllowedPorts:     allowedPortInts,
 	}
 	return protection.ValidateURL(urlStr)
 }
-
-// ValidateURLWithPortRanges 直接使用端口范围字符串验证URL(更高效的版本)
-func ValidateURLWithPortRanges(urlStr string, allowPrivateIp bool, whitelistDomains, whitelistIps, allowedPorts []string) error {
-	// 解析URL
-	u, err := url.Parse(urlStr)
-	if err != nil {
-		return fmt.Errorf("invalid URL format: %v", err)
-	}
-
-	// 只允许HTTP/HTTPS协议
-	if u.Scheme != "http" && u.Scheme != "https" {
-		return fmt.Errorf("unsupported protocol: %s (only http/https allowed)", u.Scheme)
-	}
-
-	// 解析主机和端口
-	host, portStr, err := net.SplitHostPort(u.Host)
-	if err != nil {
-		// 没有端口,使用默认端口
-		host = u.Host
-		if u.Scheme == "https" {
-			portStr = "443"
-		} else {
-			portStr = "80"
-		}
-	}
-
-	// 验证端口
-	port, err := strconv.Atoi(portStr)
-	if err != nil {
-		return fmt.Errorf("invalid port: %s", portStr)
-	}
-
-	if !isAllowedPortFromRanges(port, allowedPorts) {
-		return fmt.Errorf("port %d is not allowed", port)
-	}
-
-	// 创建临时的SSRFProtection来复用域名和IP检查逻辑
-	protection := &SSRFProtection{
-		AllowPrivateIp:   allowPrivateIp,
-		WhitelistDomains: whitelistDomains,
-		WhitelistIps:     whitelistIps,
-	}
-
-	// 检查域名白名单
-	if protection.isDomainWhitelisted(host) {
-		return nil // 白名单域名直接通过
-	}
-
-	// DNS解析获取IP地址
-	ips, err := net.LookupIP(host)
-	if err != nil {
-		return fmt.Errorf("DNS resolution failed for %s: %v", host, err)
-	}
-
-	// 检查所有解析的IP地址
-	for _, ip := range ips {
-		if !protection.IsIPAccessAllowed(ip) {
-			if isPrivateIP(ip) {
-				return fmt.Errorf("private IP address not allowed: %s resolves to %s", host, ip.String())
-			} else {
-				return fmt.Errorf("IP address not in whitelist: %s resolves to %s", host, ip.String())
-			}
-		}
-	}
-
-	return nil
-}

+ 2 - 2
service/download.go

@@ -30,7 +30,7 @@ func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) {
 
 	// SSRF防护:验证请求URL
 	fetchSetting := system_setting.GetFetchSetting()
-	if err := common.ValidateURLWithFetchSetting(req.URL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts); err != nil {
+	if err := common.ValidateURLWithFetchSetting(req.URL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts); err != nil {
 		return nil, fmt.Errorf("request reject: %v", err)
 	}
 
@@ -59,7 +59,7 @@ func DoDownloadRequest(originUrl string, reason ...string) (resp *http.Response,
 	} else {
 		// SSRF防护:验证请求URL(非Worker模式)
 		fetchSetting := system_setting.GetFetchSetting()
-		if err := common.ValidateURLWithFetchSetting(originUrl, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts); err != nil {
+		if err := common.ValidateURLWithFetchSetting(originUrl, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts); err != nil {
 			return nil, fmt.Errorf("request reject: %v", err)
 		}
 

+ 1 - 1
service/user_notify.go

@@ -115,7 +115,7 @@ func sendBarkNotify(barkURL string, data dto.Notify) error {
 	} else {
 		// SSRF防护:验证Bark URL(非Worker模式)
 		fetchSetting := system_setting.GetFetchSetting()
-		if err := common.ValidateURLWithFetchSetting(finalURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts); err != nil {
+		if err := common.ValidateURLWithFetchSetting(finalURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts); err != nil {
 			return fmt.Errorf("request reject: %v", err)
 		}
 

+ 1 - 1
service/webhook.go

@@ -89,7 +89,7 @@ func SendWebhookNotify(webhookURL string, secret string, data dto.Notify) error
 	} else {
 		// SSRF防护:验证Webhook URL(非Worker模式)
 		fetchSetting := system_setting.GetFetchSetting()
-		if err := common.ValidateURLWithFetchSetting(webhookURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts); err != nil {
+		if err := common.ValidateURLWithFetchSetting(webhookURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts); err != nil {
 			return fmt.Errorf("request reject: %v", err)
 		}