server.go 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. package hysteria2
  2. import (
  3. "context"
  4. "io"
  5. "net"
  6. "net/http"
  7. "os"
  8. "runtime"
  9. "strings"
  10. "sync"
  11. "github.com/sagernet/quic-go"
  12. "github.com/sagernet/quic-go/http3"
  13. "github.com/sagernet/sing-box/common/qtls"
  14. "github.com/sagernet/sing-box/common/tls"
  15. "github.com/sagernet/sing-box/transport/hysteria2/congestion"
  16. "github.com/sagernet/sing-box/transport/hysteria2/internal/protocol"
  17. tuicCongestion "github.com/sagernet/sing-box/transport/tuic/congestion"
  18. "github.com/sagernet/sing/common"
  19. "github.com/sagernet/sing/common/auth"
  20. "github.com/sagernet/sing/common/baderror"
  21. E "github.com/sagernet/sing/common/exceptions"
  22. "github.com/sagernet/sing/common/logger"
  23. M "github.com/sagernet/sing/common/metadata"
  24. N "github.com/sagernet/sing/common/network"
  25. )
  26. type ServerOptions struct {
  27. Context context.Context
  28. Logger logger.Logger
  29. SendBPS uint64
  30. ReceiveBPS uint64
  31. IgnoreClientBandwidth bool
  32. SalamanderPassword string
  33. TLSConfig tls.ServerConfig
  34. Users []User
  35. UDPDisabled bool
  36. Handler ServerHandler
  37. MasqueradeHandler http.Handler
  38. }
  39. type User struct {
  40. Name string
  41. Password string
  42. }
  43. type ServerHandler interface {
  44. N.TCPConnectionHandler
  45. N.UDPConnectionHandler
  46. }
  47. type Server struct {
  48. ctx context.Context
  49. logger logger.Logger
  50. sendBPS uint64
  51. receiveBPS uint64
  52. ignoreClientBandwidth bool
  53. salamanderPassword string
  54. tlsConfig tls.ServerConfig
  55. quicConfig *quic.Config
  56. userMap map[string]User
  57. udpDisabled bool
  58. handler ServerHandler
  59. masqueradeHandler http.Handler
  60. quicListener io.Closer
  61. }
  62. func NewServer(options ServerOptions) (*Server, error) {
  63. quicConfig := &quic.Config{
  64. DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"),
  65. EnableDatagrams: !options.UDPDisabled,
  66. MaxIncomingStreams: 1 << 60,
  67. InitialStreamReceiveWindow: defaultStreamReceiveWindow,
  68. MaxStreamReceiveWindow: defaultStreamReceiveWindow,
  69. InitialConnectionReceiveWindow: defaultConnReceiveWindow,
  70. MaxConnectionReceiveWindow: defaultConnReceiveWindow,
  71. MaxIdleTimeout: defaultMaxIdleTimeout,
  72. KeepAlivePeriod: defaultKeepAlivePeriod,
  73. }
  74. if len(options.Users) == 0 {
  75. return nil, E.New("missing users")
  76. }
  77. userMap := make(map[string]User)
  78. for _, user := range options.Users {
  79. userMap[user.Password] = user
  80. }
  81. if options.MasqueradeHandler == nil {
  82. options.MasqueradeHandler = http.NotFoundHandler()
  83. }
  84. return &Server{
  85. ctx: options.Context,
  86. logger: options.Logger,
  87. sendBPS: options.SendBPS,
  88. receiveBPS: options.ReceiveBPS,
  89. ignoreClientBandwidth: options.IgnoreClientBandwidth,
  90. salamanderPassword: options.SalamanderPassword,
  91. tlsConfig: options.TLSConfig,
  92. quicConfig: quicConfig,
  93. userMap: userMap,
  94. udpDisabled: options.UDPDisabled,
  95. handler: options.Handler,
  96. masqueradeHandler: options.MasqueradeHandler,
  97. }, nil
  98. }
  99. func (s *Server) Start(conn net.PacketConn) error {
  100. if s.salamanderPassword != "" {
  101. conn = NewSalamanderConn(conn, []byte(s.salamanderPassword))
  102. }
  103. err := qtls.ConfigureHTTP3(s.tlsConfig)
  104. if err != nil {
  105. return err
  106. }
  107. listener, err := qtls.Listen(conn, s.tlsConfig, s.quicConfig)
  108. if err != nil {
  109. return err
  110. }
  111. s.quicListener = listener
  112. go s.loopConnections(listener)
  113. return nil
  114. }
  115. func (s *Server) Close() error {
  116. return common.Close(
  117. s.quicListener,
  118. )
  119. }
  120. func (s *Server) loopConnections(listener qtls.QUICListener) {
  121. for {
  122. connection, err := listener.Accept(s.ctx)
  123. if err != nil {
  124. if strings.Contains(err.Error(), "server closed") {
  125. s.logger.Debug(E.Cause(err, "listener closed"))
  126. } else {
  127. s.logger.Error(E.Cause(err, "listener closed"))
  128. }
  129. return
  130. }
  131. go s.handleConnection(connection)
  132. }
  133. }
  134. func (s *Server) handleConnection(connection quic.Connection) {
  135. session := &serverSession{
  136. Server: s,
  137. ctx: s.ctx,
  138. quicConn: connection,
  139. source: M.SocksaddrFromNet(connection.RemoteAddr()),
  140. connDone: make(chan struct{}),
  141. udpConnMap: make(map[uint32]*udpPacketConn),
  142. }
  143. httpServer := http3.Server{
  144. Handler: session,
  145. StreamHijacker: session.handleStream0,
  146. }
  147. _ = httpServer.ServeQUICConn(connection)
  148. _ = connection.CloseWithError(0, "")
  149. }
  150. type serverSession struct {
  151. *Server
  152. ctx context.Context
  153. quicConn quic.Connection
  154. source M.Socksaddr
  155. connAccess sync.Mutex
  156. connDone chan struct{}
  157. connErr error
  158. authenticated bool
  159. authUser *User
  160. udpAccess sync.RWMutex
  161. udpConnMap map[uint32]*udpPacketConn
  162. }
  163. func (s *serverSession) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  164. if r.Method == http.MethodPost && r.Host == protocol.URLHost && r.URL.Path == protocol.URLPath {
  165. if s.authenticated {
  166. protocol.AuthResponseToHeader(w.Header(), protocol.AuthResponse{
  167. UDPEnabled: !s.udpDisabled,
  168. Rx: s.receiveBPS,
  169. RxAuto: s.ignoreClientBandwidth,
  170. })
  171. w.WriteHeader(protocol.StatusAuthOK)
  172. return
  173. }
  174. request := protocol.AuthRequestFromHeader(r.Header)
  175. user, loaded := s.userMap[request.Auth]
  176. if !loaded {
  177. s.masqueradeHandler.ServeHTTP(w, r)
  178. return
  179. }
  180. s.authUser = &user
  181. s.authenticated = true
  182. if !s.ignoreClientBandwidth && request.Rx > 0 {
  183. var sendBps uint64
  184. if s.sendBPS > 0 && s.sendBPS < request.Rx {
  185. sendBps = s.sendBPS
  186. } else {
  187. sendBps = request.Rx
  188. }
  189. s.quicConn.SetCongestionControl(congestion.NewBrutalSender(sendBps))
  190. } else {
  191. s.quicConn.SetCongestionControl(tuicCongestion.NewBBRSender(
  192. tuicCongestion.DefaultClock{},
  193. tuicCongestion.GetInitialPacketSize(s.quicConn.RemoteAddr()),
  194. tuicCongestion.InitialCongestionWindow*tuicCongestion.InitialMaxDatagramSize,
  195. tuicCongestion.DefaultBBRMaxCongestionWindow*tuicCongestion.InitialMaxDatagramSize,
  196. ))
  197. }
  198. protocol.AuthResponseToHeader(w.Header(), protocol.AuthResponse{
  199. UDPEnabled: !s.udpDisabled,
  200. Rx: s.receiveBPS,
  201. RxAuto: s.ignoreClientBandwidth,
  202. })
  203. w.WriteHeader(protocol.StatusAuthOK)
  204. if s.ctx.Done() != nil {
  205. go func() {
  206. select {
  207. case <-s.ctx.Done():
  208. s.closeWithError(s.ctx.Err())
  209. case <-s.connDone:
  210. }
  211. }()
  212. }
  213. if !s.udpDisabled {
  214. go s.loopMessages()
  215. }
  216. } else {
  217. s.masqueradeHandler.ServeHTTP(w, r)
  218. }
  219. }
  220. func (s *serverSession) handleStream0(frameType http3.FrameType, connection quic.Connection, stream quic.Stream, err error) (bool, error) {
  221. if !s.authenticated || err != nil {
  222. return false, nil
  223. }
  224. if frameType != protocol.FrameTypeTCPRequest {
  225. return false, nil
  226. }
  227. go func() {
  228. hErr := s.handleStream(stream)
  229. stream.CancelRead(0)
  230. stream.Close()
  231. if hErr != nil {
  232. stream.CancelRead(0)
  233. stream.Close()
  234. s.logger.Error(E.Cause(hErr, "handle stream request"))
  235. }
  236. }()
  237. return true, nil
  238. }
  239. func (s *serverSession) handleStream(stream quic.Stream) error {
  240. destinationString, err := protocol.ReadTCPRequest(stream)
  241. if err != nil {
  242. return E.New("read TCP request")
  243. }
  244. ctx := s.ctx
  245. if s.authUser.Name != "" {
  246. ctx = auth.ContextWithUser(s.ctx, s.authUser.Name)
  247. }
  248. _ = s.handler.NewConnection(ctx, &serverConn{Stream: stream}, M.Metadata{
  249. Source: s.source,
  250. Destination: M.ParseSocksaddr(destinationString),
  251. })
  252. return nil
  253. }
  254. func (s *serverSession) closeWithError(err error) {
  255. s.connAccess.Lock()
  256. defer s.connAccess.Unlock()
  257. select {
  258. case <-s.connDone:
  259. return
  260. default:
  261. s.connErr = err
  262. close(s.connDone)
  263. }
  264. if E.IsClosedOrCanceled(err) {
  265. s.logger.Debug(E.Cause(err, "connection failed"))
  266. } else {
  267. s.logger.Error(E.Cause(err, "connection failed"))
  268. }
  269. _ = s.quicConn.CloseWithError(0, "")
  270. }
  271. type serverConn struct {
  272. quic.Stream
  273. responseWritten bool
  274. }
  275. func (c *serverConn) HandshakeFailure(err error) error {
  276. if c.responseWritten {
  277. return os.ErrClosed
  278. }
  279. c.responseWritten = true
  280. buffer := protocol.WriteTCPResponse(false, err.Error(), nil)
  281. defer buffer.Release()
  282. return common.Error(c.Stream.Write(buffer.Bytes()))
  283. }
  284. func (c *serverConn) HandshakeSuccess() error {
  285. if c.responseWritten {
  286. return nil
  287. }
  288. c.responseWritten = true
  289. buffer := protocol.WriteTCPResponse(true, "", nil)
  290. defer buffer.Release()
  291. return common.Error(c.Stream.Write(buffer.Bytes()))
  292. }
  293. func (c *serverConn) Read(p []byte) (n int, err error) {
  294. n, err = c.Stream.Read(p)
  295. return n, baderror.WrapQUIC(err)
  296. }
  297. func (c *serverConn) Write(p []byte) (n int, err error) {
  298. if !c.responseWritten {
  299. c.responseWritten = true
  300. buffer := protocol.WriteTCPResponse(true, "", p)
  301. defer buffer.Release()
  302. _, err = c.Stream.Write(buffer.Bytes())
  303. if err != nil {
  304. return 0, baderror.WrapQUIC(err)
  305. }
  306. return len(p), nil
  307. }
  308. n, err = c.Stream.Write(p)
  309. return n, baderror.WrapQUIC(err)
  310. }
  311. func (c *serverConn) LocalAddr() net.Addr {
  312. return M.Socksaddr{}
  313. }
  314. func (c *serverConn) RemoteAddr() net.Addr {
  315. return M.Socksaddr{}
  316. }
  317. func (c *serverConn) Close() error {
  318. c.Stream.CancelRead(0)
  319. return c.Stream.Close()
  320. }