hysteria.go 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. //go:build with_quic
  2. package inbound
  3. import (
  4. "bytes"
  5. "context"
  6. "sync"
  7. "github.com/sagernet/quic-go"
  8. "github.com/sagernet/quic-go/congestion"
  9. "github.com/sagernet/sing-box/adapter"
  10. C "github.com/sagernet/sing-box/constant"
  11. "github.com/sagernet/sing-box/log"
  12. "github.com/sagernet/sing-box/option"
  13. "github.com/sagernet/sing-box/transport/hysteria"
  14. "github.com/sagernet/sing-dns"
  15. "github.com/sagernet/sing/common"
  16. E "github.com/sagernet/sing/common/exceptions"
  17. M "github.com/sagernet/sing/common/metadata"
  18. N "github.com/sagernet/sing/common/network"
  19. )
  20. var _ adapter.Inbound = (*Hysteria)(nil)
  21. type Hysteria struct {
  22. myInboundAdapter
  23. quicConfig *quic.Config
  24. tlsConfig *TLSConfig
  25. authKey []byte
  26. xplusKey []byte
  27. sendBPS uint64
  28. recvBPS uint64
  29. listener quic.Listener
  30. udpAccess sync.RWMutex
  31. udpSessionId uint32
  32. udpSessions map[uint32]chan *hysteria.UDPMessage
  33. udpDefragger hysteria.Defragger
  34. }
  35. func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HysteriaInboundOptions) (*Hysteria, error) {
  36. quicConfig := &quic.Config{
  37. InitialStreamReceiveWindow: options.ReceiveWindowConn,
  38. MaxStreamReceiveWindow: options.ReceiveWindowConn,
  39. InitialConnectionReceiveWindow: options.ReceiveWindowClient,
  40. MaxConnectionReceiveWindow: options.ReceiveWindowClient,
  41. MaxIncomingStreams: int64(options.MaxConnClient),
  42. KeepAlivePeriod: hysteria.KeepAlivePeriod,
  43. DisablePathMTUDiscovery: options.DisableMTUDiscovery || !(C.IsLinux || C.IsWindows),
  44. EnableDatagrams: true,
  45. }
  46. if options.ReceiveWindowConn == 0 {
  47. quicConfig.InitialStreamReceiveWindow = hysteria.DefaultStreamReceiveWindow
  48. quicConfig.MaxStreamReceiveWindow = hysteria.DefaultStreamReceiveWindow
  49. }
  50. if options.ReceiveWindowClient == 0 {
  51. quicConfig.InitialConnectionReceiveWindow = hysteria.DefaultConnectionReceiveWindow
  52. quicConfig.MaxConnectionReceiveWindow = hysteria.DefaultConnectionReceiveWindow
  53. }
  54. if quicConfig.MaxIncomingStreams == 0 {
  55. quicConfig.MaxIncomingStreams = hysteria.DefaultMaxIncomingStreams
  56. }
  57. var auth []byte
  58. if len(options.Auth) > 0 {
  59. auth = options.Auth
  60. } else {
  61. auth = []byte(options.AuthString)
  62. }
  63. var xplus []byte
  64. if options.Obfs != "" {
  65. xplus = []byte(options.Obfs)
  66. }
  67. var up, down uint64
  68. if len(options.Up) > 0 {
  69. up = hysteria.StringToBps(options.Up)
  70. if up == 0 {
  71. return nil, E.New("invalid up speed format: ", options.Up)
  72. }
  73. } else {
  74. up = uint64(options.UpMbps) * hysteria.MbpsToBps
  75. }
  76. if len(options.Down) > 0 {
  77. down = hysteria.StringToBps(options.Down)
  78. if down == 0 {
  79. return nil, E.New("invalid down speed format: ", options.Down)
  80. }
  81. } else {
  82. down = uint64(options.DownMbps) * hysteria.MbpsToBps
  83. }
  84. if up < hysteria.MinSpeedBPS {
  85. return nil, E.New("invalid up speed")
  86. }
  87. if down < hysteria.MinSpeedBPS {
  88. return nil, E.New("invalid down speed")
  89. }
  90. inbound := &Hysteria{
  91. myInboundAdapter: myInboundAdapter{
  92. protocol: C.TypeHysteria,
  93. network: []string{N.NetworkUDP},
  94. ctx: ctx,
  95. router: router,
  96. logger: logger,
  97. tag: tag,
  98. listenOptions: options.ListenOptions,
  99. },
  100. quicConfig: quicConfig,
  101. authKey: auth,
  102. xplusKey: xplus,
  103. sendBPS: up,
  104. recvBPS: down,
  105. udpSessions: make(map[uint32]chan *hysteria.UDPMessage),
  106. }
  107. if options.TLS == nil || !options.TLS.Enabled {
  108. return nil, C.ErrTLSRequired
  109. }
  110. if len(options.TLS.ALPN) == 0 {
  111. options.TLS.ALPN = []string{hysteria.DefaultALPN}
  112. }
  113. tlsConfig, err := NewTLSConfig(ctx, logger, common.PtrValueOrDefault(options.TLS))
  114. if err != nil {
  115. return nil, err
  116. }
  117. inbound.tlsConfig = tlsConfig
  118. return inbound, nil
  119. }
  120. func (h *Hysteria) Start() error {
  121. packetConn, err := h.myInboundAdapter.ListenUDP()
  122. if err != nil {
  123. return err
  124. }
  125. if len(h.xplusKey) > 0 {
  126. packetConn = hysteria.NewXPlusPacketConn(packetConn, h.xplusKey)
  127. packetConn = &hysteria.PacketConnWrapper{PacketConn: packetConn}
  128. }
  129. err = h.tlsConfig.Start()
  130. if err != nil {
  131. return err
  132. }
  133. listener, err := quic.Listen(packetConn, h.tlsConfig.Config(), h.quicConfig)
  134. if err != nil {
  135. return err
  136. }
  137. h.listener = listener
  138. h.logger.Info("udp server started at ", listener.Addr())
  139. go h.acceptLoop()
  140. return nil
  141. }
  142. func (h *Hysteria) acceptLoop() {
  143. for {
  144. ctx := log.ContextWithNewID(h.ctx)
  145. conn, err := h.listener.Accept(ctx)
  146. if err != nil {
  147. return
  148. }
  149. h.logger.InfoContext(ctx, "inbound connection from ", conn.RemoteAddr())
  150. go func() {
  151. hErr := h.accept(ctx, conn)
  152. if hErr != nil {
  153. conn.CloseWithError(0, "")
  154. NewError(h.logger, ctx, E.Cause(hErr, "process connection from ", conn.RemoteAddr()))
  155. }
  156. }()
  157. }
  158. }
  159. func (h *Hysteria) accept(ctx context.Context, conn quic.Connection) error {
  160. controlStream, err := conn.AcceptStream(ctx)
  161. if err != nil {
  162. return err
  163. }
  164. clientHello, err := hysteria.ReadClientHello(controlStream)
  165. if err != nil {
  166. return err
  167. }
  168. if !bytes.Equal(clientHello.Auth, h.authKey) {
  169. err = hysteria.WriteServerHello(controlStream, hysteria.ServerHello{
  170. Message: "wrong password",
  171. })
  172. return E.Errors(E.New("wrong password: ", string(clientHello.Auth)), err)
  173. }
  174. if clientHello.SendBPS == 0 || clientHello.RecvBPS == 0 {
  175. return E.New("invalid rate from client")
  176. }
  177. serverSendBPS, serverRecvBPS := clientHello.RecvBPS, clientHello.SendBPS
  178. if h.sendBPS > 0 && serverSendBPS > h.sendBPS {
  179. serverSendBPS = h.sendBPS
  180. }
  181. if h.recvBPS > 0 && serverRecvBPS > h.recvBPS {
  182. serverRecvBPS = h.recvBPS
  183. }
  184. err = hysteria.WriteServerHello(controlStream, hysteria.ServerHello{
  185. OK: true,
  186. SendBPS: serverSendBPS,
  187. RecvBPS: serverRecvBPS,
  188. })
  189. if err != nil {
  190. return err
  191. }
  192. conn.SetCongestionControl(hysteria.NewBrutalSender(congestion.ByteCount(serverSendBPS)))
  193. go h.udpRecvLoop(conn)
  194. for {
  195. var stream quic.Stream
  196. stream, err = conn.AcceptStream(ctx)
  197. if err != nil {
  198. return err
  199. }
  200. go func() {
  201. hErr := h.acceptStream(ctx, conn /*&hysteria.StreamWrapper{Stream: stream}*/, stream)
  202. if hErr != nil {
  203. stream.Close()
  204. NewError(h.logger, ctx, E.Cause(hErr, "process stream from ", conn.RemoteAddr()))
  205. }
  206. }()
  207. }
  208. }
  209. func (h *Hysteria) udpRecvLoop(conn quic.Connection) {
  210. for {
  211. packet, err := conn.ReceiveMessage()
  212. if err != nil {
  213. return
  214. }
  215. message, err := hysteria.ParseUDPMessage(packet)
  216. if err != nil {
  217. h.logger.Error("parse udp message: ", err)
  218. continue
  219. }
  220. dfMsg := h.udpDefragger.Feed(message)
  221. if dfMsg == nil {
  222. continue
  223. }
  224. h.udpAccess.RLock()
  225. ch, ok := h.udpSessions[dfMsg.SessionID]
  226. if ok {
  227. select {
  228. case ch <- dfMsg:
  229. // OK
  230. default:
  231. // Silently drop the message when the channel is full
  232. }
  233. }
  234. h.udpAccess.RUnlock()
  235. }
  236. }
  237. func (h *Hysteria) acceptStream(ctx context.Context, conn quic.Connection, stream quic.Stream) error {
  238. request, err := hysteria.ReadClientRequest(stream)
  239. if err != nil {
  240. return err
  241. }
  242. err = hysteria.WriteServerResponse(stream, hysteria.ServerResponse{
  243. OK: true,
  244. })
  245. if err != nil {
  246. return err
  247. }
  248. var metadata adapter.InboundContext
  249. metadata.Inbound = h.tag
  250. metadata.InboundType = C.TypeHysteria
  251. metadata.SniffEnabled = h.listenOptions.SniffEnabled
  252. metadata.SniffOverrideDestination = h.listenOptions.SniffOverrideDestination
  253. metadata.DomainStrategy = dns.DomainStrategy(h.listenOptions.DomainStrategy)
  254. metadata.Source = M.SocksaddrFromNet(conn.RemoteAddr())
  255. metadata.OriginDestination = M.SocksaddrFromNet(conn.LocalAddr())
  256. metadata.Destination = M.ParseSocksaddrHostPort(request.Host, request.Port)
  257. if !request.UDP {
  258. h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
  259. return h.router.RouteConnection(ctx, hysteria.NewConn(stream, metadata.Destination), metadata)
  260. } else {
  261. h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination)
  262. var id uint32
  263. h.udpAccess.Lock()
  264. id = h.udpSessionId
  265. nCh := make(chan *hysteria.UDPMessage, 1024)
  266. h.udpSessions[id] = nCh
  267. h.udpSessionId += 1
  268. h.udpAccess.Unlock()
  269. packetConn := hysteria.NewPacketConn(conn, stream, id, metadata.Destination, nCh, common.Closer(func() error {
  270. h.udpAccess.Lock()
  271. if ch, ok := h.udpSessions[id]; ok {
  272. close(ch)
  273. delete(h.udpSessions, id)
  274. }
  275. h.udpAccess.Unlock()
  276. return nil
  277. }))
  278. go packetConn.Hold()
  279. return h.router.RoutePacketConnection(ctx, packetConn, metadata)
  280. }
  281. }
  282. func (h *Hysteria) Close() error {
  283. h.udpAccess.Lock()
  284. for _, session := range h.udpSessions {
  285. close(session)
  286. }
  287. h.udpSessions = make(map[uint32]chan *hysteria.UDPMessage)
  288. h.udpAccess.Unlock()
  289. return common.Close(
  290. &h.myInboundAdapter,
  291. h.listener,
  292. common.PtrOrNil(h.tlsConfig),
  293. )
  294. }