hysteria.go 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. //go:build with_quic
  2. package inbound
  3. import (
  4. "context"
  5. "sync"
  6. "github.com/sagernet/quic-go"
  7. "github.com/sagernet/quic-go/congestion"
  8. "github.com/sagernet/sing-box/adapter"
  9. "github.com/sagernet/sing-box/common/tls"
  10. C "github.com/sagernet/sing-box/constant"
  11. "github.com/sagernet/sing-box/log"
  12. "github.com/sagernet/sing-box/option"
  13. "github.com/sagernet/sing-box/transport/hysteria"
  14. "github.com/sagernet/sing/common"
  15. "github.com/sagernet/sing/common/auth"
  16. E "github.com/sagernet/sing/common/exceptions"
  17. F "github.com/sagernet/sing/common/format"
  18. M "github.com/sagernet/sing/common/metadata"
  19. N "github.com/sagernet/sing/common/network"
  20. "golang.org/x/exp/slices"
  21. )
  22. var _ adapter.Inbound = (*Hysteria)(nil)
  23. type Hysteria struct {
  24. myInboundAdapter
  25. quicConfig *quic.Config
  26. tlsConfig tls.ServerConfig
  27. authKey []string
  28. authUser []string
  29. xplusKey []byte
  30. sendBPS uint64
  31. recvBPS uint64
  32. listener quic.Listener
  33. udpAccess sync.RWMutex
  34. udpSessionId uint32
  35. udpSessions map[uint32]chan *hysteria.UDPMessage
  36. udpDefragger hysteria.Defragger
  37. }
  38. func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HysteriaInboundOptions) (*Hysteria, error) {
  39. options.UDPFragmentDefault = true
  40. quicConfig := &quic.Config{
  41. InitialStreamReceiveWindow: options.ReceiveWindowConn,
  42. MaxStreamReceiveWindow: options.ReceiveWindowConn,
  43. InitialConnectionReceiveWindow: options.ReceiveWindowClient,
  44. MaxConnectionReceiveWindow: options.ReceiveWindowClient,
  45. MaxIncomingStreams: int64(options.MaxConnClient),
  46. KeepAlivePeriod: hysteria.KeepAlivePeriod,
  47. DisablePathMTUDiscovery: options.DisableMTUDiscovery || !(C.IsLinux || C.IsWindows),
  48. EnableDatagrams: true,
  49. }
  50. if options.ReceiveWindowConn == 0 {
  51. quicConfig.InitialStreamReceiveWindow = hysteria.DefaultStreamReceiveWindow
  52. quicConfig.MaxStreamReceiveWindow = hysteria.DefaultStreamReceiveWindow
  53. }
  54. if options.ReceiveWindowClient == 0 {
  55. quicConfig.InitialConnectionReceiveWindow = hysteria.DefaultConnectionReceiveWindow
  56. quicConfig.MaxConnectionReceiveWindow = hysteria.DefaultConnectionReceiveWindow
  57. }
  58. if quicConfig.MaxIncomingStreams == 0 {
  59. quicConfig.MaxIncomingStreams = hysteria.DefaultMaxIncomingStreams
  60. }
  61. authKey := common.Map(options.Users, func(it option.HysteriaUser) string {
  62. if len(it.Auth) > 0 {
  63. return string(it.Auth)
  64. } else {
  65. return it.AuthString
  66. }
  67. })
  68. authUser := common.Map(options.Users, func(it option.HysteriaUser) string {
  69. return it.Name
  70. })
  71. var xplus []byte
  72. if options.Obfs != "" {
  73. xplus = []byte(options.Obfs)
  74. }
  75. var up, down uint64
  76. if len(options.Up) > 0 {
  77. up = hysteria.StringToBps(options.Up)
  78. if up == 0 {
  79. return nil, E.New("invalid up speed format: ", options.Up)
  80. }
  81. } else {
  82. up = uint64(options.UpMbps) * hysteria.MbpsToBps
  83. }
  84. if len(options.Down) > 0 {
  85. down = hysteria.StringToBps(options.Down)
  86. if down == 0 {
  87. return nil, E.New("invalid down speed format: ", options.Down)
  88. }
  89. } else {
  90. down = uint64(options.DownMbps) * hysteria.MbpsToBps
  91. }
  92. if up < hysteria.MinSpeedBPS {
  93. return nil, E.New("invalid up speed")
  94. }
  95. if down < hysteria.MinSpeedBPS {
  96. return nil, E.New("invalid down speed")
  97. }
  98. inbound := &Hysteria{
  99. myInboundAdapter: myInboundAdapter{
  100. protocol: C.TypeHysteria,
  101. network: []string{N.NetworkUDP},
  102. ctx: ctx,
  103. router: router,
  104. logger: logger,
  105. tag: tag,
  106. listenOptions: options.ListenOptions,
  107. },
  108. quicConfig: quicConfig,
  109. authKey: authKey,
  110. authUser: authUser,
  111. xplusKey: xplus,
  112. sendBPS: up,
  113. recvBPS: down,
  114. udpSessions: make(map[uint32]chan *hysteria.UDPMessage),
  115. }
  116. if options.TLS == nil || !options.TLS.Enabled {
  117. return nil, C.ErrTLSRequired
  118. }
  119. if len(options.TLS.ALPN) == 0 {
  120. options.TLS.ALPN = []string{hysteria.DefaultALPN}
  121. }
  122. tlsConfig, err := tls.NewServer(ctx, router, logger, common.PtrValueOrDefault(options.TLS))
  123. if err != nil {
  124. return nil, err
  125. }
  126. inbound.tlsConfig = tlsConfig
  127. return inbound, nil
  128. }
  129. func (h *Hysteria) Start() error {
  130. packetConn, err := h.myInboundAdapter.ListenUDP()
  131. if err != nil {
  132. return err
  133. }
  134. if len(h.xplusKey) > 0 {
  135. packetConn = hysteria.NewXPlusPacketConn(packetConn, h.xplusKey)
  136. packetConn = &hysteria.PacketConnWrapper{PacketConn: packetConn}
  137. }
  138. err = h.tlsConfig.Start()
  139. if err != nil {
  140. return err
  141. }
  142. rawConfig, err := h.tlsConfig.Config()
  143. if err != nil {
  144. return err
  145. }
  146. listener, err := quic.Listen(packetConn, rawConfig, h.quicConfig)
  147. if err != nil {
  148. return err
  149. }
  150. h.listener = listener
  151. h.logger.Info("udp server started at ", listener.Addr())
  152. go h.acceptLoop()
  153. return nil
  154. }
  155. func (h *Hysteria) acceptLoop() {
  156. for {
  157. ctx := log.ContextWithNewID(h.ctx)
  158. conn, err := h.listener.Accept(ctx)
  159. if err != nil {
  160. return
  161. }
  162. go func() {
  163. hErr := h.accept(ctx, conn)
  164. if hErr != nil {
  165. conn.CloseWithError(0, "")
  166. NewError(h.logger, ctx, E.Cause(hErr, "process connection from ", conn.RemoteAddr()))
  167. }
  168. }()
  169. }
  170. }
  171. func (h *Hysteria) accept(ctx context.Context, conn quic.Connection) error {
  172. controlStream, err := conn.AcceptStream(ctx)
  173. if err != nil {
  174. return err
  175. }
  176. clientHello, err := hysteria.ReadClientHello(controlStream)
  177. if err != nil {
  178. return err
  179. }
  180. if len(h.authKey) > 0 {
  181. userIndex := slices.Index(h.authKey, string(clientHello.Auth))
  182. if userIndex == -1 {
  183. err = hysteria.WriteServerHello(controlStream, hysteria.ServerHello{
  184. Message: "wrong password",
  185. })
  186. return E.Errors(E.New("wrong password: ", string(clientHello.Auth)), err)
  187. }
  188. user := h.authUser[userIndex]
  189. if user == "" {
  190. user = F.ToString(userIndex)
  191. } else {
  192. ctx = auth.ContextWithUser(ctx, user)
  193. }
  194. h.logger.InfoContext(ctx, "[", user, "] inbound connection from ", conn.RemoteAddr())
  195. } else {
  196. h.logger.InfoContext(ctx, "inbound connection from ", conn.RemoteAddr())
  197. }
  198. h.logger.DebugContext(ctx, "peer send speed: ", clientHello.SendBPS/1024/1024, " MBps, peer recv speed: ", clientHello.RecvBPS/1024/1024, " MBps")
  199. if clientHello.SendBPS == 0 || clientHello.RecvBPS == 0 {
  200. return E.New("invalid rate from client")
  201. }
  202. serverSendBPS, serverRecvBPS := clientHello.RecvBPS, clientHello.SendBPS
  203. if h.sendBPS > 0 && serverSendBPS > h.sendBPS {
  204. serverSendBPS = h.sendBPS
  205. }
  206. if h.recvBPS > 0 && serverRecvBPS > h.recvBPS {
  207. serverRecvBPS = h.recvBPS
  208. }
  209. err = hysteria.WriteServerHello(controlStream, hysteria.ServerHello{
  210. OK: true,
  211. SendBPS: serverSendBPS,
  212. RecvBPS: serverRecvBPS,
  213. })
  214. if err != nil {
  215. return err
  216. }
  217. conn.SetCongestionControl(hysteria.NewBrutalSender(congestion.ByteCount(serverSendBPS)))
  218. go h.udpRecvLoop(conn)
  219. for {
  220. var stream quic.Stream
  221. stream, err = conn.AcceptStream(ctx)
  222. if err != nil {
  223. return err
  224. }
  225. go func() {
  226. hErr := h.acceptStream(ctx, conn /*&hysteria.StreamWrapper{Stream: stream}*/, stream)
  227. if hErr != nil {
  228. stream.Close()
  229. NewError(h.logger, ctx, E.Cause(hErr, "process stream from ", conn.RemoteAddr()))
  230. }
  231. }()
  232. }
  233. }
  234. func (h *Hysteria) udpRecvLoop(conn quic.Connection) {
  235. for {
  236. packet, err := conn.ReceiveMessage()
  237. if err != nil {
  238. return
  239. }
  240. message, err := hysteria.ParseUDPMessage(packet)
  241. if err != nil {
  242. h.logger.Error("parse udp message: ", err)
  243. continue
  244. }
  245. dfMsg := h.udpDefragger.Feed(message)
  246. if dfMsg == nil {
  247. continue
  248. }
  249. h.udpAccess.RLock()
  250. ch, ok := h.udpSessions[dfMsg.SessionID]
  251. if ok {
  252. select {
  253. case ch <- dfMsg:
  254. // OK
  255. default:
  256. // Silently drop the message when the channel is full
  257. }
  258. }
  259. h.udpAccess.RUnlock()
  260. }
  261. }
  262. func (h *Hysteria) acceptStream(ctx context.Context, conn quic.Connection, stream quic.Stream) error {
  263. request, err := hysteria.ReadClientRequest(stream)
  264. if err != nil {
  265. return err
  266. }
  267. var metadata adapter.InboundContext
  268. metadata.Inbound = h.tag
  269. metadata.InboundType = C.TypeHysteria
  270. metadata.InboundOptions = h.listenOptions.InboundOptions
  271. metadata.Source = M.SocksaddrFromNet(conn.RemoteAddr()).Unwrap()
  272. metadata.OriginDestination = M.SocksaddrFromNet(conn.LocalAddr()).Unwrap()
  273. metadata.Destination = M.ParseSocksaddrHostPort(request.Host, request.Port).Unwrap()
  274. if !request.UDP {
  275. err = hysteria.WriteServerResponse(stream, hysteria.ServerResponse{
  276. OK: true,
  277. })
  278. if err != nil {
  279. return err
  280. }
  281. h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
  282. return h.router.RouteConnection(ctx, hysteria.NewConn(stream, metadata.Destination, false), metadata)
  283. } else {
  284. h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination)
  285. var id uint32
  286. h.udpAccess.Lock()
  287. id = h.udpSessionId
  288. nCh := make(chan *hysteria.UDPMessage, 1024)
  289. h.udpSessions[id] = nCh
  290. h.udpSessionId += 1
  291. h.udpAccess.Unlock()
  292. err = hysteria.WriteServerResponse(stream, hysteria.ServerResponse{
  293. OK: true,
  294. UDPSessionID: id,
  295. })
  296. if err != nil {
  297. return err
  298. }
  299. packetConn := hysteria.NewPacketConn(conn, stream, id, metadata.Destination, nCh, common.Closer(func() error {
  300. h.udpAccess.Lock()
  301. if ch, ok := h.udpSessions[id]; ok {
  302. close(ch)
  303. delete(h.udpSessions, id)
  304. }
  305. h.udpAccess.Unlock()
  306. return nil
  307. }))
  308. go packetConn.Hold()
  309. return h.router.RoutePacketConnection(ctx, packetConn, metadata)
  310. }
  311. }
  312. func (h *Hysteria) Close() error {
  313. h.udpAccess.Lock()
  314. for _, session := range h.udpSessions {
  315. close(session)
  316. }
  317. h.udpSessions = make(map[uint32]chan *hysteria.UDPMessage)
  318. h.udpAccess.Unlock()
  319. return common.Close(
  320. &h.myInboundAdapter,
  321. h.listener,
  322. h.tlsConfig,
  323. )
  324. }