ssrf_protection.go 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. package common
  2. import (
  3. "fmt"
  4. "net"
  5. "net/url"
  6. "strconv"
  7. "strings"
  8. )
  9. // SSRFProtection SSRF防护配置
  10. type SSRFProtection struct {
  11. AllowPrivateIp bool
  12. DomainFilterMode bool // true: 白名单, false: 黑名单
  13. DomainList []string // domain format, e.g. example.com, *.example.com
  14. IpFilterMode bool // true: 白名单, false: 黑名单
  15. IpList []string // CIDR or single IP
  16. AllowedPorts []int // 允许的端口范围
  17. ApplyIPFilterForDomain bool // 对域名启用IP过滤
  18. }
  19. // DefaultSSRFProtection 默认SSRF防护配置
  20. var DefaultSSRFProtection = &SSRFProtection{
  21. AllowPrivateIp: false,
  22. DomainFilterMode: true,
  23. DomainList: []string{},
  24. IpFilterMode: true,
  25. IpList: []string{},
  26. AllowedPorts: []int{},
  27. }
  28. // isPrivateIP 检查IP是否为私有地址
  29. func isPrivateIP(ip net.IP) bool {
  30. if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
  31. return true
  32. }
  33. // 检查私有网段
  34. private := []net.IPNet{
  35. {IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 10.0.0.0/8
  36. {IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)}, // 172.16.0.0/12
  37. {IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)}, // 192.168.0.0/16
  38. {IP: net.IPv4(127, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 127.0.0.0/8
  39. {IP: net.IPv4(169, 254, 0, 0), Mask: net.CIDRMask(16, 32)}, // 169.254.0.0/16 (链路本地)
  40. {IP: net.IPv4(224, 0, 0, 0), Mask: net.CIDRMask(4, 32)}, // 224.0.0.0/4 (组播)
  41. {IP: net.IPv4(240, 0, 0, 0), Mask: net.CIDRMask(4, 32)}, // 240.0.0.0/4 (保留)
  42. }
  43. for _, privateNet := range private {
  44. if privateNet.Contains(ip) {
  45. return true
  46. }
  47. }
  48. // 检查IPv6私有地址
  49. if ip.To4() == nil {
  50. // IPv6 loopback
  51. if ip.Equal(net.IPv6loopback) {
  52. return true
  53. }
  54. // IPv6 link-local
  55. if strings.HasPrefix(ip.String(), "fe80:") {
  56. return true
  57. }
  58. // IPv6 unique local
  59. if strings.HasPrefix(ip.String(), "fc") || strings.HasPrefix(ip.String(), "fd") {
  60. return true
  61. }
  62. }
  63. return false
  64. }
  65. // parsePortRanges 解析端口范围配置
  66. // 支持格式: "80", "443", "8000-9000"
  67. func parsePortRanges(portConfigs []string) ([]int, error) {
  68. var ports []int
  69. for _, config := range portConfigs {
  70. config = strings.TrimSpace(config)
  71. if config == "" {
  72. continue
  73. }
  74. if strings.Contains(config, "-") {
  75. // 处理端口范围 "8000-9000"
  76. parts := strings.Split(config, "-")
  77. if len(parts) != 2 {
  78. return nil, fmt.Errorf("invalid port range format: %s", config)
  79. }
  80. startPort, err := strconv.Atoi(strings.TrimSpace(parts[0]))
  81. if err != nil {
  82. return nil, fmt.Errorf("invalid start port in range %s: %v", config, err)
  83. }
  84. endPort, err := strconv.Atoi(strings.TrimSpace(parts[1]))
  85. if err != nil {
  86. return nil, fmt.Errorf("invalid end port in range %s: %v", config, err)
  87. }
  88. if startPort > endPort {
  89. return nil, fmt.Errorf("invalid port range %s: start port cannot be greater than end port", config)
  90. }
  91. if startPort < 1 || startPort > 65535 || endPort < 1 || endPort > 65535 {
  92. return nil, fmt.Errorf("port range %s contains invalid port numbers (must be 1-65535)", config)
  93. }
  94. // 添加范围内的所有端口
  95. for port := startPort; port <= endPort; port++ {
  96. ports = append(ports, port)
  97. }
  98. } else {
  99. // 处理单个端口 "80"
  100. port, err := strconv.Atoi(config)
  101. if err != nil {
  102. return nil, fmt.Errorf("invalid port number: %s", config)
  103. }
  104. if port < 1 || port > 65535 {
  105. return nil, fmt.Errorf("invalid port number %d (must be 1-65535)", port)
  106. }
  107. ports = append(ports, port)
  108. }
  109. }
  110. return ports, nil
  111. }
  112. // isAllowedPort 检查端口是否被允许
  113. func (p *SSRFProtection) isAllowedPort(port int) bool {
  114. if len(p.AllowedPorts) == 0 {
  115. return true // 如果没有配置端口限制,则允许所有端口
  116. }
  117. for _, allowedPort := range p.AllowedPorts {
  118. if port == allowedPort {
  119. return true
  120. }
  121. }
  122. return false
  123. }
  124. // isDomainWhitelisted 检查域名是否在白名单中
  125. func isDomainListed(domain string, list []string) bool {
  126. if len(list) == 0 {
  127. return false
  128. }
  129. domain = strings.ToLower(domain)
  130. for _, item := range list {
  131. item = strings.ToLower(strings.TrimSpace(item))
  132. if item == "" {
  133. continue
  134. }
  135. // 精确匹配
  136. if domain == item {
  137. return true
  138. }
  139. // 通配符匹配 (*.example.com)
  140. if strings.HasPrefix(item, "*.") {
  141. suffix := strings.TrimPrefix(item, "*.")
  142. if strings.HasSuffix(domain, "."+suffix) || domain == suffix {
  143. return true
  144. }
  145. }
  146. }
  147. return false
  148. }
  149. func (p *SSRFProtection) isDomainAllowed(domain string) bool {
  150. listed := isDomainListed(domain, p.DomainList)
  151. if p.DomainFilterMode { // 白名单
  152. return listed
  153. }
  154. // 黑名单
  155. return !listed
  156. }
  157. // isIPWhitelisted 检查IP是否在白名单中
  158. func isIPListed(ip net.IP, list []string) bool {
  159. if len(list) == 0 {
  160. return false
  161. }
  162. for _, whitelistCIDR := range list {
  163. _, network, err := net.ParseCIDR(whitelistCIDR)
  164. if err != nil {
  165. // 尝试作为单个IP处理
  166. if whitelistIP := net.ParseIP(whitelistCIDR); whitelistIP != nil {
  167. if ip.Equal(whitelistIP) {
  168. return true
  169. }
  170. }
  171. continue
  172. }
  173. if network.Contains(ip) {
  174. return true
  175. }
  176. }
  177. return false
  178. }
  179. // IsIPAccessAllowed 检查IP是否允许访问
  180. func (p *SSRFProtection) IsIPAccessAllowed(ip net.IP) bool {
  181. // 私有IP限制
  182. if isPrivateIP(ip) && !p.AllowPrivateIp {
  183. return false
  184. }
  185. listed := isIPListed(ip, p.IpList)
  186. if p.IpFilterMode { // 白名单
  187. return listed
  188. }
  189. // 黑名单
  190. return !listed
  191. }
  192. // ValidateURL 验证URL是否安全
  193. func (p *SSRFProtection) ValidateURL(urlStr string) error {
  194. // 解析URL
  195. u, err := url.Parse(urlStr)
  196. if err != nil {
  197. return fmt.Errorf("invalid URL format: %v", err)
  198. }
  199. // 只允许HTTP/HTTPS协议
  200. if u.Scheme != "http" && u.Scheme != "https" {
  201. return fmt.Errorf("unsupported protocol: %s (only http/https allowed)", u.Scheme)
  202. }
  203. // 解析主机和端口
  204. host, portStr, err := net.SplitHostPort(u.Host)
  205. if err != nil {
  206. // 没有端口,使用默认端口
  207. host = u.Hostname()
  208. if u.Scheme == "https" {
  209. portStr = "443"
  210. } else {
  211. portStr = "80"
  212. }
  213. }
  214. // 验证端口
  215. port, err := strconv.Atoi(portStr)
  216. if err != nil {
  217. return fmt.Errorf("invalid port: %s", portStr)
  218. }
  219. if !p.isAllowedPort(port) {
  220. return fmt.Errorf("port %d is not allowed", port)
  221. }
  222. // 如果 host 是 IP,则跳过域名检查
  223. if ip := net.ParseIP(host); ip != nil {
  224. if !p.IsIPAccessAllowed(ip) {
  225. if isPrivateIP(ip) {
  226. return fmt.Errorf("private IP address not allowed: %s", ip.String())
  227. }
  228. if p.IpFilterMode {
  229. return fmt.Errorf("ip not in whitelist: %s", ip.String())
  230. }
  231. return fmt.Errorf("ip in blacklist: %s", ip.String())
  232. }
  233. return nil
  234. }
  235. // 先进行域名过滤
  236. if !p.isDomainAllowed(host) {
  237. if p.DomainFilterMode {
  238. return fmt.Errorf("domain not in whitelist: %s", host)
  239. }
  240. return fmt.Errorf("domain in blacklist: %s", host)
  241. }
  242. // 若未启用对域名应用IP过滤,则到此通过
  243. if !p.ApplyIPFilterForDomain {
  244. return nil
  245. }
  246. // 解析域名对应IP并检查
  247. ips, err := net.LookupIP(host)
  248. if err != nil {
  249. return fmt.Errorf("DNS resolution failed for %s: %v", host, err)
  250. }
  251. for _, ip := range ips {
  252. if !p.IsIPAccessAllowed(ip) {
  253. if isPrivateIP(ip) && !p.AllowPrivateIp {
  254. return fmt.Errorf("private IP address not allowed: %s resolves to %s", host, ip.String())
  255. }
  256. if p.IpFilterMode {
  257. return fmt.Errorf("ip not in whitelist: %s resolves to %s", host, ip.String())
  258. }
  259. return fmt.Errorf("ip in blacklist: %s resolves to %s", host, ip.String())
  260. }
  261. }
  262. return nil
  263. }
  264. // ValidateURLWithFetchSetting 使用FetchSetting配置验证URL
  265. func ValidateURLWithFetchSetting(urlStr string, enableSSRFProtection, allowPrivateIp bool, domainFilterMode bool, ipFilterMode bool, domainList, ipList, allowedPorts []string, applyIPFilterForDomain bool) error {
  266. // 如果SSRF防护被禁用,直接返回成功
  267. if !enableSSRFProtection {
  268. return nil
  269. }
  270. // 解析端口范围配置
  271. allowedPortInts, err := parsePortRanges(allowedPorts)
  272. if err != nil {
  273. return fmt.Errorf("request reject - invalid port configuration: %v", err)
  274. }
  275. protection := &SSRFProtection{
  276. AllowPrivateIp: allowPrivateIp,
  277. DomainFilterMode: domainFilterMode,
  278. DomainList: domainList,
  279. IpFilterMode: ipFilterMode,
  280. IpList: ipList,
  281. AllowedPorts: allowedPortInts,
  282. ApplyIPFilterForDomain: applyIPFilterForDomain,
  283. }
  284. return protection.ValidateURL(urlStr)
  285. }