Browse Source

Add trojan-go multiplex support for trojan inbound

世界 2 năm trước cách đây
mục cha
commit
8dcafa5b33
5 tập tin đã thay đổi với 520 bổ sung3 xóa
  1. 2 2
      inbound/trojan.go
  2. 1 1
      outbound/trojan.go
  3. 66 0
      transport/trojan/mux.go
  4. 313 0
      transport/trojan/protocol.go
  5. 138 0
      transport/trojan/service.go

+ 2 - 2
inbound/trojan.go

@@ -10,6 +10,7 @@ import (
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/log"
 	"github.com/sagernet/sing-box/option"
+	"github.com/sagernet/sing-box/transport/trojan"
 	"github.com/sagernet/sing-box/transport/v2ray"
 	"github.com/sagernet/sing/common"
 	"github.com/sagernet/sing/common/auth"
@@ -17,7 +18,6 @@ import (
 	F "github.com/sagernet/sing/common/format"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
-	"github.com/sagernet/sing/protocol/trojan"
 )
 
 var (
@@ -157,7 +157,7 @@ func (h *Trojan) NewConnection(ctx context.Context, conn net.Conn, metadata adap
 			return err
 		}
 	}
-	return h.service.NewConnection(adapter.WithContext(log.ContextWithNewID(ctx), &metadata), conn, adapter.UpstreamMetadata(metadata))
+	return h.service.NewConnection(adapter.WithContext(ctx, &metadata), conn, adapter.UpstreamMetadata(metadata))
 }
 
 func (h *Trojan) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {

+ 1 - 1
outbound/trojan.go

@@ -11,13 +11,13 @@ import (
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/log"
 	"github.com/sagernet/sing-box/option"
+	"github.com/sagernet/sing-box/transport/trojan"
 	"github.com/sagernet/sing-box/transport/v2ray"
 	"github.com/sagernet/sing/common"
 	"github.com/sagernet/sing/common/bufio"
 	E "github.com/sagernet/sing/common/exceptions"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
-	"github.com/sagernet/sing/protocol/trojan"
 )
 
 var _ adapter.Outbound = (*Trojan)(nil)

+ 66 - 0
transport/trojan/mux.go

@@ -0,0 +1,66 @@
+package trojan
+
+import (
+	"context"
+	"net"
+
+	E "github.com/sagernet/sing/common/exceptions"
+	M "github.com/sagernet/sing/common/metadata"
+	"github.com/sagernet/sing/common/rw"
+	"github.com/sagernet/sing/common/task"
+	"github.com/sagernet/smux"
+)
+
+func HandleMuxConnection(ctx context.Context, conn net.Conn, metadata M.Metadata, handler Handler) error {
+	session, err := smux.Server(conn, smuxConfig())
+	if err != nil {
+		return err
+	}
+	var group task.Group
+	group.Append0(func(ctx context.Context) error {
+		var stream net.Conn
+		for {
+			stream, err = session.AcceptStream()
+			if err != nil {
+				return err
+			}
+			go newMuxConnection(ctx, stream, metadata, handler)
+		}
+	})
+	group.Cleanup(func() {
+		session.Close()
+	})
+	return group.Run(ctx)
+}
+
+func newMuxConnection(ctx context.Context, stream net.Conn, metadata M.Metadata, handler Handler) {
+	err := newMuxConnection0(ctx, stream, metadata, handler)
+	if err != nil {
+		handler.NewError(ctx, E.Cause(err, "process trojan-go multiplex connection"))
+	}
+}
+
+func newMuxConnection0(ctx context.Context, stream net.Conn, metadata M.Metadata, handler Handler) error {
+	command, err := rw.ReadByte(stream)
+	if err != nil {
+		return E.Cause(err, "read command")
+	}
+	metadata.Destination, err = M.SocksaddrSerializer.ReadAddrPort(stream)
+	if err != nil {
+		return E.Cause(err, "read destination")
+	}
+	switch command {
+	case CommandTCP:
+		return handler.NewConnection(ctx, stream, metadata)
+	case CommandUDP:
+		return handler.NewPacketConnection(ctx, &PacketConn{stream}, metadata)
+	default:
+		return E.New("unknown command ", command)
+	}
+}
+
+func smuxConfig() *smux.Config {
+	config := smux.DefaultConfig()
+	config.KeepAliveDisabled = true
+	return config
+}

+ 313 - 0
transport/trojan/protocol.go

@@ -0,0 +1,313 @@
+package trojan
+
+import (
+	"crypto/sha256"
+	"encoding/binary"
+	"encoding/hex"
+	"io"
+	"net"
+	"os"
+
+	"github.com/sagernet/sing/common"
+	"github.com/sagernet/sing/common/buf"
+	"github.com/sagernet/sing/common/bufio"
+	E "github.com/sagernet/sing/common/exceptions"
+	M "github.com/sagernet/sing/common/metadata"
+	N "github.com/sagernet/sing/common/network"
+	"github.com/sagernet/sing/common/rw"
+)
+
+const (
+	KeyLength  = 56
+	CommandTCP = 1
+	CommandUDP = 3
+	CommandMux = 0x7f
+)
+
+var CRLF = []byte{'\r', '\n'}
+
+type ClientConn struct {
+	N.ExtendedConn
+	key           [KeyLength]byte
+	destination   M.Socksaddr
+	headerWritten bool
+}
+
+func NewClientConn(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr) *ClientConn {
+	return &ClientConn{
+		ExtendedConn: bufio.NewExtendedConn(conn),
+		key:          key,
+		destination:  destination,
+	}
+}
+
+func (c *ClientConn) Write(p []byte) (n int, err error) {
+	if c.headerWritten {
+		return c.ExtendedConn.Write(p)
+	}
+	err = ClientHandshake(c.ExtendedConn, c.key, c.destination, p)
+	if err != nil {
+		return
+	}
+	n = len(p)
+	c.headerWritten = true
+	return
+}
+
+func (c *ClientConn) WriteBuffer(buffer *buf.Buffer) error {
+	if c.headerWritten {
+		return c.ExtendedConn.WriteBuffer(buffer)
+	}
+	err := ClientHandshakeBuffer(c.ExtendedConn, c.key, c.destination, buffer)
+	if err != nil {
+		return err
+	}
+	c.headerWritten = true
+	return nil
+}
+
+func (c *ClientConn) ReadFrom(r io.Reader) (n int64, err error) {
+	if !c.headerWritten {
+		return bufio.ReadFrom0(c, r)
+	}
+	return bufio.Copy(c.ExtendedConn, r)
+}
+
+func (c *ClientConn) WriteTo(w io.Writer) (n int64, err error) {
+	return bufio.Copy(w, c.ExtendedConn)
+}
+
+func (c *ClientConn) FrontHeadroom() int {
+	if !c.headerWritten {
+		return KeyLength + 5 + M.MaxSocksaddrLength
+	}
+	return 0
+}
+
+func (c *ClientConn) Upstream() any {
+	return c.ExtendedConn
+}
+
+type ClientPacketConn struct {
+	net.Conn
+	key           [KeyLength]byte
+	headerWritten bool
+}
+
+func NewClientPacketConn(conn net.Conn, key [KeyLength]byte) *ClientPacketConn {
+	return &ClientPacketConn{
+		Conn: conn,
+		key:  key,
+	}
+}
+
+func (c *ClientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
+	return ReadPacket(c.Conn, buffer)
+}
+
+func (c *ClientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
+	if !c.headerWritten {
+		err := ClientHandshakePacket(c.Conn, c.key, destination, buffer)
+		c.headerWritten = true
+		return err
+	}
+	return WritePacket(c.Conn, buffer, destination)
+}
+
+func (c *ClientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
+	buffer := buf.With(p)
+	destination, err := c.ReadPacket(buffer)
+	if err != nil {
+		return
+	}
+	n = buffer.Len()
+	addr = destination.UDPAddr()
+	return
+}
+
+func (c *ClientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
+	return bufio.WritePacket(c, p, addr)
+}
+
+func (c *ClientPacketConn) Read(p []byte) (n int, err error) {
+	n, _, err = c.ReadFrom(p)
+	return
+}
+
+func (c *ClientPacketConn) Write(p []byte) (n int, err error) {
+	return 0, os.ErrInvalid
+}
+
+func (c *ClientPacketConn) FrontHeadroom() int {
+	if !c.headerWritten {
+		return KeyLength + 2*M.MaxSocksaddrLength + 9
+	}
+	return M.MaxSocksaddrLength + 4
+}
+
+func (c *ClientPacketConn) Upstream() any {
+	return c.Conn
+}
+
+func Key(password string) [KeyLength]byte {
+	var key [KeyLength]byte
+	hash := sha256.New224()
+	common.Must1(hash.Write([]byte(password)))
+	hex.Encode(key[:], hash.Sum(nil))
+	return key
+}
+
+func ClientHandshakeRaw(conn net.Conn, key [KeyLength]byte, command byte, destination M.Socksaddr, payload []byte) error {
+	_, err := conn.Write(key[:])
+	if err != nil {
+		return err
+	}
+	_, err = conn.Write(CRLF)
+	if err != nil {
+		return err
+	}
+	_, err = conn.Write([]byte{command})
+	if err != nil {
+		return err
+	}
+	err = M.SocksaddrSerializer.WriteAddrPort(conn, destination)
+	if err != nil {
+		return err
+	}
+	_, err = conn.Write(CRLF)
+	if err != nil {
+		return err
+	}
+	if len(payload) > 0 {
+		_, err = conn.Write(payload)
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+func ClientHandshake(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload []byte) error {
+	headerLen := KeyLength + M.SocksaddrSerializer.AddrPortLen(destination) + 5
+	var header *buf.Buffer
+	defer header.Release()
+	var writeHeader bool
+	if len(payload) > 0 && headerLen+len(payload) < 65535 {
+		buffer := buf.StackNewSize(headerLen + len(payload))
+		defer common.KeepAlive(buffer)
+		header = common.Dup(buffer)
+	} else {
+		buffer := buf.StackNewSize(headerLen)
+		defer common.KeepAlive(buffer)
+		header = common.Dup(buffer)
+		writeHeader = true
+	}
+	common.Must1(header.Write(key[:]))
+	common.Must1(header.Write(CRLF))
+	common.Must(header.WriteByte(CommandTCP))
+	common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))
+	common.Must1(header.Write(CRLF))
+	if !writeHeader {
+		common.Must1(header.Write(payload))
+	}
+
+	_, err := conn.Write(header.Bytes())
+	if err != nil {
+		return E.Cause(err, "write request")
+	}
+
+	if writeHeader {
+		_, err = conn.Write(payload)
+		if err != nil {
+			return E.Cause(err, "write payload")
+		}
+	}
+	return nil
+}
+
+func ClientHandshakeBuffer(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload *buf.Buffer) error {
+	header := buf.With(payload.ExtendHeader(KeyLength + M.SocksaddrSerializer.AddrPortLen(destination) + 5))
+	common.Must1(header.Write(key[:]))
+	common.Must1(header.Write(CRLF))
+	common.Must(header.WriteByte(CommandTCP))
+	common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))
+	common.Must1(header.Write(CRLF))
+
+	_, err := conn.Write(payload.Bytes())
+	if err != nil {
+		return E.Cause(err, "write request")
+	}
+	return nil
+}
+
+func ClientHandshakePacket(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload *buf.Buffer) error {
+	headerLen := KeyLength + 2*M.SocksaddrSerializer.AddrPortLen(destination) + 9
+	payloadLen := payload.Len()
+	var header *buf.Buffer
+	defer header.Release()
+	var writeHeader bool
+	if payload.Start() >= headerLen {
+		header = buf.With(payload.ExtendHeader(headerLen))
+	} else {
+		buffer := buf.StackNewSize(headerLen)
+		defer common.KeepAlive(buffer)
+		header = common.Dup(buffer)
+		writeHeader = true
+	}
+	common.Must1(header.Write(key[:]))
+	common.Must1(header.Write(CRLF))
+	common.Must(header.WriteByte(CommandUDP))
+	common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))
+	common.Must1(header.Write(CRLF))
+	common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))
+	common.Must(binary.Write(header, binary.BigEndian, uint16(payloadLen)))
+	common.Must1(header.Write(CRLF))
+
+	if writeHeader {
+		_, err := conn.Write(header.Bytes())
+		if err != nil {
+			return E.Cause(err, "write request")
+		}
+	}
+
+	_, err := conn.Write(payload.Bytes())
+	if err != nil {
+		return E.Cause(err, "write payload")
+	}
+	return nil
+}
+
+func ReadPacket(conn net.Conn, buffer *buf.Buffer) (M.Socksaddr, error) {
+	destination, err := M.SocksaddrSerializer.ReadAddrPort(conn)
+	if err != nil {
+		return M.Socksaddr{}, E.Cause(err, "read destination")
+	}
+
+	var length uint16
+	err = binary.Read(conn, binary.BigEndian, &length)
+	if err != nil {
+		return M.Socksaddr{}, E.Cause(err, "read chunk length")
+	}
+
+	err = rw.SkipN(conn, 2)
+	if err != nil {
+		return M.Socksaddr{}, E.Cause(err, "skip crlf")
+	}
+
+	_, err = buffer.ReadFullFrom(conn, int(length))
+	return destination, err
+}
+
+func WritePacket(conn net.Conn, buffer *buf.Buffer, destination M.Socksaddr) error {
+	defer buffer.Release()
+	bufferLen := buffer.Len()
+	header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination) + 4))
+	common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))
+	common.Must(binary.Write(header, binary.BigEndian, uint16(bufferLen)))
+	common.Must1(header.Write(CRLF))
+	_, err := conn.Write(buffer.Bytes())
+	if err != nil {
+		return E.Cause(err, "write packet")
+	}
+	return nil
+}

