123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321 |
- package trojan
- import (
- "crypto/sha256"
- "encoding/binary"
- "encoding/hex"
- "net"
- "os"
- "sync"
- "github.com/sagernet/sing/common"
- "github.com/sagernet/sing/common/buf"
- "github.com/sagernet/sing/common/bufio"
- E "github.com/sagernet/sing/common/exceptions"
- M "github.com/sagernet/sing/common/metadata"
- N "github.com/sagernet/sing/common/network"
- "github.com/sagernet/sing/common/rw"
- )
- const (
- KeyLength = 56
- CommandTCP = 1
- CommandUDP = 3
- CommandMux = 0x7f
- )
- var CRLF = []byte{'\r', '\n'}
- var _ N.EarlyWriter = (*ClientConn)(nil)
- type ClientConn struct {
- N.ExtendedConn
- key [KeyLength]byte
- destination M.Socksaddr
- headerWritten bool
- }
- func NewClientConn(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr) *ClientConn {
- return &ClientConn{
- ExtendedConn: bufio.NewExtendedConn(conn),
- key: key,
- destination: destination,
- }
- }
- func (c *ClientConn) NeedHandshakeForWrite() bool {
- return !c.headerWritten
- }
- func (c *ClientConn) Write(p []byte) (n int, err error) {
- if c.headerWritten {
- return c.ExtendedConn.Write(p)
- }
- err = ClientHandshake(c.ExtendedConn, c.key, c.destination, p)
- if err != nil {
- return
- }
- n = len(p)
- c.headerWritten = true
- return
- }
- func (c *ClientConn) WriteBuffer(buffer *buf.Buffer) error {
- if c.headerWritten {
- return c.ExtendedConn.WriteBuffer(buffer)
- }
- err := ClientHandshakeBuffer(c.ExtendedConn, c.key, c.destination, buffer)
- if err != nil {
- return err
- }
- c.headerWritten = true
- return nil
- }
- func (c *ClientConn) FrontHeadroom() int {
- if !c.headerWritten {
- return KeyLength + 5 + M.MaxSocksaddrLength
- }
- return 0
- }
- func (c *ClientConn) Upstream() any {
- return c.ExtendedConn
- }
- func (c *ClientConn) ReaderReplaceable() bool {
- return c.headerWritten
- }
- func (c *ClientConn) WriterReplaceable() bool {
- return c.headerWritten
- }
- type ClientPacketConn struct {
- net.Conn
- access sync.Mutex
- key [KeyLength]byte
- headerWritten bool
- readWaitOptions N.ReadWaitOptions
- }
- func NewClientPacketConn(conn net.Conn, key [KeyLength]byte) *ClientPacketConn {
- return &ClientPacketConn{
- Conn: conn,
- key: key,
- }
- }
- func (c *ClientPacketConn) NeedHandshake() bool {
- return !c.headerWritten
- }
- func (c *ClientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
- return ReadPacket(c.Conn, buffer)
- }
- func (c *ClientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
- if !c.headerWritten {
- c.access.Lock()
- if c.headerWritten {
- c.access.Unlock()
- } else {
- err := ClientHandshakePacket(c.Conn, c.key, destination, buffer)
- c.headerWritten = true
- c.access.Unlock()
- return err
- }
- }
- return WritePacket(c.Conn, buffer, destination)
- }
- func (c *ClientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
- buffer := buf.With(p)
- destination, err := c.ReadPacket(buffer)
- if err != nil {
- return
- }
- n = buffer.Len()
- if destination.IsFqdn() {
- addr = destination
- } else {
- addr = destination.UDPAddr()
- }
- return
- }
- func (c *ClientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
- return bufio.WritePacket(c, p, addr)
- }
- func (c *ClientPacketConn) Read(p []byte) (n int, err error) {
- n, _, err = c.ReadFrom(p)
- return
- }
- func (c *ClientPacketConn) Write(p []byte) (n int, err error) {
- return 0, os.ErrInvalid
- }
- func (c *ClientPacketConn) FrontHeadroom() int {
- if !c.headerWritten {
- return KeyLength + 2*M.MaxSocksaddrLength + 9
- }
- return M.MaxSocksaddrLength + 4
- }
- func (c *ClientPacketConn) Upstream() any {
- return c.Conn
- }
- func Key(password string) [KeyLength]byte {
- var key [KeyLength]byte
- hash := sha256.New224()
- common.Must1(hash.Write([]byte(password)))
- hex.Encode(key[:], hash.Sum(nil))
- return key
- }
- func ClientHandshakeRaw(conn net.Conn, key [KeyLength]byte, command byte, destination M.Socksaddr, payload []byte) error {
- _, err := conn.Write(key[:])
- if err != nil {
- return err
- }
- _, err = conn.Write(CRLF)
- if err != nil {
- return err
- }
- _, err = conn.Write([]byte{command})
- if err != nil {
- return err
- }
- err = M.SocksaddrSerializer.WriteAddrPort(conn, destination)
- if err != nil {
- return err
- }
- _, err = conn.Write(CRLF)
- if err != nil {
- return err
- }
- if len(payload) > 0 {
- _, err = conn.Write(payload)
- if err != nil {
- return err
- }
- }
- return nil
- }
- func ClientHandshake(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload []byte) error {
- headerLen := KeyLength + M.SocksaddrSerializer.AddrPortLen(destination) + 5
- header := buf.NewSize(headerLen + len(payload))
- defer header.Release()
- common.Must1(header.Write(key[:]))
- common.Must1(header.Write(CRLF))
- common.Must(header.WriteByte(CommandTCP))
- err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
- if err != nil {
- return err
- }
- common.Must1(header.Write(CRLF))
- common.Must1(header.Write(payload))
- _, err = conn.Write(header.Bytes())
- if err != nil {
- return E.Cause(err, "write request")
- }
- return nil
- }
- func ClientHandshakeBuffer(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload *buf.Buffer) error {
- header := buf.With(payload.ExtendHeader(KeyLength + M.SocksaddrSerializer.AddrPortLen(destination) + 5))
- common.Must1(header.Write(key[:]))
- common.Must1(header.Write(CRLF))
- common.Must(header.WriteByte(CommandTCP))
- err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
- if err != nil {
- return err
- }
- common.Must1(header.Write(CRLF))
- _, err = conn.Write(payload.Bytes())
- if err != nil {
- return E.Cause(err, "write request")
- }
- return nil
- }
- func ClientHandshakePacket(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload *buf.Buffer) error {
- headerLen := KeyLength + 2*M.SocksaddrSerializer.AddrPortLen(destination) + 9
- payloadLen := payload.Len()
- var header *buf.Buffer
- var writeHeader bool
- if payload.Start() >= headerLen {
- header = buf.With(payload.ExtendHeader(headerLen))
- } else {
- header = buf.NewSize(headerLen)
- defer header.Release()
- writeHeader = true
- }
- common.Must1(header.Write(key[:]))
- common.Must1(header.Write(CRLF))
- common.Must(header.WriteByte(CommandUDP))
- err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
- if err != nil {
- return err
- }
- common.Must1(header.Write(CRLF))
- common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))
- common.Must(binary.Write(header, binary.BigEndian, uint16(payloadLen)))
- common.Must1(header.Write(CRLF))
- if writeHeader {
- _, err := conn.Write(header.Bytes())
- if err != nil {
- return E.Cause(err, "write request")
- }
- }
- _, err = conn.Write(payload.Bytes())
- if err != nil {
- return E.Cause(err, "write payload")
- }
- return nil
- }
- func ReadPacket(conn net.Conn, buffer *buf.Buffer) (M.Socksaddr, error) {
- destination, err := M.SocksaddrSerializer.ReadAddrPort(conn)
- if err != nil {
- return M.Socksaddr{}, E.Cause(err, "read destination")
- }
- var length uint16
- err = binary.Read(conn, binary.BigEndian, &length)
- if err != nil {
- return M.Socksaddr{}, E.Cause(err, "read chunk length")
- }
- err = rw.SkipN(conn, 2)
- if err != nil {
- return M.Socksaddr{}, E.Cause(err, "skip crlf")
- }
- _, err = buffer.ReadFullFrom(conn, int(length))
- return destination, err
- }
- func WritePacket(conn net.Conn, buffer *buf.Buffer, destination M.Socksaddr) error {
- defer buffer.Release()
- bufferLen := buffer.Len()
- header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination) + 4))
- err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
- if err != nil {
- return err
- }
- common.Must(binary.Write(header, binary.BigEndian, uint16(bufferLen)))
- common.Must1(header.Write(CRLF))
- _, err = conn.Write(buffer.Bytes())
- if err != nil {
- return E.Cause(err, "write packet")
- }
- return nil
- }
|