1
0

service.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. package trojan
  2. import (
  3. "context"
  4. "encoding/binary"
  5. "net"
  6. "github.com/sagernet/sing/common/auth"
  7. "github.com/sagernet/sing/common/buf"
  8. "github.com/sagernet/sing/common/bufio"
  9. E "github.com/sagernet/sing/common/exceptions"
  10. "github.com/sagernet/sing/common/logger"
  11. M "github.com/sagernet/sing/common/metadata"
  12. N "github.com/sagernet/sing/common/network"
  13. "github.com/sagernet/sing/common/rw"
  14. )
  15. type Handler interface {
  16. N.TCPConnectionHandlerEx
  17. N.UDPConnectionHandlerEx
  18. }
  19. type Service[K comparable] struct {
  20. users map[K][56]byte
  21. keys map[[56]byte]K
  22. handler Handler
  23. fallbackHandler N.TCPConnectionHandlerEx
  24. logger logger.ContextLogger
  25. }
  26. func NewService[K comparable](handler Handler, fallbackHandler N.TCPConnectionHandlerEx, logger logger.ContextLogger) *Service[K] {
  27. return &Service[K]{
  28. users: make(map[K][56]byte),
  29. keys: make(map[[56]byte]K),
  30. handler: handler,
  31. fallbackHandler: fallbackHandler,
  32. logger: logger,
  33. }
  34. }
  35. var ErrUserExists = E.New("user already exists")
  36. func (s *Service[K]) UpdateUsers(userList []K, passwordList []string) error {
  37. users := make(map[K][56]byte)
  38. keys := make(map[[56]byte]K)
  39. for i, user := range userList {
  40. if _, loaded := users[user]; loaded {
  41. return ErrUserExists
  42. }
  43. key := Key(passwordList[i])
  44. if oldUser, loaded := keys[key]; loaded {
  45. return E.Extend(ErrUserExists, "password used by ", oldUser)
  46. }
  47. users[user] = key
  48. keys[key] = user
  49. }
  50. s.users = users
  51. s.keys = keys
  52. return nil
  53. }
  54. func (s *Service[K]) NewConnection(ctx context.Context, conn net.Conn, source M.Socksaddr, onClose N.CloseHandlerFunc) error {
  55. var key [KeyLength]byte
  56. n, err := conn.Read(key[:])
  57. if err != nil {
  58. return err
  59. } else if n != KeyLength {
  60. return s.fallback(ctx, conn, source, key[:n], E.New("bad request size"), onClose)
  61. }
  62. if user, loaded := s.keys[key]; loaded {
  63. ctx = auth.ContextWithUser(ctx, user)
  64. } else {
  65. return s.fallback(ctx, conn, source, key[:], E.New("bad request"), onClose)
  66. }
  67. err = rw.SkipN(conn, 2)
  68. if err != nil {
  69. return E.Cause(err, "skip crlf")
  70. }
  71. var command byte
  72. err = binary.Read(conn, binary.BigEndian, &command)
  73. if err != nil {
  74. return E.Cause(err, "read command")
  75. }
  76. switch command {
  77. case CommandTCP, CommandUDP, CommandMux:
  78. default:
  79. return E.New("unknown command ", command)
  80. }
  81. // var destination M.Socksaddr
  82. destination, err := M.SocksaddrSerializer.ReadAddrPort(conn)
  83. if err != nil {
  84. return E.Cause(err, "read destination")
  85. }
  86. err = rw.SkipN(conn, 2)
  87. if err != nil {
  88. return E.Cause(err, "skip crlf")
  89. }
  90. switch command {
  91. case CommandTCP:
  92. s.handler.NewConnectionEx(ctx, conn, source, destination, onClose)
  93. case CommandUDP:
  94. s.handler.NewPacketConnectionEx(ctx, &PacketConn{Conn: conn}, source, destination, onClose)
  95. // case CommandMux:
  96. default:
  97. return HandleMuxConnection(ctx, conn, source, s.handler, s.logger, onClose)
  98. }
  99. return nil
  100. }
  101. func (s *Service[K]) fallback(ctx context.Context, conn net.Conn, source M.Socksaddr, header []byte, err error, onClose N.CloseHandlerFunc) error {
  102. if s.fallbackHandler == nil {
  103. return E.Extend(err, "fallback disabled")
  104. }
  105. conn = bufio.NewCachedConn(conn, buf.As(header).ToOwned())
  106. s.fallbackHandler.NewConnectionEx(ctx, conn, source, M.Socksaddr{}, onClose)
  107. return nil
  108. }
  109. type PacketConn struct {
  110. net.Conn
  111. readWaitOptions N.ReadWaitOptions
  112. }
  113. func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
  114. return ReadPacket(c.Conn, buffer)
  115. }
  116. func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
  117. return WritePacket(c.Conn, buffer, destination)
  118. }
  119. func (c *PacketConn) FrontHeadroom() int {
  120. return M.MaxSocksaddrLength + 4
  121. }
  122. func (c *PacketConn) NeedAdditionalReadDeadline() bool {
  123. return true
  124. }
  125. func (c *PacketConn) Upstream() any {
  126. return c.Conn
  127. }