瀏覽代碼

Add option for custom wireguard reserved bytes

世界 3 年之前
父節點
當前提交
35886b88d7
共有 3 個文件被更改,包括 22 次插入2 次删除
  1. 1 0
      option/wireguard.go
  2. 8 1
      outbound/wireguard.go
  3. 13 1
      transport/wireguard/client_bind.go

+ 1 - 0
option/wireguard.go

@@ -9,6 +9,7 @@ type WireGuardOutboundOptions struct {
 	PrivateKey      string                 `json:"private_key"`
 	PeerPublicKey   string                 `json:"peer_public_key"`
 	PreSharedKey    string                 `json:"pre_shared_key,omitempty"`
+	Reserved        []uint8                `json:"reserved,omitempty"`
 	MTU             uint32                 `json:"mtu,omitempty"`
 	Network         NetworkList            `json:"network,omitempty"`
 }

+ 8 - 1
outbound/wireguard.go

@@ -45,8 +45,15 @@ func NewWireGuard(ctx context.Context, router adapter.Router, logger log.Context
 			tag:      tag,
 		},
 	}
+	var reserved [3]uint8
+	if len(options.Reserved) > 0 {
+		if len(options.Reserved) != 3 {
+			return nil, E.New("invalid reserved value, required 3 bytes, got ", len(options.Reserved))
+		}
+		copy(reserved[:], options.Reserved)
+	}
 	peerAddr := options.ServerOptions.Build()
-	outbound.bind = wireguard.NewClientBind(ctx, dialer.New(router, options.DialerOptions), peerAddr)
+	outbound.bind = wireguard.NewClientBind(ctx, dialer.New(router, options.DialerOptions), peerAddr, reserved)
 	localPrefixes := common.Map(options.LocalAddress, option.ListenPrefix.Build)
 	if len(localPrefixes) == 0 {
 		return nil, E.New("missing local address")

+ 13 - 1
transport/wireguard/client_bind.go

@@ -18,16 +18,18 @@ type ClientBind struct {
 	ctx        context.Context
 	dialer     N.Dialer
 	peerAddr   M.Socksaddr
+	reserved   [3]uint8
 	connAccess sync.Mutex
 	conn       *wireConn
 	done       chan struct{}
 }
 
-func NewClientBind(ctx context.Context, dialer N.Dialer, peerAddr M.Socksaddr) *ClientBind {
+func NewClientBind(ctx context.Context, dialer N.Dialer, peerAddr M.Socksaddr, reserved [3]uint8) *ClientBind {
 	return &ClientBind{
 		ctx:      ctx,
 		dialer:   dialer,
 		peerAddr: peerAddr,
+		reserved: reserved,
 	}
 }
 
@@ -89,6 +91,11 @@ func (c *ClientBind) receive(b []byte) (n int, ep conn.Endpoint, err error) {
 		}
 		return
 	}
+	if n > 3 {
+		b[1] = 0
+		b[2] = 0
+		b[3] = 0
+	}
 	ep = Endpoint(c.peerAddr)
 	return
 }
@@ -119,6 +126,11 @@ func (c *ClientBind) Send(b []byte, ep conn.Endpoint) error {
 	if err != nil {
 		return err
 	}
+	if len(b) > 3 {
+		b[1] = c.reserved[0]
+		b[2] = c.reserved[1]
+		b[3] = c.reserved[2]
+	}
 	_, err = udpConn.Write(b)
 	if err != nil {
 		udpConn.Close()