| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311 |
- package common
- import (
- "fmt"
- "net"
- "net/url"
- "strconv"
- "strings"
- )
- // SSRFProtection SSRF防护配置
- type SSRFProtection struct {
- AllowPrivateIp bool
- DomainFilterMode bool // true: 白名单, false: 黑名单
- DomainList []string // domain format, e.g. example.com, *.example.com
- IpFilterMode bool // true: 白名单, false: 黑名单
- IpList []string // CIDR or single IP
- AllowedPorts []int // 允许的端口范围
- ApplyIPFilterForDomain bool // 对域名启用IP过滤
- }
- // DefaultSSRFProtection 默认SSRF防护配置
- var DefaultSSRFProtection = &SSRFProtection{
- AllowPrivateIp: false,
- DomainFilterMode: true,
- DomainList: []string{},
- IpFilterMode: true,
- IpList: []string{},
- AllowedPorts: []int{},
- }
- // isPrivateIP 检查IP是否为私有地址
- func isPrivateIP(ip net.IP) bool {
- if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
- return true
- }
- // 检查私有网段
- private := []net.IPNet{
- {IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 10.0.0.0/8
- {IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)}, // 172.16.0.0/12
- {IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)}, // 192.168.0.0/16
- {IP: net.IPv4(127, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 127.0.0.0/8
- {IP: net.IPv4(169, 254, 0, 0), Mask: net.CIDRMask(16, 32)}, // 169.254.0.0/16 (链路本地)
- {IP: net.IPv4(224, 0, 0, 0), Mask: net.CIDRMask(4, 32)}, // 224.0.0.0/4 (组播)
- {IP: net.IPv4(240, 0, 0, 0), Mask: net.CIDRMask(4, 32)}, // 240.0.0.0/4 (保留)
- }
- for _, privateNet := range private {
- if privateNet.Contains(ip) {
- return true
- }
- }
- // 检查IPv6私有地址
- if ip.To4() == nil {
- // IPv6 loopback
- if ip.Equal(net.IPv6loopback) {
- return true
- }
- // IPv6 link-local
- if strings.HasPrefix(ip.String(), "fe80:") {
- return true
- }
- // IPv6 unique local
- if strings.HasPrefix(ip.String(), "fc") || strings.HasPrefix(ip.String(), "fd") {
- return true
- }
- }
- return false
- }
- // parsePortRanges 解析端口范围配置
- // 支持格式: "80", "443", "8000-9000"
- func parsePortRanges(portConfigs []string) ([]int, error) {
- var ports []int
- for _, config := range portConfigs {
- config = strings.TrimSpace(config)
- if config == "" {
- continue
- }
- if strings.Contains(config, "-") {
- // 处理端口范围 "8000-9000"
- parts := strings.Split(config, "-")
- if len(parts) != 2 {
- return nil, fmt.Errorf("invalid port range format: %s", config)
- }
- startPort, err := strconv.Atoi(strings.TrimSpace(parts[0]))
- if err != nil {
- return nil, fmt.Errorf("invalid start port in range %s: %v", config, err)
- }
- endPort, err := strconv.Atoi(strings.TrimSpace(parts[1]))
- if err != nil {
- return nil, fmt.Errorf("invalid end port in range %s: %v", config, err)
- }
- if startPort > endPort {
- return nil, fmt.Errorf("invalid port range %s: start port cannot be greater than end port", config)
- }
- if startPort < 1 || startPort > 65535 || endPort < 1 || endPort > 65535 {
- return nil, fmt.Errorf("port range %s contains invalid port numbers (must be 1-65535)", config)
- }
- // 添加范围内的所有端口
- for port := startPort; port <= endPort; port++ {
- ports = append(ports, port)
- }
- } else {
- // 处理单个端口 "80"
- port, err := strconv.Atoi(config)
- if err != nil {
- return nil, fmt.Errorf("invalid port number: %s", config)
- }
- if port < 1 || port > 65535 {
- return nil, fmt.Errorf("invalid port number %d (must be 1-65535)", port)
- }
- ports = append(ports, port)
- }
- }
- return ports, nil
- }
- // isAllowedPort 检查端口是否被允许
- func (p *SSRFProtection) isAllowedPort(port int) bool {
- if len(p.AllowedPorts) == 0 {
- return true // 如果没有配置端口限制,则允许所有端口
- }
- for _, allowedPort := range p.AllowedPorts {
- if port == allowedPort {
- return true
- }
- }
- return false
- }
- // isDomainWhitelisted 检查域名是否在白名单中
- func isDomainListed(domain string, list []string) bool {
- if len(list) == 0 {
- return false
- }
- domain = strings.ToLower(domain)
- for _, item := range list {
- item = strings.ToLower(strings.TrimSpace(item))
- if item == "" {
- continue
- }
- // 精确匹配
- if domain == item {
- return true
- }
- // 通配符匹配 (*.example.com)
- if strings.HasPrefix(item, "*.") {
- suffix := strings.TrimPrefix(item, "*.")
- if strings.HasSuffix(domain, "."+suffix) || domain == suffix {
- return true
- }
- }
- }
- return false
- }
- func (p *SSRFProtection) isDomainAllowed(domain string) bool {
- listed := isDomainListed(domain, p.DomainList)
- if p.DomainFilterMode { // 白名单
- return listed
- }
- // 黑名单
- return !listed
- }
- // isIPWhitelisted 检查IP是否在白名单中
- func isIPListed(ip net.IP, list []string) bool {
- if len(list) == 0 {
- return false
- }
- return IsIpInCIDRList(ip, list)
- }
- // IsIPAccessAllowed 检查IP是否允许访问
- func (p *SSRFProtection) IsIPAccessAllowed(ip net.IP) bool {
- // 私有IP限制
- if isPrivateIP(ip) && !p.AllowPrivateIp {
- return false
- }
- listed := isIPListed(ip, p.IpList)
- if p.IpFilterMode { // 白名单
- return listed
- }
- // 黑名单
- return !listed
- }
- // ValidateURL 验证URL是否安全
- func (p *SSRFProtection) ValidateURL(urlStr string) error {
- // 解析URL
- u, err := url.Parse(urlStr)
- if err != nil {
- return fmt.Errorf("invalid URL format: %v", err)
- }
- // 只允许HTTP/HTTPS协议
- if u.Scheme != "http" && u.Scheme != "https" {
- return fmt.Errorf("unsupported protocol: %s (only http/https allowed)", u.Scheme)
- }
- // 解析主机和端口
- host, portStr, err := net.SplitHostPort(u.Host)
- if err != nil {
- // 没有端口,使用默认端口
- host = u.Hostname()
- if u.Scheme == "https" {
- portStr = "443"
- } else {
- portStr = "80"
- }
- }
- // 验证端口
- port, err := strconv.Atoi(portStr)
- if err != nil {
- return fmt.Errorf("invalid port: %s", portStr)
- }
- if !p.isAllowedPort(port) {
- return fmt.Errorf("port %d is not allowed", port)
- }
- // 如果 host 是 IP,则跳过域名检查
- if ip := net.ParseIP(host); ip != nil {
- if !p.IsIPAccessAllowed(ip) {
- if isPrivateIP(ip) {
- return fmt.Errorf("private IP address not allowed: %s", ip.String())
- }
- if p.IpFilterMode {
- return fmt.Errorf("ip not in whitelist: %s", ip.String())
- }
- return fmt.Errorf("ip in blacklist: %s", ip.String())
- }
- return nil
- }
- // 先进行域名过滤
- if !p.isDomainAllowed(host) {
- if p.DomainFilterMode {
- return fmt.Errorf("domain not in whitelist: %s", host)
- }
- return fmt.Errorf("domain in blacklist: %s", host)
- }
- // 若未启用对域名应用IP过滤,则到此通过
- if !p.ApplyIPFilterForDomain {
- return nil
- }
- // 解析域名对应IP并检查
- ips, err := net.LookupIP(host)
- if err != nil {
- return fmt.Errorf("DNS resolution failed for %s: %v", host, err)
- }
- for _, ip := range ips {
- if !p.IsIPAccessAllowed(ip) {
- if isPrivateIP(ip) && !p.AllowPrivateIp {
- return fmt.Errorf("private IP address not allowed: %s resolves to %s", host, ip.String())
- }
- if p.IpFilterMode {
- return fmt.Errorf("ip not in whitelist: %s resolves to %s", host, ip.String())
- }
- return fmt.Errorf("ip in blacklist: %s resolves to %s", host, ip.String())
- }
- }
- return nil
- }
- // ValidateURLWithFetchSetting 使用FetchSetting配置验证URL
- func ValidateURLWithFetchSetting(urlStr string, enableSSRFProtection, allowPrivateIp bool, domainFilterMode bool, ipFilterMode bool, domainList, ipList, allowedPorts []string, applyIPFilterForDomain bool) error {
- // 如果SSRF防护被禁用,直接返回成功
- if !enableSSRFProtection {
- return nil
- }
- // 解析端口范围配置
- allowedPortInts, err := parsePortRanges(allowedPorts)
- if err != nil {
- return fmt.Errorf("request reject - invalid port configuration: %v", err)
- }
- protection := &SSRFProtection{
- AllowPrivateIp: allowPrivateIp,
- DomainFilterMode: domainFilterMode,
- DomainList: domainList,
- IpFilterMode: ipFilterMode,
- IpList: ipList,
- AllowedPorts: allowedPortInts,
- ApplyIPFilterForDomain: applyIPFilterForDomain,
- }
- return protection.ValidateURL(urlStr)
- }
|