|
@@ -0,0 +1,78 @@
|
|
|
|
+//go:build with_gvisor
|
|
|
|
+
|
|
|
|
+package wireguard
|
|
|
|
+
|
|
|
|
+import (
|
|
|
|
+ "context"
|
|
|
|
+ "errors"
|
|
|
|
+ "fmt"
|
|
|
|
+ "net"
|
|
|
|
+ "time"
|
|
|
|
+
|
|
|
|
+ M "github.com/sagernet/sing/common/metadata"
|
|
|
|
+
|
|
|
|
+ "gvisor.dev/gvisor/pkg/tcpip"
|
|
|
|
+ "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
|
|
|
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
|
|
|
|
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
|
|
|
+ "gvisor.dev/gvisor/pkg/waiter"
|
|
|
|
+)
|
|
|
|
+
|
|
|
|
+func DialTCPWithBind(ctx context.Context, s *stack.Stack, localAddr, remoteAddr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*gonet.TCPConn, error) {
|
|
|
|
+ // Create TCP endpoint, then connect.
|
|
|
|
+ var wq waiter.Queue
|
|
|
|
+ ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq)
|
|
|
|
+ if err != nil {
|
|
|
|
+ return nil, errors.New(err.String())
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // Create wait queue entry that notifies a channel.
|
|
|
|
+ //
|
|
|
|
+ // We do this unconditionally as Connect will always return an error.
|
|
|
|
+ waitEntry, notifyCh := waiter.NewChannelEntry(waiter.WritableEvents)
|
|
|
|
+ wq.EventRegister(&waitEntry)
|
|
|
|
+ defer wq.EventUnregister(&waitEntry)
|
|
|
|
+
|
|
|
|
+ select {
|
|
|
|
+ case <-ctx.Done():
|
|
|
|
+ return nil, ctx.Err()
|
|
|
|
+ default:
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // Bind before connect if requested.
|
|
|
|
+ if localAddr != (tcpip.FullAddress{}) {
|
|
|
|
+ if err = ep.Bind(localAddr); err != nil {
|
|
|
|
+ return nil, fmt.Errorf("ep.Bind(%+v) = %s", localAddr, err)
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ err = ep.Connect(remoteAddr)
|
|
|
|
+ if _, ok := err.(*tcpip.ErrConnectStarted); ok {
|
|
|
|
+ select {
|
|
|
|
+ case <-ctx.Done():
|
|
|
|
+ ep.Close()
|
|
|
|
+ return nil, ctx.Err()
|
|
|
|
+ case <-notifyCh:
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ err = ep.LastError()
|
|
|
|
+ }
|
|
|
|
+ if err != nil {
|
|
|
|
+ ep.Close()
|
|
|
|
+ return nil, &net.OpError{
|
|
|
|
+ Op: "connect",
|
|
|
|
+ Net: "tcp",
|
|
|
|
+ Addr: M.SocksaddrFrom(M.AddrFromIP(net.IP(remoteAddr.Addr)), remoteAddr.Port).TCPAddr(),
|
|
|
|
+ Err: errors.New(err.String()),
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // sing-box added: set keepalive
|
|
|
|
+ ep.SocketOptions().SetKeepAlive(true)
|
|
|
|
+ keepAliveIdle := tcpip.KeepaliveIdleOption(15 * time.Second)
|
|
|
|
+ ep.SetSockOpt(&keepAliveIdle)
|
|
|
|
+ keepAliveInterval := tcpip.KeepaliveIntervalOption(15 * time.Second)
|
|
|
|
+ ep.SetSockOpt(&keepAliveInterval)
|
|
|
|
+
|
|
|
|
+ return gonet.NewTCPConn(&wq, ep), nil
|
|
|
|
+}
|