Browse Source

Improve multiplexer

世界 3 years ago
parent
commit
03890151d7

+ 28 - 27
common/mux/client.go

@@ -15,40 +15,44 @@ import (
 	M "github.com/sagernet/sing/common/metadata"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
 	N "github.com/sagernet/sing/common/network"
 	"github.com/sagernet/sing/common/x/list"
 	"github.com/sagernet/sing/common/x/list"
-
-	"github.com/hashicorp/yamux"
 )
 )
 
 
 var _ N.Dialer = (*Client)(nil)
 var _ N.Dialer = (*Client)(nil)
 
 
 type Client struct {
 type Client struct {
 	access         sync.Mutex
 	access         sync.Mutex
-	connections    list.List[*yamux.Session]
+	connections    list.List[abstractSession]
 	ctx            context.Context
 	ctx            context.Context
 	dialer         N.Dialer
 	dialer         N.Dialer
+	protocol       Protocol
 	maxConnections int
 	maxConnections int
 	minStreams     int
 	minStreams     int
 	maxStreams     int
 	maxStreams     int
 }
 }
 
 
-func NewClient(ctx context.Context, dialer N.Dialer, maxConnections int, minStreams int, maxStreams int) *Client {
+func NewClient(ctx context.Context, dialer N.Dialer, protocol Protocol, maxConnections int, minStreams int, maxStreams int) *Client {
 	return &Client{
 	return &Client{
 		ctx:            ctx,
 		ctx:            ctx,
 		dialer:         dialer,
 		dialer:         dialer,
+		protocol:       protocol,
 		maxConnections: maxConnections,
 		maxConnections: maxConnections,
 		minStreams:     minStreams,
 		minStreams:     minStreams,
 		maxStreams:     maxStreams,
 		maxStreams:     maxStreams,
 	}
 	}
 }
 }
 
 
-func NewClientWithOptions(ctx context.Context, dialer N.Dialer, options option.MultiplexOptions) N.Dialer {
+func NewClientWithOptions(ctx context.Context, dialer N.Dialer, options option.MultiplexOptions) (N.Dialer, error) {
 	if !options.Enabled {
 	if !options.Enabled {
-		return dialer
+		return dialer, nil
 	}
 	}
 	if options.MaxConnections == 0 && options.MaxStreams == 0 {
 	if options.MaxConnections == 0 && options.MaxStreams == 0 {
 		options.MinStreams = 8
 		options.MinStreams = 8
 	}
 	}
-	return NewClient(ctx, dialer, options.MaxConnections, options.MinStreams, options.MaxStreams)
+	protocol, err := ParseProtocol(options.Protocol)
+	if err != nil {
+		return nil, err
+	}
+	return NewClient(ctx, dialer, protocol, options.MaxConnections, options.MinStreams, options.MaxStreams), nil
 }
 }
 
 
 func (c *Client) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
 func (c *Client) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
