Просмотр исходного кода

Improve random IP compatibility: support IPv4, add srcip option, and sync client source IP via sendthrough (#4671)

Aubrey Yang 8 месяцев назад
Родитель
Сommit
5e6a5ae01d
1 измененных файлов с 43 добавлено и 22 удалено
  1. 43 22
      app/proxyman/outbound/handler.go

+ 43 - 22
app/proxyman/outbound/handler.go

@@ -241,7 +241,9 @@ func (h *Handler) DestIpAddress() net.IP {
 // Dial implements internet.Dialer.
 func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connection, error) {
 	if h.senderSettings != nil {
+
 		if h.senderSettings.ProxySettings.HasTag() {
+
 			tag := h.senderSettings.ProxySettings.Tag
 			handler := h.outboundManager.GetHandler(tag)
 			if handler != nil {
@@ -270,22 +272,40 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connecti
 		}
 
 		if h.senderSettings.Via != nil {
+
 			outbounds := session.OutboundsFromContext(ctx)
 			ob := outbounds[len(outbounds)-1]
-			if h.senderSettings.ViaCidr == "" {
-				if h.senderSettings.Via.AsAddress().Family().IsDomain() && h.senderSettings.Via.AsAddress().Domain() == "origin" {
-					if inbound := session.InboundFromContext(ctx); inbound != nil {
-						origin, _, err := net.SplitHostPort(inbound.Conn.LocalAddr().String())
-						if err == nil {
-							ob.Gateway = net.ParseAddress(origin)
-						}
+			addr := h.senderSettings.Via.AsAddress()
+			var domain string
+			if addr.Family().IsDomain() {
+				domain = addr.Domain()
+			}
+			switch {
+			case h.senderSettings.ViaCidr != "":
+				ob.Gateway = ParseRandomIP(addr, h.senderSettings.ViaCidr)
+
+			case domain == "origin":
+
+				if inbound := session.InboundFromContext(ctx); inbound != nil {
+					origin, _, err := net.SplitHostPort(inbound.Conn.LocalAddr().String())
+					if err == nil {
+						ob.Gateway = net.ParseAddress(origin)
+					}
+
+				}
+			case domain == "srcip":
+				if inbound := session.InboundFromContext(ctx); inbound != nil {
+					srcip, _, err := net.SplitHostPort(inbound.Conn.RemoteAddr().String())
+					if err == nil {
+						ob.Gateway = net.ParseAddress(srcip)
 					}
-				} else {
-					ob.Gateway = h.senderSettings.Via.AsAddress()
 				}
-			} else { //Get a random address.
-				ob.Gateway = ParseRandomIPv6(h.senderSettings.Via.AsAddress(), h.senderSettings.ViaCidr)
+			//case addr.Family().IsDomain():
+			default:
+				ob.Gateway = addr
+
 			}
+
 		}
 	}
 
@@ -329,20 +349,21 @@ func (h *Handler) Close() error {
 	return nil
 }
 
-func ParseRandomIPv6(address net.Address, prefix string) net.Address {
-	_, network, _ := gonet.ParseCIDR(address.IP().String() + "/" + prefix)
+func ParseRandomIP(addr net.Address, prefix string) net.Address {
+
+	_, ipnet, _ := gonet.ParseCIDR(addr.IP().String() + "/" + prefix)
 
-	maskSize, totalBits := network.Mask.Size()
-	subnetSize := big.NewInt(1).Lsh(big.NewInt(1), uint(totalBits-maskSize))
+	ones, bits := ipnet.Mask.Size()
+	subnetSize := new(big.Int).Lsh(big.NewInt(1), uint(bits-ones))
 
-	// random
-	randomBigInt, _ := rand.Int(rand.Reader, subnetSize)
+	rnd, _ := rand.Int(rand.Reader, subnetSize)
 
-	startIPBigInt := big.NewInt(0).SetBytes(network.IP.To16())
-	randomIPBigInt := big.NewInt(0).Add(startIPBigInt, randomBigInt)
+	startInt := new(big.Int).SetBytes(ipnet.IP)
+	rndInt := new(big.Int).Add(startInt, rnd)
 
-	randomIPBytes := randomIPBigInt.Bytes()
-	randomIPBytes = append(make([]byte, 16-len(randomIPBytes)), randomIPBytes...)
+	rndBytes := rndInt.Bytes()
+	padded := make([]byte, len(ipnet.IP))
+	copy(padded[len(padded)-len(rndBytes):], rndBytes)
 
-	return net.ParseAddress(gonet.IP(randomIPBytes).String())
+	return net.ParseAddress(gonet.IP(padded).String())
 }