hysteria.go 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. //go:build with_quic
  2. package inbound
  3. import (
  4. "context"
  5. "sync"
  6. "github.com/sagernet/quic-go"
  7. "github.com/sagernet/quic-go/congestion"
  8. "github.com/sagernet/sing-box/adapter"
  9. "github.com/sagernet/sing-box/common/tls"
  10. C "github.com/sagernet/sing-box/constant"
  11. "github.com/sagernet/sing-box/log"
  12. "github.com/sagernet/sing-box/option"
  13. "github.com/sagernet/sing-box/transport/hysteria"
  14. "github.com/sagernet/sing/common"
  15. "github.com/sagernet/sing/common/auth"
  16. E "github.com/sagernet/sing/common/exceptions"
  17. F "github.com/sagernet/sing/common/format"
  18. M "github.com/sagernet/sing/common/metadata"
  19. N "github.com/sagernet/sing/common/network"
  20. "golang.org/x/exp/slices"
  21. )
  22. var _ adapter.Inbound = (*Hysteria)(nil)
  23. type Hysteria struct {
  24. myInboundAdapter
  25. quicConfig *quic.Config
  26. tlsConfig tls.ServerConfig
  27. authKey []string
  28. authUser []string
  29. xplusKey []byte
  30. sendBPS uint64
  31. recvBPS uint64
  32. listener quic.Listener
  33. udpAccess sync.RWMutex
  34. udpSessionId uint32
  35. udpSessions map[uint32]chan *hysteria.UDPMessage
  36. udpDefragger hysteria.Defragger
  37. }
  38. func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HysteriaInboundOptions) (*Hysteria, error) {
  39. options.UDPFragmentDefault = true
  40. quicConfig := &quic.Config{
  41. InitialStreamReceiveWindow: options.ReceiveWindowConn,
  42. MaxStreamReceiveWindow: options.ReceiveWindowConn,
  43. InitialConnectionReceiveWindow: options.ReceiveWindowClient,
  44. MaxConnectionReceiveWindow: options.ReceiveWindowClient,
  45. MaxIncomingStreams: int64(options.MaxConnClient),
  46. KeepAlivePeriod: hysteria.KeepAlivePeriod,
  47. DisablePathMTUDiscovery: options.DisableMTUDiscovery || !(C.IsLinux || C.IsWindows),
  48. EnableDatagrams: true,
  49. }
  50. if options.ReceiveWindowConn == 0 {
  51. quicConfig.InitialStreamReceiveWindow = hysteria.DefaultStreamReceiveWindow
  52. quicConfig.MaxStreamReceiveWindow = hysteria.DefaultStreamReceiveWindow
  53. }
  54. if options.ReceiveWindowClient == 0 {
  55. quicConfig.InitialConnectionReceiveWindow = hysteria.DefaultConnectionReceiveWindow
  56. quicConfig.MaxConnectionReceiveWindow = hysteria.DefaultConnectionReceiveWindow
  57. }
  58. if quicConfig.MaxIncomingStreams == 0 {
  59. quicConfig.MaxIncomingStreams = hysteria.DefaultMaxIncomingStreams
  60. }
  61. authKey := common.Map(options.Users, func(it option.HysteriaUser) string {
  62. if len(it.Auth) > 0 {
  63. return string(it.Auth)
  64. } else {
  65. return it.AuthString
  66. }
  67. })
  68. authUser := common.Map(options.Users, func(it option.HysteriaUser) string {
  69. return it.Name
  70. })
  71. var xplus []byte
  72. if options.Obfs != "" {
  73. xplus = []byte(options.Obfs)
  74. }
  75. var up, down uint64
  76. if len(options.Up) > 0 {
  77. up = hysteria.StringToBps(options.Up)
  78. if up == 0 {
  79. return nil, E.New("invalid up speed format: ", options.Up)
  80. }
  81. } else {
  82. up = uint64(options.UpMbps) * hysteria.MbpsToBps
  83. }
  84. if len(options.Down) > 0 {
  85. down = hysteria.StringToBps(options.Down)
  86. if down == 0 {
  87. return nil, E.New("invalid down speed format: ", options.Down)
  88. }
  89. } else {
  90. down = uint64(options.DownMbps) * hysteria.MbpsToBps
  91. }
  92. if up < hysteria.MinSpeedBPS {
  93. return nil, E.New("invalid up speed")
  94. }
  95. if down < hysteria.MinSpeedBPS {
  96. return nil, E.New("invalid down speed")
  97. }
  98. inbound := &Hysteria{
  99. myInboundAdapter: myInboundAdapter{
  100. protocol: C.TypeHysteria,
  101. network: []string{N.NetworkUDP},
  102. ctx: ctx,
  103. router: router,
  104. logger: logger,
  105. tag: tag,
  106. listenOptions: options.ListenOptions,
  107. },
  108. quicConfig: quicConfig,
  109. authKey: authKey,
  110. authUser: authUser,
  111. xplusKey: xplus,
  112. sendBPS: up,
  113. recvBPS: down,
  114. udpSessions: make(map[uint32]chan *hysteria.UDPMessage),
  115. }
  116. if options.TLS == nil || !options.TLS.Enabled {
  117. return nil, C.ErrTLSRequired
  118. }
  119. if len(options.TLS.ALPN) == 0 {
  120. options.TLS.ALPN = []string{hysteria.DefaultALPN}
  121. }
  122. tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS))
  123. if err != nil {
  124. return nil, err
  125. }
  126. inbound.tlsConfig = tlsConfig
  127. return inbound, nil
  128. }
  129. func (h *Hysteria) Start() error {
  130. packetConn, err := h.myInboundAdapter.ListenUDP()
  131. if err != nil {
  132. return err
  133. }
  134. if len(h.xplusKey) > 0 {
  135. packetConn = hysteria.NewXPlusPacketConn(packetConn, h.xplusKey)
  136. packetConn = &hysteria.PacketConnWrapper{PacketConn: packetConn}
  137. }
  138. err = h.tlsConfig.Start()
  139. if err != nil {
  140. return err
  141. }
  142. rawConfig, err := h.tlsConfig.Config()
  143. if err != nil {
  144. return err
  145. }
  146. listener, err := quic.Listen(packetConn, rawConfig, h.quicConfig)
  147. if err != nil {
  148. return err
  149. }
  150. h.listener = listener
  151. h.logger.Info("udp server started at ", listener.Addr())
  152. go h.acceptLoop()
  153. return nil
  154. }
  155. func (h *Hysteria) acceptLoop() {
  156. for {
  157. ctx := log.ContextWithNewID(h.ctx)
  158. conn, err := h.listener.Accept(ctx)
  159. if err != nil {
  160. return
  161. }
  162. go func() {
  163. hErr := h.accept(ctx, conn)
  164. if hErr != nil {
  165. conn.CloseWithError(0, "")
  166. NewError(h.logger, ctx, E.Cause(hErr, "process connection from ", conn.RemoteAddr()))
  167. }
  168. }()
  169. }
  170. }
  171. func (h *Hysteria) accept(ctx context.Context, conn quic.Connection) error {
  172. controlStream, err := conn.AcceptStream(ctx)
  173. if err != nil {
  174. return err
  175. }
  176. clientHello, err := hysteria.ReadClientHello(controlStream)
  177. if err != nil {
  178. return err
  179. }
  180. userIndex := slices.Index(h.authKey, string(clientHello.Auth))
  181. if userIndex == -1 {
  182. err = hysteria.WriteServerHello(controlStream, hysteria.ServerHello{
  183. Message: "wrong password",
  184. })
  185. return E.Errors(E.New("wrong password: ", string(clientHello.Auth)), err)
  186. }
  187. user := h.authUser[userIndex]
  188. if user == "" {
  189. user = F.ToString(userIndex)
  190. } else {
  191. ctx = auth.ContextWithUser(ctx, user)
  192. }
  193. h.logger.InfoContext(ctx, "[", user, "] inbound connection from ", conn.RemoteAddr())
  194. h.logger.DebugContext(ctx, "peer send speed: ", clientHello.SendBPS/1024/1024, " MBps, peer recv speed: ", clientHello.RecvBPS/1024/1024, " MBps")
  195. if clientHello.SendBPS == 0 || clientHello.RecvBPS == 0 {
  196. return E.New("invalid rate from client")
  197. }
  198. serverSendBPS, serverRecvBPS := clientHello.RecvBPS, clientHello.SendBPS
  199. if h.sendBPS > 0 && serverSendBPS > h.sendBPS {
  200. serverSendBPS = h.sendBPS
  201. }
  202. if h.recvBPS > 0 && serverRecvBPS > h.recvBPS {
  203. serverRecvBPS = h.recvBPS
  204. }
  205. err = hysteria.WriteServerHello(controlStream, hysteria.ServerHello{
  206. OK: true,
  207. SendBPS: serverSendBPS,
  208. RecvBPS: serverRecvBPS,
  209. })
  210. if err != nil {
  211. return err
  212. }
  213. conn.SetCongestionControl(hysteria.NewBrutalSender(congestion.ByteCount(serverSendBPS)))
  214. go h.udpRecvLoop(conn)
  215. for {
  216. var stream quic.Stream
  217. stream, err = conn.AcceptStream(ctx)
  218. if err != nil {
  219. return err
  220. }
  221. go func() {
  222. hErr := h.acceptStream(ctx, conn /*&hysteria.StreamWrapper{Stream: stream}*/, stream)
  223. if hErr != nil {
  224. stream.Close()
  225. NewError(h.logger, ctx, E.Cause(hErr, "process stream from ", conn.RemoteAddr()))
  226. }
  227. }()
  228. }
  229. }
  230. func (h *Hysteria) udpRecvLoop(conn quic.Connection) {
  231. for {
  232. packet, err := conn.ReceiveMessage()
  233. if err != nil {
  234. return
  235. }
  236. message, err := hysteria.ParseUDPMessage(packet)
  237. if err != nil {
  238. h.logger.Error("parse udp message: ", err)
  239. continue
  240. }
  241. dfMsg := h.udpDefragger.Feed(message)
  242. if dfMsg == nil {
  243. continue
  244. }
  245. h.udpAccess.RLock()
  246. ch, ok := h.udpSessions[dfMsg.SessionID]
  247. if ok {
  248. select {
  249. case ch <- dfMsg:
  250. // OK
  251. default:
  252. // Silently drop the message when the channel is full
  253. }
  254. }
  255. h.udpAccess.RUnlock()
  256. }
  257. }
  258. func (h *Hysteria) acceptStream(ctx context.Context, conn quic.Connection, stream quic.Stream) error {
  259. request, err := hysteria.ReadClientRequest(stream)
  260. if err != nil {
  261. return err
  262. }
  263. var metadata adapter.InboundContext
  264. metadata.Inbound = h.tag
  265. metadata.InboundType = C.TypeHysteria
  266. metadata.InboundOptions = h.listenOptions.InboundOptions
  267. metadata.Source = M.SocksaddrFromNet(conn.RemoteAddr()).Unwrap()
  268. metadata.OriginDestination = M.SocksaddrFromNet(conn.LocalAddr()).Unwrap()
  269. metadata.Destination = M.ParseSocksaddrHostPort(request.Host, request.Port).Unwrap()
  270. if !request.UDP {
  271. err = hysteria.WriteServerResponse(stream, hysteria.ServerResponse{
  272. OK: true,
  273. })
  274. if err != nil {
  275. return err
  276. }
  277. h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
  278. return h.router.RouteConnection(ctx, hysteria.NewConn(stream, metadata.Destination, false), metadata)
  279. } else {
  280. h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination)
  281. var id uint32
  282. h.udpAccess.Lock()
  283. id = h.udpSessionId
  284. nCh := make(chan *hysteria.UDPMessage, 1024)
  285. h.udpSessions[id] = nCh
  286. h.udpSessionId += 1
  287. h.udpAccess.Unlock()
  288. err = hysteria.WriteServerResponse(stream, hysteria.ServerResponse{
  289. OK: true,
  290. UDPSessionID: id,
  291. })
  292. if err != nil {
  293. return err
  294. }
  295. packetConn := hysteria.NewPacketConn(conn, stream, id, metadata.Destination, nCh, common.Closer(func() error {
  296. h.udpAccess.Lock()
  297. if ch, ok := h.udpSessions[id]; ok {
  298. close(ch)
  299. delete(h.udpSessions, id)
  300. }
  301. h.udpAccess.Unlock()
  302. return nil
  303. }))
  304. go packetConn.Hold()
  305. return h.router.RoutePacketConnection(ctx, packetConn, metadata)
  306. }
  307. }
  308. func (h *Hysteria) Close() error {
  309. h.udpAccess.Lock()
  310. for _, session := range h.udpSessions {
  311. close(session)
  312. }
  313. h.udpSessions = make(map[uint32]chan *hysteria.UDPMessage)
  314. h.udpAccess.Unlock()
  315. return common.Close(
  316. &h.myInboundAdapter,
  317. h.listener,
  318. h.tlsConfig,
  319. )
  320. }