浏览代码

WireGuard Inbound (User-space WireGuard server) (#2477)

* feat: wireguard inbound

* feat(command): generate wireguard compatible keypair

* feat(wireguard): connection idle timeout

* fix(wireguard): close endpoint after connection closed

* fix(wireguard): resolve conflicts

* feat(wireguard): set cubic as default cc algorithm in gVisor TUN

* chore(wireguard): resolve conflict

* chore(wireguard): remove redurant code

* chore(wireguard): remove redurant code

* feat: rework server for gvisor tun

* feat: keep user-space tun as an option

* fix: exclude android from native tun build

* feat: auto kernel tun

* fix: build

* fix: regulate function name & fix test
hax0r31337 1 年之前
父节点
当前提交
0ac7da2fc8

+ 2 - 2
go.mod

@@ -27,6 +27,7 @@ require (
 	golang.zx2c4.com/wireguard v0.0.0-20231022001213-2e0774f246fb
 	google.golang.org/grpc v1.59.0
 	google.golang.org/protobuf v1.31.0
+	gvisor.dev/gvisor v0.0.0-20231104011432-48a6d7d5bd0b
 	h12.io/socks v1.0.3
 	lukechampine.com/blake3 v1.2.1
 )
@@ -48,7 +49,7 @@ require (
 	github.com/pmezard/go-difflib v1.0.0 // indirect
 	github.com/quic-go/qtls-go1-20 v0.4.1 // indirect
 	github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect
-	github.com/vishvananda/netns v0.0.4 // indirect
+	github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae // indirect
 	go.uber.org/mock v0.3.0 // indirect
 	golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa // indirect
 	golang.org/x/mod v0.14.0 // indirect
@@ -59,5 +60,4 @@ require (
 	google.golang.org/genproto/googleapis/rpc v0.0.0-20231106174013-bbf56f31fb17 // indirect
 	gopkg.in/yaml.v2 v2.4.0 // indirect
 	gopkg.in/yaml.v3 v3.0.1 // indirect
-	gvisor.dev/gvisor v0.0.0-20231104011432-48a6d7d5bd0b // indirect
 )

+ 1 - 2
go.sum

@@ -168,9 +168,8 @@ github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49u
 github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM=
 github.com/vishvananda/netlink v1.2.1-beta.2.0.20230316163032-ced5aaba43e3 h1:tkMT5pTye+1NlKIXETU78NXw0fyjnaNHmJyyLyzw8+U=
 github.com/vishvananda/netlink v1.2.1-beta.2.0.20230316163032-ced5aaba43e3/go.mod h1:cAAsePK2e15YDAMJNyOpGYEWNe4sIghTY7gpz4cX/Ik=
+github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae h1:4hwBBUfQCFe3Cym0ZtKyq7L16eZUtYKs+BaHDN6mAns=
 github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
-github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
-github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
 github.com/xtls/reality v0.0.0-20231112171332-de1173cf2b19 h1:capMfFYRgH9BCLd6A3Er/cH3A9Nz3CU2KwxwOQZIePI=
 github.com/xtls/reality v0.0.0-20231112171332-de1173cf2b19/go.mod h1:dm4y/1QwzjGaK17ofi0Vs6NpKAHegZky8qk6J2JJZAE=
 github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=

+ 46 - 23
infra/conf/wireguard.go

@@ -13,7 +13,7 @@ type WireGuardPeerConfig struct {
 	PublicKey    string   `json:"publicKey"`
 	PreSharedKey string   `json:"preSharedKey"`
 	Endpoint     string   `json:"endpoint"`
-	KeepAlive    int      `json:"keepAlive"`
+	KeepAlive    uint32   `json:"keepAlive"`
 	AllowedIPs   []string `json:"allowedIPs,omitempty"`
 }
 
@@ -21,9 +21,11 @@ func (c *WireGuardPeerConfig) Build() (proto.Message, error) {
 	var err error
 	config := new(wireguard.PeerConfig)
 
-	config.PublicKey, err = parseWireGuardKey(c.PublicKey)
-	if err != nil {
-		return nil, err
+	if c.PublicKey != "" {
+		config.PublicKey, err = parseWireGuardKey(c.PublicKey)
+		if err != nil {
+			return nil, err
+		}
 	}
 
 	if c.PreSharedKey != "" {
@@ -31,13 +33,11 @@ func (c *WireGuardPeerConfig) Build() (proto.Message, error) {
 		if err != nil {
 			return nil, err
 		}
-	} else {
-		config.PreSharedKey = "0000000000000000000000000000000000000000000000000000000000000000"
 	}
 
 	config.Endpoint = c.Endpoint
 	// default 0
-	config.KeepAlive = int32(c.KeepAlive)
+	config.KeepAlive = c.KeepAlive
 	if c.AllowedIPs == nil {
 		config.AllowedIps = []string{"0.0.0.0/0", "::0/0"}
 	} else {
@@ -48,11 +48,14 @@ func (c *WireGuardPeerConfig) Build() (proto.Message, error) {
 }
 
 type WireGuardConfig struct {
+	IsClient bool `json:""`
+
+	KernelMode     *bool                  `json:"kernelMode"`
 	SecretKey      string                 `json:"secretKey"`
 	Address        []string               `json:"address"`
 	Peers          []*WireGuardPeerConfig `json:"peers"`
-	MTU            int                    `json:"mtu"`
-	NumWorkers     int                    `json:"workers"`
+	MTU            int32                  `json:"mtu"`
+	NumWorkers     int32                  `json:"workers"`
 	Reserved       []byte                 `json:"reserved"`
 	DomainStrategy string                 `json:"domainStrategy"`
 }
@@ -87,11 +90,11 @@ func (c *WireGuardConfig) Build() (proto.Message, error) {
 	if c.MTU == 0 {
 		config.Mtu = 1420
 	} else {
-		config.Mtu = int32(c.MTU)
+		config.Mtu = c.MTU
 	}
-	// these a fallback code exists in github.com/nanoda0523/wireguard-go code,
+	// these a fallback code exists in wireguard-go code,
 	// we don't need to process fallback manually
-	config.NumWorkers = int32(c.NumWorkers)
+	config.NumWorkers = c.NumWorkers
 
 	if len(c.Reserved) != 0 && len(c.Reserved) != 3 {
 		return nil, newError(`"reserved" should be empty or 3 bytes`)
@@ -113,22 +116,42 @@ func (c *WireGuardConfig) Build() (proto.Message, error) {
 		return nil, newError("unsupported domain strategy: ", c.DomainStrategy)
 	}
 
+	config.IsClient = c.IsClient
+	if c.KernelMode != nil {
+		config.KernelMode = *c.KernelMode
+		if config.KernelMode && !wireguard.KernelTunSupported() {
+			newError("kernel mode is not supported on your OS or permission is insufficient").AtWarning().WriteToLog()
+		}
+	} else {
+		config.KernelMode = wireguard.KernelTunSupported()
+		if config.KernelMode {
+			newError("kernel mode is enabled as it's supported and permission is sufficient").AtDebug().WriteToLog()
+		}
+	}
+
 	return config, nil
 }
 
 func parseWireGuardKey(str string) (string, error) {
-	if len(str) != 64 {
-		// may in base64 form
-		dat, err := base64.StdEncoding.DecodeString(str)
-		if err != nil {
-			return "", err
-		}
-		if len(dat) != 32 {
-			return "", newError("key should be 32 bytes: " + str)
+	var err error
+
+	if len(str)%2 == 0 {
+		_, err = hex.DecodeString(str)
+		if err == nil {
+			return str, nil
 		}
-		return hex.EncodeToString(dat), err
+	}
+
+	var dat []byte
+	str = strings.TrimSuffix(str, "=")
+	if strings.ContainsRune(str, '+') || strings.ContainsRune(str, '/') {
+		dat, err = base64.RawStdEncoding.DecodeString(str)
 	} else {
-		// already hex form
-		return str, nil
+		dat, err = base64.RawURLEncoding.DecodeString(str)
 	}
+	if err == nil {
+		return hex.EncodeToString(dat), nil
+	}
+
+	return "", newError("failed to deserialize key").Base(err)
 }

+ 8 - 7
infra/conf/wireguard_test.go

@@ -7,7 +7,7 @@ import (
 	"github.com/xtls/xray-core/proxy/wireguard"
 )
 
-func TestWireGuardOutbound(t *testing.T) {
+func TestWireGuardConfig(t *testing.T) {
 	creator := func() Buildable {
 		return new(WireGuardConfig)
 	}
@@ -25,7 +25,8 @@ func TestWireGuardOutbound(t *testing.T) {
 				],
 				"mtu": 1300,
 				"workers": 2,
-				"domainStrategy": "ForceIPv6v4"
+				"domainStrategy": "ForceIPv6v4",
+				"kernelMode": false
 			}`,
 			Parser: loadJSON(creator),
 			Output: &wireguard.DeviceConfig{
@@ -35,16 +36,16 @@ func TestWireGuardOutbound(t *testing.T) {
 				Peers: []*wireguard.PeerConfig{
 					{
 						// also can read from hex form directly
-						PublicKey:    "6e65ce0be17517110c17d77288ad87e7fd5252dcc7d09b95a39d61db03df832a",
-						PreSharedKey: "0000000000000000000000000000000000000000000000000000000000000000",
-						Endpoint:     "127.0.0.1:1234",
-						KeepAlive:    0,
-						AllowedIps:   []string{"0.0.0.0/0", "::0/0"},
+						PublicKey:  "6e65ce0be17517110c17d77288ad87e7fd5252dcc7d09b95a39d61db03df832a",
+						Endpoint:   "127.0.0.1:1234",
+						KeepAlive:  0,
+						AllowedIps: []string{"0.0.0.0/0", "::0/0"},
 					},
 				},
 				Mtu:            1300,
 				NumWorkers:     2,
 				DomainStrategy: wireguard.DeviceConfig_FORCE_IP64,
+				KernelMode:     false,
 			},
 		},
 	})

+ 2 - 1
infra/conf/xray.go

@@ -24,6 +24,7 @@ var (
 		"vless":         func() interface{} { return new(VLessInboundConfig) },
 		"vmess":         func() interface{} { return new(VMessInboundConfig) },
 		"trojan":        func() interface{} { return new(TrojanServerConfig) },
+		"wireguard":     func() interface{} { return &WireGuardConfig{IsClient: false} },
 	}, "protocol", "settings")
 
 	outboundConfigLoader = NewJSONConfigLoader(ConfigCreatorCache{
@@ -37,7 +38,7 @@ var (
 		"vmess":       func() interface{} { return new(VMessOutboundConfig) },
 		"trojan":      func() interface{} { return new(TrojanClientConfig) },
 		"dns":         func() interface{} { return new(DNSOutboundConfig) },
-		"wireguard":   func() interface{} { return new(WireGuardConfig) },
+		"wireguard":   func() interface{} { return &WireGuardConfig{IsClient: true} },
 	}, "protocol", "settings")
 
 	ctllog = log.New(os.Stderr, "xctl> ", 0)

+ 12 - 3
main/commands/all/x25519.go

@@ -10,7 +10,7 @@ import (
 )
 
 var cmdX25519 = &base.Command{
-	UsageLine: `{{.Exec}} x25519 [-i "private key (base64.RawURLEncoding)"]`,
+	UsageLine: `{{.Exec}} x25519 [-i "private key (base64.RawURLEncoding)"] [--std-encoding]`,
 	Short:     `Generate key pair for x25519 key exchange`,
 	Long: `
 Generate key pair for x25519 key exchange.
@@ -18,6 +18,7 @@ Generate key pair for x25519 key exchange.
 Random: {{.Exec}} x25519
 
 From private key: {{.Exec}} x25519 -i "private key (base64.RawURLEncoding)"
+For Std Encoding: {{.Exec}} x25519 --std-encoding
 `,
 }
 
@@ -26,12 +27,14 @@ func init() {
 }
 
 var input_base64 = cmdX25519.Flag.String("i", "", "")
+var input_stdEncoding = cmdX25519.Flag.Bool("std-encoding", false, "")
 
 func executeX25519(cmd *base.Command, args []string) {
 	var output string
 	var err error
 	var privateKey []byte
 	var publicKey []byte
+	var encoding *base64.Encoding
 	if len(*input_base64) > 0 {
 		privateKey, err = base64.RawURLEncoding.DecodeString(*input_base64)
 		if err != nil {
@@ -63,9 +66,15 @@ func executeX25519(cmd *base.Command, args []string) {
 		goto out
 	}
 
+	if *input_stdEncoding {
+		encoding = base64.StdEncoding
+	} else {
+		encoding = base64.RawURLEncoding
+	}
+
 	output = fmt.Sprintf("Private key: %v\nPublic key: %v",
-		base64.RawURLEncoding.EncodeToString(privateKey),
-		base64.RawURLEncoding.EncodeToString(publicKey))
+		encoding.EncodeToString(privateKey),
+		encoding.EncodeToString(publicKey))
 out:
 	fmt.Println(output)
 }

+ 59 - 111
proxy/wireguard/bind.go

@@ -27,48 +27,45 @@ type netReadInfo struct {
 	err      error
 }
 
-type netBindClient struct {
-	workers   int
-	dialer    internet.Dialer
+// reduce duplicated code
+type netBind struct {
 	dns       dns.Client
 	dnsOption dns.IPOption
-	reserved  []byte
 
+	workers   int
 	readQueue chan *netReadInfo
 }
 
-func (bind *netBindClient) ParseEndpoint(s string) (conn.Endpoint, error) {
-	ipStr, port, _, err := splitAddrPort(s)
+// SetMark implements conn.Bind
+func (bind *netBind) SetMark(mark uint32) error {
+	return nil
+}
+
+// ParseEndpoint implements conn.Bind
+func (n *netBind) ParseEndpoint(s string) (conn.Endpoint, error) {
+	ipStr, port, err := net.SplitHostPort(s)
+	if err != nil {
+		return nil, err
+	}
+	portNum, err := strconv.Atoi(port)
 	if err != nil {
 		return nil, err
 	}
 
-	var addr net.IP
-	if IsDomainName(ipStr) {
-		ips, err := bind.dns.LookupIP(ipStr, bind.dnsOption)
+	addr := xnet.ParseAddress(ipStr)
+	if addr.Family() == xnet.AddressFamilyDomain {
+		ips, err := n.dns.LookupIP(addr.Domain(), n.dnsOption)
 		if err != nil {
 			return nil, err
 		} else if len(ips) == 0 {
 			return nil, dns.ErrEmptyResponse
 		}
-		addr = ips[0]
-	} else {
-		addr = net.ParseIP(ipStr)
-	}
-	if addr == nil {
-		return nil, errors.New("failed to parse ip: " + ipStr)
-	}
-
-	var ip xnet.Address
-	if p4 := addr.To4(); len(p4) == net.IPv4len {
-		ip = xnet.IPAddress(p4[:])
-	} else {
-		ip = xnet.IPAddress(addr[:])
+		addr = xnet.IPAddress(ips[0])
 	}
 
 	dst := xnet.Destination{
-		Address: ip,
-		Port:    xnet.Port(port),
+		Address: addr,
+		Port:    xnet.Port(portNum),
 		Network: xnet.Network_UDP,
 	}
 
@@ -77,7 +74,13 @@ func (bind *netBindClient) ParseEndpoint(s string) (conn.Endpoint, error) {
 	}, nil
 }
 
-func (bind *netBindClient) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
+// BatchSize implements conn.Bind
+func (bind *netBind) BatchSize() int {
+	return 1
+}
+
+// Open implements conn.Bind
+func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
 	bind.readQueue = make(chan *netReadInfo)
 
 	fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
@@ -109,13 +112,21 @@ func (bind *netBindClient) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error
 	return arr, uint16(uport), nil
 }
 
-func (bind *netBindClient) Close() error {
+// Close implements conn.Bind
+func (bind *netBind) Close() error {
 	if bind.readQueue != nil {
 		close(bind.readQueue)
 	}
 	return nil
 }
 
+type netBindClient struct {
+	netBind
+
+	dialer   internet.Dialer
+	reserved []byte
+}
+
 func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
 	c, err := bind.dialer.Dial(context.Background(), endpoint.dst)
 	if err != nil {
@@ -177,12 +188,29 @@ func (bind *netBindClient) Send(buff [][]byte, endpoint conn.Endpoint) error {
 	return nil
 }
 
-func (bind *netBindClient) SetMark(mark uint32) error {
-	return nil
+type netBindServer struct {
+	netBind
 }
 
-func (bind *netBindClient) BatchSize() int {
-	return 1
+func (bind *netBindServer) Send(buff [][]byte, endpoint conn.Endpoint) error {
+	var err error
+
+	nend, ok := endpoint.(*netEndpoint)
+	if !ok {
+		return conn.ErrWrongEndpointType
+	}
+
+	if nend.conn == nil {
+		return newError("connection not open yet")
+	}
+
+	for _, buff := range buff {
+		if _, err = nend.conn.Write(buff); err != nil {
+			return err
+		}
+	}
+
+	return err
 }
 
 type netEndpoint struct {
@@ -193,7 +221,7 @@ type netEndpoint struct {
 func (netEndpoint) ClearSrc() {}
 
 func (e netEndpoint) DstIP() netip.Addr {
-	return toNetIpAddr(e.dst.Address)
+	return netip.Addr{}
 }
 
 func (e netEndpoint) SrcIP() netip.Addr {
@@ -232,83 +260,3 @@ func toNetIpAddr(addr xnet.Address) netip.Addr {
 		return netip.AddrFrom16(arr)
 	}
 }
-
-func stringsLastIndexByte(s string, b byte) int {
-	for i := len(s) - 1; i >= 0; i-- {
-		if s[i] == b {
-			return i
-		}
-	}
-	return -1
-}
-
-func splitAddrPort(s string) (ip string, port uint16, v6 bool, err error) {
-	i := stringsLastIndexByte(s, ':')
-	if i == -1 {
-		return "", 0, false, errors.New("not an ip:port")
-	}
-
-	ip = s[:i]
-	portStr := s[i+1:]
-	if len(ip) == 0 {
-		return "", 0, false, errors.New("no IP")
-	}
-	if len(portStr) == 0 {
-		return "", 0, false, errors.New("no port")
-	}
-	port64, err := strconv.ParseUint(portStr, 10, 16)
-	if err != nil {
-		return "", 0, false, errors.New("invalid port " + strconv.Quote(portStr) + " parsing " + strconv.Quote(s))
-	}
-	port = uint16(port64)
-	if ip[0] == '[' {
-		if len(ip) < 2 || ip[len(ip)-1] != ']' {
-			return "", 0, false, errors.New("missing ]")
-		}
-		ip = ip[1 : len(ip)-1]
-		v6 = true
-	}
-
-	return ip, port, v6, nil
-}
-
-func IsDomainName(s string) bool {
-	l := len(s)
-	if l == 0 || l > 254 || l == 254 && s[l-1] != '.' {
-		return false
-	}
-	last := byte('.')
-	nonNumeric := false
-	partlen := 0
-	for i := 0; i < len(s); i++ {
-		c := s[i]
-		switch {
-		default:
-			return false
-		case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_':
-			nonNumeric = true
-			partlen++
-		case '0' <= c && c <= '9':
-			partlen++
-		case c == '-':
-			if last == '.' {
-				return false
-			}
-			partlen++
-			nonNumeric = true
-		case c == '.':
-			if last == '.' || last == '-' {
-				return false
-			}
-			if partlen > 63 || partlen == 0 {
-				return false
-			}
-			partlen = 0
-		}
-		last = c
-	}
-	if last == '-' || partlen > 63 {
-		return false
-	}
-	return nonNumeric
-}

+ 255 - 0
proxy/wireguard/client.go

@@ -0,0 +1,255 @@
+/*
+
+Some of codes are copied from https://github.com/octeep/wireproxy, license below.
+
+Copyright (c) 2022 Wind T.F. Wong <[email protected]>
+
+Permission to use, copy, modify, and distribute this software for any
+purpose with or without fee is hereby granted, provided that the above
+copyright notice and this permission notice appear in all copies.
+
+THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
+WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
+ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
+OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+*/
+
+package wireguard
+
+import (
+	"context"
+	"net/netip"
+	"sync"
+
+	"github.com/xtls/xray-core/common"
+	"github.com/xtls/xray-core/common/buf"
+	"github.com/xtls/xray-core/common/dice"
+	"github.com/xtls/xray-core/common/log"
+	"github.com/xtls/xray-core/common/net"
+	"github.com/xtls/xray-core/common/protocol"
+	"github.com/xtls/xray-core/common/session"
+	"github.com/xtls/xray-core/common/signal"
+	"github.com/xtls/xray-core/common/task"
+	"github.com/xtls/xray-core/core"
+	"github.com/xtls/xray-core/features/dns"
+	"github.com/xtls/xray-core/features/policy"
+	"github.com/xtls/xray-core/transport"
+	"github.com/xtls/xray-core/transport/internet"
+)
+
+// Handler is an outbound connection that silently swallow the entire payload.
+type Handler struct {
+	conf          *DeviceConfig
+	net           Tunnel
+	bind          *netBindClient
+	policyManager policy.Manager
+	dns           dns.Client
+	// cached configuration
+	ipc              string
+	endpoints        []netip.Addr
+	hasIPv4, hasIPv6 bool
+	wgLock           sync.Mutex
+}
+
+// New creates a new wireguard handler.
+func New(ctx context.Context, conf *DeviceConfig) (*Handler, error) {
+	v := core.MustFromContext(ctx)
+
+	endpoints, hasIPv4, hasIPv6, err := parseEndpoints(conf)
+	if err != nil {
+		return nil, err
+	}
+
+	d := v.GetFeature(dns.ClientType()).(dns.Client)
+	return &Handler{
+		conf:          conf,
+		policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
+		dns:           d,
+		ipc:           createIPCRequest(conf),
+		endpoints:     endpoints,
+		hasIPv4:       hasIPv4,
+		hasIPv6:       hasIPv6,
+	}, nil
+}
+
+func (h *Handler) processWireGuard(dialer internet.Dialer) (err error) {
+	h.wgLock.Lock()
+	defer h.wgLock.Unlock()
+
+	if h.bind != nil && h.bind.dialer == dialer && h.net != nil {
+		return nil
+	}
+
+	log.Record(&log.GeneralMessage{
+		Severity: log.Severity_Info,
+		Content:  "switching dialer",
+	})
+
+	if h.net != nil {
+		_ = h.net.Close()
+		h.net = nil
+	}
+	if h.bind != nil {
+		_ = h.bind.Close()
+		h.bind = nil
+	}
+
+	// bind := conn.NewStdNetBind() // TODO: conn.Bind wrapper for dialer
+	bind := &netBindClient{
+		netBind: netBind{
+			dns: h.dns,
+			dnsOption: dns.IPOption{
+				IPv4Enable: h.hasIPv4,
+				IPv6Enable: h.hasIPv6,
+			},
+			workers: int(h.conf.NumWorkers),
+		},
+		dialer:   dialer,
+		reserved: h.conf.Reserved,
+	}
+	defer func() {
+		if err != nil {
+			_ = bind.Close()
+		}
+	}()
+
+	h.net, err = h.makeVirtualTun(bind)
+	if err != nil {
+		return newError("failed to create virtual tun interface").Base(err)
+	}
+	h.bind = bind
+	return nil
+}
+
+// Process implements OutboundHandler.Dispatch().
+func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
+	outbound := session.OutboundFromContext(ctx)
+	if outbound == nil || !outbound.Target.IsValid() {
+		return newError("target not specified")
+	}
+	outbound.Name = "wireguard"
+	inbound := session.InboundFromContext(ctx)
+	if inbound != nil {
+		inbound.SetCanSpliceCopy(3)
+	}
+
+	if err := h.processWireGuard(dialer); err != nil {
+		return err
+	}
+
+	// Destination of the inner request.
+	destination := outbound.Target
+	command := protocol.RequestCommandTCP
+	if destination.Network == net.Network_UDP {
+		command = protocol.RequestCommandUDP
+	}
+
+	// resolve dns
+	addr := destination.Address
+	if addr.Family().IsDomain() {
+		ips, err := h.dns.LookupIP(addr.Domain(), dns.IPOption{
+			IPv4Enable: h.hasIPv4 && h.conf.preferIP4(),
+			IPv6Enable: h.hasIPv6 && h.conf.preferIP6(),
+		})
+		{ // Resolve fallback
+			if (len(ips) == 0 || err != nil) && h.conf.hasFallback() {
+				ips, err = h.dns.LookupIP(addr.Domain(), dns.IPOption{
+					IPv4Enable: h.hasIPv4 && h.conf.fallbackIP4(),
+					IPv6Enable: h.hasIPv6 && h.conf.fallbackIP6(),
+				})
+			}
+		}
+		if err != nil {
+			return newError("failed to lookup DNS").Base(err)
+		} else if len(ips) == 0 {
+			return dns.ErrEmptyResponse
+		}
+		addr = net.IPAddress(ips[dice.Roll(len(ips))])
+	}
+
+	var newCtx context.Context
+	var newCancel context.CancelFunc
+	if session.TimeoutOnlyFromContext(ctx) {
+		newCtx, newCancel = context.WithCancel(context.Background())
+	}
+
+	p := h.policyManager.ForLevel(0)
+
+	ctx, cancel := context.WithCancel(ctx)
+	timer := signal.CancelAfterInactivity(ctx, func() {
+		cancel()
+		if newCancel != nil {
+			newCancel()
+		}
+	}, p.Timeouts.ConnectionIdle)
+	addrPort := netip.AddrPortFrom(toNetIpAddr(addr), destination.Port.Value())
+
+	var requestFunc func() error
+	var responseFunc func() error
+
+	if command == protocol.RequestCommandTCP {
+		conn, err := h.net.DialContextTCPAddrPort(ctx, addrPort)
+		if err != nil {
+			return newError("failed to create TCP connection").Base(err)
+		}
+		defer conn.Close()
+
+		requestFunc = func() error {
+			defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
+			return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
+		}
+		responseFunc = func() error {
+			defer timer.SetTimeout(p.Timeouts.UplinkOnly)
+			return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
+		}
+	} else if command == protocol.RequestCommandUDP {
+		conn, err := h.net.DialUDPAddrPort(netip.AddrPort{}, addrPort)
+		if err != nil {
+			return newError("failed to create UDP connection").Base(err)
+		}
+		defer conn.Close()
+
+		requestFunc = func() error {
+			defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
+			return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
+		}
+		responseFunc = func() error {
+			defer timer.SetTimeout(p.Timeouts.UplinkOnly)
+			return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
+		}
+	}
+
+	if newCtx != nil {
+		ctx = newCtx
+	}
+
+	responseDonePost := task.OnSuccess(responseFunc, task.Close(link.Writer))
+	if err := task.Run(ctx, requestFunc, responseDonePost); err != nil {
+		common.Interrupt(link.Reader)
+		common.Interrupt(link.Writer)
+		return newError("connection ends").Base(err)
+	}
+
+	return nil
+}
+
+// creates a tun interface on netstack given a configuration
+func (h *Handler) makeVirtualTun(bind *netBindClient) (Tunnel, error) {
+	t, err := h.conf.createTun()(h.endpoints, int(h.conf.Mtu), nil)
+	if err != nil {
+		return nil, err
+	}
+
+	bind.dnsOption.IPv4Enable = h.hasIPv4
+	bind.dnsOption.IPv6Enable = h.hasIPv6
+
+	if err = t.BuildDevice(h.ipc, bind); err != nil {
+		_ = t.Close()
+		return nil, err
+	}
+	return t, nil
+}

+ 7 - 0
proxy/wireguard/config.go

@@ -23,3 +23,10 @@ func (c *DeviceConfig) fallbackIP4() bool {
 func (c *DeviceConfig) fallbackIP6() bool {
 	return c.DomainStrategy == DeviceConfig_FORCE_IP46
 }
+
+func (c *DeviceConfig) createTun() tunCreator {
+	if c.KernelMode {
+		return createKernelTun
+	}
+	return createGVisorTun
+}

+ 39 - 19
proxy/wireguard/config.pb.go

@@ -1,7 +1,7 @@
 // Code generated by protoc-gen-go. DO NOT EDIT.
 // versions:
-// 	protoc-gen-go v1.31.0
-// 	protoc        v4.23.1
+// 	protoc-gen-go v1.28.1
+// 	protoc        v4.25.0
 // source: proxy/wireguard/config.proto
 
 package wireguard
@@ -83,7 +83,7 @@ type PeerConfig struct {
 	PublicKey    string   `protobuf:"bytes,1,opt,name=public_key,json=publicKey,proto3" json:"public_key,omitempty"`
 	PreSharedKey string   `protobuf:"bytes,2,opt,name=pre_shared_key,json=preSharedKey,proto3" json:"pre_shared_key,omitempty"`
 	Endpoint     string   `protobuf:"bytes,3,opt,name=endpoint,proto3" json:"endpoint,omitempty"`
-	KeepAlive    int32    `protobuf:"varint,4,opt,name=keep_alive,json=keepAlive,proto3" json:"keep_alive,omitempty"`
+	KeepAlive    uint32   `protobuf:"varint,4,opt,name=keep_alive,json=keepAlive,proto3" json:"keep_alive,omitempty"`
 	AllowedIps   []string `protobuf:"bytes,5,rep,name=allowed_ips,json=allowedIps,proto3" json:"allowed_ips,omitempty"`
 }
 
@@ -140,7 +140,7 @@ func (x *PeerConfig) GetEndpoint() string {
 	return ""
 }
 
-func (x *PeerConfig) GetKeepAlive() int32 {
+func (x *PeerConfig) GetKeepAlive() uint32 {
 	if x != nil {
 		return x.KeepAlive
 	}
@@ -166,6 +166,8 @@ type DeviceConfig struct {
 	NumWorkers     int32                       `protobuf:"varint,5,opt,name=num_workers,json=numWorkers,proto3" json:"num_workers,omitempty"`
 	Reserved       []byte                      `protobuf:"bytes,6,opt,name=reserved,proto3" json:"reserved,omitempty"`
 	DomainStrategy DeviceConfig_DomainStrategy `protobuf:"varint,7,opt,name=domain_strategy,json=domainStrategy,proto3,enum=xray.proxy.wireguard.DeviceConfig_DomainStrategy" json:"domain_strategy,omitempty"`
+	IsClient       bool                        `protobuf:"varint,8,opt,name=is_client,json=isClient,proto3" json:"is_client,omitempty"`
+	KernelMode     bool                        `protobuf:"varint,9,opt,name=kernel_mode,json=kernelMode,proto3" json:"kernel_mode,omitempty"`
 }
 
 func (x *DeviceConfig) Reset() {
@@ -249,6 +251,20 @@ func (x *DeviceConfig) GetDomainStrategy() DeviceConfig_DomainStrategy {
 	return DeviceConfig_FORCE_IP
 }
 
+func (x *DeviceConfig) GetIsClient() bool {
+	if x != nil {
+		return x.IsClient
+	}
+	return false
+}
+
+func (x *DeviceConfig) GetKernelMode() bool {
+	if x != nil {
+		return x.KernelMode
+	}
+	return false
+}
+
 var File_proxy_wireguard_config_proto protoreflect.FileDescriptor
 
 var file_proxy_wireguard_config_proto_rawDesc = []byte{
@@ -263,10 +279,10 @@ var file_proxy_wireguard_config_proto_rawDesc = []byte{
 	0x68, 0x61, 0x72, 0x65, 0x64, 0x4b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x65, 0x6e, 0x64, 0x70,
 	0x6f, 0x69, 0x6e, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x65, 0x6e, 0x64, 0x70,
 	0x6f, 0x69, 0x6e, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x6b, 0x65, 0x65, 0x70, 0x5f, 0x61, 0x6c, 0x69,
-	0x76, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x05, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x41, 0x6c,
+	0x76, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x41, 0x6c,
 	0x69, 0x76, 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x5f, 0x69,
 	0x70, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65,
-	0x64, 0x49, 0x70, 0x73, 0x22, 0x8a, 0x03, 0x0a, 0x0c, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x43,
+	0x64, 0x49, 0x70, 0x73, 0x22, 0xc8, 0x03, 0x0a, 0x0c, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x43,
 	0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x65, 0x63, 0x72, 0x65, 0x74, 0x5f,
 	0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x65, 0x63, 0x72, 0x65,
 	0x74, 0x4b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x65, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74,
@@ -285,19 +301,23 @@ var file_proxy_wireguard_config_proto_rawDesc = []byte{
 	0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x43, 0x6f,
 	0x6e, 0x66, 0x69, 0x67, 0x2e, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x53, 0x74, 0x72, 0x61, 0x74,
 	0x65, 0x67, 0x79, 0x52, 0x0e, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x53, 0x74, 0x72, 0x61, 0x74,
-	0x65, 0x67, 0x79, 0x22, 0x5c, 0x0a, 0x0e, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x53, 0x74, 0x72,
-	0x61, 0x74, 0x65, 0x67, 0x79, 0x12, 0x0c, 0x0a, 0x08, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49,
-	0x50, 0x10, 0x00, 0x12, 0x0d, 0x0a, 0x09, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x34,
-	0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x36, 0x10,
-	0x02, 0x12, 0x0e, 0x0a, 0x0a, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x34, 0x36, 0x10,
-	0x03, 0x12, 0x0e, 0x0a, 0x0a, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x36, 0x34, 0x10,
-	0x04, 0x42, 0x5e, 0x0a, 0x18, 0x63, 0x6f, 0x6d, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x70, 0x72,
-	0x6f, 0x78, 0x79, 0x2e, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x01, 0x5a,
-	0x29, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x78, 0x74, 0x6c, 0x73,
-	0x2f, 0x78, 0x72, 0x61, 0x79, 0x2d, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x78, 0x79,
-	0x2f, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0xaa, 0x02, 0x14, 0x58, 0x72, 0x61,
-	0x79, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x57, 0x69, 0x72, 0x65, 0x47, 0x75, 0x61, 0x72,
-	0x64, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
+	0x65, 0x67, 0x79, 0x12, 0x1b, 0x0a, 0x09, 0x69, 0x73, 0x5f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74,
+	0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x69, 0x73, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74,
+	0x12, 0x1f, 0x0a, 0x0b, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x5f, 0x6d, 0x6f, 0x64, 0x65, 0x18,
+	0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x4d, 0x6f, 0x64,
+	0x65, 0x22, 0x5c, 0x0a, 0x0e, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x53, 0x74, 0x72, 0x61, 0x74,
+	0x65, 0x67, 0x79, 0x12, 0x0c, 0x0a, 0x08, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x10,
+	0x00, 0x12, 0x0d, 0x0a, 0x09, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x34, 0x10, 0x01,
+	0x12, 0x0d, 0x0a, 0x09, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x36, 0x10, 0x02, 0x12,
+	0x0e, 0x0a, 0x0a, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x34, 0x36, 0x10, 0x03, 0x12,
+	0x0e, 0x0a, 0x0a, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x36, 0x34, 0x10, 0x04, 0x42,
+	0x5e, 0x0a, 0x18, 0x63, 0x6f, 0x6d, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x78,
+	0x79, 0x2e, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x01, 0x5a, 0x29, 0x67,
+	0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x78, 0x74, 0x6c, 0x73, 0x2f, 0x78,
+	0x72, 0x61, 0x79, 0x2d, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x2f, 0x77,
+	0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0xaa, 0x02, 0x14, 0x58, 0x72, 0x61, 0x79, 0x2e,
+	0x50, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x57, 0x69, 0x72, 0x65, 0x47, 0x75, 0x61, 0x72, 0x64, 0x62,
+	0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
 }
 
 var (

+ 21 - 19
proxy/wireguard/config.proto

@@ -7,26 +7,28 @@ option java_package = "com.xray.proxy.wireguard";
 option java_multiple_files = true;
 
 message PeerConfig {
-    string public_key = 1;
-    string pre_shared_key = 2;
-    string endpoint = 3;
-    int32 keep_alive = 4;
-    repeated string allowed_ips = 5;
+  string public_key = 1;
+  string pre_shared_key = 2;
+  string endpoint = 3;
+  uint32 keep_alive = 4;
+  repeated string allowed_ips = 5;
 }
 
 message DeviceConfig {
-    enum DomainStrategy {
-        FORCE_IP = 0;
-        FORCE_IP4 = 1;
-        FORCE_IP6 = 2;
-        FORCE_IP46 = 3;
-        FORCE_IP64 = 4;
-    }
-    string secret_key = 1;
-    repeated string endpoint = 2;
-    repeated PeerConfig peers = 3;
-    int32 mtu = 4;
-    int32 num_workers = 5;
-    bytes reserved = 6;
-    DomainStrategy domain_strategy = 7;
+  enum DomainStrategy {
+    FORCE_IP = 0;
+    FORCE_IP4 = 1;
+    FORCE_IP6 = 2;
+    FORCE_IP46 = 3;
+    FORCE_IP64 = 4;
+  }
+  string secret_key = 1;
+  repeated string endpoint = 2;
+  repeated PeerConfig peers = 3;
+  int32 mtu = 4;
+  int32 num_workers = 5;
+  bytes reserved = 6;
+  DomainStrategy domain_strategy = 7;
+  bool is_client = 8;
+  bool kernel_mode = 9;
 }

+ 230 - 0
proxy/wireguard/gvisortun/tun.go

@@ -0,0 +1,230 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2022 WireGuard LLC. All Rights Reserved.
+ */
+
+package gvisortun
+
+import (
+	"context"
+	"fmt"
+	"net/netip"
+	"os"
+	"syscall"
+
+	"golang.zx2c4.com/wireguard/tun"
+	"gvisor.dev/gvisor/pkg/buffer"
+	"gvisor.dev/gvisor/pkg/tcpip"
+	"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
+	"gvisor.dev/gvisor/pkg/tcpip/header"
+	"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+	"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+	"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+	"gvisor.dev/gvisor/pkg/tcpip/stack"
+	"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+	"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+	"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+)
+
+type netTun struct {
+	ep             *channel.Endpoint
+	stack          *stack.Stack
+	events         chan tun.Event
+	incomingPacket chan *buffer.View
+	mtu            int
+	hasV4, hasV6   bool
+}
+
+type Net netTun
+
+func CreateNetTUN(localAddresses []netip.Addr, mtu int, promiscuousMode bool) (tun.Device, *Net, *stack.Stack, error) {
+	opts := stack.Options{
+		NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+		TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
+		HandleLocal:        !promiscuousMode,
+	}
+	dev := &netTun{
+		ep:             channel.New(1024, uint32(mtu), ""),
+		stack:          stack.New(opts),
+		events:         make(chan tun.Event, 1),
+		incomingPacket: make(chan *buffer.View),
+		mtu:            mtu,
+	}
+	dev.ep.AddNotify(dev)
+	tcpipErr := dev.stack.CreateNIC(1, dev.ep)
+	if tcpipErr != nil {
+		return nil, nil, dev.stack, fmt.Errorf("CreateNIC: %v", tcpipErr)
+	}
+	for _, ip := range localAddresses {
+		var protoNumber tcpip.NetworkProtocolNumber
+		if ip.Is4() {
+			protoNumber = ipv4.ProtocolNumber
+		} else if ip.Is6() {
+			protoNumber = ipv6.ProtocolNumber
+		}
+		protoAddr := tcpip.ProtocolAddress{
+			Protocol:          protoNumber,
+			AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(),
+		}
+		tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
+		if tcpipErr != nil {
+			return nil, nil, dev.stack, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
+		}
+		if ip.Is4() {
+			dev.hasV4 = true
+		} else if ip.Is6() {
+			dev.hasV6 = true
+		}
+	}
+	if dev.hasV4 {
+		dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1})
+	}
+	if dev.hasV6 {
+		dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1})
+	}
+	if promiscuousMode {
+		// enable promiscuous mode to handle all packets processed by netstack
+		dev.stack.SetPromiscuousMode(1, true)
+		dev.stack.SetSpoofing(1, true)
+	}
+
+	opt := tcpip.CongestionControlOption("cubic")
+	if err := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+		return nil, nil, dev.stack, fmt.Errorf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, opt, err)
+	}
+
+	dev.events <- tun.EventUp
+	return dev, (*Net)(dev), dev.stack, nil
+}
+
+// BatchSize implements tun.Device
+func (tun *netTun) BatchSize() int {
+	return 1
+}
+
+// Name implements tun.Device
+func (tun *netTun) Name() (string, error) {
+	return "go", nil
+}
+
+// File implements tun.Device
+func (tun *netTun) File() *os.File {
+	return nil
+}
+
+// Events implements tun.Device
+func (tun *netTun) Events() <-chan tun.Event {
+	return tun.events
+}
+
+// Read implements tun.Device
+
+func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) {
+	view, ok := <-tun.incomingPacket
+	if !ok {
+		return 0, os.ErrClosed
+	}
+
+	n, err := view.Read(buf[0][offset:])
+	if err != nil {
+		return 0, err
+	}
+	sizes[0] = n
+	return 1, nil
+}
+
+// Write implements tun.Device
+func (tun *netTun) Write(buf [][]byte, offset int) (int, error) {
+	for _, buf := range buf {
+		packet := buf[offset:]
+		if len(packet) == 0 {
+			continue
+		}
+
+		pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)})
+		switch packet[0] >> 4 {
+		case 4:
+			tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
+		case 6:
+			tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb)
+		default:
+			return 0, syscall.EAFNOSUPPORT
+		}
+	}
+	return len(buf), nil
+}
+
+// WriteNotify implements channel.Notification
+func (tun *netTun) WriteNotify() {
+	pkt := tun.ep.Read()
+	if pkt.IsNil() {
+		return
+	}
+
+	view := pkt.ToView()
+	pkt.DecRef()
+
+	tun.incomingPacket <- view
+}
+
+// Flush  implements tun.Device
+func (tun *netTun) Flush() error {
+	return nil
+}
+
+// Close implements tun.Device
+func (tun *netTun) Close() error {
+	tun.stack.RemoveNIC(1)
+
+	if tun.events != nil {
+		close(tun.events)
+	}
+
+	tun.ep.Close()
+
+	if tun.incomingPacket != nil {
+		close(tun.incomingPacket)
+	}
+
+	return nil
+}
+
+// MTU  implements tun.Device
+func (tun *netTun) MTU() (int, error) {
+	return tun.mtu, nil
+}
+
+func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
+	var protoNumber tcpip.NetworkProtocolNumber
+	if endpoint.Addr().Is4() {
+		protoNumber = ipv4.ProtocolNumber
+	} else {
+		protoNumber = ipv6.ProtocolNumber
+	}
+	return tcpip.FullAddress{
+		NIC:  1,
+		Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()),
+		Port: endpoint.Port(),
+	}, protoNumber
+}
+
+func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) {
+	fa, pn := convertToFullAddr(addr)
+	return gonet.DialContextTCP(ctx, net.stack, fa, pn)
+}
+
+func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) {
+	var lfa, rfa *tcpip.FullAddress
+	var pn tcpip.NetworkProtocolNumber
+	if laddr.IsValid() || laddr.Port() > 0 {
+		var addr tcpip.FullAddress
+		addr, pn = convertToFullAddr(laddr)
+		lfa = &addr
+	}
+	if raddr.IsValid() || raddr.Port() > 0 {
+		var addr tcpip.FullAddress
+		addr, pn = convertToFullAddr(raddr)
+		rfa = &addr
+	}
+	return gonet.DialUDP(net.stack, lfa, rfa, pn)
+}

+ 181 - 0
proxy/wireguard/server.go

@@ -0,0 +1,181 @@
+package wireguard
+
+import (
+	"context"
+	"errors"
+	"io"
+
+	"github.com/xtls/xray-core/common"
+	"github.com/xtls/xray-core/common/buf"
+	"github.com/xtls/xray-core/common/log"
+	"github.com/xtls/xray-core/common/net"
+	"github.com/xtls/xray-core/common/session"
+	"github.com/xtls/xray-core/common/signal"
+	"github.com/xtls/xray-core/common/task"
+	"github.com/xtls/xray-core/core"
+	"github.com/xtls/xray-core/features/dns"
+	"github.com/xtls/xray-core/features/policy"
+	"github.com/xtls/xray-core/features/routing"
+	"github.com/xtls/xray-core/transport/internet/stat"
+)
+
+var nullDestination = net.TCPDestination(net.AnyIP, 0)
+
+type Server struct {
+	bindServer *netBindServer
+
+	info          routingInfo
+	policyManager policy.Manager
+}
+
+type routingInfo struct {
+	ctx         context.Context
+	dispatcher  routing.Dispatcher
+	inboundTag  *session.Inbound
+	outboundTag *session.Outbound
+	contentTag  *session.Content
+}
+
+func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) {
+	v := core.MustFromContext(ctx)
+
+	endpoints, hasIPv4, hasIPv6, err := parseEndpoints(conf)
+	if err != nil {
+		return nil, err
+	}
+
+	server := &Server{
+		bindServer: &netBindServer{
+			netBind: netBind{
+				dns: v.GetFeature(dns.ClientType()).(dns.Client),
+				dnsOption: dns.IPOption{
+					IPv4Enable: hasIPv4,
+					IPv6Enable: hasIPv6,
+				},
+			},
+		},
+		policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
+	}
+
+	tun, err := conf.createTun()(endpoints, int(conf.Mtu), server.forwardConnection)
+	if err != nil {
+		return nil, err
+	}
+
+	if err = tun.BuildDevice(createIPCRequest(conf), server.bindServer); err != nil {
+		_ = tun.Close()
+		return nil, err
+	}
+
+	return server, nil
+}
+
+// Network implements proxy.Inbound.
+func (*Server) Network() []net.Network {
+	return []net.Network{net.Network_UDP}
+}
+
+// Process implements proxy.Inbound.
+func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
+	s.info = routingInfo{
+		ctx:         core.ToBackgroundDetachedContext(ctx),
+		dispatcher:  dispatcher,
+		inboundTag:  session.InboundFromContext(ctx),
+		outboundTag: session.OutboundFromContext(ctx),
+		contentTag:  session.ContentFromContext(ctx),
+	}
+
+	ep, err := s.bindServer.ParseEndpoint(conn.RemoteAddr().String())
+	if err != nil {
+		return err
+	}
+
+	nep := ep.(*netEndpoint)
+	nep.conn = conn
+
+	reader := buf.NewPacketReader(conn)
+	for {
+		mpayload, err := reader.ReadMultiBuffer()
+		if err != nil {
+			return err
+		}
+
+		for _, payload := range mpayload {
+			v, ok := <-s.bindServer.readQueue
+			if !ok {
+				return nil
+			}
+			i, err := payload.Read(v.buff)
+
+			v.bytes = i
+			v.endpoint = nep
+			v.err = err
+			v.waiter.Done()
+			if err != nil && errors.Is(err, io.EOF) {
+				nep.conn = nil
+				return nil
+			}
+		}
+	}
+}
+
+func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
+	if s.info.dispatcher == nil {
+		newError("unexpected: dispatcher == nil").AtError().WriteToLog()
+		return
+	}
+	defer conn.Close()
+
+	ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(s.info.ctx))
+	plcy := s.policyManager.ForLevel(0)
+	timer := signal.CancelAfterInactivity(ctx, cancel, plcy.Timeouts.ConnectionIdle)
+
+	ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
+		From:   nullDestination,
+		To:     dest,
+		Status: log.AccessAccepted,
+		Reason: "",
+	})
+
+	if s.info.inboundTag != nil {
+		ctx = session.ContextWithInbound(ctx, s.info.inboundTag)
+	}
+	if s.info.outboundTag != nil {
+		ctx = session.ContextWithOutbound(ctx, s.info.outboundTag)
+	}
+	if s.info.contentTag != nil {
+		ctx = session.ContextWithContent(ctx, s.info.contentTag)
+	}
+
+	link, err := s.info.dispatcher.Dispatch(ctx, dest)
+	if err != nil {
+		newError("dispatch connection").Base(err).AtError().WriteToLog()
+	}
+	defer cancel()
+
+	requestDone := func() error {
+		defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
+		if err := buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)); err != nil {
+			return newError("failed to transport all TCP request").Base(err)
+		}
+
+		return nil
+	}
+
+	responseDone := func() error {
+		defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
+		if err := buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)); err != nil {
+			return newError("failed to transport all TCP response").Base(err)
+		}
+
+		return nil
+	}
+
+	requestDonePost := task.OnSuccess(requestDone, task.Close(link.Writer))
+	if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
+		common.Interrupt(link.Reader)
+		common.Interrupt(link.Writer)
+		newError("connection ends").Base(err).AtDebug().WriteToLog()
+		return
+	}
+}

+ 100 - 0
proxy/wireguard/tun.go

@@ -10,14 +10,26 @@ import (
 	"strconv"
 	"strings"
 	"sync"
+	"time"
 
 	"github.com/xtls/xray-core/common/log"
+	xnet "github.com/xtls/xray-core/common/net"
+	"github.com/xtls/xray-core/proxy/wireguard/gvisortun"
+	"gvisor.dev/gvisor/pkg/tcpip"
+	"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
+	"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+	"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+	"gvisor.dev/gvisor/pkg/waiter"
 
 	"golang.zx2c4.com/wireguard/conn"
 	"golang.zx2c4.com/wireguard/device"
 	"golang.zx2c4.com/wireguard/tun"
 )
 
+type tunCreator func(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (Tunnel, error)
+
+type promiscuousModeHandler func(dest xnet.Destination, conn net.Conn)
+
 type Tunnel interface {
 	BuildDevice(ipc string, bind conn.Bind) error
 	DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (net.Conn, error)
@@ -103,3 +115,91 @@ func CalculateInterfaceName(name string) (tunName string) {
 	tunName = fmt.Sprintf("%s%d", tunName, tunIndex)
 	return
 }
+
+var _ Tunnel = (*gvisorNet)(nil)
+
+type gvisorNet struct {
+	tunnel
+	net *gvisortun.Net
+}
+
+func (g *gvisorNet) Close() error {
+	return g.tunnel.Close()
+}
+
+func (g *gvisorNet) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (
+	net.Conn, error,
+) {
+	return g.net.DialContextTCPAddrPort(ctx, addr)
+}
+
+func (g *gvisorNet) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error) {
+	return g.net.DialUDPAddrPort(laddr, raddr)
+}
+
+func createGVisorTun(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (Tunnel, error) {
+	out := &gvisorNet{}
+	tun, n, stack, err := gvisortun.CreateNetTUN(localAddresses, mtu, handler != nil)
+	if err != nil {
+		return nil, err
+	}
+
+	if handler != nil {
+		// handler is only used for promiscuous mode
+		// capture all packets and send to handler
+
+		tcpForwarder := tcp.NewForwarder(stack, 0, 65535, func(r *tcp.ForwarderRequest) {
+			go func(r *tcp.ForwarderRequest) {
+				var (
+					wq waiter.Queue
+					id = r.ID()
+				)
+
+				// Perform a TCP three-way handshake.
+				ep, err := r.CreateEndpoint(&wq)
+				if err != nil {
+					newError(err.String()).AtError().WriteToLog()
+					r.Complete(true)
+					return
+				}
+				r.Complete(false)
+				defer ep.Close()
+
+				// enable tcp keep-alive to prevent hanging connections
+				ep.SocketOptions().SetKeepAlive(true)
+
+				// local address is actually destination
+				handler(xnet.TCPDestination(xnet.IPAddress(id.LocalAddress.AsSlice()), xnet.Port(id.LocalPort)), gonet.NewTCPConn(&wq, ep))
+			}(r)
+		})
+		stack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
+
+		udpForwarder := udp.NewForwarder(stack, func(r *udp.ForwarderRequest) {
+			go func(r *udp.ForwarderRequest) {
+				var (
+					wq waiter.Queue
+					id = r.ID()
+				)
+
+				ep, err := r.CreateEndpoint(&wq)
+				if err != nil {
+					newError(err.String()).AtError().WriteToLog()
+					return
+				}
+				defer ep.Close()
+
+				// prevents hanging connections and ensure timely release
+				ep.SocketOptions().SetLinger(tcpip.LingerOption{
+					Enabled: true,
+					Timeout: 15 * time.Second,
+				})
+
+				handler(xnet.UDPDestination(xnet.IPAddress(id.LocalAddress.AsSlice()), xnet.Port(id.LocalPort)), gonet.NewUDPConn(stack, &wq, ep))
+			}(r)
+		})
+		stack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
+	}
+
+	out.tun, out.net = tun, n
+	return out, nil
+}

+ 6 - 32
proxy/wireguard/tun_default.go

@@ -1,42 +1,16 @@
-//go:build !linux
+//go:build !linux || android
 
 package wireguard
 
 import (
-	"context"
-	"net"
+	"errors"
 	"net/netip"
-
-	"golang.zx2c4.com/wireguard/tun/netstack"
 )
 
-var _ Tunnel = (*gvisorNet)(nil)
-
-type gvisorNet struct {
-	tunnel
-	net *netstack.Net
-}
-
-func (g *gvisorNet) Close() error {
-	return g.tunnel.Close()
-}
-
-func (g *gvisorNet) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (
-	net.Conn, error,
-) {
-	return g.net.DialContextTCPAddrPort(ctx, addr)
-}
-
-func (g *gvisorNet) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error) {
-	return g.net.DialUDPAddrPort(laddr, raddr)
+func createKernelTun(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (t Tunnel, err error) {
+	return nil, errors.New("not implemented")
 }
 
-func CreateTun(localAddresses []netip.Addr, mtu int) (Tunnel, error) {
-	out := &gvisorNet{}
-	tun, n, err := netstack.CreateNetTUN(localAddresses, nil, mtu)
-	if err != nil {
-		return nil, err
-	}
-	out.tun, out.net = tun, n
-	return out, nil
+func KernelTunSupported() bool {
+	return false
 }

+ 15 - 1
proxy/wireguard/tun_linux.go

@@ -1,3 +1,5 @@
+//go:build linux && !android
+
 package wireguard
 
 import (
@@ -69,7 +71,11 @@ func (d *deviceNet) Close() (err error) {
 	return errors.Join(errs...)
 }
 
-func CreateTun(localAddresses []netip.Addr, mtu int) (t Tunnel, err error) {
+func createKernelTun(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (t Tunnel, err error) {
+	if handler != nil {
+		return nil, newError("TODO: support promiscuous mode")
+	}
+
 	var v4, v6 *netip.Addr
 	for _, prefixes := range localAddresses {
 		if v4 == nil && prefixes.Is4() {
@@ -221,3 +227,11 @@ func CreateTun(localAddresses []netip.Addr, mtu int) (t Tunnel, err error) {
 	out.tun = wgt
 	return out, nil
 }
+
+func KernelTunSupported() bool {
+	// run a superuser permission check to check
+	// if the current user has the sufficient permission
+	// to create a tun device.
+
+	return unix.Geteuid() == 0 // 0 means root
+}

+ 64 - 279
proxy/wireguard/wireguard.go

@@ -1,326 +1,111 @@
-/*
-
-Some of codes are copied from https://github.com/octeep/wireproxy, license below.
-
-Copyright (c) 2022 Wind T.F. Wong <[email protected]>
-
-Permission to use, copy, modify, and distribute this software for any
-purpose with or without fee is hereby granted, provided that the above
-copyright notice and this permission notice appear in all copies.
-
-THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
-WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
-MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
-ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
-WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
-ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
-OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
-
-*/
-
 package wireguard
 
 import (
-	"bytes"
 	"context"
 	"fmt"
-	stdnet "net"
 	"net/netip"
 	"strings"
-	"sync"
 
 	"github.com/xtls/xray-core/common"
-	"github.com/xtls/xray-core/common/buf"
-	"github.com/xtls/xray-core/common/dice"
 	"github.com/xtls/xray-core/common/log"
-	"github.com/xtls/xray-core/common/net"
-	"github.com/xtls/xray-core/common/protocol"
-	"github.com/xtls/xray-core/common/session"
-	"github.com/xtls/xray-core/common/signal"
-	"github.com/xtls/xray-core/common/task"
-	"github.com/xtls/xray-core/core"
-	"github.com/xtls/xray-core/features/dns"
-	"github.com/xtls/xray-core/features/policy"
-	"github.com/xtls/xray-core/transport"
-	"github.com/xtls/xray-core/transport/internet"
+	"golang.zx2c4.com/wireguard/device"
 )
 
-// Handler is an outbound connection that silently swallow the entire payload.
-type Handler struct {
-	conf          *DeviceConfig
-	net           Tunnel
-	bind          *netBindClient
-	policyManager policy.Manager
-	dns           dns.Client
-	// cached configuration
-	ipc              string
-	endpoints        []netip.Addr
-	hasIPv4, hasIPv6 bool
-	wgLock           sync.Mutex
-}
-
-// New creates a new wireguard handler.
-func New(ctx context.Context, conf *DeviceConfig) (*Handler, error) {
-	v := core.MustFromContext(ctx)
-
-	endpoints, err := parseEndpoints(conf)
-	if err != nil {
-		return nil, err
-	}
-
-	hasIPv4, hasIPv6 := false, false
-	for _, e := range endpoints {
-		if e.Is4() {
-			hasIPv4 = true
-		}
-		if e.Is6() {
-			hasIPv6 = true
-		}
-	}
-
-	d := v.GetFeature(dns.ClientType()).(dns.Client)
-	return &Handler{
-		conf:          conf,
-		policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
-		dns:           d,
-		ipc:           createIPCRequest(conf, d, hasIPv6),
-		endpoints:     endpoints,
-		hasIPv4:       hasIPv4,
-		hasIPv6:       hasIPv6,
-	}, nil
-}
-
-func (h *Handler) processWireGuard(dialer internet.Dialer) (err error) {
-	h.wgLock.Lock()
-	defer h.wgLock.Unlock()
-
-	if h.bind != nil && h.bind.dialer == dialer && h.net != nil {
-		return nil
-	}
-
-	log.Record(&log.GeneralMessage{
-		Severity: log.Severity_Info,
-		Content:  "switching dialer",
-	})
-
-	if h.net != nil {
-		_ = h.net.Close()
-		h.net = nil
-	}
-	if h.bind != nil {
-		_ = h.bind.Close()
-		h.bind = nil
-	}
-
-	// bind := conn.NewStdNetBind() // TODO: conn.Bind wrapper for dialer
-	bind := &netBindClient{
-		dialer:   dialer,
-		workers:  int(h.conf.NumWorkers),
-		dns:      h.dns,
-		reserved: h.conf.Reserved,
-	}
-	defer func() {
-		if err != nil {
-			_ = bind.Close()
-		}
-	}()
-
-	h.net, err = h.makeVirtualTun(bind)
-	if err != nil {
-		return newError("failed to create virtual tun interface").Base(err)
-	}
-	h.bind = bind
-	return nil
-}
-
-// Process implements OutboundHandler.Dispatch().
-func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
-	outbound := session.OutboundFromContext(ctx)
-	if outbound == nil || !outbound.Target.IsValid() {
-		return newError("target not specified")
-	}
-	outbound.Name = "wireguard"
-	inbound := session.InboundFromContext(ctx)
-	if inbound != nil {
-		inbound.SetCanSpliceCopy(3)
-	}
-
-	if err := h.processWireGuard(dialer); err != nil {
-		return err
-	}
-
-	// Destination of the inner request.
-	destination := outbound.Target
-	command := protocol.RequestCommandTCP
-	if destination.Network == net.Network_UDP {
-		command = protocol.RequestCommandUDP
-	}
+//go:generate go run github.com/xtls/xray-core/common/errors/errorgen
 
-	// resolve dns
-	addr := destination.Address
-	if addr.Family().IsDomain() {
-		ips, err := h.dns.LookupIP(addr.Domain(), dns.IPOption{
-			IPv4Enable: h.hasIPv4 && h.conf.preferIP4(),
-			IPv6Enable: h.hasIPv6 && h.conf.preferIP6(),
+var wgLogger = &device.Logger{
+	Verbosef: func(format string, args ...any) {
+		log.Record(&log.GeneralMessage{
+			Severity: log.Severity_Debug,
+			Content:  fmt.Sprintf(format, args...),
 		})
-		{ // Resolve fallback
-			if (len(ips) == 0 || err != nil) && h.conf.hasFallback() {
-				ips, err = h.dns.LookupIP(addr.Domain(), dns.IPOption{
-					IPv4Enable: h.hasIPv4 && h.conf.fallbackIP4(),
-					IPv6Enable: h.hasIPv6 && h.conf.fallbackIP6(),
-				})
-			}
-		}
-		if err != nil {
-			return newError("failed to lookup DNS").Base(err)
-		} else if len(ips) == 0 {
-			return dns.ErrEmptyResponse
-		}
-		addr = net.IPAddress(ips[dice.Roll(len(ips))])
-	}
-
-	var newCtx context.Context
-	var newCancel context.CancelFunc
-	if session.TimeoutOnlyFromContext(ctx) {
-		newCtx, newCancel = context.WithCancel(context.Background())
-	}
-
-	p := h.policyManager.ForLevel(0)
-
-	ctx, cancel := context.WithCancel(ctx)
-	timer := signal.CancelAfterInactivity(ctx, func() {
-		cancel()
-		if newCancel != nil {
-			newCancel()
-		}
-	}, p.Timeouts.ConnectionIdle)
-	addrPort := netip.AddrPortFrom(toNetIpAddr(addr), destination.Port.Value())
-
-	var requestFunc func() error
-	var responseFunc func() error
-
-	if command == protocol.RequestCommandTCP {
-		conn, err := h.net.DialContextTCPAddrPort(ctx, addrPort)
-		if err != nil {
-			return newError("failed to create TCP connection").Base(err)
-		}
-		defer conn.Close()
-
-		requestFunc = func() error {
-			defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
-			return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
-		}
-		responseFunc = func() error {
-			defer timer.SetTimeout(p.Timeouts.UplinkOnly)
-			return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
-		}
-	} else if command == protocol.RequestCommandUDP {
-		conn, err := h.net.DialUDPAddrPort(netip.AddrPort{}, addrPort)
-		if err != nil {
-			return newError("failed to create UDP connection").Base(err)
-		}
-		defer conn.Close()
-
-		requestFunc = func() error {
-			defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
-			return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
-		}
-		responseFunc = func() error {
-			defer timer.SetTimeout(p.Timeouts.UplinkOnly)
-			return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
-		}
-	}
-
-	if newCtx != nil {
-		ctx = newCtx
-	}
-
-	responseDonePost := task.OnSuccess(responseFunc, task.Close(link.Writer))
-	if err := task.Run(ctx, requestFunc, responseDonePost); err != nil {
-		common.Interrupt(link.Reader)
-		common.Interrupt(link.Writer)
-		return newError("connection ends").Base(err)
-	}
-
-	return nil
+	},
+	Errorf: func(format string, args ...any) {
+		log.Record(&log.GeneralMessage{
+			Severity: log.Severity_Error,
+			Content:  fmt.Sprintf(format, args...),
+		})
+	},
 }
 
-// serialize the config into an IPC request
-func createIPCRequest(conf *DeviceConfig, d dns.Client, resolveEndPointToV4 bool) string {
-	var request bytes.Buffer
-
-	request.WriteString(fmt.Sprintf("private_key=%s\n", conf.SecretKey))
-
-	for _, peer := range conf.Peers {
-		endpoint := peer.Endpoint
-		host, port, err := net.SplitHostPort(endpoint)
-		if resolveEndPointToV4 && err == nil {
-			_, err = netip.ParseAddr(host)
-			if err != nil {
-				ipList, err := d.LookupIP(host, dns.IPOption{IPv4Enable: true, IPv6Enable: false})
-				if err == nil && len(ipList) > 0 {
-					endpoint = stdnet.JoinHostPort(ipList[0].String(), port)
-				}
-			}
-		}
-
-		request.WriteString(fmt.Sprintf("public_key=%s\nendpoint=%s\npersistent_keepalive_interval=%d\npreshared_key=%s\n",
-			peer.PublicKey, endpoint, peer.KeepAlive, peer.PreSharedKey))
-
-		for _, ip := range peer.AllowedIps {
-			request.WriteString(fmt.Sprintf("allowed_ip=%s\n", ip))
+func init() {
+	common.Must(common.RegisterConfig((*DeviceConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
+		deviceConfig := config.(*DeviceConfig)
+		if deviceConfig.IsClient {
+			return New(ctx, deviceConfig)
+		} else {
+			return NewServer(ctx, deviceConfig)
 		}
-	}
-
-	return request.String()[:request.Len()]
+	}))
 }
 
 // convert endpoint string to netip.Addr
-func parseEndpoints(conf *DeviceConfig) ([]netip.Addr, error) {
+func parseEndpoints(conf *DeviceConfig) ([]netip.Addr, bool, bool, error) {
+	var hasIPv4, hasIPv6 bool
+
 	endpoints := make([]netip.Addr, len(conf.Endpoint))
 	for i, str := range conf.Endpoint {
 		var addr netip.Addr
 		if strings.Contains(str, "/") {
 			prefix, err := netip.ParsePrefix(str)
 			if err != nil {
-				return nil, err
+				return nil, false, false, err
 			}
 			addr = prefix.Addr()
 			if prefix.Bits() != addr.BitLen() {
-				return nil, newError("interface address subnet should be /32 for IPv4 and /128 for IPv6")
+				return nil, false, false, newError("interface address subnet should be /32 for IPv4 and /128 for IPv6")
 			}
 		} else {
 			var err error
 			addr, err = netip.ParseAddr(str)
 			if err != nil {
-				return nil, err
+				return nil, false, false, err
 			}
 		}
 		endpoints[i] = addr
+
+		if addr.Is4() {
+			hasIPv4 = true
+		} else if addr.Is6() {
+			hasIPv6 = true
+		}
 	}
 
-	return endpoints, nil
+	return endpoints, hasIPv4, hasIPv6, nil
 }
 
-// creates a tun interface on netstack given a configuration
-func (h *Handler) makeVirtualTun(bind *netBindClient) (Tunnel, error) {
-	t, err := CreateTun(h.endpoints, int(h.conf.Mtu))
-	if err != nil {
-		return nil, err
+// serialize the config into an IPC request
+func createIPCRequest(conf *DeviceConfig) string {
+	var request strings.Builder
+
+	request.WriteString(fmt.Sprintf("private_key=%s\n", conf.SecretKey))
+
+	if !conf.IsClient {
+		// placeholder, we'll handle actual port listening on Xray
+		request.WriteString("listen_port=1337\n")
 	}
 
-	bind.dnsOption.IPv4Enable = h.hasIPv4
-	bind.dnsOption.IPv6Enable = h.hasIPv6
+	for _, peer := range conf.Peers {
+		if peer.PublicKey != "" {
+			request.WriteString(fmt.Sprintf("public_key=%s\n", peer.PublicKey))
+		}
 
-	if err = t.BuildDevice(h.ipc, bind); err != nil {
-		_ = t.Close()
-		return nil, err
+		if peer.PreSharedKey != "" {
+			request.WriteString(fmt.Sprintf("preshared_key=%s\n", peer.PreSharedKey))
+		}
+
+		if peer.Endpoint != "" {
+			request.WriteString(fmt.Sprintf("endpoint=%s\n", peer.Endpoint))
+		}
+
+		for _, ip := range peer.AllowedIps {
+			request.WriteString(fmt.Sprintf("allowed_ip=%s\n", ip))
+		}
+
+		if peer.KeepAlive != 0 {
+			request.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", peer.KeepAlive))
+		}
 	}
-	return t, nil
-}
 
-func init() {
-	common.Must(common.RegisterConfig((*DeviceConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
-		return New(ctx, config.(*DeviceConfig))
-	}))
+	return request.String()[:request.Len()]
 }