123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149 |
- package trojan
- import (
- "context"
- "encoding/binary"
- "net"
- "github.com/sagernet/sing/common/auth"
- "github.com/sagernet/sing/common/buf"
- "github.com/sagernet/sing/common/bufio"
- E "github.com/sagernet/sing/common/exceptions"
- "github.com/sagernet/sing/common/logger"
- M "github.com/sagernet/sing/common/metadata"
- N "github.com/sagernet/sing/common/network"
- "github.com/sagernet/sing/common/rw"
- )
- type Handler interface {
- N.TCPConnectionHandlerEx
- N.UDPConnectionHandlerEx
- }
- type Service[K comparable] struct {
- users map[K][56]byte
- keys map[[56]byte]K
- handler Handler
- fallbackHandler N.TCPConnectionHandlerEx
- logger logger.ContextLogger
- }
- func NewService[K comparable](handler Handler, fallbackHandler N.TCPConnectionHandlerEx, logger logger.ContextLogger) *Service[K] {
- return &Service[K]{
- users: make(map[K][56]byte),
- keys: make(map[[56]byte]K),
- handler: handler,
- fallbackHandler: fallbackHandler,
- logger: logger,
- }
- }
- var ErrUserExists = E.New("user already exists")
- func (s *Service[K]) UpdateUsers(userList []K, passwordList []string) error {
- users := make(map[K][56]byte)
- keys := make(map[[56]byte]K)
- for i, user := range userList {
- if _, loaded := users[user]; loaded {
- return ErrUserExists
- }
- key := Key(passwordList[i])
- if oldUser, loaded := keys[key]; loaded {
- return E.Extend(ErrUserExists, "password used by ", oldUser)
- }
- users[user] = key
- keys[key] = user
- }
- s.users = users
- s.keys = keys
- return nil
- }
- func (s *Service[K]) NewConnection(ctx context.Context, conn net.Conn, source M.Socksaddr, onClose N.CloseHandlerFunc) error {
- var key [KeyLength]byte
- n, err := conn.Read(key[:])
- if err != nil {
- return err
- } else if n != KeyLength {
- return s.fallback(ctx, conn, source, key[:n], E.New("bad request size"), onClose)
- }
- if user, loaded := s.keys[key]; loaded {
- ctx = auth.ContextWithUser(ctx, user)
- } else {
- return s.fallback(ctx, conn, source, key[:], E.New("bad request"), onClose)
- }
- err = rw.SkipN(conn, 2)
- if err != nil {
- return E.Cause(err, "skip crlf")
- }
- var command byte
- err = binary.Read(conn, binary.BigEndian, &command)
- if err != nil {
- return E.Cause(err, "read command")
- }
- switch command {
- case CommandTCP, CommandUDP, CommandMux:
- default:
- return E.New("unknown command ", command)
- }
- // var destination M.Socksaddr
- destination, err := M.SocksaddrSerializer.ReadAddrPort(conn)
- if err != nil {
- return E.Cause(err, "read destination")
- }
- err = rw.SkipN(conn, 2)
- if err != nil {
- return E.Cause(err, "skip crlf")
- }
- switch command {
- case CommandTCP:
- s.handler.NewConnectionEx(ctx, conn, source, destination, onClose)
- case CommandUDP:
- s.handler.NewPacketConnectionEx(ctx, &PacketConn{Conn: conn}, source, destination, onClose)
- // case CommandMux:
- default:
- return HandleMuxConnection(ctx, conn, source, s.handler, s.logger, onClose)
- }
- return nil
- }
- func (s *Service[K]) fallback(ctx context.Context, conn net.Conn, source M.Socksaddr, header []byte, err error, onClose N.CloseHandlerFunc) error {
- if s.fallbackHandler == nil {
- return E.Extend(err, "fallback disabled")
- }
- conn = bufio.NewCachedConn(conn, buf.As(header).ToOwned())
- s.fallbackHandler.NewConnectionEx(ctx, conn, source, M.Socksaddr{}, onClose)
- return nil
- }
- type PacketConn struct {
- net.Conn
- readWaitOptions N.ReadWaitOptions
- }
- func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
- return ReadPacket(c.Conn, buffer)
- }
- func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
- return WritePacket(c.Conn, buffer, destination)
- }
- func (c *PacketConn) FrontHeadroom() int {
- return M.MaxSocksaddrLength + 4
- }
- func (c *PacketConn) NeedAdditionalReadDeadline() bool {
- return true
- }
- func (c *PacketConn) Upstream() any {
- return c.Conn
- }
|