device_stack_gonet.go 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. //go:build with_gvisor
  2. package wireguard
  3. import (
  4. "context"
  5. "errors"
  6. "fmt"
  7. "net"
  8. "net/netip"
  9. "time"
  10. "github.com/sagernet/gvisor/pkg/tcpip"
  11. "github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
  12. "github.com/sagernet/gvisor/pkg/tcpip/stack"
  13. "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
  14. "github.com/sagernet/gvisor/pkg/waiter"
  15. "github.com/sagernet/sing-tun"
  16. M "github.com/sagernet/sing/common/metadata"
  17. )
  18. func DialTCPWithBind(ctx context.Context, s *stack.Stack, localAddr, remoteAddr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*gonet.TCPConn, error) {
  19. // Create TCP endpoint, then connect.
  20. var wq waiter.Queue
  21. ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq)
  22. if err != nil {
  23. return nil, errors.New(err.String())
  24. }
  25. // Create wait queue entry that notifies a channel.
  26. //
  27. // We do this unconditionally as Connect will always return an error.
  28. waitEntry, notifyCh := waiter.NewChannelEntry(waiter.WritableEvents)
  29. wq.EventRegister(&waitEntry)
  30. defer wq.EventUnregister(&waitEntry)
  31. select {
  32. case <-ctx.Done():
  33. return nil, ctx.Err()
  34. default:
  35. }
  36. // Bind before connect if requested.
  37. if localAddr != (tcpip.FullAddress{}) {
  38. if err = ep.Bind(localAddr); err != nil {
  39. return nil, fmt.Errorf("ep.Bind(%+v) = %s", localAddr, err)
  40. }
  41. }
  42. err = ep.Connect(remoteAddr)
  43. if _, ok := err.(*tcpip.ErrConnectStarted); ok {
  44. select {
  45. case <-ctx.Done():
  46. ep.Close()
  47. return nil, ctx.Err()
  48. case <-notifyCh:
  49. }
  50. err = ep.LastError()
  51. }
  52. if err != nil {
  53. ep.Close()
  54. return nil, &net.OpError{
  55. Op: "connect",
  56. Net: "tcp",
  57. Addr: M.SocksaddrFromNetIP(netip.AddrPortFrom(tun.AddrFromAddress(remoteAddr.Addr), remoteAddr.Port)).TCPAddr(),
  58. Err: errors.New(err.String()),
  59. }
  60. }
  61. // sing-box added: set keepalive
  62. ep.SocketOptions().SetKeepAlive(true)
  63. keepAliveIdle := tcpip.KeepaliveIdleOption(15 * time.Second)
  64. ep.SetSockOpt(&keepAliveIdle)
  65. keepAliveInterval := tcpip.KeepaliveIntervalOption(15 * time.Second)
  66. ep.SetSockOpt(&keepAliveInterval)
  67. return gonet.NewTCPConn(&wq, ep), nil
  68. }