mux.go 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. package trojan
  2. import (
  3. std_bufio "bufio"
  4. "context"
  5. "net"
  6. "os"
  7. "github.com/sagernet/sing/common/buf"
  8. "github.com/sagernet/sing/common/bufio"
  9. E "github.com/sagernet/sing/common/exceptions"
  10. "github.com/sagernet/sing/common/logger"
  11. M "github.com/sagernet/sing/common/metadata"
  12. N "github.com/sagernet/sing/common/network"
  13. "github.com/sagernet/sing/common/task"
  14. "github.com/sagernet/smux"
  15. )
  16. func HandleMuxConnection(ctx context.Context, conn net.Conn, source M.Socksaddr, handler Handler, logger logger.ContextLogger, onClose N.CloseHandlerFunc) error {
  17. session, err := smux.Server(conn, smuxConfig())
  18. if err != nil {
  19. return err
  20. }
  21. var group task.Group
  22. group.Append0(func(_ context.Context) error {
  23. var stream net.Conn
  24. for {
  25. stream, err = session.AcceptStream()
  26. if err != nil {
  27. return err
  28. }
  29. go newMuxConnection(ctx, stream, source, handler, logger)
  30. }
  31. })
  32. group.Cleanup(func() {
  33. session.Close()
  34. if onClose != nil {
  35. onClose(os.ErrClosed)
  36. }
  37. })
  38. return group.Run(ctx)
  39. }
  40. func newMuxConnection(ctx context.Context, conn net.Conn, source M.Socksaddr, handler Handler, logger logger.ContextLogger) {
  41. err := newMuxConnection0(ctx, conn, source, handler)
  42. if err != nil {
  43. logger.ErrorContext(ctx, E.Cause(err, "process trojan-go multiplex connection"))
  44. }
  45. }
  46. func newMuxConnection0(ctx context.Context, conn net.Conn, source M.Socksaddr, handler Handler) error {
  47. reader := std_bufio.NewReader(conn)
  48. command, err := reader.ReadByte()
  49. if err != nil {
  50. return E.Cause(err, "read command")
  51. }
  52. destination, err := M.SocksaddrSerializer.ReadAddrPort(reader)
  53. if err != nil {
  54. return E.Cause(err, "read destination")
  55. }
  56. if reader.Buffered() > 0 {
  57. buffer := buf.NewSize(reader.Buffered())
  58. _, err = buffer.ReadFullFrom(reader, buffer.Len())
  59. if err != nil {
  60. return err
  61. }
  62. conn = bufio.NewCachedConn(conn, buffer)
  63. }
  64. switch command {
  65. case CommandTCP:
  66. handler.NewConnectionEx(ctx, conn, source, destination, nil)
  67. case CommandUDP:
  68. handler.NewPacketConnectionEx(ctx, &PacketConn{Conn: conn}, source, destination, nil)
  69. default:
  70. return E.New("unknown command ", command)
  71. }
  72. return nil
  73. }
  74. func smuxConfig() *smux.Config {
  75. config := smux.DefaultConfig()
  76. config.KeepAliveDisabled = true
  77. return config
  78. }