@@ -80,8 +84,8 @@ func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net
 
 
 func (c *Client) openStream() (net.Conn, error) {
 func (c *Client) openStream() (net.Conn, error) {
 	var (
 	var (
-		session *yamux.Session
-		stream  *yamux.Stream
+		session abstractSession
+		stream  net.Conn
 		err     error
 		err     error
 	)
 	)
 	for attempts := 0; attempts < 2; attempts++ {
 	for attempts := 0; attempts < 2; attempts++ {
@@ -89,7 +93,7 @@ func (c *Client) openStream() (net.Conn, error) {
 		if err != nil {
 		if err != nil {
 			continue
 			continue
 		}
 		}
-		stream, err = session.OpenStream()
+		stream, err = session.Open()
 		if err != nil {
 		if err != nil {
 			continue
 			continue
 		}
 		}
@@ -101,11 +105,11 @@ func (c *Client) openStream() (net.Conn, error) {
 	return &wrapStream{stream}, nil
 	return &wrapStream{stream}, nil
 }
 }
 
 
-func (c *Client) offer() (*yamux.Session, error) {
+func (c *Client) offer() (abstractSession, error) {
 	c.access.Lock()
 	c.access.Lock()
 	defer c.access.Unlock()
 	defer c.access.Unlock()
 
 
-	sessions := make([]*yamux.Session, 0, c.maxConnections)
+	sessions := make([]abstractSession, 0, c.maxConnections)
 	for element := c.connections.Front(); element != nil; {
 	for element := c.connections.Front(); element != nil; {
 		if element.Value.IsClosed() {
 		if element.Value.IsClosed() {
 			nextElement := element.Next()
 			nextElement := element.Next()
@@ -120,10 +124,7 @@ func (c *Client) offer() (*yamux.Session, error) {
 	if sLen == 0 {
 	if sLen == 0 {
 		return c.offerNew()
 		return c.offerNew()
 	}
 	}
-	// session := common.MinBy(sessions, yamux.Session.NumStreams)
-	session := common.MinBy(sessions, func(it *yamux.Session) int {
-		return it.NumStreams()
-	})
+	session := common.MinBy(sessions, abstractSession.NumStreams)
 	numStreams := session.NumStreams()
 	numStreams := session.NumStreams()
 	if numStreams == 0 {
 	if numStreams == 0 {
 		return session, nil
 		return session, nil
@@ -140,12 +141,12 @@ func (c *Client) offer() (*yamux.Session, error) {
 	return c.offerNew()
 	return c.offerNew()
 }
 }
 
 
-func (c *Client) offerNew() (*yamux.Session, error) {
+func (c *Client) offerNew() (abstractSession, error) {
 	conn, err := c.dialer.DialContext(c.ctx, N.NetworkTCP, Destination)
 	conn, err := c.dialer.DialContext(c.ctx, N.NetworkTCP, Destination)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	session, err := yamux.Client(conn, newMuxConfig())
+	session, err := c.protocol.newClient(&protocolConn{Conn: conn, protocol: c.protocol})
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
@@ -170,7 +171,7 @@ type ClientConn struct {
 }
 }
 
 
 func (c *ClientConn) readResponse() error {
 func (c *ClientConn) readResponse() error {
-	response, err := ReadResponse(c.Conn)
+	response, err := ReadStreamResponse(c.Conn)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -195,7 +196,7 @@ func (c *ClientConn) Write(b []byte) (n int, err error) {
 	if c.requestWrite {
 	if c.requestWrite {
 		return c.Conn.Write(b)
 		return c.Conn.Write(b)
 	}
 	}
-	request := Request{
+	request := StreamRequest{
 		Network:     N.NetworkTCP,
 		Network:     N.NetworkTCP,
 		Destination: c.destination,
 		Destination: c.destination,
 	}
 	}
@@ -203,7 +204,7 @@ func (c *ClientConn) Write(b []byte) (n int, err error) {
 	defer common.KeepAlive(_buffer)
 	defer common.KeepAlive(_buffer)
 	buffer := common.Dup(_buffer)
 	buffer := common.Dup(_buffer)
 	defer buffer.Release()
 	defer buffer.Release()
-	EncodeRequest(request, buffer)
+	EncodeStreamRequest(request, buffer)
 	buffer.Write(b)
 	buffer.Write(b)
 	_, err = c.Conn.Write(buffer.Bytes())
 	_, err = c.Conn.Write(buffer.Bytes())
 	if err != nil {
 	if err != nil {
@@ -255,7 +256,7 @@ type ClientPacketConn struct {
 }
 }
 
 
 func (c *ClientPacketConn) readResponse() error {
 func (c *ClientPacketConn) readResponse() error {
-	response, err := ReadResponse(c.ExtendedConn)
+	response, err := ReadStreamResponse(c.ExtendedConn)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -285,7 +286,7 @@ func (c *ClientPacketConn) Read(b []byte) (n int, err error) {
 }
 }
 
 
 func (c *ClientPacketConn) writeRequest(payload []byte) (n int, err error) {
 func (c *ClientPacketConn) writeRequest(payload []byte) (n int, err error) {
-	request := Request{
+	request := StreamRequest{
 		Network:     N.NetworkUDP,
 		Network:     N.NetworkUDP,
 		Destination: c.destination,
 		Destination: c.destination,
 	}
 	}
@@ -297,7 +298,7 @@ func (c *ClientPacketConn) writeRequest(payload []byte) (n int, err error) {
 	defer common.KeepAlive(_buffer)
 	defer common.KeepAlive(_buffer)
 	buffer := common.Dup(_buffer)
 	buffer := common.Dup(_buffer)
 	defer buffer.Release()
 	defer buffer.Release()
-	EncodeRequest(request, buffer)
+	EncodeStreamRequest(request, buffer)
 	if len(payload) > 0 {
 	if len(payload) > 0 {
 		common.Must(
 		common.Must(
 			binary.Write(buffer, binary.BigEndian, uint16(len(payload))),
 			binary.Write(buffer, binary.BigEndian, uint16(len(payload))),
@@ -363,7 +364,7 @@ type ClientPacketAddrConn struct {
 }
 }
 
 
 func (c *ClientPacketAddrConn) readResponse() error {
 func (c *ClientPacketAddrConn) readResponse() error {
-	response, err := ReadResponse(c.ExtendedConn)
+	response, err := ReadStreamResponse(c.ExtendedConn)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -399,7 +400,7 @@ func (c *ClientPacketAddrConn) ReadFrom(p []byte) (n int, addr net.Addr, err err
 }
 }
 
 
 func (c *ClientPacketAddrConn) writeRequest(payload []byte, destination M.Socksaddr) (n int, err error) {
 func (c *ClientPacketAddrConn) writeRequest(payload []byte, destination M.Socksaddr) (n int, err error) {
-	request := Request{
+	request := StreamRequest{
 		Network:     N.NetworkUDP,
 		Network:     N.NetworkUDP,
 		Destination: c.destination,
 		Destination: c.destination,
 		PacketAddr:  true,
 		PacketAddr:  true,
@@ -412,7 +413,7 @@ func (c *ClientPacketAddrConn) writeRequest(payload []byte, destination M.Socksa
 	defer common.KeepAlive(_buffer)
 	defer common.KeepAlive(_buffer)
 	buffer := common.Dup(_buffer)
 	buffer := common.Dup(_buffer)
 	defer buffer.Release()
 	defer buffer.Release()
-	EncodeRequest(request, buffer)
+	EncodeStreamRequest(request, buffer)
 	if len(payload) > 0 {
 	if len(payload) > 0 {
 		common.Must(
 		common.Must(
 			M.SocksaddrSerializer.WriteAddrPort(buffer, destination),
 			M.SocksaddrSerializer.WriteAddrPort(buffer, destination),

+ 99 - 17
common/mux/protocol.go

@@ -14,6 +14,7 @@ import (
 	"github.com/sagernet/sing/common/rw"
 	"github.com/sagernet/sing/common/rw"
 
 
 	"github.com/hashicorp/yamux"
 	"github.com/hashicorp/yamux"
+	"github.com/xtaci/smux"
 )
 )
 
 
 var Destination = M.Socksaddr{
 var Destination = M.Socksaddr{
@@ -21,7 +22,55 @@ var Destination = M.Socksaddr{
 	Port: 444,
 	Port: 444,
 }
 }
 
 
-func newMuxConfig() *yamux.Config {
+const (
+	ProtocolYAMux Protocol = 0
+	ProtocolSMux  Protocol = 1
+)
+
+type Protocol byte
+
+func ParseProtocol(name string) (Protocol, error) {
+	switch name {
+	case "", "yamux":
+		return ProtocolYAMux, nil
+	case "smux":
+		return ProtocolSMux, nil
+	default:
+		return ProtocolYAMux, E.New("unknown multiplex protocol: ", name)
+	}
+}
+
+func (p Protocol) newServer(conn net.Conn) (abstractSession, error) {
+	switch p {
+	case ProtocolYAMux:
+		return yamux.Server(conn, yaMuxConfig())
+	case ProtocolSMux:
+		session, err := smux.Server(conn, nil)
+		if err != nil {
+			return nil, err
+		}
+		return &smuxSession{session}, nil
+	default:
+		panic("unknown protocol")
+	}
+}
+
+func (p Protocol) newClient(conn net.Conn) (abstractSession, error) {
+	switch p {
+	case ProtocolYAMux:
+		return yamux.Client(conn, yaMuxConfig())
+	case ProtocolSMux:
+		session, err := smux.Client(conn, nil)
+		if err != nil {
+			return nil, err
+		}
+		return &smuxSession{session}, nil
+	default:
+		panic("unknown protocol")
+	}
+}
+
+func yaMuxConfig() *yamux.Config {
 	config := yamux.DefaultConfig()
 	config := yamux.DefaultConfig()
 	config.LogOutput = io.Discard
 	config.LogOutput = io.Discard
 	config.StreamCloseTimeout = C.TCPTimeout
 	config.StreamCloseTimeout = C.TCPTimeout
@@ -29,18 +78,23 @@ func newMuxConfig() *yamux.Config {
 	return config
 	return config
 }
 }
 
 
+func (p Protocol) String() string {
+	switch p {
+	case ProtocolYAMux:
+		return "yamux"
+	case ProtocolSMux:
+		return "smux"
+	default:
+		return "unknown"
+	}
+}
+
 const (
 const (
-	version0      = 0
-	flagUDP       = 1
-	flagAddr      = 2
-	statusSuccess = 0
-	statusError   = 1
+	version0 = 0
 )
 )
 
 
 type Request struct {
 type Request struct {
-	Network     string
-	Destination M.Socksaddr
-	PacketAddr  bool
+	Protocol Protocol
 }
 }
 
 
 func ReadRequest(reader io.Reader) (*Request, error) {
 func ReadRequest(reader io.Reader) (*Request, error) {
@@ -51,8 +105,37 @@ func ReadRequest(reader io.Reader) (*Request, error) {
 	if version != version0 {
 	if version != version0 {
 		return nil, E.New("unsupported version: ", version)
 		return nil, E.New("unsupported version: ", version)
 	}
 	}
+	protocol, err := rw.ReadByte(reader)
+	if err != nil {
+		return nil, err
+	}
+	if protocol > byte(ProtocolSMux) {
+		return nil, E.New("unsupported protocol: ", protocol)
+	}
+	return &Request{Protocol: Protocol(protocol)}, nil
+}
+
+func EncodeRequest(buffer *buf.Buffer, request Request) {
+	buffer.WriteByte(version0)
+	buffer.WriteByte(byte(request.Protocol))
+}
+
+const (
+	flagUDP       = 1
+	flagAddr      = 2
+	statusSuccess = 0
+	statusError   = 1
+)
+
+type StreamRequest struct {
+	Network     string
+	Destination M.Socksaddr
+	PacketAddr  bool
+}
+
+func ReadStreamRequest(reader io.Reader) (*StreamRequest, error) {
 	var flags uint16
 	var flags uint16
-	err = binary.Read(reader, binary.BigEndian, &flags)
+	err := binary.Read(reader, binary.BigEndian, &flags)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
@@ -68,10 +151,10 @@ func ReadRequest(reader io.Reader) (*Request, error) {
 		network = N.NetworkUDP
 		network = N.NetworkUDP
 		udpAddr = flags&flagAddr != 0
 		udpAddr = flags&flagAddr != 0
 	}
 	}
-	return &Request{network, destination, udpAddr}, nil
+	return &StreamRequest{network, destination, udpAddr}, nil
 }
 }
 
 
-func requestLen(request Request) int {
+func requestLen(request StreamRequest) int {
 	var rLen int
 	var rLen int
 	rLen += 1 // version
 	rLen += 1 // version
 	rLen += 2 // flags
 	rLen += 2 // flags
@@ -79,7 +162,7 @@ func requestLen(request Request) int {
 	return rLen
 	return rLen
 }
 }
 
 
-func EncodeRequest(request Request, buffer *buf.Buffer) {
+func EncodeStreamRequest(request StreamRequest, buffer *buf.Buffer) {
 	destination := request.Destination
 	destination := request.Destination
 	var flags uint16
 	var flags uint16
 	if request.Network == N.NetworkUDP {
 	if request.Network == N.NetworkUDP {
@@ -92,19 +175,18 @@ func EncodeRequest(request Request, buffer *buf.Buffer) {
 		}
 		}
 	}
 	}
 	common.Must(
 	common.Must(
-		buffer.WriteByte(version0),
 		binary.Write(buffer, binary.BigEndian, flags),
 		binary.Write(buffer, binary.BigEndian, flags),
 		M.SocksaddrSerializer.WriteAddrPort(buffer, destination),
 		M.SocksaddrSerializer.WriteAddrPort(buffer, destination),
 	)
 	)
 }
 }
 
 
-type Response struct {
+type StreamResponse struct {
 	Status  uint8
 	Status  uint8
 	Message string
 	Message string
 }
 }
 
 
-func ReadResponse(reader io.Reader) (*Response, error) {
-	var response Response
+func ReadStreamResponse(reader io.Reader) (*StreamResponse, error) {
+	var response StreamResponse
 	status, err := rw.ReadByte(reader)
 	status, err := rw.ReadByte(reader)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err

+ 6 - 4
common/mux/service.go

@@ -14,12 +14,14 @@ import (
 	M "github.com/sagernet/sing/common/metadata"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
 	N "github.com/sagernet/sing/common/network"
 	"github.com/sagernet/sing/common/rw"
 	"github.com/sagernet/sing/common/rw"
-
-	"github.com/hashicorp/yamux"
 )
 )
 
 
 func NewConnection(ctx context.Context, router adapter.Router, errorHandler E.Handler, logger log.ContextLogger, conn net.Conn, metadata adapter.InboundContext) error {
 func NewConnection(ctx context.Context, router adapter.Router, errorHandler E.Handler, logger log.ContextLogger, conn net.Conn, metadata adapter.InboundContext) error {
-	session, err := yamux.Server(conn, newMuxConfig())
+	request, err := ReadRequest(conn)
+	if err != nil {
+		return err
+	}
+	session, err := request.Protocol.newServer(conn)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -34,7 +36,7 @@ func NewConnection(ctx context.Context, router adapter.Router, errorHandler E.Ha
 
 
 func newConnection(ctx context.Context, router adapter.Router, errorHandler E.Handler, logger log.ContextLogger, stream net.Conn, metadata adapter.InboundContext) {
 func newConnection(ctx context.Context, router adapter.Router, errorHandler E.Handler, logger log.ContextLogger, stream net.Conn, metadata adapter.InboundContext) {
 	stream = &wrapStream{stream}
 	stream = &wrapStream{stream}
-	request, err := ReadRequest(stream)
+	request, err := ReadStreamRequest(stream)
 	if err != nil {
 	if err != nil {
 		logger.ErrorContext(ctx, err)
 		logger.ErrorContext(ctx, err)
 		return
 		return

+ 71 - 0
common/mux/session.go

@@ -0,0 +1,71 @@
+package mux
+
+import (
+	"io"
+	"net"
+
+	"github.com/sagernet/sing/common"
+	"github.com/sagernet/sing/common/buf"
+	"github.com/sagernet/sing/common/bufio"
+
+	"github.com/xtaci/smux"
+)
+
+type abstractSession interface {
+	Open() (net.Conn, error)
+	Accept() (net.Conn, error)
+	NumStreams() int
+	Close() error
+	IsClosed() bool
+}
+
+var _ abstractSession = (*smuxSession)(nil)
+
+type smuxSession struct {
+	*smux.Session
+}
+
+func (s *smuxSession) Open() (net.Conn, error) {
+	return s.OpenStream()
+}
+
+func (s *smuxSession) Accept() (net.Conn, error) {
+	return s.AcceptStream()
+}
+
+type protocolConn struct {
+	net.Conn
+	protocol        Protocol
+	protocolWritten bool
+}
+
+func (c *protocolConn) Write(p []byte) (n int, err error) {
+	if c.protocolWritten {
+		return c.Conn.Write(p)
+	}
+	_buffer := buf.StackNewSize(2 + len(p))
+	defer common.KeepAlive(_buffer)
+	buffer := common.Dup(_buffer)
+	defer buffer.Release()
+	EncodeRequest(buffer, Request{
+		Protocol: c.protocol,
+	})
+	common.Must(common.Error(buffer.Write(p)))
+	n, err = c.Conn.Write(buffer.Bytes())
+	if err == nil {
+		n--
+	}
+	c.protocolWritten = true
+	return n, err
+}
+
+func (c *protocolConn) ReadFrom(r io.Reader) (n int64, err error) {
+	if !c.protocolWritten {
+		return bufio.ReadFrom0(c, r)
+	}
+	return bufio.Copy(c.Conn, r)
+}
+
+func (c *protocolConn) Upstream() any {
+	return c.Conn
+}

+ 12 - 0
docs/configuration/shared/multiplex.md

@@ -7,6 +7,7 @@
 ```json
 ```json
 {
 {
   "enabled": true,
   "enabled": true,
+  "protocol": "yamux",
   "max_connections": 4,
   "max_connections": 4,
   "min_streams": 4,
   "min_streams": 4,
   "max_streams": 0
   "max_streams": 0
@@ -19,6 +20,17 @@
 
 
 Enable multiplex.
 Enable multiplex.
 
 
+#### protocol
+
+Multiplex protocol.
+
+| Protocol | Description                        |
+|----------|------------------------------------|
+| yamux    | https://github.com/hashicorp/yamux |
+| smux     | https://github.com/xtaci/smux      |
+
+YAMux is used by default.
+
 #### max_connections
 #### max_connections
 
 
 Maximum connections.
 Maximum connections.

+ 1 - 0
go.mod

@@ -20,6 +20,7 @@ require (
 	github.com/sagernet/sing-vmess v0.0.0-20220802053753-a38d3b22e6b9
 	github.com/sagernet/sing-vmess v0.0.0-20220802053753-a38d3b22e6b9
 	github.com/spf13/cobra v1.5.0
 	github.com/spf13/cobra v1.5.0
 	github.com/stretchr/testify v1.8.0
 	github.com/stretchr/testify v1.8.0
+	github.com/xtaci/smux v1.5.16
 	go.uber.org/atomic v1.9.0
 	go.uber.org/atomic v1.9.0
 	golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa
 	golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa
 	golang.org/x/net v0.0.0-20220802222814-0bcc04d9c69b
 	golang.org/x/net v0.0.0-20220802222814-0bcc04d9c69b

+ 2 - 0
go.sum

@@ -201,6 +201,8 @@ github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49u
 github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM=
 github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM=
 github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 h1:gga7acRE695APm9hlsSMoOoE65U4/TcqNj90mc69Rlg=
 github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 h1:gga7acRE695APm9hlsSMoOoE65U4/TcqNj90mc69Rlg=
 github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
 github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
+github.com/xtaci/smux v1.5.16 h1:FBPYOkW8ZTjLKUM4LI4xnnuuDC8CQ/dB04HD519WoEk=
+github.com/xtaci/smux v1.5.16/go.mod h1:OMlQbT5vcgl2gb49mFkYo6SMf+zP3rcjcwQz7ZU7IGY=
 github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
 github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
 github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
 github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
 go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA=
 go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA=

+ 5 - 4
option/outbound.go

@@ -100,8 +100,9 @@ func (o ServerOptions) Build() M.Socksaddr {
 }
 }
 
 
 type MultiplexOptions struct {
 type MultiplexOptions struct {
-	Enabled        bool `json:"enabled,omitempty"`
-	MaxConnections int  `json:"max_connections,omitempty"`
-	MinStreams     int  `json:"min_streams,omitempty"`
-	MaxStreams     int  `json:"max_streams,omitempty"`
+	Enabled        bool   `json:"enabled,omitempty"`
+	Protocol       string `json:"protocol,omitempty"`
+	MaxConnections int    `json:"max_connections,omitempty"`
+	MinStreams     int    `json:"min_streams,omitempty"`
+	MaxStreams     int    `json:"max_streams,omitempty"`
 }
 }

+ 4 - 1
outbound/shadowsocks.go

@@ -46,7 +46,10 @@ func NewShadowsocks(ctx context.Context, router adapter.Router, logger log.Conte
 		method:     method,
 		method:     method,
 		serverAddr: options.ServerOptions.Build(),
 		serverAddr: options.ServerOptions.Build(),
 	}
 	}
-	outbound.multiplexDialer = mux.NewClientWithOptions(ctx, (*shadowsocksDialer)(outbound), common.PtrValueOrDefault(options.Multiplex))
+	outbound.multiplexDialer, err = mux.NewClientWithOptions(ctx, (*shadowsocksDialer)(outbound), common.PtrValueOrDefault(options.Multiplex))
+	if err != nil {
+		return nil, err
+	}
 	return outbound, nil
 	return outbound, nil
 }
 }
 
 

+ 2 - 2
route/router_dns.go

@@ -32,13 +32,13 @@ func (r *Router) Exchange(ctx context.Context, message *dnsmessage.Message) (*dn
 }
 }
 
 
 func (r *Router) Lookup(ctx context.Context, domain string, strategy dns.DomainStrategy) ([]netip.Addr, error) {
 func (r *Router) Lookup(ctx context.Context, domain string, strategy dns.DomainStrategy) ([]netip.Addr, error) {
-	r.dnsLogger.Debug(ctx, "lookup domain ", domain)
+	r.dnsLogger.DebugContext(ctx, "lookup domain ", domain)
 	ctx, transport := r.matchDNS(ctx)
 	ctx, transport := r.matchDNS(ctx)
 	ctx, cancel := context.WithTimeout(ctx, C.DNSTimeout)
 	ctx, cancel := context.WithTimeout(ctx, C.DNSTimeout)
 	defer cancel()
 	defer cancel()
 	addrs, err := r.dnsClient.Lookup(ctx, transport, domain, strategy)
 	addrs, err := r.dnsClient.Lookup(ctx, transport, domain, strategy)
 	if len(addrs) > 0 {
 	if len(addrs) > 0 {
-		r.logger.InfoContext(ctx, "lookup succeed for ", domain, ": ", F.MapToString(addrs))
+		r.logger.InfoContext(ctx, "lookup succeed for ", domain, ": ", strings.Join(F.MapToString(addrs), " "))
 	} else {
 	} else {
 		r.logger.ErrorContext(ctx, E.Cause(err, "lookup failed for ", domain))
 		r.logger.ErrorContext(ctx, E.Cause(err, "lookup failed for ", domain))
 	}
 	}

+ 1 - 0
test/go.mod

@@ -58,6 +58,7 @@ require (
 	github.com/sagernet/sing-vmess v0.0.0-20220802053753-a38d3b22e6b9 // indirect
 	github.com/sagernet/sing-vmess v0.0.0-20220802053753-a38d3b22e6b9 // indirect
 	github.com/sirupsen/logrus v1.8.1 // indirect
 	github.com/sirupsen/logrus v1.8.1 // indirect
 	github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 // indirect
 	github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 // indirect
+	github.com/xtaci/smux v1.5.16 // indirect
 	go.uber.org/atomic v1.9.0 // indirect
 	go.uber.org/atomic v1.9.0 // indirect
 	golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa // indirect
 	golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa // indirect
 	golang.org/x/mod v0.5.1 // indirect
 	golang.org/x/mod v0.5.1 // indirect

+ 2 - 0
test/go.sum

@@ -228,6 +228,8 @@ github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49u
 github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM=
 github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM=
 github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 h1:gga7acRE695APm9hlsSMoOoE65U4/TcqNj90mc69Rlg=
 github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 h1:gga7acRE695APm9hlsSMoOoE65U4/TcqNj90mc69Rlg=
 github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
 github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
+github.com/xtaci/smux v1.5.16 h1:FBPYOkW8ZTjLKUM4LI4xnnuuDC8CQ/dB04HD519WoEk=
+github.com/xtaci/smux v1.5.16/go.mod h1:OMlQbT5vcgl2gb49mFkYo6SMf+zP3rcjcwQz7ZU7IGY=
 github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
 github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
 github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
 github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
 github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
 github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=

+ 14 - 1
test/mux_test.go

@@ -4,12 +4,24 @@ import (
 	"net/netip"
 	"net/netip"
 	"testing"
 	"testing"
 
 
+	"github.com/sagernet/sing-box/common/mux"
 	C "github.com/sagernet/sing-box/constant"
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/option"
 	"github.com/sagernet/sing-box/option"
 	"github.com/sagernet/sing-shadowsocks/shadowaead_2022"
 	"github.com/sagernet/sing-shadowsocks/shadowaead_2022"
 )
 )
 
 
 func TestShadowsocksMux(t *testing.T) {
 func TestShadowsocksMux(t *testing.T) {
+	for _, protocol := range []mux.Protocol{
+		mux.ProtocolYAMux,
+		mux.ProtocolSMux,
+	} {
+		t.Run(protocol.String(), func(t *testing.T) {
+			testShadowsocksMux(t, protocol.String())
+		})
+	}
+}
+
+func testShadowsocksMux(t *testing.T, protocol string) {
 	method := shadowaead_2022.List[0]
 	method := shadowaead_2022.List[0]
 	password := mkBase64(t, 16)
 	password := mkBase64(t, 16)
 	startInstance(t, option.Options{
 	startInstance(t, option.Options{
@@ -54,7 +66,8 @@ func TestShadowsocksMux(t *testing.T) {
 					Method:   method,
 					Method:   method,
 					Password: password,
 					Password: password,
 					Multiplex: &option.MultiplexOptions{
 					Multiplex: &option.MultiplexOptions{
-						Enabled: true,
+						Enabled:  true,
+						Protocol: protocol,
 					},
 					},
 				},
 				},
 			},
 			},