hysteria.go 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. //go:build with_quic
  2. package inbound
  3. import (
  4. "bytes"
  5. "context"
  6. "sync"
  7. "github.com/sagernet/quic-go"
  8. "github.com/sagernet/quic-go/congestion"
  9. "github.com/sagernet/sing-box/adapter"
  10. "github.com/sagernet/sing-box/common/tls"
  11. C "github.com/sagernet/sing-box/constant"
  12. "github.com/sagernet/sing-box/log"
  13. "github.com/sagernet/sing-box/option"
  14. "github.com/sagernet/sing-box/transport/hysteria"
  15. "github.com/sagernet/sing-dns"
  16. "github.com/sagernet/sing/common"
  17. E "github.com/sagernet/sing/common/exceptions"
  18. M "github.com/sagernet/sing/common/metadata"
  19. N "github.com/sagernet/sing/common/network"
  20. )
  21. var _ adapter.Inbound = (*Hysteria)(nil)
  22. type Hysteria struct {
  23. myInboundAdapter
  24. quicConfig *quic.Config
  25. tlsConfig tls.ServerConfig
  26. authKey []byte
  27. xplusKey []byte
  28. sendBPS uint64
  29. recvBPS uint64
  30. listener quic.Listener
  31. udpAccess sync.RWMutex
  32. udpSessionId uint32
  33. udpSessions map[uint32]chan *hysteria.UDPMessage
  34. udpDefragger hysteria.Defragger
  35. }
  36. func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HysteriaInboundOptions) (*Hysteria, error) {
  37. quicConfig := &quic.Config{
  38. InitialStreamReceiveWindow: options.ReceiveWindowConn,
  39. MaxStreamReceiveWindow: options.ReceiveWindowConn,
  40. InitialConnectionReceiveWindow: options.ReceiveWindowClient,
  41. MaxConnectionReceiveWindow: options.ReceiveWindowClient,
  42. MaxIncomingStreams: int64(options.MaxConnClient),
  43. KeepAlivePeriod: hysteria.KeepAlivePeriod,
  44. DisablePathMTUDiscovery: options.DisableMTUDiscovery || !(C.IsLinux || C.IsWindows),
  45. EnableDatagrams: true,
  46. }
  47. if options.ReceiveWindowConn == 0 {
  48. quicConfig.InitialStreamReceiveWindow = hysteria.DefaultStreamReceiveWindow
  49. quicConfig.MaxStreamReceiveWindow = hysteria.DefaultStreamReceiveWindow
  50. }
  51. if options.ReceiveWindowClient == 0 {
  52. quicConfig.InitialConnectionReceiveWindow = hysteria.DefaultConnectionReceiveWindow
  53. quicConfig.MaxConnectionReceiveWindow = hysteria.DefaultConnectionReceiveWindow
  54. }
  55. if quicConfig.MaxIncomingStreams == 0 {
  56. quicConfig.MaxIncomingStreams = hysteria.DefaultMaxIncomingStreams
  57. }
  58. var auth []byte
  59. if len(options.Auth) > 0 {
  60. auth = options.Auth
  61. } else {
  62. auth = []byte(options.AuthString)
  63. }
  64. var xplus []byte
  65. if options.Obfs != "" {
  66. xplus = []byte(options.Obfs)
  67. }
  68. var up, down uint64
  69. if len(options.Up) > 0 {
  70. up = hysteria.StringToBps(options.Up)
  71. if up == 0 {
  72. return nil, E.New("invalid up speed format: ", options.Up)
  73. }
  74. } else {
  75. up = uint64(options.UpMbps) * hysteria.MbpsToBps
  76. }
  77. if len(options.Down) > 0 {
  78. down = hysteria.StringToBps(options.Down)
  79. if down == 0 {
  80. return nil, E.New("invalid down speed format: ", options.Down)
  81. }
  82. } else {
  83. down = uint64(options.DownMbps) * hysteria.MbpsToBps
  84. }
  85. if up < hysteria.MinSpeedBPS {
  86. return nil, E.New("invalid up speed")
  87. }
  88. if down < hysteria.MinSpeedBPS {
  89. return nil, E.New("invalid down speed")
  90. }
  91. inbound := &Hysteria{
  92. myInboundAdapter: myInboundAdapter{
  93. protocol: C.TypeHysteria,
  94. network: []string{N.NetworkUDP},
  95. ctx: ctx,
  96. router: router,
  97. logger: logger,
  98. tag: tag,
  99. listenOptions: options.ListenOptions,
  100. },
  101. quicConfig: quicConfig,
  102. authKey: auth,
  103. xplusKey: xplus,
  104. sendBPS: up,
  105. recvBPS: down,
  106. udpSessions: make(map[uint32]chan *hysteria.UDPMessage),
  107. }
  108. if options.TLS == nil || !options.TLS.Enabled {
  109. return nil, C.ErrTLSRequired
  110. }
  111. if len(options.TLS.ALPN) == 0 {
  112. options.TLS.ALPN = []string{hysteria.DefaultALPN}
  113. }
  114. tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS))
  115. if err != nil {
  116. return nil, err
  117. }
  118. inbound.tlsConfig = tlsConfig
  119. return inbound, nil
  120. }
  121. func (h *Hysteria) Start() error {
  122. packetConn, err := h.myInboundAdapter.ListenUDP()
  123. if err != nil {
  124. return err
  125. }
  126. if len(h.xplusKey) > 0 {
  127. packetConn = hysteria.NewXPlusPacketConn(packetConn, h.xplusKey)
  128. packetConn = &hysteria.PacketConnWrapper{PacketConn: packetConn}
  129. }
  130. err = h.tlsConfig.Start()
  131. if err != nil {
  132. return err
  133. }
  134. rawConfig, err := h.tlsConfig.Config()
  135. if err != nil {
  136. return err
  137. }
  138. listener, err := quic.Listen(packetConn, rawConfig, h.quicConfig)
  139. if err != nil {
  140. return err
  141. }
  142. h.listener = listener
  143. h.logger.Info("udp server started at ", listener.Addr())
  144. go h.acceptLoop()
  145. return nil
  146. }
  147. func (h *Hysteria) acceptLoop() {
  148. for {
  149. ctx := log.ContextWithNewID(h.ctx)
  150. conn, err := h.listener.Accept(ctx)
  151. if err != nil {
  152. return
  153. }
  154. h.logger.InfoContext(ctx, "inbound connection from ", conn.RemoteAddr())
  155. go func() {
  156. hErr := h.accept(ctx, conn)
  157. if hErr != nil {
  158. conn.CloseWithError(0, "")
  159. NewError(h.logger, ctx, E.Cause(hErr, "process connection from ", conn.RemoteAddr()))
  160. }
  161. }()
  162. }
  163. }
  164. func (h *Hysteria) accept(ctx context.Context, conn quic.Connection) error {
  165. controlStream, err := conn.AcceptStream(ctx)
  166. if err != nil {
  167. return err
  168. }
  169. clientHello, err := hysteria.ReadClientHello(controlStream)
  170. if err != nil {
  171. return err
  172. }
  173. if !bytes.Equal(clientHello.Auth, h.authKey) {
  174. err = hysteria.WriteServerHello(controlStream, hysteria.ServerHello{
  175. Message: "wrong password",
  176. })
  177. return E.Errors(E.New("wrong password: ", string(clientHello.Auth)), err)
  178. }
  179. if clientHello.SendBPS == 0 || clientHello.RecvBPS == 0 {
  180. return E.New("invalid rate from client")
  181. }
  182. serverSendBPS, serverRecvBPS := clientHello.RecvBPS, clientHello.SendBPS
  183. if h.sendBPS > 0 && serverSendBPS > h.sendBPS {
  184. serverSendBPS = h.sendBPS
  185. }
  186. if h.recvBPS > 0 && serverRecvBPS > h.recvBPS {
  187. serverRecvBPS = h.recvBPS
  188. }
  189. err = hysteria.WriteServerHello(controlStream, hysteria.ServerHello{
  190. OK: true,
  191. SendBPS: serverSendBPS,
  192. RecvBPS: serverRecvBPS,
  193. })
  194. if err != nil {
  195. return err
  196. }
  197. conn.SetCongestionControl(hysteria.NewBrutalSender(congestion.ByteCount(serverSendBPS)))
  198. go h.udpRecvLoop(conn)
  199. for {
  200. var stream quic.Stream
  201. stream, err = conn.AcceptStream(ctx)
  202. if err != nil {
  203. return err
  204. }
  205. go func() {
  206. hErr := h.acceptStream(ctx, conn /*&hysteria.StreamWrapper{Stream: stream}*/, stream)
  207. if hErr != nil {
  208. stream.Close()
  209. NewError(h.logger, ctx, E.Cause(hErr, "process stream from ", conn.RemoteAddr()))
  210. }
  211. }()
  212. }
  213. }
  214. func (h *Hysteria) udpRecvLoop(conn quic.Connection) {
  215. for {
  216. packet, err := conn.ReceiveMessage()
  217. if err != nil {
  218. return
  219. }
  220. message, err := hysteria.ParseUDPMessage(packet)
  221. if err != nil {
  222. h.logger.Error("parse udp message: ", err)
  223. continue
  224. }
  225. dfMsg := h.udpDefragger.Feed(message)
  226. if dfMsg == nil {
  227. continue
  228. }
  229. h.udpAccess.RLock()
  230. ch, ok := h.udpSessions[dfMsg.SessionID]
  231. if ok {
  232. select {
  233. case ch <- dfMsg:
  234. // OK
  235. default:
  236. // Silently drop the message when the channel is full
  237. }
  238. }
  239. h.udpAccess.RUnlock()
  240. }
  241. }
  242. func (h *Hysteria) acceptStream(ctx context.Context, conn quic.Connection, stream quic.Stream) error {
  243. request, err := hysteria.ReadClientRequest(stream)
  244. if err != nil {
  245. return err
  246. }
  247. err = hysteria.WriteServerResponse(stream, hysteria.ServerResponse{
  248. OK: true,
  249. })
  250. if err != nil {
  251. return err
  252. }
  253. var metadata adapter.InboundContext
  254. metadata.Inbound = h.tag
  255. metadata.InboundType = C.TypeHysteria
  256. metadata.SniffEnabled = h.listenOptions.SniffEnabled
  257. metadata.SniffOverrideDestination = h.listenOptions.SniffOverrideDestination
  258. metadata.DomainStrategy = dns.DomainStrategy(h.listenOptions.DomainStrategy)
  259. metadata.Source = M.SocksaddrFromNet(conn.RemoteAddr())
  260. metadata.OriginDestination = M.SocksaddrFromNet(conn.LocalAddr())
  261. metadata.Destination = M.ParseSocksaddrHostPort(request.Host, request.Port)
  262. if !request.UDP {
  263. h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
  264. return h.router.RouteConnection(ctx, hysteria.NewConn(stream, metadata.Destination), metadata)
  265. } else {
  266. h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination)
  267. var id uint32
  268. h.udpAccess.Lock()
  269. id = h.udpSessionId
  270. nCh := make(chan *hysteria.UDPMessage, 1024)
  271. h.udpSessions[id] = nCh
  272. h.udpSessionId += 1
  273. h.udpAccess.Unlock()
  274. packetConn := hysteria.NewPacketConn(conn, stream, id, metadata.Destination, nCh, common.Closer(func() error {
  275. h.udpAccess.Lock()
  276. if ch, ok := h.udpSessions[id]; ok {
  277. close(ch)
  278. delete(h.udpSessions, id)
  279. }
  280. h.udpAccess.Unlock()
  281. return nil
  282. }))
  283. go packetConn.Hold()
  284. return h.router.RoutePacketConnection(ctx, packetConn, metadata)
  285. }
  286. }
  287. func (h *Hysteria) Close() error {
  288. h.udpAccess.Lock()
  289. for _, session := range h.udpSessions {
  290. close(session)
  291. }
  292. h.udpSessions = make(map[uint32]chan *hysteria.UDPMessage)
  293. h.udpAccess.Unlock()
  294. return common.Close(
  295. &h.myInboundAdapter,
  296. h.listener,
  297. h.tlsConfig,
  298. )
  299. }