service.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. package trojan
  2. import (
  3. "context"
  4. "net"
  5. "github.com/sagernet/sing/common/auth"
  6. "github.com/sagernet/sing/common/buf"
  7. "github.com/sagernet/sing/common/bufio"
  8. E "github.com/sagernet/sing/common/exceptions"
  9. M "github.com/sagernet/sing/common/metadata"
  10. N "github.com/sagernet/sing/common/network"
  11. "github.com/sagernet/sing/common/rw"
  12. )
  13. type Handler interface {
  14. N.TCPConnectionHandler
  15. N.UDPConnectionHandler
  16. E.Handler
  17. }
  18. type Service[K comparable] struct {
  19. users map[K][56]byte
  20. keys map[[56]byte]K
  21. handler Handler
  22. fallbackHandler N.TCPConnectionHandler
  23. }
  24. func NewService[K comparable](handler Handler, fallbackHandler N.TCPConnectionHandler) *Service[K] {
  25. return &Service[K]{
  26. users: make(map[K][56]byte),
  27. keys: make(map[[56]byte]K),
  28. handler: handler,
  29. fallbackHandler: fallbackHandler,
  30. }
  31. }
  32. var ErrUserExists = E.New("user already exists")
  33. func (s *Service[K]) UpdateUsers(userList []K, passwordList []string) error {
  34. users := make(map[K][56]byte)
  35. keys := make(map[[56]byte]K)
  36. for i, user := range userList {
  37. if _, loaded := users[user]; loaded {
  38. return ErrUserExists
  39. }
  40. key := Key(passwordList[i])
  41. if oldUser, loaded := keys[key]; loaded {
  42. return E.Extend(ErrUserExists, "password used by ", oldUser)
  43. }
  44. users[user] = key
  45. keys[key] = user
  46. }
  47. s.users = users
  48. s.keys = keys
  49. return nil
  50. }
  51. func (s *Service[K]) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
  52. var key [KeyLength]byte
  53. n, err := conn.Read(key[:])
  54. if err != nil {
  55. return err
  56. } else if n != KeyLength {
  57. return s.fallback(ctx, conn, metadata, key[:n], E.New("bad request size"))
  58. }
  59. if user, loaded := s.keys[key]; loaded {
  60. ctx = auth.ContextWithUser(ctx, user)
  61. } else {
  62. return s.fallback(ctx, conn, metadata, key[:], E.New("bad request"))
  63. }
  64. err = rw.SkipN(conn, 2)
  65. if err != nil {
  66. return E.Cause(err, "skip crlf")
  67. }
  68. command, err := rw.ReadByte(conn)
  69. if err != nil {
  70. return E.Cause(err, "read command")
  71. }
  72. switch command {
  73. case CommandTCP, CommandUDP, CommandMux:
  74. default:
  75. return E.New("unknown command ", command)
  76. }
  77. // var destination M.Socksaddr
  78. destination, err := M.SocksaddrSerializer.ReadAddrPort(conn)
  79. if err != nil {
  80. return E.Cause(err, "read destination")
  81. }
  82. err = rw.SkipN(conn, 2)
  83. if err != nil {
  84. return E.Cause(err, "skip crlf")
  85. }
  86. metadata.Protocol = "trojan"
  87. metadata.Destination = destination
  88. switch command {
  89. case CommandTCP:
  90. return s.handler.NewConnection(ctx, conn, metadata)
  91. case CommandUDP:
  92. return s.handler.NewPacketConnection(ctx, &PacketConn{Conn: conn}, metadata)
  93. // case CommandMux:
  94. default:
  95. return HandleMuxConnection(ctx, conn, metadata, s.handler)
  96. }
  97. }
  98. func (s *Service[K]) fallback(ctx context.Context, conn net.Conn, metadata M.Metadata, header []byte, err error) error {
  99. if s.fallbackHandler == nil {
  100. return E.Extend(err, "fallback disabled")
  101. }
  102. conn = bufio.NewCachedConn(conn, buf.As(header).ToOwned())
  103. return s.fallbackHandler.NewConnection(ctx, conn, metadata)
  104. }
  105. type PacketConn struct {
  106. net.Conn
  107. readWaitOptions N.ReadWaitOptions
  108. }
  109. func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
  110. return ReadPacket(c.Conn, buffer)
  111. }
  112. func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
  113. return WritePacket(c.Conn, buffer, destination)
  114. }
  115. func (c *PacketConn) FrontHeadroom() int {
  116. return M.MaxSocksaddrLength + 4
  117. }
  118. func (c *PacketConn) NeedAdditionalReadDeadline() bool {
  119. return true
  120. }
  121. func (c *PacketConn) Upstream() any {
  122. return c.Conn
  123. }