Browse Source

ipn/ipnlocal: add ProxyProtocol support to VIP service TCP handler (#18175)

tcpHandlerForVIPService was missing ProxyProtocol support that
tcpHandlerForServe already had. Extract the shared logic into
forwardTCPWithProxyProtocol helper and use it in both handlers.

Fixes #18172

Signed-off-by: Raj Singh <[email protected]>
Raj Singh 2 months ago
parent
commit
65182f2119
1 changed files with 81 additions and 90 deletions
  1. 81 90
      ipn/ipnlocal/serve.go

+ 81 - 90
ipn/ipnlocal/serve.go

@@ -591,16 +591,7 @@ func (b *LocalBackend) tcpHandlerForVIPService(dstAddr, srcAddr netip.AddrPort)
 				})
 			}
 
-			errc := make(chan error, 1)
-			go func() {
-				_, err := io.Copy(backConn, conn)
-				errc <- err
-			}()
-			go func() {
-				_, err := io.Copy(conn, backConn)
-				errc <- err
-			}()
-			return <-errc
+			return b.forwardTCPWithProxyProtocol(conn, backConn, tcph.ProxyProtocol(), srcAddr, dport, backDst)
 		}
 	}
 
@@ -678,93 +669,93 @@ func (b *LocalBackend) tcpHandlerForServe(dport uint16, srcAddr netip.AddrPort,
 				})
 			}
 
-			var proxyHeader []byte
-			if ver := tcph.ProxyProtocol(); ver > 0 {
-				// backAddr is the final "destination" of the connection,
-				// which is the connection to the proxied-to backend.
-				backAddr := backConn.RemoteAddr().(*net.TCPAddr)
-
-				// We always want to format the PROXY protocol
-				// header based on the IPv4 or IPv6-ness of
-				// the client. The SourceAddr and
-				// DestinationAddr need to match in type, so we
-				// need to be careful to not e.g. set a
-				// SourceAddr of type IPv6 and DestinationAddr
-				// of type IPv4.
-				//
-				// If this is an IPv6-mapped IPv4 address,
-				// though, unmap it.
-				proxySrcAddr := srcAddr
-				if proxySrcAddr.Addr().Is4In6() {
-					proxySrcAddr = netip.AddrPortFrom(
-						proxySrcAddr.Addr().Unmap(),
-						proxySrcAddr.Port(),
-					)
-				}
-
-				is4 := proxySrcAddr.Addr().Is4()
+			// TODO(bradfitz): do the RegisterIPPortIdentity and
+			// UnregisterIPPortIdentity stuff that netstack does
+			return b.forwardTCPWithProxyProtocol(conn, backConn, tcph.ProxyProtocol(), srcAddr, dport, backDst)
+		}
+	}
 
-				var destAddr netip.Addr
-				if self := b.currentNode().Self(); self.Valid() {
-					if is4 {
-						destAddr = nodeIP(self, netip.Addr.Is4)
-					} else {
-						destAddr = nodeIP(self, netip.Addr.Is6)
-					}
-				}
-				if !destAddr.IsValid() {
-					// Pick a best-effort destination address of localhost.
-					if is4 {
-						destAddr = netip.AddrFrom4([4]byte{127, 0, 0, 1})
-					} else {
-						destAddr = netip.IPv6Loopback()
-					}
-				}
+	return nil
+}
 