+ 138 - 0
transport/trojan/service.go

@@ -0,0 +1,138 @@
+package trojan
+
+import (
+	"context"
+	"net"
+
+	"github.com/sagernet/sing/common"
+	"github.com/sagernet/sing/common/auth"
+	"github.com/sagernet/sing/common/buf"
+	"github.com/sagernet/sing/common/bufio"
+	E "github.com/sagernet/sing/common/exceptions"
+	M "github.com/sagernet/sing/common/metadata"
+	N "github.com/sagernet/sing/common/network"
+	"github.com/sagernet/sing/common/rw"
+)
+
+type Handler interface {
+	N.TCPConnectionHandler
+	N.UDPConnectionHandler
+	E.Handler
+}
+
+type Service[K comparable] struct {
+	users           map[K][56]byte
+	keys            map[[56]byte]K
+	handler         Handler
+	fallbackHandler N.TCPConnectionHandler
+}
+
+func NewService[K comparable](handler Handler, fallbackHandler N.TCPConnectionHandler) *Service[K] {
+	return &Service[K]{
+		users:           make(map[K][56]byte),
+		keys:            make(map[[56]byte]K),
+		handler:         handler,
+		fallbackHandler: fallbackHandler,
+	}
+}
+
+var ErrUserExists = E.New("user already exists")
+
+func (s *Service[K]) UpdateUsers(userList []K, passwordList []string) error {
+	users := make(map[K][56]byte)
+	keys := make(map[[56]byte]K)
+	for i, user := range userList {
+		if _, loaded := users[user]; loaded {
+			return ErrUserExists
+		}
+		key := Key(passwordList[i])
+		if oldUser, loaded := keys[key]; loaded {
+			return E.Extend(ErrUserExists, "password used by ", oldUser)
+		}
+		users[user] = key
+		keys[key] = user
+	}
+	s.users = users
+	s.keys = keys
+	return nil
+}
+
+func (s *Service[K]) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
+	var key [KeyLength]byte
+	n, err := conn.Read(common.Dup(key[:]))
+	if err != nil {
+		return err
+	} else if n != KeyLength {
+		return s.fallback(ctx, conn, metadata, key[:n], E.New("bad request size"))
+	}
+
+	if user, loaded := s.keys[key]; loaded {
+		ctx = auth.ContextWithUser(ctx, user)
+	} else {
+		return s.fallback(ctx, conn, metadata, key[:], E.New("bad request"))
+	}
+
+	err = rw.SkipN(conn, 2)
+	if err != nil {
+		return E.Cause(err, "skip crlf")
+	}
+
+	command, err := rw.ReadByte(conn)
+	if err != nil {
+		return E.Cause(err, "read command")
+	}
+
+	switch command {
+	case CommandTCP, CommandUDP, CommandMux:
+	default:
+		return E.New("unknown command ", command)
+	}
+
+	// var destination M.Socksaddr
+	destination, err := M.SocksaddrSerializer.ReadAddrPort(conn)
+	if err != nil {
+		return E.Cause(err, "read destination")
+	}
+
+	err = rw.SkipN(conn, 2)
+	if err != nil {
+		return E.Cause(err, "skip crlf")
+	}
+
+	metadata.Protocol = "trojan"
+	metadata.Destination = destination
+
+	switch command {
+	case CommandTCP:
+		return s.handler.NewConnection(ctx, conn, metadata)
+	case CommandUDP:
+		return s.handler.NewPacketConnection(ctx, &PacketConn{conn}, metadata)
+	// case CommandMux:
+	default:
+		return HandleMuxConnection(ctx, conn, metadata, s.handler)
+	}
+}
+
+func (s *Service[K]) fallback(ctx context.Context, conn net.Conn, metadata M.Metadata, header []byte, err error) error {
+	if s.fallbackHandler == nil {
+		return E.Extend(err, "fallback disabled")
+	}
+	conn = bufio.NewCachedConn(conn, buf.As(header).ToOwned())
+	return s.fallbackHandler.NewConnection(ctx, conn, metadata)
+}
+
+type PacketConn struct {
+	net.Conn
+}
+
+func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
+	return ReadPacket(c.Conn, buffer)
+}
+
+func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
+	return WritePacket(c.Conn, buffer, destination)
+}
+
+func (c *PacketConn) FrontHeadroom() int {
+	return M.MaxSocksaddrLength + 4
+}