Răsfoiți Sursa

Set TCP keepalive for WireGuard gVisor TCP connections

世界 2 ani în urmă
părinte
comite
f949ddc0ab
2 a modificat fișierele cu 79 adăugiri și 1 ștergeri
  1. 1 1
      transport/wireguard/device_stack.go
  2. 78 0
      transport/wireguard/gonet.go

+ 1 - 1
transport/wireguard/device_stack.go

@@ -112,7 +112,7 @@ func (w *StackDevice) DialContext(ctx context.Context, network string, destinati
 	}
 	switch N.NetworkName(network) {
 	case N.NetworkTCP:
-		tcpConn, err := gonet.DialTCPWithBind(ctx, w.stack, bind, addr, networkProtocol)
+		tcpConn, err := DialTCPWithBind(ctx, w.stack, bind, addr, networkProtocol)
 		if err != nil {
 			return nil, err
 		}

+ 78 - 0
transport/wireguard/gonet.go

@@ -0,0 +1,78 @@
+//go:build with_gvisor
+
+package wireguard
+
+import (
+	"context"
+	"errors"
+	"fmt"
+	"net"
+	"time"
+
+	M "github.com/sagernet/sing/common/metadata"
+
+	"gvisor.dev/gvisor/pkg/tcpip"
+	"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
+	"gvisor.dev/gvisor/pkg/tcpip/stack"
+	"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+	"gvisor.dev/gvisor/pkg/waiter"
+)
+
+func DialTCPWithBind(ctx context.Context, s *stack.Stack, localAddr, remoteAddr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*gonet.TCPConn, error) {
+	// Create TCP endpoint, then connect.
+	var wq waiter.Queue
+	ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq)
+	if err != nil {
+		return nil, errors.New(err.String())
+	}
+
+	// Create wait queue entry that notifies a channel.
+	//
+	// We do this unconditionally as Connect will always return an error.
+	waitEntry, notifyCh := waiter.NewChannelEntry(waiter.WritableEvents)
+	wq.EventRegister(&waitEntry)
+	defer wq.EventUnregister(&waitEntry)
+
+	select {
+	case <-ctx.Done():
+		return nil, ctx.Err()
+	default:
+	}
+
+	// Bind before connect if requested.
+	if localAddr != (tcpip.FullAddress{}) {
+		if err = ep.Bind(localAddr); err != nil {
+			return nil, fmt.Errorf("ep.Bind(%+v) = %s", localAddr, err)
+		}
+	}
+
+	err = ep.Connect(remoteAddr)
+	if _, ok := err.(*tcpip.ErrConnectStarted); ok {
+		select {
+		case <-ctx.Done():
+			ep.Close()
+			return nil, ctx.Err()
+		case <-notifyCh:
+		}
+
+		err = ep.LastError()
+	}
+	if err != nil {
+		ep.Close()
+		return nil, &net.OpError{
+			Op:   "connect",
+			Net:  "tcp",
+			Addr: M.SocksaddrFrom(M.AddrFromIP(net.IP(remoteAddr.Addr)), remoteAddr.Port).TCPAddr(),
+			Err:  errors.New(err.String()),
+		}
+	}
+
+	// sing-box added: set keepalive
+	ep.SocketOptions().SetKeepAlive(true)
+	keepAliveIdle := tcpip.KeepaliveIdleOption(15 * time.Second)
+	ep.SetSockOpt(&keepAliveIdle)
+	keepAliveInterval := tcpip.KeepaliveIntervalOption(15 * time.Second)
+	ep.SetSockOpt(&keepAliveInterval)
+
+	return gonet.NewTCPConn(&wq, ep), nil
+}