-				header := &proxyproto.Header{
-					Version:    byte(ver),
-					Command:    proxyproto.PROXY,
-					SourceAddr: net.TCPAddrFromAddrPort(proxySrcAddr),
-					DestinationAddr: &net.TCPAddr{
-						IP:   destAddr.AsSlice(),
-						Port: backAddr.Port,
-					},
-				}
-				if is4 {
-					header.TransportProtocol = proxyproto.TCPv4
-				} else {
-					header.TransportProtocol = proxyproto.TCPv6
-				}
-				var err error
-				proxyHeader, err = header.Format()
-				if err != nil {
-					b.logf("localbackend: failed to format proxy protocol header for port %v (from %v) to %s: %v", dport, srcAddr, backDst, err)
-				}
+// forwardTCPWithProxyProtocol forwards TCP traffic between conn and backConn,
+// optionally prepending a PROXY protocol header if proxyProtoVer > 0.
+// The srcAddr is the original client address used to build the PROXY header.
+func (b *LocalBackend) forwardTCPWithProxyProtocol(conn, backConn net.Conn, proxyProtoVer int, srcAddr netip.AddrPort, dport uint16, backDst string) error {
+	var proxyHeader []byte
+	if proxyProtoVer > 0 {
+		backAddr := backConn.RemoteAddr().(*net.TCPAddr)
+
+		// We always want to format the PROXY protocol header based on
+		// the IPv4 or IPv6-ness of the client. The SourceAddr and
+		// DestinationAddr need to match in type.
+		// If this is an IPv6-mapped IPv4 address, unmap it.
+		proxySrcAddr := srcAddr
+		if proxySrcAddr.Addr().Is4In6() {
+			proxySrcAddr = netip.AddrPortFrom(
+				proxySrcAddr.Addr().Unmap(),
+				proxySrcAddr.Port(),
+			)
+		}
+
+		is4 := proxySrcAddr.Addr().Is4()
+
+		var destAddr netip.Addr
+		if self := b.currentNode().Self(); self.Valid() {
+			if is4 {
+				destAddr = nodeIP(self, netip.Addr.Is4)
+			} else {
+				destAddr = nodeIP(self, netip.Addr.Is6)
 			}
+		}
+		if !destAddr.IsValid() {
+			// Unexpected: we couldn't determine the node's IP address.
+			// Pick a best-effort destination address of localhost.
+			if is4 {
+				destAddr = netip.AddrFrom4([4]byte{127, 0, 0, 1})
+			} else {
+				destAddr = netip.IPv6Loopback()
+			}
+		}
 
-			// TODO(bradfitz): do the RegisterIPPortIdentity and
-			// UnregisterIPPortIdentity stuff that netstack does
-			errc := make(chan error, 1)
-			go func() {
-				if len(proxyHeader) > 0 {
-					if _, err := backConn.Write(proxyHeader); err != nil {
-						errc <- err
-						backConn.Close() // to ensure that the other side gets EOF
-						return
-					}
-				}
-				_, err := io.Copy(backConn, conn)
-				errc <- err
-			}()
-			go func() {
-				_, err := io.Copy(conn, backConn)
-				errc <- err
-			}()
-			return <-errc
+		header := &proxyproto.Header{
+			Version:    byte(proxyProtoVer),
+			Command:    proxyproto.PROXY,
+			SourceAddr: net.TCPAddrFromAddrPort(proxySrcAddr),
+			DestinationAddr: &net.TCPAddr{
+				IP:   destAddr.AsSlice(),
+				Port: backAddr.Port,
+			},
+		}
+		if is4 {
+			header.TransportProtocol = proxyproto.TCPv4
+		} else {
+			header.TransportProtocol = proxyproto.TCPv6
+		}
+		var err error
+		proxyHeader, err = header.Format()
+		if err != nil {
+			b.logf("localbackend: failed to format proxy protocol header for port %v (from %v) to %s: %v", dport, srcAddr, backDst, err)
 		}
 	}
 
-	return nil
+	errc := make(chan error, 1)
+	go func() {
+		if len(proxyHeader) > 0 {
+			if _, err := backConn.Write(proxyHeader); err != nil {
+				errc <- err
+				backConn.Close()
+				return
+			}
+		}
+		_, err := io.Copy(backConn, conn)
+		errc <- err
+	}()
+	go func() {
+		_, err := io.Copy(conn, backConn)
+		errc <- err
+	}()
+	return <-errc
 }
 
 func (b *LocalBackend) getServeHandler(r *http.Request) (_ ipn.HTTPHandlerView, at string, ok bool) {