123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338 |
- //go:build with_quic
- package inbound
- import (
- "context"
- "sync"
- "github.com/sagernet/quic-go"
- "github.com/sagernet/quic-go/congestion"
- "github.com/sagernet/sing-box/adapter"
- "github.com/sagernet/sing-box/common/tls"
- C "github.com/sagernet/sing-box/constant"
- "github.com/sagernet/sing-box/log"
- "github.com/sagernet/sing-box/option"
- "github.com/sagernet/sing-box/transport/hysteria"
- "github.com/sagernet/sing/common"
- "github.com/sagernet/sing/common/auth"
- E "github.com/sagernet/sing/common/exceptions"
- F "github.com/sagernet/sing/common/format"
- M "github.com/sagernet/sing/common/metadata"
- N "github.com/sagernet/sing/common/network"
- "golang.org/x/exp/slices"
- )
- var _ adapter.Inbound = (*Hysteria)(nil)
- type Hysteria struct {
- myInboundAdapter
- quicConfig *quic.Config
- tlsConfig tls.ServerConfig
- authKey []string
- authUser []string
- xplusKey []byte
- sendBPS uint64
- recvBPS uint64
- listener quic.Listener
- udpAccess sync.RWMutex
- udpSessionId uint32
- udpSessions map[uint32]chan *hysteria.UDPMessage
- udpDefragger hysteria.Defragger
- }
- func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HysteriaInboundOptions) (*Hysteria, error) {
- options.UDPFragmentDefault = true
- quicConfig := &quic.Config{
- InitialStreamReceiveWindow: options.ReceiveWindowConn,
- MaxStreamReceiveWindow: options.ReceiveWindowConn,
- InitialConnectionReceiveWindow: options.ReceiveWindowClient,
- MaxConnectionReceiveWindow: options.ReceiveWindowClient,
- MaxIncomingStreams: int64(options.MaxConnClient),
- KeepAlivePeriod: hysteria.KeepAlivePeriod,
- DisablePathMTUDiscovery: options.DisableMTUDiscovery || !(C.IsLinux || C.IsWindows),
- EnableDatagrams: true,
- }
- if options.ReceiveWindowConn == 0 {
- quicConfig.InitialStreamReceiveWindow = hysteria.DefaultStreamReceiveWindow
- quicConfig.MaxStreamReceiveWindow = hysteria.DefaultStreamReceiveWindow
- }
- if options.ReceiveWindowClient == 0 {
- quicConfig.InitialConnectionReceiveWindow = hysteria.DefaultConnectionReceiveWindow
- quicConfig.MaxConnectionReceiveWindow = hysteria.DefaultConnectionReceiveWindow
- }
- if quicConfig.MaxIncomingStreams == 0 {
- quicConfig.MaxIncomingStreams = hysteria.DefaultMaxIncomingStreams
- }
- authKey := common.Map(options.Users, func(it option.HysteriaUser) string {
- if len(it.Auth) > 0 {
- return string(it.Auth)
- } else {
- return it.AuthString
- }
- })
- authUser := common.Map(options.Users, func(it option.HysteriaUser) string {
- return it.Name
- })
- var xplus []byte
- if options.Obfs != "" {
- xplus = []byte(options.Obfs)
- }
- var up, down uint64
- if len(options.Up) > 0 {
- up = hysteria.StringToBps(options.Up)
- if up == 0 {
- return nil, E.New("invalid up speed format: ", options.Up)
- }
- } else {
- up = uint64(options.UpMbps) * hysteria.MbpsToBps
- }
- if len(options.Down) > 0 {
- down = hysteria.StringToBps(options.Down)
- if down == 0 {
- return nil, E.New("invalid down speed format: ", options.Down)
- }
- } else {
- down = uint64(options.DownMbps) * hysteria.MbpsToBps
- }
- if up < hysteria.MinSpeedBPS {
- return nil, E.New("invalid up speed")
- }
- if down < hysteria.MinSpeedBPS {
- return nil, E.New("invalid down speed")
- }
- inbound := &Hysteria{
- myInboundAdapter: myInboundAdapter{
- protocol: C.TypeHysteria,
- network: []string{N.NetworkUDP},
- ctx: ctx,
- router: router,
- logger: logger,
- tag: tag,
- listenOptions: options.ListenOptions,
- },
- quicConfig: quicConfig,
- authKey: authKey,
- authUser: authUser,
- xplusKey: xplus,
- sendBPS: up,
- recvBPS: down,
- udpSessions: make(map[uint32]chan *hysteria.UDPMessage),
- }
- if options.TLS == nil || !options.TLS.Enabled {
- return nil, C.ErrTLSRequired
- }
- if len(options.TLS.ALPN) == 0 {
- options.TLS.ALPN = []string{hysteria.DefaultALPN}
- }
- tlsConfig, err := tls.NewServer(ctx, router, logger, common.PtrValueOrDefault(options.TLS))
- if err != nil {
- return nil, err
- }
- inbound.tlsConfig = tlsConfig
- return inbound, nil
- }
- func (h *Hysteria) Start() error {
- packetConn, err := h.myInboundAdapter.ListenUDP()
- if err != nil {
- return err
- }
- if len(h.xplusKey) > 0 {
- packetConn = hysteria.NewXPlusPacketConn(packetConn, h.xplusKey)
- packetConn = &hysteria.PacketConnWrapper{PacketConn: packetConn}
- }
- err = h.tlsConfig.Start()
- if err != nil {
- return err
- }
- rawConfig, err := h.tlsConfig.Config()
- if err != nil {
- return err
- }
- listener, err := quic.Listen(packetConn, rawConfig, h.quicConfig)
- if err != nil {
- return err
- }
- h.listener = listener
- h.logger.Info("udp server started at ", listener.Addr())
- go h.acceptLoop()
- return nil
- }
- func (h *Hysteria) acceptLoop() {
- for {
- ctx := log.ContextWithNewID(h.ctx)
- conn, err := h.listener.Accept(ctx)
- if err != nil {
- return
- }
- go func() {
- hErr := h.accept(ctx, conn)
- if hErr != nil {
- conn.CloseWithError(0, "")
- NewError(h.logger, ctx, E.Cause(hErr, "process connection from ", conn.RemoteAddr()))
- }
- }()
- }
- }
- func (h *Hysteria) accept(ctx context.Context, conn quic.Connection) error {
- controlStream, err := conn.AcceptStream(ctx)
- if err != nil {
- return err
- }
- clientHello, err := hysteria.ReadClientHello(controlStream)
- if err != nil {
- return err
- }
- if len(h.authKey) > 0 {
- userIndex := slices.Index(h.authKey, string(clientHello.Auth))
- if userIndex == -1 {
- err = hysteria.WriteServerHello(controlStream, hysteria.ServerHello{
- Message: "wrong password",
- })
- return E.Errors(E.New("wrong password: ", string(clientHello.Auth)), err)
- }
- user := h.authUser[userIndex]
- if user == "" {
- user = F.ToString(userIndex)
- } else {
- ctx = auth.ContextWithUser(ctx, user)
- }
- h.logger.InfoContext(ctx, "[", user, "] inbound connection from ", conn.RemoteAddr())
- } else {
- h.logger.InfoContext(ctx, "inbound connection from ", conn.RemoteAddr())
- }
- h.logger.DebugContext(ctx, "peer send speed: ", clientHello.SendBPS/1024/1024, " MBps, peer recv speed: ", clientHello.RecvBPS/1024/1024, " MBps")
- if clientHello.SendBPS == 0 || clientHello.RecvBPS == 0 {
- return E.New("invalid rate from client")
- }
- serverSendBPS, serverRecvBPS := clientHello.RecvBPS, clientHello.SendBPS
- if h.sendBPS > 0 && serverSendBPS > h.sendBPS {
- serverSendBPS = h.sendBPS
- }
- if h.recvBPS > 0 && serverRecvBPS > h.recvBPS {
- serverRecvBPS = h.recvBPS
- }
- err = hysteria.WriteServerHello(controlStream, hysteria.ServerHello{
- OK: true,
- SendBPS: serverSendBPS,
- RecvBPS: serverRecvBPS,
- })
- if err != nil {
- return err
- }
- conn.SetCongestionControl(hysteria.NewBrutalSender(congestion.ByteCount(serverSendBPS)))
- go h.udpRecvLoop(conn)
- for {
- var stream quic.Stream
- stream, err = conn.AcceptStream(ctx)
- if err != nil {
- return err
- }
- go func() {
- hErr := h.acceptStream(ctx, conn /*&hysteria.StreamWrapper{Stream: stream}*/, stream)
- if hErr != nil {
- stream.Close()
- NewError(h.logger, ctx, E.Cause(hErr, "process stream from ", conn.RemoteAddr()))
- }
- }()
- }
- }
- func (h *Hysteria) udpRecvLoop(conn quic.Connection) {
- for {
- packet, err := conn.ReceiveMessage()
- if err != nil {
- return
- }
- message, err := hysteria.ParseUDPMessage(packet)
- if err != nil {
- h.logger.Error("parse udp message: ", err)
- continue
- }
- dfMsg := h.udpDefragger.Feed(message)
- if dfMsg == nil {
- continue
- }
- h.udpAccess.RLock()
- ch, ok := h.udpSessions[dfMsg.SessionID]
- if ok {
- select {
- case ch <- dfMsg:
- // OK
- default:
- // Silently drop the message when the channel is full
- }
- }
- h.udpAccess.RUnlock()
- }
- }
- func (h *Hysteria) acceptStream(ctx context.Context, conn quic.Connection, stream quic.Stream) error {
- request, err := hysteria.ReadClientRequest(stream)
- if err != nil {
- return err
- }
- var metadata adapter.InboundContext
- metadata.Inbound = h.tag
- metadata.InboundType = C.TypeHysteria
- metadata.InboundOptions = h.listenOptions.InboundOptions
- metadata.Source = M.SocksaddrFromNet(conn.RemoteAddr()).Unwrap()
- metadata.OriginDestination = M.SocksaddrFromNet(conn.LocalAddr()).Unwrap()
- metadata.Destination = M.ParseSocksaddrHostPort(request.Host, request.Port).Unwrap()
- if !request.UDP {
- err = hysteria.WriteServerResponse(stream, hysteria.ServerResponse{
- OK: true,
- })
- if err != nil {
- return err
- }
- h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
- return h.router.RouteConnection(ctx, hysteria.NewConn(stream, metadata.Destination, false), metadata)
- } else {
- h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination)
- var id uint32
- h.udpAccess.Lock()
- id = h.udpSessionId
- nCh := make(chan *hysteria.UDPMessage, 1024)
- h.udpSessions[id] = nCh
- h.udpSessionId += 1
- h.udpAccess.Unlock()
- err = hysteria.WriteServerResponse(stream, hysteria.ServerResponse{
- OK: true,
- UDPSessionID: id,
- })
- if err != nil {
- return err
- }
- packetConn := hysteria.NewPacketConn(conn, stream, id, metadata.Destination, nCh, common.Closer(func() error {
- h.udpAccess.Lock()
- if ch, ok := h.udpSessions[id]; ok {
- close(ch)
- delete(h.udpSessions, id)
- }
- h.udpAccess.Unlock()
- return nil
- }))
- go packetConn.Hold()
- return h.router.RoutePacketConnection(ctx, packetConn, metadata)
- }
- }
- func (h *Hysteria) Close() error {
- h.udpAccess.Lock()
- for _, session := range h.udpSessions {
- close(session)
- }
- h.udpSessions = make(map[uint32]chan *hysteria.UDPMessage)
- h.udpAccess.Unlock()
- return common.Close(
- &h.myInboundAdapter,
- h.listener,
- h.tlsConfig,
- )
- }
|