protocol.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. package mux
  2. import (
  3. "encoding/binary"
  4. "io"
  5. "net"
  6. C "github.com/sagernet/sing-box/constant"
  7. "github.com/sagernet/sing/common"
  8. "github.com/sagernet/sing/common/buf"
  9. E "github.com/sagernet/sing/common/exceptions"
  10. M "github.com/sagernet/sing/common/metadata"
  11. N "github.com/sagernet/sing/common/network"
  12. "github.com/sagernet/sing/common/rw"
  13. "github.com/sagernet/smux"
  14. "github.com/hashicorp/yamux"
  15. )
  16. var Destination = M.Socksaddr{
  17. Fqdn: "sp.mux.sing-box.arpa",
  18. Port: 444,
  19. }
  20. const (
  21. ProtocolSMux Protocol = iota
  22. ProtocolYAMux
  23. )
  24. type Protocol byte
  25. func ParseProtocol(name string) (Protocol, error) {
  26. switch name {
  27. case "", "smux":
  28. return ProtocolSMux, nil
  29. case "yamux":
  30. return ProtocolYAMux, nil
  31. default:
  32. return ProtocolYAMux, E.New("unknown multiplex protocol: ", name)
  33. }
  34. }
  35. func (p Protocol) newServer(conn net.Conn) (abstractSession, error) {
  36. switch p {
  37. case ProtocolSMux:
  38. session, err := smux.Server(conn, smuxConfig())
  39. if err != nil {
  40. return nil, err
  41. }
  42. return &smuxSession{session}, nil
  43. case ProtocolYAMux:
  44. return yamux.Server(conn, yaMuxConfig())
  45. default:
  46. panic("unknown protocol")
  47. }
  48. }
  49. func (p Protocol) newClient(conn net.Conn) (abstractSession, error) {
  50. switch p {
  51. case ProtocolSMux:
  52. session, err := smux.Client(conn, smuxConfig())
  53. if err != nil {
  54. return nil, err
  55. }
  56. return &smuxSession{session}, nil
  57. case ProtocolYAMux:
  58. return yamux.Client(conn, yaMuxConfig())
  59. default:
  60. panic("unknown protocol")
  61. }
  62. }
  63. func smuxConfig() *smux.Config {
  64. config := smux.DefaultConfig()
  65. config.KeepAliveDisabled = true
  66. return config
  67. }
  68. func yaMuxConfig() *yamux.Config {
  69. config := yamux.DefaultConfig()
  70. config.LogOutput = io.Discard
  71. config.StreamCloseTimeout = C.TCPTimeout
  72. config.StreamOpenTimeout = C.TCPTimeout
  73. return config
  74. }
  75. func (p Protocol) String() string {
  76. switch p {
  77. case ProtocolSMux:
  78. return "smux"
  79. case ProtocolYAMux:
  80. return "yamux"
  81. default:
  82. return "unknown"
  83. }
  84. }
  85. const (
  86. version0 = 0
  87. )
  88. type Request struct {
  89. Protocol Protocol
  90. }
  91. func ReadRequest(reader io.Reader) (*Request, error) {
  92. version, err := rw.ReadByte(reader)
  93. if err != nil {
  94. return nil, err
  95. }
  96. if version != version0 {
  97. return nil, E.New("unsupported version: ", version)
  98. }
  99. protocol, err := rw.ReadByte(reader)
  100. if err != nil {
  101. return nil, err
  102. }
  103. if protocol > byte(ProtocolYAMux) {
  104. return nil, E.New("unsupported protocol: ", protocol)
  105. }
  106. return &Request{Protocol: Protocol(protocol)}, nil
  107. }
  108. func EncodeRequest(buffer *buf.Buffer, request Request) {
  109. buffer.WriteByte(version0)
  110. buffer.WriteByte(byte(request.Protocol))
  111. }
  112. const (
  113. flagUDP = 1
  114. flagAddr = 2
  115. statusSuccess = 0
  116. statusError = 1
  117. )
  118. type StreamRequest struct {
  119. Network string
  120. Destination M.Socksaddr
  121. PacketAddr bool
  122. }
  123. func ReadStreamRequest(reader io.Reader) (*StreamRequest, error) {
  124. var flags uint16
  125. err := binary.Read(reader, binary.BigEndian, &flags)
  126. if err != nil {
  127. return nil, err
  128. }
  129. destination, err := M.SocksaddrSerializer.ReadAddrPort(reader)
  130. if err != nil {
  131. return nil, err
  132. }
  133. var network string
  134. var udpAddr bool
  135. if flags&flagUDP == 0 {
  136. network = N.NetworkTCP
  137. } else {
  138. network = N.NetworkUDP
  139. udpAddr = flags&flagAddr != 0
  140. }
  141. return &StreamRequest{network, destination, udpAddr}, nil
  142. }
  143. func requestLen(request StreamRequest) int {
  144. var rLen int
  145. rLen += 1 // version
  146. rLen += 2 // flags
  147. rLen += M.SocksaddrSerializer.AddrPortLen(request.Destination)
  148. return rLen
  149. }
  150. func EncodeStreamRequest(request StreamRequest, buffer *buf.Buffer) {
  151. destination := request.Destination
  152. var flags uint16
  153. if request.Network == N.NetworkUDP {
  154. flags |= flagUDP
  155. }
  156. if request.PacketAddr {
  157. flags |= flagAddr
  158. if !destination.IsValid() {
  159. destination = Destination
  160. }
  161. }
  162. common.Must(
  163. binary.Write(buffer, binary.BigEndian, flags),
  164. M.SocksaddrSerializer.WriteAddrPort(buffer, destination),
  165. )
  166. }
  167. type StreamResponse struct {
  168. Status uint8
  169. Message string
  170. }
  171. func ReadStreamResponse(reader io.Reader) (*StreamResponse, error) {
  172. var response StreamResponse
  173. status, err := rw.ReadByte(reader)
  174. if err != nil {
  175. return nil, err
  176. }
  177. response.Status = status
  178. if status == statusError {
  179. response.Message, err = rw.ReadVString(reader)
  180. if err != nil {
  181. return nil, err
  182. }
  183. }
  184. return &response, nil
  185. }
  186. type wrapStream struct {
  187. net.Conn
  188. }
  189. func (w *wrapStream) Read(p []byte) (n int, err error) {
  190. n, err = w.Conn.Read(p)
  191. err = wrapError(err)
  192. return
  193. }
  194. func (w *wrapStream) Write(p []byte) (n int, err error) {
  195. n, err = w.Conn.Write(p)
  196. err = wrapError(err)
  197. return
  198. }
  199. func (w *wrapStream) WriteIsThreadUnsafe() {
  200. }
  201. func (w *wrapStream) Upstream() any {
  202. return w.Conn
  203. }
  204. func wrapError(err error) error {
  205. switch err {
  206. case yamux.ErrStreamClosed:
  207. return io.EOF
  208. default:
  209. return err
  210. }
  211. }