Browse Source

Migrate QUIC wrapper and protocol implementations to library

世界 2 years ago
parent
commit
bd7adcbb7e
43 changed files with 82 additions and 5973 deletions
  1. 0 120
      common/qtls/wrapper.go
  2. 5 5
      common/tls/ech_quic.go
  3. 1 0
      go.mod
  4. 2 0
      go.sum
  5. 2 2
      inbound/hysteria.go
  6. 32 18
      inbound/hysteria2.go
  7. 1 1
      inbound/naive_quic.go
  8. 33 19
      inbound/tuic.go
  9. 1 1
      outbound/hysteria.go
  10. 1 1
      outbound/hysteria2.go
  11. 1 1
      outbound/tuic.go
  12. 0 314
      transport/hysteria2/client.go
  13. 0 47
      transport/hysteria2/client_paclet.go
  14. 0 151
      transport/hysteria2/congestion/brutal.go
  15. 0 86
      transport/hysteria2/congestion/pacer.go
  16. 0 68
      transport/hysteria2/internal/protocol/http.go
  17. 0 31
      transport/hysteria2/internal/protocol/padding.go
  18. 0 266
      transport/hysteria2/internal/protocol/proxy.go
  19. 0 450
      transport/hysteria2/packet.go
  20. 0 106
      transport/hysteria2/salamander.go
  21. 0 344
      transport/hysteria2/server.go
  22. 0 55
      transport/hysteria2/server_packet.go
  23. 0 10
      transport/tuic/address.go
  24. 0 307
      transport/tuic/client.go
  25. 0 112
      transport/tuic/client_packet.go
  26. 0 46
      transport/tuic/congestion.go
  27. 0 3
      transport/tuic/congestion/README.md
  28. 0 25
      transport/tuic/congestion/bandwidth.go
  29. 0 374
      transport/tuic/congestion/bandwidth_sampler.go
  30. 0 1000
      transport/tuic/congestion/bbr_sender.go
  31. 0 20
      transport/tuic/congestion/clock.go
  32. 0 213
      transport/tuic/congestion/cubic.go
  33. 0 318
      transport/tuic/congestion/cubic_sender.go
  34. 0 112
      transport/tuic/congestion/hybrid_slow_start.go
  35. 0 72
      transport/tuic/congestion/minmax.go
  36. 0 81
      transport/tuic/congestion/pacer.go
  37. 0 132
      transport/tuic/congestion/windowed_filter.go
  38. 0 532
      transport/tuic/packet.go
  39. 0 15
      transport/tuic/protocol.go
  40. 0 437
      transport/tuic/server.go
  41. 0 75
      transport/tuic/server_packet.go
  42. 1 1
      transport/v2rayquic/client.go
  43. 2 2
      transport/v2rayquic/server.go

+ 0 - 120
common/qtls/wrapper.go

@@ -1,120 +0,0 @@
-package qtls
-
-import (
-	"context"
-	"crypto/tls"
-	"net"
-	"net/http"
-
-	"github.com/sagernet/quic-go"
-	"github.com/sagernet/quic-go/http3"
-	M "github.com/sagernet/sing/common/metadata"
-	aTLS "github.com/sagernet/sing/common/tls"
-)
-
-type QUICConfig interface {
-	Dial(ctx context.Context, conn net.PacketConn, addr net.Addr, config *quic.Config) (quic.Connection, error)
-	DialEarly(ctx context.Context, conn net.PacketConn, addr net.Addr, config *quic.Config) (quic.EarlyConnection, error)
-	CreateTransport(conn net.PacketConn, quicConnPtr *quic.EarlyConnection, serverAddr M.Socksaddr, quicConfig *quic.Config, enableDatagrams bool) http.RoundTripper
-}
-
-type QUICServerConfig interface {
-	Listen(conn net.PacketConn, config *quic.Config) (QUICListener, error)
-	ListenEarly(conn net.PacketConn, config *quic.Config) (QUICEarlyListener, error)
-	ConfigureHTTP3()
-}
-
-type QUICListener interface {
-	Accept(ctx context.Context) (quic.Connection, error)
-	Close() error
-	Addr() net.Addr
-}
-
-type QUICEarlyListener interface {
-	Accept(ctx context.Context) (quic.EarlyConnection, error)
-	Close() error
-	Addr() net.Addr
-}
-
-func Dial(ctx context.Context, conn net.PacketConn, addr net.Addr, config aTLS.Config, quicConfig *quic.Config) (quic.Connection, error) {
-	if quicTLSConfig, isQUICConfig := config.(QUICConfig); isQUICConfig {
-		return quicTLSConfig.Dial(ctx, conn, addr, quicConfig)
-	}
-	tlsConfig, err := config.Config()
-	if err != nil {
-		return nil, err
-	}
-	return quic.Dial(ctx, conn, addr, tlsConfig, quicConfig)
-}
-
-func DialEarly(ctx context.Context, conn net.PacketConn, addr net.Addr, config aTLS.Config, quicConfig *quic.Config) (quic.EarlyConnection, error) {
-	if quicTLSConfig, isQUICConfig := config.(QUICConfig); isQUICConfig {
-		return quicTLSConfig.DialEarly(ctx, conn, addr, quicConfig)
-	}
-	tlsConfig, err := config.Config()
-	if err != nil {
-		return nil, err
-	}
-	return quic.DialEarly(ctx, conn, addr, tlsConfig, quicConfig)
-}
-
-func CreateTransport(conn net.PacketConn, quicConnPtr *quic.EarlyConnection, serverAddr M.Socksaddr, config aTLS.Config, quicConfig *quic.Config, enableDatagrams bool) (http.RoundTripper, error) {
-	if quicTLSConfig, isQUICConfig := config.(QUICConfig); isQUICConfig {
-		return quicTLSConfig.CreateTransport(conn, quicConnPtr, serverAddr, quicConfig, enableDatagrams), nil
-	}
-	tlsConfig, err := config.Config()
-	if err != nil {
-		return nil, err
-	}
-	return &http3.RoundTripper{
-		TLSClientConfig: tlsConfig,
-		QuicConfig:      quicConfig,
-		EnableDatagrams: enableDatagrams,
-		Dial: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
-			quicConn, err := quic.DialEarly(ctx, conn, serverAddr.UDPAddr(), tlsCfg, cfg)
-			if err != nil {
-				return nil, err
-			}
-			*quicConnPtr = quicConn
-			return quicConn, nil
-		},
-	}, nil
-}
-
-func Listen(conn net.PacketConn, config aTLS.ServerConfig, quicConfig *quic.Config) (QUICListener, error) {
-	if quicTLSConfig, isQUICConfig := config.(QUICServerConfig); isQUICConfig {
-		return quicTLSConfig.Listen(conn, quicConfig)
-	}
-	tlsConfig, err := config.Config()
-	if err != nil {
-		return nil, err
-	}
-	return quic.Listen(conn, tlsConfig, quicConfig)
-}
-
-func ListenEarly(conn net.PacketConn, config aTLS.ServerConfig, quicConfig *quic.Config) (QUICEarlyListener, error) {
-	if quicTLSConfig, isQUICConfig := config.(QUICServerConfig); isQUICConfig {
-		return quicTLSConfig.ListenEarly(conn, quicConfig)
-	}
-	tlsConfig, err := config.Config()
-	if err != nil {
-		return nil, err
-	}
-	return quic.ListenEarly(conn, tlsConfig, quicConfig)
-}
-
-func ConfigureHTTP3(config aTLS.ServerConfig) error {
-	if len(config.NextProtos()) == 0 {
-		config.SetNextProtos([]string{http3.NextProtoH3})
-	}
-	if quicTLSConfig, isQUICConfig := config.(QUICServerConfig); isQUICConfig {
-		quicTLSConfig.ConfigureHTTP3()
-		return nil
-	}
-	tlsConfig, err := config.Config()
-	if err != nil {
-		return err
-	}
-	http3.ConfigureTLSConfig(tlsConfig)
-	return nil
-}

+ 5 - 5
common/tls/ech_quic.go

@@ -10,13 +10,13 @@ import (
 	"github.com/sagernet/cloudflare-tls"
 	"github.com/sagernet/quic-go/ech"
 	"github.com/sagernet/quic-go/http3_ech"
-	"github.com/sagernet/sing-box/common/qtls"
+	"github.com/sagernet/sing-quic"
 	M "github.com/sagernet/sing/common/metadata"
 )
 
 var (
-	_ qtls.QUICConfig       = (*echClientConfig)(nil)
-	_ qtls.QUICServerConfig = (*echServerConfig)(nil)
+	_ qtls.Config       = (*echClientConfig)(nil)
+	_ qtls.ServerConfig = (*echServerConfig)(nil)
 )
 
 func (c *echClientConfig) Dial(ctx context.Context, conn net.PacketConn, addr net.Addr, config *quic.Config) (quic.Connection, error) {
@@ -43,11 +43,11 @@ func (c *echClientConfig) CreateTransport(conn net.PacketConn, quicConnPtr *quic
 	}
 }
 
-func (c *echServerConfig) Listen(conn net.PacketConn, config *quic.Config) (qtls.QUICListener, error) {
+func (c *echServerConfig) Listen(conn net.PacketConn, config *quic.Config) (qtls.Listener, error) {
 	return quic.Listen(conn, c.config, config)
 }
 
-func (c *echServerConfig) ListenEarly(conn net.PacketConn, config *quic.Config) (qtls.QUICEarlyListener, error) {
+func (c *echServerConfig) ListenEarly(conn net.PacketConn, config *quic.Config) (qtls.EarlyListener, error) {
 	return quic.ListenEarly(conn, c.config, config)
 }
 

+ 1 - 0
go.mod

@@ -29,6 +29,7 @@ require (
 	github.com/sagernet/sing v0.2.10-0.20230912050851-1453c7c8c20d
 	github.com/sagernet/sing-dns v0.1.9-0.20230911082806-425022bdc92b
 	github.com/sagernet/sing-mux v0.1.3-0.20230908032617-759a1886a400
+	github.com/sagernet/sing-quic v0.0.0-20230915093242-b55f3531e703
 	github.com/sagernet/sing-shadowsocks v0.2.5-0.20230907005610-126234728ca0
 	github.com/sagernet/sing-shadowsocks2 v0.1.4-0.20230907005906-5d2917b29248
 	github.com/sagernet/sing-shadowtls v0.1.4

+ 2 - 0
go.sum

@@ -118,6 +118,8 @@ github.com/sagernet/sing-dns v0.1.9-0.20230911082806-425022bdc92b h1:m/UWg2voyb9
 github.com/sagernet/sing-dns v0.1.9-0.20230911082806-425022bdc92b/go.mod h1:Kg98PBJEg/08jsNFtmZWmPomhskn9Ausn50ecNm4M+8=
 github.com/sagernet/sing-mux v0.1.3-0.20230908032617-759a1886a400 h1:LtpYd5c5AJtUSxjyH4KjUS8HT+2XgilyozjbCq/x3EM=
 github.com/sagernet/sing-mux v0.1.3-0.20230908032617-759a1886a400/go.mod h1:TKxqIvfQQgd36jp2tzsPavGjYTVZilV+atip1cssjIY=
+github.com/sagernet/sing-quic v0.0.0-20230915093242-b55f3531e703 h1:BbJZ5RkY3jQk5P9G5Ra0VhmDNKdT0aIP1FszEDyQL+o=
+github.com/sagernet/sing-quic v0.0.0-20230915093242-b55f3531e703/go.mod h1:Mh5Senu4XDuX+RxSPQEoUB0j6kVmGais2h62Cnfj6Xk=
 github.com/sagernet/sing-shadowsocks v0.2.5-0.20230907005610-126234728ca0 h1:9wHYWxH+fcs01PM2+DylA8LNNY3ElnZykQo9rysng8U=
 github.com/sagernet/sing-shadowsocks v0.2.5-0.20230907005610-126234728ca0/go.mod h1:80fNKP0wnqlu85GZXV1H1vDPC/2t+dQbFggOw4XuFUM=
 github.com/sagernet/sing-shadowsocks2 v0.1.4-0.20230907005906-5d2917b29248 h1:JTFfy/LDmVFEK4KZJEujmC1iO8+aoF4unYhhZZRzRq4=

+ 2 - 2
inbound/hysteria.go

@@ -9,12 +9,12 @@ import (
 	"github.com/sagernet/quic-go"
 	"github.com/sagernet/quic-go/congestion"
 	"github.com/sagernet/sing-box/adapter"
-	"github.com/sagernet/sing-box/common/qtls"
 	"github.com/sagernet/sing-box/common/tls"
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/log"
 	"github.com/sagernet/sing-box/option"
 	"github.com/sagernet/sing-box/transport/hysteria"
+	"github.com/sagernet/sing-quic"
 	"github.com/sagernet/sing/common"
 	"github.com/sagernet/sing/common/auth"
 	E "github.com/sagernet/sing/common/exceptions"
@@ -36,7 +36,7 @@ type Hysteria struct {
 	xplusKey     []byte
 	sendBPS      uint64
 	recvBPS      uint64
-	listener     qtls.QUICListener
+	listener     qtls.Listener
 	udpAccess    sync.RWMutex
 	udpSessionId uint32
 	udpSessions  map[uint32]chan *hysteria.UDPMessage

+ 32 - 18
inbound/hysteria2.go

@@ -14,7 +14,7 @@ import (
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/log"
 	"github.com/sagernet/sing-box/option"
-	"github.com/sagernet/sing-box/transport/hysteria2"
+	"github.com/sagernet/sing-quic/hysteria2"
 	"github.com/sagernet/sing/common"
 	"github.com/sagernet/sing/common/auth"
 	E "github.com/sagernet/sing/common/exceptions"
@@ -25,8 +25,9 @@ var _ adapter.Inbound = (*Hysteria2)(nil)
 
 type Hysteria2 struct {
 	myInboundAdapter
-	tlsConfig tls.ServerConfig
-	server    *hysteria2.Server
+	tlsConfig    tls.ServerConfig
+	service      *hysteria2.Service[int]
+	userNameList []string
 }
 
 func NewHysteria2(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.Hysteria2InboundOptions) (*Hysteria2, error) {
@@ -84,16 +85,13 @@ func NewHysteria2(ctx context.Context, router adapter.Router, logger log.Context
 		},
 		tlsConfig: tlsConfig,
 	}
-	server, err := hysteria2.NewServer(hysteria2.ServerOptions{
-		Context:            ctx,
-		Logger:             logger,
-		SendBPS:            uint64(options.UpMbps * 1024 * 1024),
-		ReceiveBPS:         uint64(options.DownMbps * 1024 * 1024),
-		SalamanderPassword: salamanderPassword,
-		TLSConfig:          tlsConfig,
-		Users: common.Map(options.Users, func(it option.Hysteria2User) hysteria2.User {
-			return hysteria2.User(it)
-		}),
+	service, err := hysteria2.NewService[int](hysteria2.ServiceOptions{
+		Context:               ctx,
+		Logger:                logger,
+		SendBPS:               uint64(options.UpMbps * 1024 * 1024),
+		ReceiveBPS:            uint64(options.DownMbps * 1024 * 1024),
+		SalamanderPassword:    salamanderPassword,
+		TLSConfig:             tlsConfig,
 		IgnoreClientBandwidth: options.IgnoreClientBandwidth,
 		Handler:               adapter.NewUpstreamHandler(adapter.InboundContext{}, inbound.newConnection, inbound.newPacketConnection, nil),
 		MasqueradeHandler:     masqueradeHandler,
@@ -101,7 +99,17 @@ func NewHysteria2(ctx context.Context, router adapter.Router, logger log.Context
 	if err != nil {
 		return nil, err
 	}
-	inbound.server = server
+	userList := make([]int, 0, len(options.Users))
+	userNameList := make([]string, 0, len(options.Users))
+	userPasswordList := make([]string, 0, len(options.Users))
+	for index, user := range options.Users {
+		userList = append(userList, index)
+		userNameList = append(userNameList, user.Name)
+		userPasswordList = append(userPasswordList, user.Password)
+	}
+	service.UpdateUsers(userList, userPasswordList)
+	inbound.service = service
+	inbound.userNameList = userNameList
 	return inbound, nil
 }
 
@@ -109,14 +117,20 @@ func (h *Hysteria2) newConnection(ctx context.Context, conn net.Conn, metadata a
 	ctx = log.ContextWithNewID(ctx)
 	h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
 	metadata = h.createMetadata(conn, metadata)
-	metadata.User, _ = auth.UserFromContext[string](ctx)
+	userID, _ := auth.UserFromContext[int](ctx)
+	if userName := h.userNameList[userID]; userName != "" {
+		metadata.User = userName
+	}
 	return h.router.RouteConnection(ctx, conn, metadata)
 }
 
 func (h *Hysteria2) newPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
 	ctx = log.ContextWithNewID(ctx)
 	metadata = h.createPacketMetadata(conn, metadata)
-	metadata.User, _ = auth.UserFromContext[string](ctx)
+	userID, _ := auth.UserFromContext[int](ctx)
+	if userName := h.userNameList[userID]; userName != "" {
+		metadata.User = userName
+	}
 	h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination)
 	return h.router.RoutePacketConnection(ctx, conn, metadata)
 }
@@ -132,13 +146,13 @@ func (h *Hysteria2) Start() error {
 	if err != nil {
 		return err
 	}
-	return h.server.Start(packetConn)
+	return h.service.Start(packetConn)
 }
 
 func (h *Hysteria2) Close() error {
 	return common.Close(
 		&h.myInboundAdapter,
 		h.tlsConfig,
-		common.PtrOrNil(h.server),
+		common.PtrOrNil(h.service),
 	)
 }

+ 1 - 1
inbound/naive_quic.go

@@ -5,7 +5,7 @@ package inbound
 import (
 	"github.com/sagernet/quic-go"
 	"github.com/sagernet/quic-go/http3"
-	"github.com/sagernet/sing-box/common/qtls"
+	"github.com/sagernet/sing-quic"
 	E "github.com/sagernet/sing/common/exceptions"
 )
 

+ 33 - 19
inbound/tuic.go

@@ -12,7 +12,7 @@ import (
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/log"
 	"github.com/sagernet/sing-box/option"
-	"github.com/sagernet/sing-box/transport/tuic"
+	"github.com/sagernet/sing-quic/tuic"
 	"github.com/sagernet/sing/common"
 	"github.com/sagernet/sing/common/auth"
 	E "github.com/sagernet/sing/common/exceptions"
@@ -25,8 +25,9 @@ var _ adapter.Inbound = (*TUIC)(nil)
 
 type TUIC struct {
 	myInboundAdapter
-	server    *tuic.Server
-	tlsConfig tls.ServerConfig
+	tlsConfig    tls.ServerConfig
+	server       *tuic.Service[int]
+	userNameList []string
 }
 
 func NewTUIC(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TUICInboundOptions) (*TUIC, error) {
@@ -38,17 +39,6 @@ func NewTUIC(ctx context.Context, router adapter.Router, logger log.ContextLogge
 	if err != nil {
 		return nil, err
 	}
-	var users []tuic.User
-	for index, user := range options.Users {
-		if user.UUID == "" {
-			return nil, E.New("missing uuid for user ", index)
-		}
-		userUUID, err := uuid.FromString(user.UUID)
-		if err != nil {
-			return nil, E.Cause(err, "invalid uuid for user ", index)
-		}
-		users = append(users, tuic.User{Name: user.Name, UUID: userUUID, Password: user.Password})
-	}
 	inbound := &TUIC{
 		myInboundAdapter: myInboundAdapter{
 			protocol:      C.TypeTUIC,
@@ -60,11 +50,10 @@ func NewTUIC(ctx context.Context, router adapter.Router, logger log.ContextLogge
 			listenOptions: options.ListenOptions,
 		},
 	}
-	server, err := tuic.NewServer(tuic.ServerOptions{
+	service, err := tuic.NewService[int](tuic.ServiceOptions{
 		Context:           ctx,
 		Logger:            logger,
 		TLSConfig:         tlsConfig,
-		Users:             users,
 		CongestionControl: options.CongestionControl,
 		AuthTimeout:       time.Duration(options.AuthTimeout),
 		ZeroRTTHandshake:  options.ZeroRTTHandshake,
@@ -74,7 +63,26 @@ func NewTUIC(ctx context.Context, router adapter.Router, logger log.ContextLogge
 	if err != nil {
 		return nil, err
 	}
-	inbound.server = server
+	var userList []int
+	var userNameList []string
+	var userUUIDList [][16]byte
+	var userPasswordList []string
+	for index, user := range options.Users {
+		if user.UUID == "" {
+			return nil, E.New("missing uuid for user ", index)
+		}
+		userUUID, err := uuid.FromString(user.UUID)
+		if err != nil {
+			return nil, E.Cause(err, "invalid uuid for user ", index)
+		}
+		userList = append(userList, index)
+		userNameList = append(userNameList, user.Name)
+		userUUIDList = append(userUUIDList, userUUID)
+		userPasswordList = append(userPasswordList, user.Password)
+	}
+	service.UpdateUsers(userList, userUUIDList, userPasswordList)
+	inbound.server = service
+	inbound.userNameList = userNameList
 	return inbound, nil
 }
 
@@ -82,14 +90,20 @@ func (h *TUIC) newConnection(ctx context.Context, conn net.Conn, metadata adapte
 	ctx = log.ContextWithNewID(ctx)
 	h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
 	metadata = h.createMetadata(conn, metadata)
-	metadata.User, _ = auth.UserFromContext[string](ctx)
+	userID, _ := auth.UserFromContext[int](ctx)
+	if userName := h.userNameList[userID]; userName != "" {
+		metadata.User = userName
+	}
 	return h.router.RouteConnection(ctx, conn, metadata)
 }
 
 func (h *TUIC) newPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
 	ctx = log.ContextWithNewID(ctx)
 	metadata = h.createPacketMetadata(conn, metadata)
-	metadata.User, _ = auth.UserFromContext[string](ctx)
+	userID, _ := auth.UserFromContext[int](ctx)
+	if userName := h.userNameList[userID]; userName != "" {
+		metadata.User = userName
+	}
 	h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination)
 	return h.router.RoutePacketConnection(ctx, conn, metadata)
 }

+ 1 - 1
outbound/hysteria.go

@@ -11,12 +11,12 @@ import (
 	"github.com/sagernet/quic-go/congestion"
 	"github.com/sagernet/sing-box/adapter"
 	"github.com/sagernet/sing-box/common/dialer"
-	"github.com/sagernet/sing-box/common/qtls"
 	"github.com/sagernet/sing-box/common/tls"
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/log"
 	"github.com/sagernet/sing-box/option"
 	"github.com/sagernet/sing-box/transport/hysteria"
+	"github.com/sagernet/sing-quic"
 	"github.com/sagernet/sing/common"
 	"github.com/sagernet/sing/common/bufio"
 	E "github.com/sagernet/sing/common/exceptions"

+ 1 - 1
outbound/hysteria2.go

@@ -13,7 +13,7 @@ import (
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/log"
 	"github.com/sagernet/sing-box/option"
-	"github.com/sagernet/sing-box/transport/hysteria2"
+	"github.com/sagernet/sing-quic/hysteria2"
 	"github.com/sagernet/sing/common"
 	"github.com/sagernet/sing/common/bufio"
 	E "github.com/sagernet/sing/common/exceptions"

+ 1 - 1
outbound/tuic.go

@@ -14,7 +14,7 @@ import (
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/log"
 	"github.com/sagernet/sing-box/option"
-	"github.com/sagernet/sing-box/transport/tuic"
+	"github.com/sagernet/sing-quic/tuic"
 	"github.com/sagernet/sing/common"
 	"github.com/sagernet/sing/common/bufio"
 	E "github.com/sagernet/sing/common/exceptions"

+ 0 - 314
transport/hysteria2/client.go

@@ -1,314 +0,0 @@
-package hysteria2
-
-import (
-	"context"
-	"io"
-	"net"
-	"net/http"
-	"net/url"
-	"os"
-	"runtime"
-	"sync"
-	"time"
-
-	"github.com/sagernet/quic-go"
-	"github.com/sagernet/sing-box/common/qtls"
-	"github.com/sagernet/sing-box/common/tls"
-	"github.com/sagernet/sing-box/transport/hysteria2/congestion"
-	"github.com/sagernet/sing-box/transport/hysteria2/internal/protocol"
-	tuicCongestion "github.com/sagernet/sing-box/transport/tuic/congestion"
-	"github.com/sagernet/sing/common/baderror"
-	"github.com/sagernet/sing/common/bufio"
-	E "github.com/sagernet/sing/common/exceptions"
-	M "github.com/sagernet/sing/common/metadata"
-	N "github.com/sagernet/sing/common/network"
-)
-
-const (
-	defaultStreamReceiveWindow = 8388608                            // 8MB
-	defaultConnReceiveWindow   = defaultStreamReceiveWindow * 5 / 2 // 20MB
-	defaultMaxIdleTimeout      = 30 * time.Second
-	defaultKeepAlivePeriod     = 10 * time.Second
-)
-
-type ClientOptions struct {
-	Context            context.Context
-	Dialer             N.Dialer
-	ServerAddress      M.Socksaddr
-	SendBPS            uint64
-	ReceiveBPS         uint64
-	SalamanderPassword string
-	Password           string
-	TLSConfig          tls.Config
-	UDPDisabled        bool
-}
-
-type Client struct {
-	ctx                context.Context
-	dialer             N.Dialer
-	serverAddr         M.Socksaddr
-	sendBPS            uint64
-	receiveBPS         uint64
-	salamanderPassword string
-	password           string
-	tlsConfig          tls.Config
-	quicConfig         *quic.Config
-	udpDisabled        bool
-
-	connAccess sync.RWMutex
-	conn       *clientQUICConnection
-}
-
-func NewClient(options ClientOptions) (*Client, error) {
-	quicConfig := &quic.Config{
-		DisablePathMTUDiscovery:        !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"),
-		EnableDatagrams:                true,
-		InitialStreamReceiveWindow:     defaultStreamReceiveWindow,
-		MaxStreamReceiveWindow:         defaultStreamReceiveWindow,
-		InitialConnectionReceiveWindow: defaultConnReceiveWindow,
-		MaxConnectionReceiveWindow:     defaultConnReceiveWindow,
-		MaxIdleTimeout:                 defaultMaxIdleTimeout,
-		KeepAlivePeriod:                defaultKeepAlivePeriod,
-	}
-	return &Client{
-		ctx:                options.Context,
-		dialer:             options.Dialer,
-		serverAddr:         options.ServerAddress,
-		sendBPS:            options.SendBPS,
-		receiveBPS:         options.ReceiveBPS,
-		salamanderPassword: options.SalamanderPassword,
-		password:           options.Password,
-		tlsConfig:          options.TLSConfig,
-		quicConfig:         quicConfig,
-		udpDisabled:        options.UDPDisabled,
-	}, nil
-}
-
-func (c *Client) offer(ctx context.Context) (*clientQUICConnection, error) {
-	conn := c.conn
-	if conn != nil && conn.active() {
-		return conn, nil
-	}
-	c.connAccess.Lock()
-	defer c.connAccess.Unlock()
-	conn = c.conn
-	if conn != nil && conn.active() {
-		return conn, nil
-	}
-	conn, err := c.offerNew(ctx)
-	if err != nil {
-		return nil, err
-	}
-	return conn, nil
-}
-
-func (c *Client) offerNew(ctx context.Context) (*clientQUICConnection, error) {
-	udpConn, err := c.dialer.DialContext(c.ctx, "udp", c.serverAddr)
-	if err != nil {
-		return nil, err
-	}
-	var packetConn net.PacketConn
-	packetConn = bufio.NewUnbindPacketConn(udpConn)
-	if c.salamanderPassword != "" {
-		packetConn = NewSalamanderConn(packetConn, []byte(c.salamanderPassword))
-	}
-	var quicConn quic.EarlyConnection
-	http3Transport, err := qtls.CreateTransport(packetConn, &quicConn, c.serverAddr, c.tlsConfig, c.quicConfig, true)
-	if err != nil {
-		udpConn.Close()
-		return nil, err
-	}
-	request := &http.Request{
-		Method: http.MethodPost,
-		URL: &url.URL{
-			Scheme: "https",
-			Host:   protocol.URLHost,
-			Path:   protocol.URLPath,
-		},
-		Header: make(http.Header),
-	}
-	protocol.AuthRequestToHeader(request.Header, protocol.AuthRequest{Auth: c.password, Rx: c.receiveBPS})
-	response, err := http3Transport.RoundTrip(request.WithContext(ctx))
-	if err != nil {
-		if quicConn != nil {
-			quicConn.CloseWithError(0, "")
-		}
-		udpConn.Close()
-		return nil, err
-	}
-	if response.StatusCode != protocol.StatusAuthOK {
-		if quicConn != nil {
-			quicConn.CloseWithError(0, "")
-		}
-		udpConn.Close()
-		return nil, E.New("authentication failed, status code: ", response.StatusCode)
-	}
-	response.Body.Close()
-	authResponse := protocol.AuthResponseFromHeader(response.Header)
-	actualTx := authResponse.Rx
-	if actualTx == 0 || actualTx > c.sendBPS {
-		actualTx = c.sendBPS
-	}
-	if !authResponse.RxAuto && actualTx > 0 {
-		quicConn.SetCongestionControl(congestion.NewBrutalSender(actualTx))
-	} else {
-		quicConn.SetCongestionControl(tuicCongestion.NewBBRSender(
-			tuicCongestion.DefaultClock{},
-			tuicCongestion.GetInitialPacketSize(quicConn.RemoteAddr()),
-			tuicCongestion.InitialCongestionWindow*tuicCongestion.InitialMaxDatagramSize,
-			tuicCongestion.DefaultBBRMaxCongestionWindow*tuicCongestion.InitialMaxDatagramSize,
-		))
-	}
-	conn := &clientQUICConnection{
-		quicConn:    quicConn,
-		rawConn:     udpConn,
-		connDone:    make(chan struct{}),
-		udpDisabled: c.udpDisabled || !authResponse.UDPEnabled,
-		udpConnMap:  make(map[uint32]*udpPacketConn),
-	}
-	if !c.udpDisabled {
-		go c.loopMessages(conn)
-	}
-	c.conn = conn
-	return conn, nil
-}
-
-func (c *Client) DialConn(ctx context.Context, destination M.Socksaddr) (net.Conn, error) {
-	conn, err := c.offer(ctx)
-	if err != nil {
-		return nil, err
-	}
-	stream, err := conn.quicConn.OpenStream()
-	if err != nil {
-		return nil, err
-	}
-	return &clientConn{
-		Stream:      stream,
-		destination: destination,
-	}, nil
-}
-
-func (c *Client) ListenPacket(ctx context.Context) (net.PacketConn, error) {
-	if c.udpDisabled {
-		return nil, os.ErrInvalid
-	}
-	conn, err := c.offer(ctx)
-	if err != nil {
-		return nil, err
-	}
-	if conn.udpDisabled {
-		return nil, E.New("UDP disabled by server")
-	}
-	var sessionID uint32
-	clientPacketConn := newUDPPacketConn(ctx, conn.quicConn, func() {
-		conn.udpAccess.Lock()
-		delete(conn.udpConnMap, sessionID)
-		conn.udpAccess.Unlock()
-	})
-	conn.udpAccess.Lock()
-	sessionID = conn.udpSessionID
-	conn.udpSessionID++
-	conn.udpConnMap[sessionID] = clientPacketConn
-	conn.udpAccess.Unlock()
-	clientPacketConn.sessionID = sessionID
-	return clientPacketConn, nil
-}
-
-func (c *Client) CloseWithError(err error) error {
-	conn := c.conn
-	if conn != nil {
-		conn.closeWithError(err)
-	}
-	return nil
-}
-
-type clientQUICConnection struct {
-	quicConn     quic.Connection
-	rawConn      io.Closer
-	closeOnce    sync.Once
-	connDone     chan struct{}
-	connErr      error
-	udpDisabled  bool
-	udpAccess    sync.RWMutex
-	udpConnMap   map[uint32]*udpPacketConn
-	udpSessionID uint32
-}
-
-func (c *clientQUICConnection) active() bool {
-	select {
-	case <-c.quicConn.Context().Done():
-		return false
-	default:
-	}
-	select {
-	case <-c.connDone:
-		return false
-	default:
-	}
-	return true
-}
-
-func (c *clientQUICConnection) closeWithError(err error) {
-	c.closeOnce.Do(func() {
-		c.connErr = err
-		close(c.connDone)
-		c.quicConn.CloseWithError(0, "")
-	})
-}
-
-type clientConn struct {
-	quic.Stream
-	destination    M.Socksaddr
-	requestWritten bool
-	responseRead   bool
-}
-
-func (c *clientConn) NeedHandshake() bool {
-	return !c.requestWritten
-}
-
-func (c *clientConn) Read(p []byte) (n int, err error) {
-	if c.responseRead {
-		n, err = c.Stream.Read(p)
-		return n, baderror.WrapQUIC(err)
-	}
-	status, errorMessage, err := protocol.ReadTCPResponse(c.Stream)
-	if err != nil {
-		return 0, baderror.WrapQUIC(err)
-	}
-	if !status {
-		err = E.New("remote error: ", errorMessage)
-		return
-	}
-	c.responseRead = true
-	n, err = c.Stream.Read(p)
-	return n, baderror.WrapQUIC(err)
-}
-
-func (c *clientConn) Write(p []byte) (n int, err error) {
-	if !c.requestWritten {
-		buffer := protocol.WriteTCPRequest(c.destination.String(), p)
-		defer buffer.Release()
-		_, err = c.Stream.Write(buffer.Bytes())
-		if err != nil {
-			return
-		}
-		c.requestWritten = true
-		return len(p), nil
-	}
-	n, err = c.Stream.Write(p)
-	return n, baderror.WrapQUIC(err)
-}
-
-func (c *clientConn) LocalAddr() net.Addr {
-	return M.Socksaddr{}
-}
-
-func (c *clientConn) RemoteAddr() net.Addr {
-	return M.Socksaddr{}
-}
-
-func (c *clientConn) Close() error {
-	c.Stream.CancelRead(0)
-	return c.Stream.Close()
-}

+ 0 - 47
transport/hysteria2/client_paclet.go

@@ -1,47 +0,0 @@
-package hysteria2
-
-import E "github.com/sagernet/sing/common/exceptions"
-
-func (c *Client) loopMessages(conn *clientQUICConnection) {
-	for {
-		message, err := conn.quicConn.ReceiveMessage(c.ctx)
-		if err != nil {
-			conn.closeWithError(E.Cause(err, "receive message"))
-			return
-		}
-		go func() {
-			hErr := c.handleMessage(conn, message)
-			if hErr != nil {
-				conn.closeWithError(E.Cause(hErr, "handle message"))
-			}
-		}()
-	}
-}
-
-func (c *Client) handleMessage(conn *clientQUICConnection, data []byte) error {
-	message := allocMessage()
-	err := decodeUDPMessage(message, data)
-	if err != nil {
-		message.release()
-		return E.Cause(err, "decode UDP message")
-	}
-	conn.handleUDPMessage(message)
-	return nil
-}
-
-func (c *clientQUICConnection) handleUDPMessage(message *udpMessage) {
-	c.udpAccess.RLock()
-	udpConn, loaded := c.udpConnMap[message.sessionID]
-	c.udpAccess.RUnlock()
-	if !loaded {
-		message.releaseMessage()
-		return
-	}
-	select {
-	case <-udpConn.ctx.Done():
-		message.releaseMessage()
-		return
-	default:
-	}
-	udpConn.inputPacket(message)
-}

+ 0 - 151
transport/hysteria2/congestion/brutal.go

@@ -1,151 +0,0 @@
-package congestion
-
-import (
-	"time"
-
-	"github.com/sagernet/quic-go/congestion"
-)
-
-const (
-	initMaxDatagramSize = 1252
-
-	pktInfoSlotCount = 4
-	minSampleCount   = 50
-	minAckRate       = 0.8
-)
-
-var _ congestion.CongestionControl = &BrutalSender{}
-
-type BrutalSender struct {
-	rttStats        congestion.RTTStatsProvider
-	bps             congestion.ByteCount
-	maxDatagramSize congestion.ByteCount
-	pacer           *pacer
-
-	pktInfoSlots [pktInfoSlotCount]pktInfo
-	ackRate      float64
-}
-
-type pktInfo struct {
-	Timestamp int64
-	AckCount  uint64
-	LossCount uint64
-}
-
-func NewBrutalSender(bps uint64) *BrutalSender {
-	bs := &BrutalSender{
-		bps:             congestion.ByteCount(bps),
-		maxDatagramSize: initMaxDatagramSize,
-		ackRate:         1,
-	}
-	bs.pacer = newPacer(func() congestion.ByteCount {
-		return congestion.ByteCount(float64(bs.bps) / bs.ackRate)
-	})
-	return bs
-}
-
-func (b *BrutalSender) SetRTTStatsProvider(rttStats congestion.RTTStatsProvider) {
-	b.rttStats = rttStats
-}
-
-func (b *BrutalSender) TimeUntilSend(bytesInFlight congestion.ByteCount) time.Time {
-	return b.pacer.TimeUntilSend()
-}
-
-func (b *BrutalSender) HasPacingBudget(now time.Time) bool {
-	return b.pacer.Budget(now) >= b.maxDatagramSize
-}
-
-func (b *BrutalSender) CanSend(bytesInFlight congestion.ByteCount) bool {
-	return bytesInFlight < b.GetCongestionWindow()
-}
-
-func (b *BrutalSender) GetCongestionWindow() congestion.ByteCount {
-	rtt := b.rttStats.SmoothedRTT()
-	if rtt <= 0 {
-		return 10240
-	}
-	return congestion.ByteCount(float64(b.bps) * rtt.Seconds() * 1.5 / b.ackRate)
-}
-
-func (b *BrutalSender) OnPacketSent(sentTime time.Time, bytesInFlight congestion.ByteCount,
-	packetNumber congestion.PacketNumber, bytes congestion.ByteCount, isRetransmittable bool,
-) {
-	b.pacer.SentPacket(sentTime, bytes)
-}
-
-func (b *BrutalSender) OnPacketAcked(number congestion.PacketNumber, ackedBytes congestion.ByteCount,
-	priorInFlight congestion.ByteCount, eventTime time.Time,
-) {
-	currentTimestamp := eventTime.Unix()
-	slot := currentTimestamp % pktInfoSlotCount
-	if b.pktInfoSlots[slot].Timestamp == currentTimestamp {
-		b.pktInfoSlots[slot].AckCount++
-	} else {
-		// uninitialized slot or too old, reset
-		b.pktInfoSlots[slot].Timestamp = currentTimestamp
-		b.pktInfoSlots[slot].AckCount = 1
-		b.pktInfoSlots[slot].LossCount = 0
-	}
-	b.updateAckRate(currentTimestamp)
-}
-
-func (b *BrutalSender) OnPacketLost(number congestion.PacketNumber, lostBytes congestion.ByteCount,
-	priorInFlight congestion.ByteCount,
-) {
-	currentTimestamp := time.Now().Unix()
-	slot := currentTimestamp % pktInfoSlotCount
-	if b.pktInfoSlots[slot].Timestamp == currentTimestamp {
-		b.pktInfoSlots[slot].LossCount++
-	} else {
-		// uninitialized slot or too old, reset
-		b.pktInfoSlots[slot].Timestamp = currentTimestamp
-		b.pktInfoSlots[slot].AckCount = 0
-		b.pktInfoSlots[slot].LossCount = 1
-	}
-	b.updateAckRate(currentTimestamp)
-}
-
-func (b *BrutalSender) SetMaxDatagramSize(size congestion.ByteCount) {
-	b.maxDatagramSize = size
-	b.pacer.SetMaxDatagramSize(size)
-}
-
-func (b *BrutalSender) updateAckRate(currentTimestamp int64) {
-	minTimestamp := currentTimestamp - pktInfoSlotCount
-	var ackCount, lossCount uint64
-	for _, info := range b.pktInfoSlots {
-		if info.Timestamp < minTimestamp {
-			continue
-		}
-		ackCount += info.AckCount
-		lossCount += info.LossCount
-	}
-	if ackCount+lossCount < minSampleCount {
-		b.ackRate = 1
-	}
-	rate := float64(ackCount) / float64(ackCount+lossCount)
-	if rate < minAckRate {
-		b.ackRate = minAckRate
-	}
-	b.ackRate = rate
-}
-
-func (b *BrutalSender) InSlowStart() bool {
-	return false
-}
-
-func (b *BrutalSender) InRecovery() bool {
-	return false
-}
-
-func (b *BrutalSender) MaybeExitSlowStart() {}
-
-func (b *BrutalSender) OnRetransmissionTimeout(packetsRetransmitted bool) {}
-
-func maxDuration(a, b time.Duration) time.Duration {
-	if a > b {
-		return a
-	}
-	return b
-}

+ 0 - 86
transport/hysteria2/congestion/pacer.go

@@ -1,86 +0,0 @@
-package congestion
-
-import (
-	"math"
-	"time"
-
-	"github.com/sagernet/quic-go/congestion"
-)
-
-const (
-	maxBurstPackets = 10
-	minPacingDelay  = time.Millisecond
-)
-
-// The pacer implements a token bucket pacing algorithm.
-type pacer struct {
-	budgetAtLastSent congestion.ByteCount
-	maxDatagramSize  congestion.ByteCount
-	lastSentTime     time.Time
-	getBandwidth     func() congestion.ByteCount // in bytes/s
-}
-
-func newPacer(getBandwidth func() congestion.ByteCount) *pacer {
-	p := &pacer{
-		budgetAtLastSent: maxBurstPackets * initMaxDatagramSize,
-		maxDatagramSize:  initMaxDatagramSize,
-		getBandwidth:     getBandwidth,
-	}
-	return p
-}
-
-func (p *pacer) SentPacket(sendTime time.Time, size congestion.ByteCount) {
-	budget := p.Budget(sendTime)
-	if size > budget {
-		p.budgetAtLastSent = 0
-	} else {
-		p.budgetAtLastSent = budget - size
-	}
-	p.lastSentTime = sendTime
-}
-
-func (p *pacer) Budget(now time.Time) congestion.ByteCount {
-	if p.lastSentTime.IsZero() {
-		return p.maxBurstSize()
-	}
-	budget := p.budgetAtLastSent + (p.getBandwidth()*congestion.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9
-	return minByteCount(p.maxBurstSize(), budget)
-}
-
-func (p *pacer) maxBurstSize() congestion.ByteCount {
-	return maxByteCount(
-		congestion.ByteCount((minPacingDelay+time.Millisecond).Nanoseconds())*p.getBandwidth()/1e9,
-		maxBurstPackets*p.maxDatagramSize,
-	)
-}
-
-// TimeUntilSend returns when the next packet should be sent.
-// It returns the zero value of time.Time if a packet can be sent immediately.
-func (p *pacer) TimeUntilSend() time.Time {
-	if p.budgetAtLastSent >= p.maxDatagramSize {
-		return time.Time{}
-	}
-	return p.lastSentTime.Add(maxDuration(
-		minPacingDelay,
-		time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/
-			float64(p.getBandwidth())))*time.Nanosecond,
-	))
-}
-
-func (p *pacer) SetMaxDatagramSize(s congestion.ByteCount) {
-	p.maxDatagramSize = s
-}
-
-func maxByteCount(a, b congestion.ByteCount) congestion.ByteCount {
-	if a < b {
-		return b
-	}
-	return a
-}
-
-func minByteCount(a, b congestion.ByteCount) congestion.ByteCount {
-	if a < b {
-		return a
-	}
-	return b
-}

+ 0 - 68
transport/hysteria2/internal/protocol/http.go

@@ -1,68 +0,0 @@
-package protocol
-
-import (
-	"net/http"
-	"strconv"
-)
-
-const (
-	URLHost = "hysteria"
-	URLPath = "/auth"
-
-	RequestHeaderAuth        = "Hysteria-Auth"
-	ResponseHeaderUDPEnabled = "Hysteria-UDP"
-	CommonHeaderCCRX         = "Hysteria-CC-RX"
-	CommonHeaderPadding      = "Hysteria-Padding"
-
-	StatusAuthOK = 233
-)
-
-// AuthRequest is what client sends to server for authentication.
-type AuthRequest struct {
-	Auth string
-	Rx   uint64 // 0 = unknown, client asks server to use bandwidth detection
-}
-
-// AuthResponse is what server sends to client when authentication is passed.
-type AuthResponse struct {
-	UDPEnabled bool
-	Rx         uint64 // 0 = unlimited
-	RxAuto     bool   // true = server asks client to use bandwidth detection
-}
-
-func AuthRequestFromHeader(h http.Header) AuthRequest {
-	rx, _ := strconv.ParseUint(h.Get(CommonHeaderCCRX), 10, 64)
-	return AuthRequest{
-		Auth: h.Get(RequestHeaderAuth),
-		Rx:   rx,
-	}
-}
-
-func AuthRequestToHeader(h http.Header, req AuthRequest) {
-	h.Set(RequestHeaderAuth, req.Auth)
-	h.Set(CommonHeaderCCRX, strconv.FormatUint(req.Rx, 10))
-	h.Set(CommonHeaderPadding, authRequestPadding.String())
-}
-
-func AuthResponseFromHeader(h http.Header) AuthResponse {
-	resp := AuthResponse{}
-	resp.UDPEnabled, _ = strconv.ParseBool(h.Get(ResponseHeaderUDPEnabled))
-	rxStr := h.Get(CommonHeaderCCRX)
-	if rxStr == "auto" {
-		// Special case for server requesting client to use bandwidth detection
-		resp.RxAuto = true
-	} else {
-		resp.Rx, _ = strconv.ParseUint(rxStr, 10, 64)
-	}
-	return resp
-}
-
-func AuthResponseToHeader(h http.Header, resp AuthResponse) {
-	h.Set(ResponseHeaderUDPEnabled, strconv.FormatBool(resp.UDPEnabled))
-	if resp.RxAuto {
-		h.Set(CommonHeaderCCRX, "auto")
-	} else {
-		h.Set(CommonHeaderCCRX, strconv.FormatUint(resp.Rx, 10))
-	}
-	h.Set(CommonHeaderPadding, authResponsePadding.String())
-}

+ 0 - 31
transport/hysteria2/internal/protocol/padding.go

@@ -1,31 +0,0 @@
-package protocol
-
-import (
-	"math/rand"
-)
-
-const (
-	paddingChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
-)
-
-// padding specifies a half-open range [Min, Max).
-type padding struct {
-	Min int
-	Max int
-}
-
-func (p padding) String() string {
-	n := p.Min + rand.Intn(p.Max-p.Min)
-	bs := make([]byte, n)
-	for i := range bs {
-		bs[i] = paddingChars[rand.Intn(len(paddingChars))]
-	}
-	return string(bs)
-}
-
-var (
-	authRequestPadding  = padding{Min: 256, Max: 2048}
-	authResponsePadding = padding{Min: 256, Max: 2048}
-	tcpRequestPadding   = padding{Min: 64, Max: 512}
-	tcpResponsePadding  = padding{Min: 128, Max: 1024}
-)

+ 0 - 266
transport/hysteria2/internal/protocol/proxy.go

@@ -1,266 +0,0 @@
-package protocol
-
-import (
-	"bytes"
-	"encoding/binary"
-	"fmt"
-	"io"
-
-	"github.com/sagernet/quic-go/quicvarint"
-	"github.com/sagernet/sing/common"
-	"github.com/sagernet/sing/common/buf"
-	E "github.com/sagernet/sing/common/exceptions"
-	"github.com/sagernet/sing/common/rw"
-)
-
-const (
-	FrameTypeTCPRequest = 0x401
-
-	// Max length values are for preventing DoS attacks
-
-	MaxAddressLength = 2048
-	MaxMessageLength = 2048
-	MaxPaddingLength = 4096
-
-	MaxUDPSize = 4096
-
-	maxVarInt1 = 63
-	maxVarInt2 = 16383
-	maxVarInt4 = 1073741823
-	maxVarInt8 = 4611686018427387903
-)
-
-// TCPRequest format:
-// 0x401 (QUIC varint)
-// Address length (QUIC varint)
-// Address (bytes)
-// Padding length (QUIC varint)
-// Padding (bytes)
-
-func ReadTCPRequest(r io.Reader) (string, error) {
-	bReader := quicvarint.NewReader(r)
-	addrLen, err := quicvarint.Read(bReader)
-	if err != nil {
-		return "", err
-	}
-	if addrLen == 0 || addrLen > MaxAddressLength {
-		return "", E.New("invalid address length")
-	}
-	addrBuf := make([]byte, addrLen)
-	_, err = io.ReadFull(r, addrBuf)
-	if err != nil {
-		return "", err
-	}
-	paddingLen, err := quicvarint.Read(bReader)
-	if err != nil {
-		return "", err
-	}
-	if paddingLen > MaxPaddingLength {
-		return "", E.New("invalid padding length")
-	}
-	if paddingLen > 0 {
-		_, err = io.CopyN(io.Discard, r, int64(paddingLen))
-		if err != nil {
-			return "", err
-		}
-	}
-	return string(addrBuf), nil
-}
-
-func WriteTCPRequest(addr string, payload []byte) *buf.Buffer {
-	padding := tcpRequestPadding.String()
-	paddingLen := len(padding)
-	addrLen := len(addr)
-	sz := int(quicvarint.Len(FrameTypeTCPRequest)) +
-		int(quicvarint.Len(uint64(addrLen))) + addrLen +
-		int(quicvarint.Len(uint64(paddingLen))) + paddingLen
-	buffer := buf.NewSize(sz + len(payload))
-	bufferContent := buffer.Extend(sz)
-	i := varintPut(bufferContent, FrameTypeTCPRequest)
-	i += varintPut(bufferContent[i:], uint64(addrLen))
-	i += copy(bufferContent[i:], addr)
-	i += varintPut(bufferContent[i:], uint64(paddingLen))
-	copy(bufferContent[i:], padding)
-	buffer.Write(payload)
-	return buffer
-}
-
-// TCPResponse format:
-// Status (byte, 0=ok, 1=error)
-// Message length (QUIC varint)
-// Message (bytes)
-// Padding length (QUIC varint)
-// Padding (bytes)
-
-func ReadTCPResponse(r io.Reader) (bool, string, error) {
-	var status [1]byte
-	if _, err := io.ReadFull(r, status[:]); err != nil {
-		return false, "", err
-	}
-	bReader := quicvarint.NewReader(r)
-	msg, err := ReadVString(bReader)
-	if err != nil {
-		return false, "", err
-	}
-	paddingLen, err := quicvarint.Read(bReader)
-	if err != nil {
-		return false, "", err
-	}
-	if paddingLen > MaxPaddingLength {
-		return false, "", E.New("invalid padding length")
-	}
-	if paddingLen > 0 {
-		_, err = io.CopyN(io.Discard, r, int64(paddingLen))
-		if err != nil {
-			return false, "", err
-		}
-	}
-	return status[0] == 0, msg, nil
-}
-
-func WriteTCPResponse(ok bool, msg string, payload []byte) *buf.Buffer {
-	padding := tcpResponsePadding.String()
-	paddingLen := len(padding)
-	msgLen := len(msg)
-	sz := 1 + int(quicvarint.Len(uint64(msgLen))) + msgLen +
-		int(quicvarint.Len(uint64(paddingLen))) + paddingLen
-	buffer := buf.NewSize(sz + len(payload))
-	if ok {
-		buffer.WriteByte(0)
-	} else {
-		buffer.WriteByte(1)
-	}
-	WriteVString(buffer, msg)
-	WriteUVariant(buffer, uint64(paddingLen))
-	buffer.Extend(paddingLen)
-	buffer.Write(payload)
-	return buffer
-}
-
-// UDPMessage format:
-// Session ID (uint32 BE)
-// Packet ID (uint16 BE)
-// Fragment ID (uint8)
-// Fragment count (uint8)
-// Address length (QUIC varint)
-// Address (bytes)
-// Data...
-
-type UDPMessage struct {
-	SessionID uint32 // 4
-	PacketID  uint16 // 2
-	FragID    uint8  // 1
-	FragCount uint8  // 1
-	Addr      string // varint + bytes
-	Data      []byte
-}
-
-func (m *UDPMessage) HeaderSize() int {
-	lAddr := len(m.Addr)
-	return 4 + 2 + 1 + 1 + int(quicvarint.Len(uint64(lAddr))) + lAddr
-}
-
-func (m *UDPMessage) Size() int {
-	return m.HeaderSize() + len(m.Data)
-}
-
-func (m *UDPMessage) Serialize(buf []byte) int {
-	// Make sure the buffer is big enough
-	if len(buf) < m.Size() {
-		return -1
-	}
-	binary.BigEndian.PutUint32(buf, m.SessionID)
-	binary.BigEndian.PutUint16(buf[4:], m.PacketID)
-	buf[6] = m.FragID
-	buf[7] = m.FragCount
-	i := varintPut(buf[8:], uint64(len(m.Addr)))
-	i += copy(buf[8+i:], m.Addr)
-	i += copy(buf[8+i:], m.Data)
-	return 8 + i
-}
-
-func ParseUDPMessage(msg []byte) (*UDPMessage, error) {
-	m := &UDPMessage{}
-	buf := bytes.NewBuffer(msg)
-	if err := binary.Read(buf, binary.BigEndian, &m.SessionID); err != nil {
-		return nil, err
-	}
-	if err := binary.Read(buf, binary.BigEndian, &m.PacketID); err != nil {
-		return nil, err
-	}
-	if err := binary.Read(buf, binary.BigEndian, &m.FragID); err != nil {
-		return nil, err
-	}
-	if err := binary.Read(buf, binary.BigEndian, &m.FragCount); err != nil {
-		return nil, err
-	}
-	lAddr, err := quicvarint.Read(buf)
-	if err != nil {
-		return nil, err
-	}
-	if lAddr == 0 || lAddr > MaxMessageLength {
-		return nil, E.New("invalid address length")
-	}
-	bs := buf.Bytes()
-	m.Addr = string(bs[:lAddr])
-	m.Data = bs[lAddr:]
-	return m, nil
-}
-
-func ReadVString(reader io.Reader) (string, error) {
-	length, err := quicvarint.Read(quicvarint.NewReader(reader))
-	if err != nil {
-		return "", err
-	}
-	value, err := rw.ReadBytes(reader, int(length))
-	if err != nil {
-		return "", err
-	}
-	return string(value), nil
-}
-
-func WriteVString(writer io.Writer, value string) error {
-	err := WriteUVariant(writer, uint64(len(value)))
-	if err != nil {
-		return err
-	}
-	return rw.WriteString(writer, value)
-}
-
-func WriteUVariant(writer io.Writer, value uint64) error {
-	var b [8]byte
-	return common.Error(writer.Write(b[:varintPut(b[:], value)]))
-}
-
-// varintPut is like quicvarint.Append, but instead of appending to a slice,
-// it writes to a fixed-size buffer. Returns the number of bytes written.
-func varintPut(b []byte, i uint64) int {
-	if i <= maxVarInt1 {
-		b[0] = uint8(i)
-		return 1
-	}
-	if i <= maxVarInt2 {
-		b[0] = uint8(i>>8) | 0x40
-		b[1] = uint8(i)
-		return 2
-	}
-	if i <= maxVarInt4 {
-		b[0] = uint8(i>>24) | 0x80
-		b[1] = uint8(i >> 16)
-		b[2] = uint8(i >> 8)
-		b[3] = uint8(i)
-		return 4
-	}
-	if i <= maxVarInt8 {
-		b[0] = uint8(i>>56) | 0xc0
-		b[1] = uint8(i >> 48)
-		b[2] = uint8(i >> 40)
-		b[3] = uint8(i >> 32)
-		b[4] = uint8(i >> 24)
-		b[5] = uint8(i >> 16)
-		b[6] = uint8(i >> 8)
-		b[7] = uint8(i)
-		return 8
-	}
-	panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i))
-}

+ 0 - 450
transport/hysteria2/packet.go

@@ -1,450 +0,0 @@
-package hysteria2
-
-import (
-	"bytes"
-	"context"
-	"encoding/binary"
-	"errors"
-	"io"
-	"math"
-	"net"
-	"os"
-	"sync"
-	"time"
-
-	"github.com/sagernet/quic-go"
-	"github.com/sagernet/quic-go/quicvarint"
-	"github.com/sagernet/sing-box/transport/hysteria2/internal/protocol"
-	"github.com/sagernet/sing/common"
-	"github.com/sagernet/sing/common/atomic"
-	"github.com/sagernet/sing/common/buf"
-	"github.com/sagernet/sing/common/cache"
-	M "github.com/sagernet/sing/common/metadata"
-)
-
-var udpMessagePool = sync.Pool{
-	New: func() interface{} {
-		return new(udpMessage)
-	},
-}
-
-func allocMessage() *udpMessage {
-	message := udpMessagePool.Get().(*udpMessage)
-	message.referenced = true
-	return message
-}
-
-func releaseMessages(messages []*udpMessage) {
-	for _, message := range messages {
-		if message != nil {
-			message.release()
-		}
-	}
-}
-
-type udpMessage struct {
-	sessionID     uint32
-	packetID      uint16
-	fragmentID    uint8
-	fragmentTotal uint8
-	destination   string
-	data          *buf.Buffer
-	referenced    bool
-}
-
-func (m *udpMessage) release() {
-	if !m.referenced {
-		return
-	}
-	*m = udpMessage{}
-	udpMessagePool.Put(m)
-}
-
-func (m *udpMessage) releaseMessage() {
-	m.data.Release()
-	m.release()
-}
-
-func (m *udpMessage) pack() *buf.Buffer {
-	buffer := buf.NewSize(m.headerSize() + m.data.Len())
-	common.Must(
-		binary.Write(buffer, binary.BigEndian, m.sessionID),
-		binary.Write(buffer, binary.BigEndian, m.packetID),
-		binary.Write(buffer, binary.BigEndian, m.fragmentID),
-		binary.Write(buffer, binary.BigEndian, m.fragmentTotal),
-		protocol.WriteVString(buffer, m.destination),
-		common.Error(buffer.Write(m.data.Bytes())),
-	)
-	return buffer
-}
-
-func (m *udpMessage) headerSize() int {
-	return 8 + int(quicvarint.Len(uint64(len(m.destination)))) + len(m.destination)
-}
-
-func fragUDPMessage(message *udpMessage, maxPacketSize int) []*udpMessage {
-	if message.data.Len() <= maxPacketSize {
-		return []*udpMessage{message}
-	}
-	var fragments []*udpMessage
-	originPacket := message.data.Bytes()
-	udpMTU := maxPacketSize - message.headerSize()
-	for remaining := len(originPacket); remaining > 0; remaining -= udpMTU {
-		fragment := allocMessage()
-		*fragment = *message
-		if remaining > udpMTU {
-			fragment.data = buf.As(originPacket[:udpMTU])
-			originPacket = originPacket[udpMTU:]
-		} else {
-			fragment.data = buf.As(originPacket)
-			originPacket = nil
-		}
-		fragments = append(fragments, fragment)
-	}
-	fragmentTotal := uint16(len(fragments))
-	for index, fragment := range fragments {
-		fragment.fragmentID = uint8(index)
-		fragment.fragmentTotal = uint8(fragmentTotal)
-		/*if index > 0 {
-			fragment.destination = ""
-			// not work in hysteria
-		}*/
-	}
-	return fragments
-}
-
-type udpPacketConn struct {
-	ctx        context.Context
-	cancel     common.ContextCancelCauseFunc
-	sessionID  uint32
-	quicConn   quic.Connection
-	data       chan *udpMessage
-	udpMTU     int
-	udpMTUTime time.Time
-	packetId   atomic.Uint32
-	closeOnce  sync.Once
-	defragger  *udpDefragger
-	onDestroy  func()
-}
-
-func newUDPPacketConn(ctx context.Context, quicConn quic.Connection, onDestroy func()) *udpPacketConn {
-	ctx, cancel := common.ContextWithCancelCause(ctx)
-	return &udpPacketConn{
-		ctx:       ctx,
-		cancel:    cancel,
-		quicConn:  quicConn,
-		data:      make(chan *udpMessage, 64),
-		defragger: newUDPDefragger(),
-		onDestroy: onDestroy,
-	}
-}
-
-func (c *udpPacketConn) ReadPacketThreadSafe() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
-	select {
-	case p := <-c.data:
-		buffer = p.data
-		destination = M.ParseSocksaddr(p.destination)
-		p.release()
-		return
-	case <-c.ctx.Done():
-		return nil, M.Socksaddr{}, io.ErrClosedPipe
-	}
-}
-
-func (c *udpPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
-	select {
-	case p := <-c.data:
-		_, err = buffer.ReadOnceFrom(p.data)
-		destination = M.ParseSocksaddr(p.destination)
-		p.releaseMessage()
-		return
-	case <-c.ctx.Done():
-		return M.Socksaddr{}, io.ErrClosedPipe
-	}
-}
-
-func (c *udpPacketConn) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) {
-	select {
-	case p := <-c.data:
-		_, err = newBuffer().ReadOnceFrom(p.data)
-		destination = M.ParseSocksaddr(p.destination)
-		p.releaseMessage()
-		return
-	case <-c.ctx.Done():
-		return M.Socksaddr{}, io.ErrClosedPipe
-	}
-}
-
-func (c *udpPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
-	select {
-	case pkt := <-c.data:
-		n = copy(p, pkt.data.Bytes())
-		destination := M.ParseSocksaddr(pkt.destination)
-		if destination.IsFqdn() {
-			addr = destination
-		} else {
-			addr = destination.UDPAddr()
-		}
-		pkt.releaseMessage()
-		return n, addr, nil
-	case <-c.ctx.Done():
-		return 0, nil, io.ErrClosedPipe
-	}
-}
-
-func (c *udpPacketConn) needFragment() bool {
-	nowTime := time.Now()
-	if c.udpMTU > 0 && nowTime.Sub(c.udpMTUTime) < 5*time.Second {
-		c.udpMTUTime = nowTime
-		return true
-	}
-	return false
-}
-
-func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
-	defer buffer.Release()
-	select {
-	case <-c.ctx.Done():
-		return net.ErrClosed
-	default:
-	}
-	if buffer.Len() > 0xffff {
-		return quic.ErrMessageTooLarge(0xffff)
-	}
-	packetId := c.packetId.Add(1)
-	if packetId > math.MaxUint16 {
-		c.packetId.Store(0)
-		packetId = 0
-	}
-	message := allocMessage()
-	*message = udpMessage{
-		sessionID:     c.sessionID,
-		packetID:      uint16(packetId),
-		fragmentTotal: 1,
-		destination:   destination.String(),
-		data:          buffer,
-	}
-	defer message.releaseMessage()
-	var err error
-	if c.needFragment() && buffer.Len() > c.udpMTU {
-		err = c.writePackets(fragUDPMessage(message, c.udpMTU))
-	} else {
-		err = c.writePacket(message)
-	}
-	if err == nil {
-		return nil
-	}
-	var tooLargeErr quic.ErrMessageTooLarge
-	if !errors.As(err, &tooLargeErr) {
-		return err
-	}
-	c.udpMTU = int(tooLargeErr)
-	c.udpMTUTime = time.Now()
-	return c.writePackets(fragUDPMessage(message, c.udpMTU))
-}
-
-func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
-	select {
-	case <-c.ctx.Done():
-		return 0, net.ErrClosed
-	default:
-	}
-	if len(p) > 0xffff {
-		return 0, quic.ErrMessageTooLarge(0xffff)
-	}
-	packetId := c.packetId.Add(1)
-	if packetId > math.MaxUint16 {
-		c.packetId.Store(0)
-		packetId = 0
-	}
-	message := allocMessage()
-	*message = udpMessage{
-		sessionID:     c.sessionID,
-		packetID:      uint16(packetId),
-		fragmentTotal: 1,
-		destination:   addr.String(),
-		data:          buf.As(p),
-	}
-	if c.needFragment() && len(p) > c.udpMTU {
-		err = c.writePackets(fragUDPMessage(message, c.udpMTU))
-		if err == nil {
-			return len(p), nil
-		}
-	} else {
-		err = c.writePacket(message)
-	}
-	if err == nil {
-		return len(p), nil
-	}
-	var tooLargeErr quic.ErrMessageTooLarge
-	if !errors.As(err, &tooLargeErr) {
-		return
-	}
-	c.udpMTU = int(tooLargeErr)
-	c.udpMTUTime = time.Now()
-	err = c.writePackets(fragUDPMessage(message, c.udpMTU))
-	if err == nil {
-		return len(p), nil
-	}
-	return
-}
-
-func (c *udpPacketConn) inputPacket(message *udpMessage) {
-	if message.fragmentTotal <= 1 {
-		select {
-		case c.data <- message:
-		default:
-		}
-	} else {
-		newMessage := c.defragger.feed(message)
-		if newMessage != nil {
-			select {
-			case c.data <- newMessage:
-			default:
-			}
-		}
-	}
-}
-
-func (c *udpPacketConn) writePackets(messages []*udpMessage) error {
-	defer releaseMessages(messages)
-	for _, message := range messages {
-		err := c.writePacket(message)
-		if err != nil {
-			return err
-		}
-	}
-	return nil
-}
-
-func (c *udpPacketConn) writePacket(message *udpMessage) error {
-	buffer := message.pack()
-	defer buffer.Release()
-	return c.quicConn.SendMessage(buffer.Bytes())
-}
-
-func (c *udpPacketConn) Close() error {
-	c.closeOnce.Do(func() {
-		c.closeWithError(os.ErrClosed)
-		c.onDestroy()
-	})
-	return nil
-}
-
-func (c *udpPacketConn) closeWithError(err error) {
-	c.cancel(err)
-}
-
-func (c *udpPacketConn) LocalAddr() net.Addr {
-	return c.quicConn.LocalAddr()
-}
-
-func (c *udpPacketConn) SetDeadline(t time.Time) error {
-	return os.ErrInvalid
-}
-
-func (c *udpPacketConn) SetReadDeadline(t time.Time) error {
-	return os.ErrInvalid
-}
-
-func (c *udpPacketConn) SetWriteDeadline(t time.Time) error {
-	return os.ErrInvalid
-}
-
-type udpDefragger struct {
-	packetMap *cache.LruCache[uint16, *packetItem]
-}
-
-func newUDPDefragger() *udpDefragger {
-	return &udpDefragger{
-		packetMap: cache.New(
-			cache.WithAge[uint16, *packetItem](10),
-			cache.WithUpdateAgeOnGet[uint16, *packetItem](),
-			cache.WithEvict[uint16, *packetItem](func(key uint16, value *packetItem) {
-				releaseMessages(value.messages)
-			}),
-		),
-	}
-}
-
-type packetItem struct {
-	access   sync.Mutex
-	messages []*udpMessage
-	count    uint8
-}
-
-func (d *udpDefragger) feed(m *udpMessage) *udpMessage {
-	if m.fragmentTotal <= 1 {
-		return m
-	}
-	if m.fragmentID >= m.fragmentTotal {
-		return nil
-	}
-	item, _ := d.packetMap.LoadOrStore(m.packetID, newPacketItem)
-	item.access.Lock()
-	defer item.access.Unlock()
-	if int(m.fragmentTotal) != len(item.messages) {
-		releaseMessages(item.messages)
-		item.messages = make([]*udpMessage, m.fragmentTotal)
-		item.count = 1
-		item.messages[m.fragmentID] = m
-		return nil
-	}
-	if item.messages[m.fragmentID] != nil {
-		return nil
-	}
-	item.messages[m.fragmentID] = m
-	item.count++
-	if int(item.count) != len(item.messages) {
-		return nil
-	}
-	newMessage := allocMessage()
-	newMessage.sessionID = m.sessionID
-	newMessage.packetID = m.packetID
-	newMessage.destination = item.messages[0].destination
-	var finalLength int
-	for _, message := range item.messages {
-		finalLength += message.data.Len()
-	}
-	if finalLength > 0 {
-		newMessage.data = buf.NewSize(finalLength)
-		for _, message := range item.messages {
-			newMessage.data.Write(message.data.Bytes())
-			message.releaseMessage()
-		}
-		item.messages = nil
-		return newMessage
-	}
-	item.messages = nil
-	return nil
-}
-
-func newPacketItem() *packetItem {
-	return new(packetItem)
-}
-
-func decodeUDPMessage(message *udpMessage, data []byte) error {
-	reader := bytes.NewReader(data)
-	err := binary.Read(reader, binary.BigEndian, &message.sessionID)
-	if err != nil {
-		return err
-	}
-	err = binary.Read(reader, binary.BigEndian, &message.packetID)
-	if err != nil {
-		return err
-	}
-	err = binary.Read(reader, binary.BigEndian, &message.fragmentID)
-	if err != nil {
-		return err
-	}
-	err = binary.Read(reader, binary.BigEndian, &message.fragmentTotal)
-	if err != nil {
-		return err
-	}
-	message.destination, err = protocol.ReadVString(reader)
-	if err != nil {
-		return err
-	}
-	message.data = buf.As(data[len(data)-reader.Len():])
-	return nil
-}

+ 0 - 106
transport/hysteria2/salamander.go

@@ -1,106 +0,0 @@
-package hysteria2
-
-import (
-	"net"
-
-	"github.com/sagernet/sing/common"
-	"github.com/sagernet/sing/common/buf"
-	"github.com/sagernet/sing/common/bufio"
-	E "github.com/sagernet/sing/common/exceptions"
-	M "github.com/sagernet/sing/common/metadata"
-	N "github.com/sagernet/sing/common/network"
-
-	"golang.org/x/crypto/blake2b"
-)
-
-const salamanderSaltLen = 8
-
-const ObfsTypeSalamander = "salamander"
-
-type Salamander struct {
-	net.PacketConn
-	password []byte
-}
-
-func NewSalamanderConn(conn net.PacketConn, password []byte) net.PacketConn {
-	writer, isVectorised := bufio.CreateVectorisedPacketWriter(conn)
-	if isVectorised {
-		return &VectorisedSalamander{
-			Salamander: Salamander{
-				PacketConn: conn,
-				password:   password,
-			},
-			writer: writer,
-		}
-	} else {
-		return &Salamander{
-			PacketConn: conn,
-			password:   password,
-		}
-	}
-}
-
-func (s *Salamander) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
-	n, addr, err = s.PacketConn.ReadFrom(p)
-	if err != nil {
-		return
-	}
-	if n <= salamanderSaltLen {
-		return 0, nil, E.New("salamander: packet too short")
-	}
-	key := blake2b.Sum256(append(s.password, p[:salamanderSaltLen]...))
-	for index, c := range p[salamanderSaltLen:n] {
-		p[index] = c ^ key[index%blake2b.Size256]
-	}
-	return n - salamanderSaltLen, addr, nil
-}
-
-func (s *Salamander) WriteTo(p []byte, addr net.Addr) (n int, err error) {
-	buffer := buf.NewSize(len(p) + salamanderSaltLen)
-	defer buffer.Release()
-	buffer.WriteRandom(salamanderSaltLen)
-	key := blake2b.Sum256(append(s.password, buffer.Bytes()...))
-	for index, c := range p {
-		common.Must(buffer.WriteByte(c ^ key[index%blake2b.Size256]))
-	}
-	_, err = s.PacketConn.WriteTo(buffer.Bytes(), addr)
-	if err != nil {
-		return
-	}
-	return len(p), nil
-}
-
-type VectorisedSalamander struct {
-	Salamander
-	writer N.VectorisedPacketWriter
-}
-
-func (s *VectorisedSalamander) WriteTo(p []byte, addr net.Addr) (n int, err error) {
-	buffer := buf.NewSize(salamanderSaltLen)
-	buffer.WriteRandom(salamanderSaltLen)
-	key := blake2b.Sum256(append(s.password, buffer.Bytes()...))
-	for i := range p {
-		p[i] ^= key[i%blake2b.Size256]
-	}
-	err = s.writer.WriteVectorisedPacket([]*buf.Buffer{buffer, buf.As(p)}, M.SocksaddrFromNet(addr))
-	if err != nil {
-		return
-	}
-	return len(p), nil
-}
-
-func (s *VectorisedSalamander) WriteVectorisedPacket(buffers []*buf.Buffer, destination M.Socksaddr) error {
-	header := buf.NewSize(salamanderSaltLen)
-	defer header.Release()
-	header.WriteRandom(salamanderSaltLen)
-	key := blake2b.Sum256(append(s.password, header.Bytes()...))
-	var bufferIndex int
-	for _, buffer := range buffers {
-		content := buffer.Bytes()
-		for index, c := range content {
-			content[bufferIndex+index] = c ^ key[bufferIndex+index%blake2b.Size256]
-		}
-		bufferIndex += len(content)
-	}
-	return s.writer.WriteVectorisedPacket(append([]*buf.Buffer{header}, buffers...), destination)
-}

+ 0 - 344
transport/hysteria2/server.go

@@ -1,344 +0,0 @@
-package hysteria2
-
-import (
-	"context"
-	"io"
-	"net"
-	"net/http"
-	"os"
-	"runtime"
-	"strings"
-	"sync"
-
-	"github.com/sagernet/quic-go"
-	"github.com/sagernet/quic-go/http3"
-	"github.com/sagernet/sing-box/common/qtls"
-	"github.com/sagernet/sing-box/common/tls"
-	"github.com/sagernet/sing-box/transport/hysteria2/congestion"
-	"github.com/sagernet/sing-box/transport/hysteria2/internal/protocol"
-	tuicCongestion "github.com/sagernet/sing-box/transport/tuic/congestion"
-	"github.com/sagernet/sing/common"
-	"github.com/sagernet/sing/common/auth"
-	"github.com/sagernet/sing/common/baderror"
-	E "github.com/sagernet/sing/common/exceptions"
-	"github.com/sagernet/sing/common/logger"
-	M "github.com/sagernet/sing/common/metadata"
-	N "github.com/sagernet/sing/common/network"
-)
-
-type ServerOptions struct {
-	Context               context.Context
-	Logger                logger.Logger
-	SendBPS               uint64
-	ReceiveBPS            uint64
-	IgnoreClientBandwidth bool
-	SalamanderPassword    string
-	TLSConfig             tls.ServerConfig
-	Users                 []User
-	UDPDisabled           bool
-	Handler               ServerHandler
-	MasqueradeHandler     http.Handler
-}
-
-type User struct {
-	Name     string
-	Password string
-}
-
-type ServerHandler interface {
-	N.TCPConnectionHandler
-	N.UDPConnectionHandler
-}
-
-type Server struct {
-	ctx                   context.Context
-	logger                logger.Logger
-	sendBPS               uint64
-	receiveBPS            uint64
-	ignoreClientBandwidth bool
-	salamanderPassword    string
-	tlsConfig             tls.ServerConfig
-	quicConfig            *quic.Config
-	userMap               map[string]User
-	udpDisabled           bool
-	handler               ServerHandler
-	masqueradeHandler     http.Handler
-	quicListener          io.Closer
-}
-
-func NewServer(options ServerOptions) (*Server, error) {
-	quicConfig := &quic.Config{
-		DisablePathMTUDiscovery:        !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"),
-		EnableDatagrams:                !options.UDPDisabled,
-		MaxIncomingStreams:             1 << 60,
-		InitialStreamReceiveWindow:     defaultStreamReceiveWindow,
-		MaxStreamReceiveWindow:         defaultStreamReceiveWindow,
-		InitialConnectionReceiveWindow: defaultConnReceiveWindow,
-		MaxConnectionReceiveWindow:     defaultConnReceiveWindow,
-		MaxIdleTimeout:                 defaultMaxIdleTimeout,
-		KeepAlivePeriod:                defaultKeepAlivePeriod,
-	}
-	if len(options.Users) == 0 {
-		return nil, E.New("missing users")
-	}
-	userMap := make(map[string]User)
-	for _, user := range options.Users {
-		userMap[user.Password] = user
-	}
-	if options.MasqueradeHandler == nil {
-		options.MasqueradeHandler = http.NotFoundHandler()
-	}
-	return &Server{
-		ctx:                   options.Context,
-		logger:                options.Logger,
-		sendBPS:               options.SendBPS,
-		receiveBPS:            options.ReceiveBPS,
-		ignoreClientBandwidth: options.IgnoreClientBandwidth,
-		salamanderPassword:    options.SalamanderPassword,
-		tlsConfig:             options.TLSConfig,
-		quicConfig:            quicConfig,
-		userMap:               userMap,
-		udpDisabled:           options.UDPDisabled,
-		handler:               options.Handler,
-		masqueradeHandler:     options.MasqueradeHandler,
-	}, nil
-}
-
-func (s *Server) Start(conn net.PacketConn) error {
-	if s.salamanderPassword != "" {
-		conn = NewSalamanderConn(conn, []byte(s.salamanderPassword))
-	}
-	err := qtls.ConfigureHTTP3(s.tlsConfig)
-	if err != nil {
-		return err
-	}
-	listener, err := qtls.Listen(conn, s.tlsConfig, s.quicConfig)
-	if err != nil {
-		return err
-	}
-	s.quicListener = listener
-	go s.loopConnections(listener)
-	return nil
-}
-
-func (s *Server) Close() error {
-	return common.Close(
-		s.quicListener,
-	)
-}
-
-func (s *Server) loopConnections(listener qtls.QUICListener) {
-	for {
-		connection, err := listener.Accept(s.ctx)
-		if err != nil {
-			if strings.Contains(err.Error(), "server closed") {
-				s.logger.Debug(E.Cause(err, "listener closed"))
-			} else {
-				s.logger.Error(E.Cause(err, "listener closed"))
-			}
-			return
-		}
-		go s.handleConnection(connection)
-	}
-}
-
-func (s *Server) handleConnection(connection quic.Connection) {
-	session := &serverSession{
-		Server:     s,
-		ctx:        s.ctx,
-		quicConn:   connection,
-		source:     M.SocksaddrFromNet(connection.RemoteAddr()),
-		connDone:   make(chan struct{}),
-		udpConnMap: make(map[uint32]*udpPacketConn),
-	}
-	httpServer := http3.Server{
-		Handler:        session,
-		StreamHijacker: session.handleStream0,
-	}
-	_ = httpServer.ServeQUICConn(connection)
-	_ = connection.CloseWithError(0, "")
-}
-
-type serverSession struct {
-	*Server
-	ctx           context.Context
-	quicConn      quic.Connection
-	source        M.Socksaddr
-	connAccess    sync.Mutex
-	connDone      chan struct{}
-	connErr       error
-	authenticated bool
-	authUser      *User
-	udpAccess     sync.RWMutex
-	udpConnMap    map[uint32]*udpPacketConn
-}
-
-func (s *serverSession) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	if r.Method == http.MethodPost && r.Host == protocol.URLHost && r.URL.Path == protocol.URLPath {
-		if s.authenticated {
-			protocol.AuthResponseToHeader(w.Header(), protocol.AuthResponse{
-				UDPEnabled: !s.udpDisabled,
-				Rx:         s.receiveBPS,
-				RxAuto:     s.ignoreClientBandwidth,
-			})
-			w.WriteHeader(protocol.StatusAuthOK)
-			return
-		}
-		request := protocol.AuthRequestFromHeader(r.Header)
-		user, loaded := s.userMap[request.Auth]
-		if !loaded {
-			s.masqueradeHandler.ServeHTTP(w, r)
-			return
-		}
-		s.authUser = &user
-		s.authenticated = true
-		if !s.ignoreClientBandwidth && request.Rx > 0 {
-			var sendBps uint64
-			if s.sendBPS > 0 && s.sendBPS < request.Rx {
-				sendBps = s.sendBPS
-			} else {
-				sendBps = request.Rx
-			}
-			s.quicConn.SetCongestionControl(congestion.NewBrutalSender(sendBps))
-		} else {
-			s.quicConn.SetCongestionControl(tuicCongestion.NewBBRSender(
-				tuicCongestion.DefaultClock{},
-				tuicCongestion.GetInitialPacketSize(s.quicConn.RemoteAddr()),
-				tuicCongestion.InitialCongestionWindow*tuicCongestion.InitialMaxDatagramSize,
-				tuicCongestion.DefaultBBRMaxCongestionWindow*tuicCongestion.InitialMaxDatagramSize,
-			))
-		}
-		protocol.AuthResponseToHeader(w.Header(), protocol.AuthResponse{
-			UDPEnabled: !s.udpDisabled,
-			Rx:         s.receiveBPS,
-			RxAuto:     s.ignoreClientBandwidth,
-		})
-		w.WriteHeader(protocol.StatusAuthOK)
-		if s.ctx.Done() != nil {
-			go func() {
-				select {
-				case <-s.ctx.Done():
-					s.closeWithError(s.ctx.Err())
-				case <-s.connDone:
-				}
-			}()
-		}
-		if !s.udpDisabled {
-			go s.loopMessages()
-		}
-	} else {
-		s.masqueradeHandler.ServeHTTP(w, r)
-	}
-}
-
-func (s *serverSession) handleStream0(frameType http3.FrameType, connection quic.Connection, stream quic.Stream, err error) (bool, error) {
-	if !s.authenticated || err != nil {
-		return false, nil
-	}
-	if frameType != protocol.FrameTypeTCPRequest {
-		return false, nil
-	}
-	go func() {
-		hErr := s.handleStream(stream)
-		stream.CancelRead(0)
-		stream.Close()
-		if hErr != nil {
-			stream.CancelRead(0)
-			stream.Close()
-			s.logger.Error(E.Cause(hErr, "handle stream request"))
-		}
-	}()
-	return true, nil
-}
-
-func (s *serverSession) handleStream(stream quic.Stream) error {
-	destinationString, err := protocol.ReadTCPRequest(stream)
-	if err != nil {
-		return E.New("read TCP request")
-	}
-	ctx := s.ctx
-	if s.authUser.Name != "" {
-		ctx = auth.ContextWithUser(s.ctx, s.authUser.Name)
-	}
-	_ = s.handler.NewConnection(ctx, &serverConn{Stream: stream}, M.Metadata{
-		Source:      s.source,
-		Destination: M.ParseSocksaddr(destinationString),
-	})
-	return nil
-}
-
-func (s *serverSession) closeWithError(err error) {
-	s.connAccess.Lock()
-	defer s.connAccess.Unlock()
-	select {
-	case <-s.connDone:
-		return
-	default:
-		s.connErr = err
-		close(s.connDone)
-	}
-	if E.IsClosedOrCanceled(err) {
-		s.logger.Debug(E.Cause(err, "connection failed"))
-	} else {
-		s.logger.Error(E.Cause(err, "connection failed"))
-	}
-	_ = s.quicConn.CloseWithError(0, "")
-}
-
-type serverConn struct {
-	quic.Stream
-	responseWritten bool
-}
-
-func (c *serverConn) HandshakeFailure(err error) error {
-	if c.responseWritten {
-		return os.ErrClosed
-	}
-	c.responseWritten = true
-	buffer := protocol.WriteTCPResponse(false, err.Error(), nil)
-	defer buffer.Release()
-	return common.Error(c.Stream.Write(buffer.Bytes()))
-}
-
-func (c *serverConn) HandshakeSuccess() error {
-	if c.responseWritten {
-		return nil
-	}
-	c.responseWritten = true
-	buffer := protocol.WriteTCPResponse(true, "", nil)
-	defer buffer.Release()
-	return common.Error(c.Stream.Write(buffer.Bytes()))
-}
-
-func (c *serverConn) Read(p []byte) (n int, err error) {
-	n, err = c.Stream.Read(p)
-	return n, baderror.WrapQUIC(err)
-}
-
-func (c *serverConn) Write(p []byte) (n int, err error) {
-	if !c.responseWritten {
-		c.responseWritten = true
-		buffer := protocol.WriteTCPResponse(true, "", p)
-		defer buffer.Release()
-		_, err = c.Stream.Write(buffer.Bytes())
-		if err != nil {
-			return 0, baderror.WrapQUIC(err)
-		}
-		return len(p), nil
-	}
-	n, err = c.Stream.Write(p)
-	return n, baderror.WrapQUIC(err)
-}
-
-func (c *serverConn) LocalAddr() net.Addr {
-	return M.Socksaddr{}
-}
-
-func (c *serverConn) RemoteAddr() net.Addr {
-	return M.Socksaddr{}
-}
-
-func (c *serverConn) Close() error {
-	c.Stream.CancelRead(0)
-	return c.Stream.Close()
-}

+ 0 - 55
transport/hysteria2/server_packet.go

@@ -1,55 +0,0 @@
-package hysteria2
-
-import (
-	"github.com/sagernet/sing/common"
-	E "github.com/sagernet/sing/common/exceptions"
-	M "github.com/sagernet/sing/common/metadata"
-)
-
-func (s *serverSession) loopMessages() {
-	for {
-		message, err := s.quicConn.ReceiveMessage(s.ctx)
-		if err != nil {
-			s.closeWithError(E.Cause(err, "receive message"))
-			return
-		}
-		hErr := s.handleMessage(message)
-		if hErr != nil {
-			s.closeWithError(E.Cause(hErr, "handle message"))
-			return
-		}
-	}
-}
-
-func (s *serverSession) handleMessage(data []byte) error {
-	message := allocMessage()
-	err := decodeUDPMessage(message, data)
-	if err != nil {
-		message.release()
-		return E.Cause(err, "decode UDP message")
-	}
-	s.handleUDPMessage(message)
-	return nil
-}
-
-func (s *serverSession) handleUDPMessage(message *udpMessage) {
-	s.udpAccess.RLock()
-	udpConn, loaded := s.udpConnMap[message.sessionID]
-	s.udpAccess.RUnlock()
-	if !loaded || common.Done(udpConn.ctx) {
-		udpConn = newUDPPacketConn(s.ctx, s.quicConn, func() {
-			s.udpAccess.Lock()
-			delete(s.udpConnMap, message.sessionID)
-			s.udpAccess.Unlock()
-		})
-		udpConn.sessionID = message.sessionID
-		s.udpAccess.Lock()
-		s.udpConnMap[message.sessionID] = udpConn
-		s.udpAccess.Unlock()
-		go s.handler.NewPacketConnection(udpConn.ctx, udpConn, M.Metadata{
-			Source:      s.source,
-			Destination: M.ParseSocksaddr(message.destination),
-		})
-	}
-	udpConn.inputPacket(message)
-}

+ 0 - 10
transport/tuic/address.go

@@ -1,10 +0,0 @@
-package tuic
-
-import M "github.com/sagernet/sing/common/metadata"
-
-var addressSerializer = M.NewSerializer(
-	M.AddressFamilyByte(0x00, M.AddressFamilyFqdn),
-	M.AddressFamilyByte(0x01, M.AddressFamilyIPv4),
-	M.AddressFamilyByte(0x02, M.AddressFamilyIPv6),
-	M.AddressFamilyByte(0xff, M.AddressFamilyEmpty),
-)

+ 0 - 307
transport/tuic/client.go

@@ -1,307 +0,0 @@
-//go:build with_quic
-
-package tuic
-
-import (
-	"context"
-	"io"
-	"net"
-	"runtime"
-	"sync"
-	"time"
-
-	"github.com/sagernet/quic-go"
-	"github.com/sagernet/sing-box/common/qtls"
-	"github.com/sagernet/sing-box/common/tls"
-	"github.com/sagernet/sing/common"
-	"github.com/sagernet/sing/common/baderror"
-	"github.com/sagernet/sing/common/buf"
-	"github.com/sagernet/sing/common/bufio"
-	E "github.com/sagernet/sing/common/exceptions"
-	M "github.com/sagernet/sing/common/metadata"
-	N "github.com/sagernet/sing/common/network"
-
-	"github.com/gofrs/uuid/v5"
-)
-
-type ClientOptions struct {
-	Context           context.Context
-	Dialer            N.Dialer
-	ServerAddress     M.Socksaddr
-	TLSConfig         tls.Config
-	UUID              uuid.UUID
-	Password          string
-	CongestionControl string
-	UDPStream         bool
-	ZeroRTTHandshake  bool
-	Heartbeat         time.Duration
-}
-
-type Client struct {
-	ctx               context.Context
-	dialer            N.Dialer
-	serverAddr        M.Socksaddr
-	tlsConfig         tls.Config
-	quicConfig        *quic.Config
-	uuid              uuid.UUID
-	password          string
-	congestionControl string
-	udpStream         bool
-	zeroRTTHandshake  bool
-	heartbeat         time.Duration
-
-	connAccess sync.RWMutex
-	conn       *clientQUICConnection
-}
-
-func NewClient(options ClientOptions) (*Client, error) {
-	if options.Heartbeat == 0 {
-		options.Heartbeat = 10 * time.Second
-	}
-	quicConfig := &quic.Config{
-		DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"),
-		MaxDatagramFrameSize:    1400,
-		EnableDatagrams:         true,
-		MaxIncomingUniStreams:   1 << 60,
-	}
-	switch options.CongestionControl {
-	case "":
-		options.CongestionControl = "cubic"
-	case "cubic", "new_reno", "bbr":
-	default:
-		return nil, E.New("unknown congestion control algorithm: ", options.CongestionControl)
-	}
-	return &Client{
-		ctx:               options.Context,
-		dialer:            options.Dialer,
-		serverAddr:        options.ServerAddress,
-		tlsConfig:         options.TLSConfig,
-		quicConfig:        quicConfig,
-		uuid:              options.UUID,
-		password:          options.Password,
-		congestionControl: options.CongestionControl,
-		udpStream:         options.UDPStream,
-		zeroRTTHandshake:  options.ZeroRTTHandshake,
-		heartbeat:         options.Heartbeat,
-	}, nil
-}
-
-func (c *Client) offer(ctx context.Context) (*clientQUICConnection, error) {
-	conn := c.conn
-	if conn != nil && conn.active() {
-		return conn, nil
-	}
-	c.connAccess.Lock()
-	defer c.connAccess.Unlock()
-	conn = c.conn
-	if conn != nil && conn.active() {
-		return conn, nil
-	}
-	conn, err := c.offerNew(ctx)
-	if err != nil {
-		return nil, err
-	}
-	return conn, nil
-}
-
-func (c *Client) offerNew(ctx context.Context) (*clientQUICConnection, error) {
-	udpConn, err := c.dialer.DialContext(c.ctx, "udp", c.serverAddr)
-	if err != nil {
-		return nil, err
-	}
-	var quicConn quic.Connection
-	if c.zeroRTTHandshake {
-		quicConn, err = qtls.DialEarly(ctx, bufio.NewUnbindPacketConn(udpConn), udpConn.RemoteAddr(), c.tlsConfig, c.quicConfig)
-	} else {
-		quicConn, err = qtls.Dial(ctx, bufio.NewUnbindPacketConn(udpConn), udpConn.RemoteAddr(), c.tlsConfig, c.quicConfig)
-	}
-	if err != nil {
-		udpConn.Close()
-		return nil, E.Cause(err, "open connection")
-	}
-	setCongestion(c.ctx, quicConn, c.congestionControl)
-	conn := &clientQUICConnection{
-		quicConn:   quicConn,
-		rawConn:    udpConn,
-		connDone:   make(chan struct{}),
-		udpConnMap: make(map[uint16]*udpPacketConn),
-	}
-	go func() {
-		hErr := c.clientHandshake(quicConn)
-		if hErr != nil {
-			conn.closeWithError(hErr)
-		}
-	}()
-	if c.udpStream {
-		go c.loopUniStreams(conn)
-	}
-	go c.loopMessages(conn)
-	go c.loopHeartbeats(conn)
-	c.conn = conn
-	return conn, nil
-}
-
-func (c *Client) clientHandshake(conn quic.Connection) error {
-	authStream, err := conn.OpenUniStream()
-	if err != nil {
-		return E.Cause(err, "open handshake stream")
-	}
-	defer authStream.Close()
-	handshakeState := conn.ConnectionState()
-	tuicAuthToken, err := handshakeState.ExportKeyingMaterial(string(c.uuid[:]), []byte(c.password), 32)
-	if err != nil {
-		return E.Cause(err, "export keying material")
-	}
-	authRequest := buf.NewSize(AuthenticateLen)
-	authRequest.WriteByte(Version)
-	authRequest.WriteByte(CommandAuthenticate)
-	authRequest.Write(c.uuid[:])
-	authRequest.Write(tuicAuthToken)
-	return common.Error(authStream.Write(authRequest.Bytes()))
-}
-
-func (c *Client) loopHeartbeats(conn *clientQUICConnection) {
-	ticker := time.NewTicker(c.heartbeat)
-	defer ticker.Stop()
-	for {
-		select {
-		case <-conn.connDone:
-			return
-		case <-ticker.C:
-			err := conn.quicConn.SendMessage([]byte{Version, CommandHeartbeat})
-			if err != nil {
-				conn.closeWithError(E.Cause(err, "send heartbeat"))
-			}
-		}
-	}
-}
-
-func (c *Client) DialConn(ctx context.Context, destination M.Socksaddr) (net.Conn, error) {
-	conn, err := c.offer(ctx)
-	if err != nil {
-		return nil, err
-	}
-	stream, err := conn.quicConn.OpenStream()
-	if err != nil {
-		return nil, err
-	}
-	return &clientConn{
-		Stream:      stream,
-		parent:      conn,
-		destination: destination,
-	}, nil
-}
-
-func (c *Client) ListenPacket(ctx context.Context) (net.PacketConn, error) {
-	conn, err := c.offer(ctx)
-	if err != nil {
-		return nil, err
-	}
-	var sessionID uint16
-	clientPacketConn := newUDPPacketConn(ctx, conn.quicConn, c.udpStream, false, func() {
-		conn.udpAccess.Lock()
-		delete(conn.udpConnMap, sessionID)
-		conn.udpAccess.Unlock()
-	})
-	conn.udpAccess.Lock()
-	sessionID = conn.udpSessionID
-	conn.udpSessionID++
-	conn.udpConnMap[sessionID] = clientPacketConn
-	conn.udpAccess.Unlock()
-	clientPacketConn.sessionID = sessionID
-	return clientPacketConn, nil
-}
-
-func (c *Client) CloseWithError(err error) error {
-	conn := c.conn
-	if conn != nil {
-		conn.closeWithError(err)
-	}
-	return nil
-}
-
-type clientQUICConnection struct {
-	quicConn     quic.Connection
-	rawConn      io.Closer
-	closeOnce    sync.Once
-	connDone     chan struct{}
-	connErr      error
-	udpAccess    sync.RWMutex
-	udpConnMap   map[uint16]*udpPacketConn
-	udpSessionID uint16
-}
-
-func (c *clientQUICConnection) active() bool {
-	select {
-	case <-c.quicConn.Context().Done():
-		return false
-	default:
-	}
-	select {
-	case <-c.connDone:
-		return false
-	default:
-	}
-	return true
-}
-
-func (c *clientQUICConnection) closeWithError(err error) {
-	c.closeOnce.Do(func() {
-		c.connErr = err
-		close(c.connDone)
-		_ = c.quicConn.CloseWithError(0, "")
-		_ = c.rawConn.Close()
-	})
-}
-
-type clientConn struct {
-	quic.Stream
-	parent         *clientQUICConnection
-	destination    M.Socksaddr
-	requestWritten bool
-}
-
-func (c *clientConn) NeedHandshake() bool {
-	return !c.requestWritten
-}
-
-func (c *clientConn) Read(b []byte) (n int, err error) {
-	n, err = c.Stream.Read(b)
-	return n, baderror.WrapQUIC(err)
-}
-
-func (c *clientConn) Write(b []byte) (n int, err error) {
-	if !c.requestWritten {
-		request := buf.NewSize(2 + addressSerializer.AddrPortLen(c.destination) + len(b))
-		defer request.Release()
-		request.WriteByte(Version)
-		request.WriteByte(CommandConnect)
-		err = addressSerializer.WriteAddrPort(request, c.destination)
-		if err != nil {
-			return
-		}
-		request.Write(b)
-		_, err = c.Stream.Write(request.Bytes())
-		if err != nil {
-			c.parent.closeWithError(E.Cause(err, "create new connection"))
-			return 0, baderror.WrapQUIC(err)
-		}
-		c.requestWritten = true
-		return len(b), nil
-	}
-	n, err = c.Stream.Write(b)
-	return n, baderror.WrapQUIC(err)
-}
-
-func (c *clientConn) Close() error {
-	c.Stream.CancelRead(0)
-	return c.Stream.Close()
-}
-
-func (c *clientConn) LocalAddr() net.Addr {
-	return M.Socksaddr{}
-}
-
-func (c *clientConn) RemoteAddr() net.Addr {
-	return c.destination
-}

+ 0 - 112
transport/tuic/client_packet.go

@@ -1,112 +0,0 @@
-//go:build with_quic
-
-package tuic
-
-import (
-	"io"
-
-	"github.com/sagernet/quic-go"
-	"github.com/sagernet/sing/common/buf"
-	"github.com/sagernet/sing/common/bufio"
-	E "github.com/sagernet/sing/common/exceptions"
-)
-
-func (c *Client) loopMessages(conn *clientQUICConnection) {
-	for {
-		message, err := conn.quicConn.ReceiveMessage(c.ctx)
-		if err != nil {
-			conn.closeWithError(E.Cause(err, "receive message"))
-			return
-		}
-		go func() {
-			hErr := c.handleMessage(conn, message)
-			if hErr != nil {
-				conn.closeWithError(E.Cause(hErr, "handle message"))
-			}
-		}()
-	}
-}
-
-func (c *Client) handleMessage(conn *clientQUICConnection, data []byte) error {
-	if len(data) < 2 {
-		return E.New("invalid message")
-	}
-	if data[0] != Version {
-		return E.New("unknown version ", data[0])
-	}
-	switch data[1] {
-	case CommandPacket:
-		message := allocMessage()
-		err := decodeUDPMessage(message, data[2:])
-		if err != nil {
-			message.release()
-			return E.Cause(err, "decode UDP message")
-		}
-		conn.handleUDPMessage(message)
-		return nil
-	case CommandHeartbeat:
-		return nil
-	default:
-		return E.New("unknown command ", data[0])
-	}
-}
-
-func (c *Client) loopUniStreams(conn *clientQUICConnection) {
-	for {
-		stream, err := conn.quicConn.AcceptUniStream(c.ctx)
-		if err != nil {
-			conn.closeWithError(E.Cause(err, "handle uni stream"))
-			return
-		}
-		go func() {
-			hErr := c.handleUniStream(conn, stream)
-			if hErr != nil {
-				conn.closeWithError(hErr)
-			}
-		}()
-	}
-}
-
-func (c *Client) handleUniStream(conn *clientQUICConnection, stream quic.ReceiveStream) error {
-	defer stream.CancelRead(0)
-	buffer := buf.NewPacket()
-	defer buffer.Release()
-	_, err := buffer.ReadAtLeastFrom(stream, 2)
-	if err != nil {
-		return err
-	}
-	version, _ := buffer.ReadByte()
-	if version != Version {
-		return E.New("unknown version ", version)
-	}
-	command, _ := buffer.ReadByte()
-	if command != CommandPacket {
-		return E.New("unknown command ", command)
-	}
-	reader := io.MultiReader(bufio.NewCachedReader(stream, buffer), stream)
-	message := allocMessage()
-	err = readUDPMessage(message, reader)
-	if err != nil {
-		message.release()
-		return err
-	}
-	conn.handleUDPMessage(message)
-	return nil
-}
-
-func (c *clientQUICConnection) handleUDPMessage(message *udpMessage) {
-	c.udpAccess.RLock()
-	udpConn, loaded := c.udpConnMap[message.sessionID]
-	c.udpAccess.RUnlock()
-	if !loaded {
-		message.releaseMessage()
-		return
-	}
-	select {
-	case <-udpConn.ctx.Done():
-		message.releaseMessage()
-		return
-	default:
-	}
-	udpConn.inputPacket(message)
-}

+ 0 - 46
transport/tuic/congestion.go

@@ -1,46 +0,0 @@
-package tuic
-
-import (
-	"context"
-	"time"
-
-	"github.com/sagernet/quic-go"
-	"github.com/sagernet/sing-box/transport/tuic/congestion"
-	"github.com/sagernet/sing/common/ntp"
-)
-
-func setCongestion(ctx context.Context, connection quic.Connection, congestionName string) {
-	timeFunc := ntp.TimeFuncFromContext(ctx)
-	if timeFunc == nil {
-		timeFunc = time.Now
-	}
-	switch congestionName {
-	case "cubic":
-		connection.SetCongestionControl(
-			congestion.NewCubicSender(
-				congestion.DefaultClock{TimeFunc: timeFunc},
-				congestion.GetInitialPacketSize(connection.RemoteAddr()),
-				false,
-				nil,
-			),
-		)
-	case "new_reno":
-		connection.SetCongestionControl(
-			congestion.NewCubicSender(
-				congestion.DefaultClock{TimeFunc: timeFunc},
-				congestion.GetInitialPacketSize(connection.RemoteAddr()),
-				true,
-				nil,
-			),
-		)
-	case "bbr":
-		connection.SetCongestionControl(
-			congestion.NewBBRSender(
-				congestion.DefaultClock{},
-				congestion.GetInitialPacketSize(connection.RemoteAddr()),
-				congestion.InitialCongestionWindow*congestion.InitialMaxDatagramSize,
-				congestion.DefaultBBRMaxCongestionWindow*congestion.InitialMaxDatagramSize,
-			),
-		)
-	}
-}

+ 0 - 3
transport/tuic/congestion/README.md

@@ -1,3 +0,0 @@
-# congestion
-
-mod from https://github.com/MetaCubeX/Clash.Meta/tree/53f9e1ee7104473da2b4ff5da29965563084482d/transport/tuic/congestion

+ 0 - 25
transport/tuic/congestion/bandwidth.go

@@ -1,25 +0,0 @@
-package congestion
-
-import (
-	"math"
-	"time"
-
-	"github.com/sagernet/quic-go/congestion"
-)
-
-// Bandwidth of a connection
-type Bandwidth uint64
-
-const infBandwidth Bandwidth = math.MaxUint64
-
-const (
-	// BitsPerSecond is 1 bit per second
-	BitsPerSecond Bandwidth = 1
-	// BytesPerSecond is 1 byte per second
-	BytesPerSecond = 8 * BitsPerSecond
-)
-
-// BandwidthFromDelta calculates the bandwidth from a number of bytes and a time delta
-func BandwidthFromDelta(bytes congestion.ByteCount, delta time.Duration) Bandwidth {
-	return Bandwidth(bytes) * Bandwidth(time.Second) / Bandwidth(delta) * BytesPerSecond
-}

+ 0 - 374
transport/tuic/congestion/bandwidth_sampler.go

@@ -1,374 +0,0 @@
-package congestion
-
-import (
-	"math"
-	"time"
-
-	"github.com/sagernet/quic-go/congestion"
-)
-
-var InfiniteBandwidth = Bandwidth(math.MaxUint64)
-
-// SendTimeState is a subset of ConnectionStateOnSentPacket which is returned
-// to the caller when the packet is acked or lost.
-type SendTimeState struct {
-	// Whether other states in this object is valid.
-	isValid bool
-	// Whether the sender is app limited at the time the packet was sent.
-	// App limited bandwidth sample might be artificially low because the sender
-	// did not have enough data to send in order to saturate the link.
-	isAppLimited bool
-	// Total number of sent bytes at the time the packet was sent.
-	// Includes the packet itself.
-	totalBytesSent congestion.ByteCount
-	// Total number of acked bytes at the time the packet was sent.
-	totalBytesAcked congestion.ByteCount
-	// Total number of lost bytes at the time the packet was sent.
-	totalBytesLost congestion.ByteCount
-}
-
-// ConnectionStateOnSentPacket represents the information about a sent packet
-// and the state of the connection at the moment the packet was sent,
-// specifically the information about the most recently acknowledged packet at
-// that moment.
-type ConnectionStateOnSentPacket struct {
-	packetNumber congestion.PacketNumber
-	// Time at which the packet is sent.
-	sendTime time.Time
-	// Size of the packet.
-	size congestion.ByteCount
-	// The value of |totalBytesSentAtLastAckedPacket| at the time the
-	// packet was sent.
-	totalBytesSentAtLastAckedPacket congestion.ByteCount
-	// The value of |lastAckedPacketSentTime| at the time the packet was
-	// sent.
-	lastAckedPacketSentTime time.Time
-	// The value of |lastAckedPacketAckTime| at the time the packet was
-	// sent.
-	lastAckedPacketAckTime time.Time
-	// Send time states that are returned to the congestion controller when the
-	// packet is acked or lost.
-	sendTimeState SendTimeState
-}
-
-// BandwidthSample
-type BandwidthSample struct {
-	// The bandwidth at that particular sample. Zero if no valid bandwidth sample
-	// is available.
-	bandwidth Bandwidth
-	// The RTT measurement at this particular sample.  Zero if no RTT sample is
-	// available.  Does not correct for delayed ack time.
-	rtt time.Duration
-	// States captured when the packet was sent.
-	stateAtSend SendTimeState
-}
-
-func NewBandwidthSample() *BandwidthSample {
-	return &BandwidthSample{
-		// FIXME: the default value of original code is zero.
-		rtt: InfiniteRTT,
-	}
-}
-
-// BandwidthSampler keeps track of sent and acknowledged packets and outputs a
-// bandwidth sample for every packet acknowledged. The samples are taken for
-// individual packets, and are not filtered; the consumer has to filter the
-// bandwidth samples itself. In certain cases, the sampler will locally severely
-// underestimate the bandwidth, hence a maximum filter with a size of at least
-// one RTT is recommended.
-//
-// This class bases its samples on the slope of two curves: the number of bytes
-// sent over time, and the number of bytes acknowledged as received over time.
-// It produces a sample of both slopes for every packet that gets acknowledged,
-// based on a slope between two points on each of the corresponding curves. Note
-// that due to the packet loss, the number of bytes on each curve might get
-// further and further away from each other, meaning that it is not feasible to
-// compare byte values coming from different curves with each other.
-//
-// The obvious points for measuring slope sample are the ones corresponding to
-// the packet that was just acknowledged. Let us denote them as S_1 (point at
-// which the current packet was sent) and A_1 (point at which the current packet
-// was acknowledged). However, taking a slope requires two points on each line,
-// so estimating bandwidth requires picking a packet in the past with respect to
-// which the slope is measured.
-//
-// For that purpose, BandwidthSampler always keeps track of the most recently
-// acknowledged packet, and records it together with every outgoing packet.
-// When a packet gets acknowledged (A_1), it has not only information about when
-// it itself was sent (S_1), but also the information about the latest
-// acknowledged packet right before it was sent (S_0 and A_0).
-//
-// Based on that data, send and ack rate are estimated as:
-//
-//	send_rate = (bytes(S_1) - bytes(S_0)) / (time(S_1) - time(S_0))
-//	ack_rate = (bytes(A_1) - bytes(A_0)) / (time(A_1) - time(A_0))
-//
-// Here, the ack rate is intuitively the rate we want to treat as bandwidth.
-// However, in certain cases (e.g. ack compression) the ack rate at a point may
-// end up higher than the rate at which the data was originally sent, which is
-// not indicative of the real bandwidth. Hence, we use the send rate as an upper
-// bound, and the sample value is
-//
-//	rate_sample = min(send_rate, ack_rate)
-//
-// An important edge case handled by the sampler is tracking the app-limited
-// samples. There are multiple meaning of "app-limited" used interchangeably,
-// hence it is important to understand and to be able to distinguish between
-// them.
-//
-// Meaning 1: connection state. The connection is said to be app-limited when
-// there is no outstanding data to send. This means that certain bandwidth
-// samples in the future would not be an accurate indication of the link
-// capacity, and it is important to inform consumer about that. Whenever
-// connection becomes app-limited, the sampler is notified via OnAppLimited()
-// method.
-//
-// Meaning 2: a phase in the bandwidth sampler. As soon as the bandwidth
-// sampler becomes notified about the connection being app-limited, it enters
-// app-limited phase. In that phase, all *sent* packets are marked as
-// app-limited. Note that the connection itself does not have to be
-// app-limited during the app-limited phase, and in fact it will not be
-// (otherwise how would it send packets?). The boolean flag below indicates
-// whether the sampler is in that phase.
-//
-// Meaning 3: a flag on the sent packet and on the sample. If a sent packet is
-// sent during the app-limited phase, the resulting sample related to the
-// packet will be marked as app-limited.
-//
-// With the terminology issue out of the way, let us consider the question of
-// what kind of situation it addresses.
-//
-// Consider a scenario where we first send packets 1 to 20 at a regular
-// bandwidth, and then immediately run out of data. After a few seconds, we send
-// packets 21 to 60, and only receive ack for 21 between sending packets 40 and
-// 41. In this case, when we sample bandwidth for packets 21 to 40, the S_0/A_0
-// we use to compute the slope is going to be packet 20, a few seconds apart
-// from the current packet, hence the resulting estimate would be extremely low
-// and not indicative of anything. Only at packet 41 the S_0/A_0 will become 21,
-// meaning that the bandwidth sample would exclude the quiescence.
-//
-// Based on the analysis of that scenario, we implement the following rule: once
-// OnAppLimited() is called, all sent packets will produce app-limited samples
-// up until an ack for a packet that was sent after OnAppLimited() was called.
-// Note that while the scenario above is not the only scenario when the
-// connection is app-limited, the approach works in other cases too.
-type BandwidthSampler struct {
-	// The total number of congestion controlled bytes sent during the connection.
-	totalBytesSent congestion.ByteCount
-	// The total number of congestion controlled bytes which were acknowledged.
-	totalBytesAcked congestion.ByteCount
-	// The total number of congestion controlled bytes which were lost.
-	totalBytesLost congestion.ByteCount
-	// The value of |totalBytesSent| at the time the last acknowledged packet
-	// was sent. Valid only when |lastAckedPacketSentTime| is valid.
-	totalBytesSentAtLastAckedPacket congestion.ByteCount
-	// The time at which the last acknowledged packet was sent. Set to
-	// QuicTime::Zero() if no valid timestamp is available.
-	lastAckedPacketSentTime time.Time
-	// The time at which the most recent packet was acknowledged.
-	lastAckedPacketAckTime time.Time
-	// The most recently sent packet.
-	lastSendPacket congestion.PacketNumber
-	// Indicates whether the bandwidth sampler is currently in an app-limited
-	// phase.
-	isAppLimited bool
-	// The packet that will be acknowledged after this one will cause the sampler
-	// to exit the app-limited phase.
-	endOfAppLimitedPhase congestion.PacketNumber
-	// Record of the connection state at the point where each packet in flight was
-	// sent, indexed by the packet number.
-	connectionStats *ConnectionStates
-}
-
-func NewBandwidthSampler() *BandwidthSampler {
-	return &BandwidthSampler{
-		connectionStats: &ConnectionStates{
-			stats: make(map[congestion.PacketNumber]*ConnectionStateOnSentPacket),
-		},
-	}
-}
-
-// OnPacketSent Inputs the sent packet information into the sampler. Assumes that all
-// packets are sent in order. The information about the packet will not be
-// released from the sampler until it the packet is either acknowledged or
-// declared lost.
-func (s *BandwidthSampler) OnPacketSent(sentTime time.Time, lastSentPacket congestion.PacketNumber, sentBytes, bytesInFlight congestion.ByteCount, hasRetransmittableData bool) {
-	s.lastSendPacket = lastSentPacket
-
-	if !hasRetransmittableData {
-		return
-	}
-
-	s.totalBytesSent += sentBytes
-
-	// If there are no packets in flight, the time at which the new transmission
-	// opens can be treated as the A_0 point for the purpose of bandwidth
-	// sampling. This underestimates bandwidth to some extent, and produces some
-	// artificially low samples for most packets in flight, but it provides with
-	// samples at important points where we would not have them otherwise, most
-	// importantly at the beginning of the connection.
-	if bytesInFlight == 0 {
-		s.lastAckedPacketAckTime = sentTime
-		s.totalBytesSentAtLastAckedPacket = s.totalBytesSent
-
-		// In this situation ack compression is not a concern, set send rate to
-		// effectively infinite.
-		s.lastAckedPacketSentTime = sentTime
-	}
-
-	s.connectionStats.Insert(lastSentPacket, sentTime, sentBytes, s)
-}
-
-// OnPacketAcked Notifies the sampler that the |lastAckedPacket| is acknowledged. Returns a
-// bandwidth sample. If no bandwidth sample is available,
-// QuicBandwidth::Zero() is returned.
-func (s *BandwidthSampler) OnPacketAcked(ackTime time.Time, lastAckedPacket congestion.PacketNumber) *BandwidthSample {
-	sentPacketState := s.connectionStats.Get(lastAckedPacket)
-	if sentPacketState == nil {
-		return NewBandwidthSample()
-	}
-
-	sample := s.onPacketAckedInner(ackTime, lastAckedPacket, sentPacketState)
-	s.connectionStats.Remove(lastAckedPacket)
-
-	return sample
-}
-
-// onPacketAckedInner Handles the actual bandwidth calculations, whereas the outer method handles
-// retrieving and removing |sentPacket|.
-func (s *BandwidthSampler) onPacketAckedInner(ackTime time.Time, lastAckedPacket congestion.PacketNumber, sentPacket *ConnectionStateOnSentPacket) *BandwidthSample {
-	s.totalBytesAcked += sentPacket.size
-
-	s.totalBytesSentAtLastAckedPacket = sentPacket.sendTimeState.totalBytesSent
-	s.lastAckedPacketSentTime = sentPacket.sendTime
-	s.lastAckedPacketAckTime = ackTime
-
-	// Exit app-limited phase once a packet that was sent while the connection is
-	// not app-limited is acknowledged.
-	if s.isAppLimited && lastAckedPacket > s.endOfAppLimitedPhase {
-		s.isAppLimited = false
-	}
-
-	// There might have been no packets acknowledged at the moment when the
-	// current packet was sent. In that case, there is no bandwidth sample to
-	// make.
-	if sentPacket.lastAckedPacketSentTime.IsZero() {
-		return NewBandwidthSample()
-	}
-
-	// Infinite rate indicates that the sampler is supposed to discard the
-	// current send rate sample and use only the ack rate.
-	sendRate := InfiniteBandwidth
-	if sentPacket.sendTime.After(sentPacket.lastAckedPacketSentTime) {
-		sendRate = BandwidthFromDelta(sentPacket.sendTimeState.totalBytesSent-sentPacket.totalBytesSentAtLastAckedPacket, sentPacket.sendTime.Sub(sentPacket.lastAckedPacketSentTime))
-	}
-
-	// During the slope calculation, ensure that ack time of the current packet is
-	// always larger than the time of the previous packet, otherwise division by
-	// zero or integer underflow can occur.
-	if !ackTime.After(sentPacket.lastAckedPacketAckTime) {
-		// TODO(wub): Compare this code count before and after fixing clock jitter
-		// issue.
-		// if sentPacket.lastAckedPacketAckTime.Equal(sentPacket.sendTime) {
-		// This is the 1st packet after quiescense.
-		// QUIC_CODE_COUNT_N(quic_prev_ack_time_larger_than_current_ack_time, 1, 2);
-		// } else {
-		//   QUIC_CODE_COUNT_N(quic_prev_ack_time_larger_than_current_ack_time, 2, 2);
-		// }
-
-		return NewBandwidthSample()
-	}
-
-	ackRate := BandwidthFromDelta(s.totalBytesAcked-sentPacket.sendTimeState.totalBytesAcked,
-		ackTime.Sub(sentPacket.lastAckedPacketAckTime))
-
-	// Note: this sample does not account for delayed acknowledgement time.  This
-	// means that the RTT measurements here can be artificially high, especially
-	// on low bandwidth connections.
-	sample := &BandwidthSample{
-		bandwidth: minBandwidth(sendRate, ackRate),
-		rtt:       ackTime.Sub(sentPacket.sendTime),
-	}
-
-	SentPacketToSendTimeState(sentPacket, &sample.stateAtSend)
-	return sample
-}
-
-// OnPacketLost Informs the sampler that a packet is considered lost and it should no
-// longer keep track of it.
-func (s *BandwidthSampler) OnPacketLost(packetNumber congestion.PacketNumber) SendTimeState {
-	ok, sentPacket := s.connectionStats.Remove(packetNumber)
-	sendTimeState := SendTimeState{
-		isValid: ok,
-	}
-	if sentPacket != nil {
-		s.totalBytesLost += sentPacket.size
-		SentPacketToSendTimeState(sentPacket, &sendTimeState)
-	}
-
-	return sendTimeState
-}
-
-// OnAppLimited Informs the sampler that the connection is currently app-limited, causing
-// the sampler to enter the app-limited phase.  The phase will expire by
-// itself.
-func (s *BandwidthSampler) OnAppLimited() {
-	s.isAppLimited = true
-	s.endOfAppLimitedPhase = s.lastSendPacket
-}
-
-// SentPacketToSendTimeState Copy a subset of the (private) ConnectionStateOnSentPacket to the (public)
-// SendTimeState. Always set send_time_state->is_valid to true.
-func SentPacketToSendTimeState(sentPacket *ConnectionStateOnSentPacket, sendTimeState *SendTimeState) {
-	sendTimeState.isAppLimited = sentPacket.sendTimeState.isAppLimited
-	sendTimeState.totalBytesSent = sentPacket.sendTimeState.totalBytesSent
-	sendTimeState.totalBytesAcked = sentPacket.sendTimeState.totalBytesAcked
-	sendTimeState.totalBytesLost = sentPacket.sendTimeState.totalBytesLost
-	sendTimeState.isValid = true
-}
-
-// ConnectionStates Record of the connection state at the point where each packet in flight was
-// sent, indexed by the packet number.
-// FIXME: using LinkedList replace map to fast remove all the packets lower than the specified packet number.
-type ConnectionStates struct {
-	stats map[congestion.PacketNumber]*ConnectionStateOnSentPacket
-}
-
-func (s *ConnectionStates) Insert(packetNumber congestion.PacketNumber, sentTime time.Time, bytes congestion.ByteCount, sampler *BandwidthSampler) bool {
-	if _, ok := s.stats[packetNumber]; ok {
-		return false
-	}
-
-	s.stats[packetNumber] = NewConnectionStateOnSentPacket(packetNumber, sentTime, bytes, sampler)
-	return true
-}
-
-func (s *ConnectionStates) Get(packetNumber congestion.PacketNumber) *ConnectionStateOnSentPacket {
-	return s.stats[packetNumber]
-}
-
-func (s *ConnectionStates) Remove(packetNumber congestion.PacketNumber) (bool, *ConnectionStateOnSentPacket) {
-	state, ok := s.stats[packetNumber]
-	if ok {
-		delete(s.stats, packetNumber)
-	}
-	return ok, state
-}
-
-func NewConnectionStateOnSentPacket(packetNumber congestion.PacketNumber, sentTime time.Time, bytes congestion.ByteCount, sampler *BandwidthSampler) *ConnectionStateOnSentPacket {
-	return &ConnectionStateOnSentPacket{
-		packetNumber:                    packetNumber,
-		sendTime:                        sentTime,
-		size:                            bytes,
-		lastAckedPacketSentTime:         sampler.lastAckedPacketSentTime,
-		lastAckedPacketAckTime:          sampler.lastAckedPacketAckTime,
-		totalBytesSentAtLastAckedPacket: sampler.totalBytesSentAtLastAckedPacket,
-		sendTimeState: SendTimeState{
-			isValid:         true,
-			isAppLimited:    sampler.isAppLimited,
-			totalBytesSent:  sampler.totalBytesSent,
-			totalBytesAcked: sampler.totalBytesAcked,
-			totalBytesLost:  sampler.totalBytesLost,
-		},
-	}
-}

+ 0 - 1000
transport/tuic/congestion/bbr_sender.go

@@ -1,1000 +0,0 @@
-package congestion
-
-// src from https://quiche.googlesource.com/quiche.git/+/66dea072431f94095dfc3dd2743cb94ef365f7ef/quic/core/congestion_control/bbr_sender.cc
-
-import (
-	"fmt"
-	"math"
-	"math/rand"
-	"net"
-	"time"
-
-	"github.com/sagernet/quic-go/congestion"
-)
-
-const (
-	// InitialMaxDatagramSize is the default maximum packet size used in QUIC for congestion window computations in bytes.
-	InitialMaxDatagramSize        = 1252
-	InitialPacketSizeIPv4         = 1252
-	InitialPacketSizeIPv6         = 1232
-	InitialCongestionWindow       = 32
-	DefaultBBRMaxCongestionWindow = 10000
-)
-
-func GetInitialPacketSize(addr net.Addr) congestion.ByteCount {
-	maxSize := congestion.ByteCount(1200)
-	// If this is not a UDP address, we don't know anything about the MTU.
-	// Use the minimum size of an Initial packet as the max packet size.
-	if udpAddr, ok := addr.(*net.UDPAddr); ok {
-		if udpAddr.IP.To4() != nil {
-			maxSize = InitialPacketSizeIPv4
-		} else {
-			maxSize = InitialPacketSizeIPv6
-		}
-	}
-	return congestion.ByteCount(maxSize)
-}
-
-var (
-
-	// Default initial rtt used before any samples are received.
-	InitialRtt = 100 * time.Millisecond
-
-	// The gain used for the STARTUP, equal to  4*ln(2).
-	DefaultHighGain = 2.77
-
-	// The gain used in STARTUP after loss has been detected.
-	// 1.5 is enough to allow for 25% exogenous loss and still observe a 25% growth
-	// in measured bandwidth.
-	StartupAfterLossGain = 1.5
-
-	// The cycle of gains used during the PROBE_BW stage.
-	PacingGain = []float64{1.25, 0.75, 1, 1, 1, 1, 1, 1}
-
-	// The length of the gain cycle.
-	GainCycleLength = len(PacingGain)
-
-	// The size of the bandwidth filter window, in round-trips.
-	BandwidthWindowSize = GainCycleLength + 2
-
-	// The time after which the current min_rtt value expires.
-	MinRttExpiry = 10 * time.Second
-
-	// The minimum time the connection can spend in PROBE_RTT mode.
-	ProbeRttTime = time.Millisecond * 200
-
-	// If the bandwidth does not increase by the factor of |kStartupGrowthTarget|
-	// within |kRoundTripsWithoutGrowthBeforeExitingStartup| rounds, the connection
-	// will exit the STARTUP mode.
-	StartupGrowthTarget                         = 1.25
-	RoundTripsWithoutGrowthBeforeExitingStartup = int64(3)
-
-	// Coefficient of target congestion window to use when basing PROBE_RTT on BDP.
-	ModerateProbeRttMultiplier = 0.75
-
-	// Coefficient to determine if a new RTT is sufficiently similar to min_rtt that
-	// we don't need to enter PROBE_RTT.
-	SimilarMinRttThreshold = 1.125
-
-	// Congestion window gain for QUIC BBR during PROBE_BW phase.
-	DefaultCongestionWindowGainConst = 2.0
-)
-
-type bbrMode int
-
-const (
-	// Startup phase of the connection.
-	STARTUP = iota
-	// After achieving the highest possible bandwidth during the startup, lower
-	// the pacing rate in order to drain the queue.
-	DRAIN
-	// Cruising mode.
-	PROBE_BW
-	// Temporarily slow down sending in order to empty the buffer and measure
-	// the real minimum RTT.
-	PROBE_RTT
-)
-
-type bbrRecoveryState int
-
-const (
-	// Do not limit.
-	NOT_IN_RECOVERY = iota
-
-	// Allow an extra outstanding byte for each byte acknowledged.
-	CONSERVATION
-
-	// Allow two extra outstanding bytes for each byte acknowledged (slow
-	// start).
-	GROWTH
-)
-
-type bbrSender struct {
-	mode          bbrMode
-	clock         Clock
-	rttStats      congestion.RTTStatsProvider
-	bytesInFlight congestion.ByteCount
-	// return total bytes of unacked packets.
-	// GetBytesInFlight func() congestion.ByteCount
-	// Bandwidth sampler provides BBR with the bandwidth measurements at
-	// individual points.
-	sampler *BandwidthSampler
-	// The number of the round trips that have occurred during the connection.
-	roundTripCount int64
-	// The packet number of the most recently sent packet.
-	lastSendPacket congestion.PacketNumber
-	// Acknowledgement of any packet after |current_round_trip_end_| will cause
-	// the round trip counter to advance.
-	currentRoundTripEnd congestion.PacketNumber
-	// The filter that tracks the maximum bandwidth over the multiple recent
-	// round-trips.
-	maxBandwidth *WindowedFilter
-	// Tracks the maximum number of bytes acked faster than the sending rate.
-	maxAckHeight *WindowedFilter
-	// The time this aggregation started and the number of bytes acked during it.
-	aggregationEpochStartTime time.Time
-	aggregationEpochBytes     congestion.ByteCount
-	// Minimum RTT estimate.  Automatically expires within 10 seconds (and
-	// triggers PROBE_RTT mode) if no new value is sampled during that period.
-	minRtt time.Duration
-	// The time at which the current value of |min_rtt_| was assigned.
-	minRttTimestamp time.Time
-	// The maximum allowed number of bytes in flight.
-	congestionWindow congestion.ByteCount
-	// The initial value of the |congestion_window_|.
-	initialCongestionWindow congestion.ByteCount
-	// The largest value the |congestion_window_| can achieve.
-	initialMaxCongestionWindow congestion.ByteCount
-	// The smallest value the |congestion_window_| can achieve.
-	// minCongestionWindow congestion.ByteCount
-	// The pacing gain applied during the STARTUP phase.
-	highGain float64
-	// The CWND gain applied during the STARTUP phase.
-	highCwndGain float64
-	// The pacing gain applied during the DRAIN phase.
-	drainGain float64
-	// The current pacing rate of the connection.
-	pacingRate Bandwidth
-	// The gain currently applied to the pacing rate.
-	pacingGain float64
-	// The gain currently applied to the congestion window.
-	congestionWindowGain float64
-	// The gain used for the congestion window during PROBE_BW.  Latched from
-	// quic_bbr_cwnd_gain flag.
-	congestionWindowGainConst float64
-	// The number of RTTs to stay in STARTUP mode.  Defaults to 3.
-	numStartupRtts int64
-	// If true, exit startup if 1RTT has passed with no bandwidth increase and
-	// the connection is in recovery.
-	exitStartupOnLoss bool
-	// Number of round-trips in PROBE_BW mode, used for determining the current
-	// pacing gain cycle.
-	cycleCurrentOffset int
-	// The time at which the last pacing gain cycle was started.
-	lastCycleStart time.Time
-	// Indicates whether the connection has reached the full bandwidth mode.
-	isAtFullBandwidth bool
-	// Number of rounds during which there was no significant bandwidth increase.
-	roundsWithoutBandwidthGain int64
-	// The bandwidth compared to which the increase is measured.
-	bandwidthAtLastRound Bandwidth
-	// Set to true upon exiting quiescence.
-	exitingQuiescence bool
-	// Time at which PROBE_RTT has to be exited.  Setting it to zero indicates
-	// that the time is yet unknown as the number of packets in flight has not
-	// reached the required value.
-	exitProbeRttAt time.Time
-	// Indicates whether a round-trip has passed since PROBE_RTT became active.
-	probeRttRoundPassed bool
-	// Indicates whether the most recent bandwidth sample was marked as
-	// app-limited.
-	lastSampleIsAppLimited bool
-	// Indicates whether any non app-limited samples have been recorded.
-	hasNoAppLimitedSample bool
-	// Indicates app-limited calls should be ignored as long as there's
-	// enough data inflight to see more bandwidth when necessary.
-	flexibleAppLimited bool
-	// Current state of recovery.
-	recoveryState bbrRecoveryState
-	// Receiving acknowledgement of a packet after |end_recovery_at_| will cause
-	// BBR to exit the recovery mode.  A value above zero indicates at least one
-	// loss has been detected, so it must not be set back to zero.
-	endRecoveryAt congestion.PacketNumber
-	// A window used to limit the number of bytes in flight during loss recovery.
-	recoveryWindow congestion.ByteCount
-	// If true, consider all samples in recovery app-limited.
-	isAppLimitedRecovery bool
-	// When true, pace at 1.5x and disable packet conservation in STARTUP.
-	slowerStartup bool
-	// When true, disables packet conservation in STARTUP.
-	rateBasedStartup bool
-	// When non-zero, decreases the rate in STARTUP by the total number of bytes
-	// lost in STARTUP divided by CWND.
-	startupRateReductionMultiplier int64
-	// Sum of bytes lost in STARTUP.
-	startupBytesLost congestion.ByteCount
-	// When true, add the most recent ack aggregation measurement during STARTUP.
-	enableAckAggregationDuringStartup bool
-	// When true, expire the windowed ack aggregation values in STARTUP when
-	// bandwidth increases more than 25%.
-	expireAckAggregationInStartup bool
-	// If true, will not exit low gain mode until bytes_in_flight drops below BDP
-	// or it's time for high gain mode.
-	drainToTarget bool
-	// If true, use a CWND of 0.75*BDP during probe_rtt instead of 4 packets.
-	probeRttBasedOnBdp bool
-	// If true, skip probe_rtt and update the timestamp of the existing min_rtt to
-	// now if min_rtt over the last cycle is within 12.5% of the current min_rtt.
-	// Even if the min_rtt is 12.5% too low, the 25% gain cycling and 2x CWND gain
-	// should overcome an overly small min_rtt.
-	probeRttSkippedIfSimilarRtt bool
-	// If true, disable PROBE_RTT entirely as long as the connection was recently
-	// app limited.
-	probeRttDisabledIfAppLimited bool
-	appLimitedSinceLastProbeRtt  bool
-	minRttSinceLastProbeRtt      time.Duration
-	// Latched value of --quic_always_get_bw_sample_when_acked.
-	alwaysGetBwSampleWhenAcked bool
-
-	pacer *pacer
-
-	maxDatagramSize congestion.ByteCount
-}
-
-func NewBBRSender(
-	clock Clock,
-	initialMaxDatagramSize,
-	initialCongestionWindow,
-	initialMaxCongestionWindow congestion.ByteCount,
-) *bbrSender {
-	b := &bbrSender{
-		mode:                      STARTUP,
-		clock:                     clock,
-		sampler:                   NewBandwidthSampler(),
-		maxBandwidth:              NewWindowedFilter(int64(BandwidthWindowSize), MaxFilter),
-		maxAckHeight:              NewWindowedFilter(int64(BandwidthWindowSize), MaxFilter),
-		congestionWindow:          initialCongestionWindow,
-		initialCongestionWindow:   initialCongestionWindow,
-		highGain:                  DefaultHighGain,
-		highCwndGain:              DefaultHighGain,
-		drainGain:                 1.0 / DefaultHighGain,
-		pacingGain:                1.0,
-		congestionWindowGain:      1.0,
-		congestionWindowGainConst: DefaultCongestionWindowGainConst,
-		numStartupRtts:            RoundTripsWithoutGrowthBeforeExitingStartup,
-		recoveryState:             NOT_IN_RECOVERY,
-		recoveryWindow:            initialMaxCongestionWindow,
-		minRttSinceLastProbeRtt:   InfiniteRTT,
-		maxDatagramSize:           initialMaxDatagramSize,
-	}
-	b.pacer = newPacer(b.BandwidthEstimate)
-	return b
-}
-
-func (b *bbrSender) maxCongestionWindow() congestion.ByteCount {
-	return b.maxDatagramSize * DefaultBBRMaxCongestionWindow
-}
-
-func (b *bbrSender) minCongestionWindow() congestion.ByteCount {
-	return b.maxDatagramSize * b.initialCongestionWindow
-}
-
-func (b *bbrSender) SetRTTStatsProvider(provider congestion.RTTStatsProvider) {
-	b.rttStats = provider
-}
-
-func (b *bbrSender) GetBytesInFlight() congestion.ByteCount {
-	return b.bytesInFlight
-}
-
-// TimeUntilSend returns when the next packet should be sent.
-func (b *bbrSender) TimeUntilSend(bytesInFlight congestion.ByteCount) time.Time {
-	b.bytesInFlight = bytesInFlight
-	return b.pacer.TimeUntilSend()
-}
-
-func (b *bbrSender) HasPacingBudget(now time.Time) bool {
-	return b.pacer.Budget(now) >= b.maxDatagramSize
-}
-
-func (b *bbrSender) SetMaxDatagramSize(s congestion.ByteCount) {
-	if s < b.maxDatagramSize {
-		panic(fmt.Sprintf("congestion BUG: decreased max datagram size from %d to %d", b.maxDatagramSize, s))
-	}
-	cwndIsMinCwnd := b.congestionWindow == b.minCongestionWindow()
-	b.maxDatagramSize = s
-	if cwndIsMinCwnd {
-		b.congestionWindow = b.minCongestionWindow()
-	}
-	b.pacer.SetMaxDatagramSize(s)
-}
-
-func (b *bbrSender) OnPacketSent(sentTime time.Time, bytesInFlight congestion.ByteCount, packetNumber congestion.PacketNumber, bytes congestion.ByteCount, isRetransmittable bool) {
-	b.pacer.SentPacket(sentTime, bytes)
-	b.lastSendPacket = packetNumber
-
-	b.bytesInFlight = bytesInFlight
-	if bytesInFlight == 0 && b.sampler.isAppLimited {
-		b.exitingQuiescence = true
-	}
-
-	if b.aggregationEpochStartTime.IsZero() {
-		b.aggregationEpochStartTime = sentTime
-	}
-
-	b.sampler.OnPacketSent(sentTime, packetNumber, bytes, bytesInFlight, isRetransmittable)
-}
-
-func (b *bbrSender) CanSend(bytesInFlight congestion.ByteCount) bool {
-	b.bytesInFlight = bytesInFlight
-	return bytesInFlight < b.GetCongestionWindow()
-}
-
-func (b *bbrSender) GetCongestionWindow() congestion.ByteCount {
-	if b.mode == PROBE_RTT {
-		return b.ProbeRttCongestionWindow()
-	}
-
-	if b.InRecovery() && !(b.rateBasedStartup && b.mode == STARTUP) {
-		return minByteCount(b.congestionWindow, b.recoveryWindow)
-	}
-
-	return b.congestionWindow
-}
-
-func (b *bbrSender) MaybeExitSlowStart() {
-}
-
-func (b *bbrSender) OnPacketAcked(number congestion.PacketNumber, ackedBytes congestion.ByteCount, priorInFlight congestion.ByteCount, eventTime time.Time) {
-	totalBytesAckedBefore := b.sampler.totalBytesAcked
-	isRoundStart, minRttExpired := false, false
-	lastAckedPacket := number
-
-	isRoundStart = b.UpdateRoundTripCounter(lastAckedPacket)
-	minRttExpired = b.UpdateBandwidthAndMinRtt(eventTime, number, ackedBytes)
-	b.UpdateRecoveryState(false, isRoundStart)
-	bytesAcked := b.sampler.totalBytesAcked - totalBytesAckedBefore
-	excessAcked := b.UpdateAckAggregationBytes(eventTime, bytesAcked)
-
-	// Handle logic specific to STARTUP and DRAIN modes.
-	if isRoundStart && !b.isAtFullBandwidth {
-		b.CheckIfFullBandwidthReached()
-	}
-	b.MaybeExitStartupOrDrain(eventTime)
-
-	// Handle logic specific to PROBE_RTT.
-	b.MaybeEnterOrExitProbeRtt(eventTime, isRoundStart, minRttExpired)
-
-	// After the model is updated, recalculate the pacing rate and congestion
-	// window.
-	b.CalculatePacingRate()
-	b.CalculateCongestionWindow(bytesAcked, excessAcked)
-	b.CalculateRecoveryWindow(bytesAcked, congestion.ByteCount(0))
-}
-
-func (b *bbrSender) OnPacketLost(number congestion.PacketNumber, lostBytes congestion.ByteCount, priorInFlight congestion.ByteCount) {
-	eventTime := time.Now()
-	totalBytesAckedBefore := b.sampler.totalBytesAcked
-	isRoundStart, minRttExpired := false, false
-
-	b.DiscardLostPackets(number, lostBytes)
-
-	// Input the new data into the BBR model of the connection.
-	var excessAcked congestion.ByteCount
-
-	// Handle logic specific to PROBE_BW mode.
-	if b.mode == PROBE_BW {
-		b.UpdateGainCyclePhase(time.Now(), priorInFlight, true)
-	}
-
-	// Handle logic specific to STARTUP and DRAIN modes.
-	b.MaybeExitStartupOrDrain(eventTime)
-
-	// Handle logic specific to PROBE_RTT.
-	b.MaybeEnterOrExitProbeRtt(eventTime, isRoundStart, minRttExpired)
-
-	// Calculate number of packets acked and lost.
-	bytesAcked := b.sampler.totalBytesAcked - totalBytesAckedBefore
-	bytesLost := lostBytes
-
-	// After the model is updated, recalculate the pacing rate and congestion
-	// window.
-	b.CalculatePacingRate()
-	b.CalculateCongestionWindow(bytesAcked, excessAcked)
-	b.CalculateRecoveryWindow(bytesAcked, bytesLost)
-}
-
-//func (b *bbrSender) OnCongestionEvent(priorInFlight congestion.ByteCount, eventTime time.Time, ackedPackets, lostPackets []*congestion.Packet) {
-//	totalBytesAckedBefore := b.sampler.totalBytesAcked
-//	isRoundStart, minRttExpired := false, false
-//
-//	if lostPackets != nil {
-//		b.DiscardLostPackets(lostPackets)
-//	}
-//
-//	// Input the new data into the BBR model of the connection.
-//	var excessAcked congestion.ByteCount
-//	if len(ackedPackets) > 0 {
-//		lastAckedPacket := ackedPackets[len(ackedPackets)-1].PacketNumber
-//		isRoundStart = b.UpdateRoundTripCounter(lastAckedPacket)
-//		minRttExpired = b.UpdateBandwidthAndMinRtt(eventTime, ackedPackets)
-//		b.UpdateRecoveryState(lastAckedPacket, len(lostPackets) > 0, isRoundStart)
-//		bytesAcked := b.sampler.totalBytesAcked - totalBytesAckedBefore
-//		excessAcked = b.UpdateAckAggregationBytes(eventTime, bytesAcked)
-//	}
-//
-//	// Handle logic specific to PROBE_BW mode.
-//	if b.mode == PROBE_BW {
-//		b.UpdateGainCyclePhase(eventTime, priorInFlight, len(lostPackets) > 0)
-//	}
-//
-//	// Handle logic specific to STARTUP and DRAIN modes.
-//	if isRoundStart && !b.isAtFullBandwidth {
-//		b.CheckIfFullBandwidthReached()
-//	}
-//	b.MaybeExitStartupOrDrain(eventTime)
-//
-//	// Handle logic specific to PROBE_RTT.
-//	b.MaybeEnterOrExitProbeRtt(eventTime, isRoundStart, minRttExpired)
-//
-//	// Calculate number of packets acked and lost.
-//	bytesAcked := b.sampler.totalBytesAcked - totalBytesAckedBefore
-//	bytesLost := congestion.ByteCount(0)
-//	for _, packet := range lostPackets {
-//		bytesLost += packet.Length
-//	}
-//
-//	// After the model is updated, recalculate the pacing rate and congestion
-//	// window.
-//	b.CalculatePacingRate()
-//	b.CalculateCongestionWindow(bytesAcked, excessAcked)
-//	b.CalculateRecoveryWindow(bytesAcked, bytesLost)
-//}
-
-//func (b *bbrSender) SetNumEmulatedConnections(n int) {
-//
-//}
-
-func (b *bbrSender) OnRetransmissionTimeout(packetsRetransmitted bool) {
-}
-
-//func (b *bbrSender) OnConnectionMigration() {
-//
-//}
-
-//// Experiments
-//func (b *bbrSender) SetSlowStartLargeReduction(enabled bool) {
-//
-//}
-
-//func (b *bbrSender) BandwidthEstimate() Bandwidth {
-//	return Bandwidth(b.maxBandwidth.GetBest())
-//}
-
-// BandwidthEstimate returns the current bandwidth estimate
-func (b *bbrSender) BandwidthEstimate() Bandwidth {
-	if b.rttStats == nil {
-		return infBandwidth
-	}
-	srtt := b.rttStats.SmoothedRTT()
-	if srtt == 0 {
-		// If we haven't measured an rtt, the bandwidth estimate is unknown.
-		return infBandwidth
-	}
-	return BandwidthFromDelta(b.GetCongestionWindow(), srtt)
-}
-
-//func (b *bbrSender) HybridSlowStart() *HybridSlowStart {
-//	return nil
-//}
-
-//func (b *bbrSender) SlowstartThreshold() congestion.ByteCount {
-//	return 0
-//}
-
-//func (b *bbrSender) RenoBeta() float32 {
-//	return 0.0
-//}
-
-func (b *bbrSender) InRecovery() bool {
-	return b.recoveryState != NOT_IN_RECOVERY
-}
-
-func (b *bbrSender) InSlowStart() bool {
-	return b.mode == STARTUP
-}
-
-//func (b *bbrSender) ShouldSendProbingPacket() bool {
-//	if b.pacingGain <= 1 {
-//		return false
-//	}
-//	// TODO(b/77975811): If the pipe is highly under-utilized, consider not
-//	// sending a probing transmission, because the extra bandwidth is not needed.
-//	// If flexible_app_limited is enabled, check if the pipe is sufficiently full.
-//	if b.flexibleAppLimited {
-//		return !b.IsPipeSufficientlyFull()
-//	} else {
-//		return true
-//	}
-//}
-
-//func (b *bbrSender) IsPipeSufficientlyFull() bool {
-//	// See if we need more bytes in flight to see more bandwidth.
-//	if b.mode == STARTUP {
-//		// STARTUP exits if it doesn't observe a 25% bandwidth increase, so the CWND
-//		// must be more than 25% above the target.
-//		return b.GetBytesInFlight() >= b.GetTargetCongestionWindow(1.5)
-//	}
-//	if b.pacingGain > 1 {
-//		// Super-unity PROBE_BW doesn't exit until 1.25 * BDP is achieved.
-//		return b.GetBytesInFlight() >= b.GetTargetCongestionWindow(b.pacingGain)
-//	}
-//	// If bytes_in_flight are above the target congestion window, it should be
-//	// possible to observe the same or more bandwidth if it's available.
-//	return b.GetBytesInFlight() >= b.GetTargetCongestionWindow(1.1)
-//}
-
-//func (b *bbrSender) SetFromConfig() {
-//	// TODO: not impl.
-//}
-
-func (b *bbrSender) UpdateRoundTripCounter(lastAckedPacket congestion.PacketNumber) bool {
-	if b.currentRoundTripEnd == 0 || lastAckedPacket > b.currentRoundTripEnd {
-		b.currentRoundTripEnd = lastAckedPacket
-		b.roundTripCount++
-		// if b.rttStats != nil && b.InSlowStart() {
-		// TODO: ++stats_->slowstart_num_rtts;
-		// }
-		return true
-	}
-	return false
-}
-
-func (b *bbrSender) UpdateBandwidthAndMinRtt(now time.Time, number congestion.PacketNumber, ackedBytes congestion.ByteCount) bool {
-	sampleMinRtt := InfiniteRTT
-
-	if !b.alwaysGetBwSampleWhenAcked && ackedBytes == 0 {
-		// Skip acked packets with 0 in flight bytes when updating bandwidth.
-		return false
-	}
-	bandwidthSample := b.sampler.OnPacketAcked(now, number)
-	if b.alwaysGetBwSampleWhenAcked && !bandwidthSample.stateAtSend.isValid {
-		// From the sampler's perspective, the packet has never been sent, or the
-		// packet has been acked or marked as lost previously.
-		return false
-	}
-	b.lastSampleIsAppLimited = bandwidthSample.stateAtSend.isAppLimited
-	//     has_non_app_limited_sample_ |=
-	//        !bandwidth_sample.state_at_send.is_app_limited;
-	if !bandwidthSample.stateAtSend.isAppLimited {
-		b.hasNoAppLimitedSample = true
-	}
-	if bandwidthSample.rtt > 0 {
-		sampleMinRtt = minRtt(sampleMinRtt, bandwidthSample.rtt)
-	}
-	if !bandwidthSample.stateAtSend.isAppLimited || bandwidthSample.bandwidth > b.BandwidthEstimate() {
-		b.maxBandwidth.Update(int64(bandwidthSample.bandwidth), b.roundTripCount)
-	}
-
-	// If none of the RTT samples are valid, return immediately.
-	if sampleMinRtt == InfiniteRTT {
-		return false
-	}
-
-	b.minRttSinceLastProbeRtt = minRtt(b.minRttSinceLastProbeRtt, sampleMinRtt)
-	// Do not expire min_rtt if none was ever available.
-	minRttExpired := b.minRtt > 0 && (now.After(b.minRttTimestamp.Add(MinRttExpiry)))
-	if minRttExpired || sampleMinRtt < b.minRtt || b.minRtt == 0 {
-		if minRttExpired && b.ShouldExtendMinRttExpiry() {
-			minRttExpired = false
-		} else {
-			b.minRtt = sampleMinRtt
-		}
-		b.minRttTimestamp = now
-		// Reset since_last_probe_rtt fields.
-		b.minRttSinceLastProbeRtt = InfiniteRTT
-		b.appLimitedSinceLastProbeRtt = false
-	}
-
-	return minRttExpired
-}
-
-func (b *bbrSender) ShouldExtendMinRttExpiry() bool {
-	if b.probeRttDisabledIfAppLimited && b.appLimitedSinceLastProbeRtt {
-		// Extend the current min_rtt if we've been app limited recently.
-		return true
-	}
-
-	minRttIncreasedSinceLastProbe := b.minRttSinceLastProbeRtt > time.Duration(float64(b.minRtt)*SimilarMinRttThreshold)
-	if b.probeRttSkippedIfSimilarRtt && b.appLimitedSinceLastProbeRtt && !minRttIncreasedSinceLastProbe {
-		// Extend the current min_rtt if we've been app limited recently and an rtt
-		// has been measured in that time that's less than 12.5% more than the
-		// current min_rtt.
-		return true
-	}
-
-	return false
-}
-
-func (b *bbrSender) DiscardLostPackets(number congestion.PacketNumber, lostBytes congestion.ByteCount) {
-	b.sampler.OnPacketLost(number)
-	if b.mode == STARTUP {
-		// if b.rttStats != nil {
-		// TODO: slow start.
-		// }
-		if b.startupRateReductionMultiplier != 0 {
-			b.startupBytesLost += lostBytes
-		}
-	}
-}
-
-func (b *bbrSender) UpdateRecoveryState(hasLosses, isRoundStart bool) {
-	// Exit recovery when there are no losses for a round.
-	if !hasLosses {
-		b.endRecoveryAt = b.lastSendPacket
-	}
-	switch b.recoveryState {
-	case NOT_IN_RECOVERY:
-		// Enter conservation on the first loss.
-		if hasLosses {
-			b.recoveryState = CONSERVATION
-			// This will cause the |recovery_window_| to be set to the correct
-			// value in CalculateRecoveryWindow().
-			b.recoveryWindow = 0
-			// Since the conservation phase is meant to be lasting for a whole
-			// round, extend the current round as if it were started right now.
-			b.currentRoundTripEnd = b.lastSendPacket
-			if false && b.lastSampleIsAppLimited {
-				b.isAppLimitedRecovery = true
-			}
-		}
-	case CONSERVATION:
-		if isRoundStart {
-			b.recoveryState = GROWTH
-		}
-		fallthrough
-	case GROWTH:
-		// Exit recovery if appropriate.
-		if !hasLosses && b.lastSendPacket > b.endRecoveryAt {
-			b.recoveryState = NOT_IN_RECOVERY
-			b.isAppLimitedRecovery = false
-		}
-	}
-
-	if b.recoveryState != NOT_IN_RECOVERY && b.isAppLimitedRecovery {
-		b.sampler.OnAppLimited()
-	}
-}
-
-func (b *bbrSender) UpdateAckAggregationBytes(ackTime time.Time, ackedBytes congestion.ByteCount) congestion.ByteCount {
-	// Compute how many bytes are expected to be delivered, assuming max bandwidth
-	// is correct.
-	expectedAckedBytes := congestion.ByteCount(b.maxBandwidth.GetBest()) *
-		congestion.ByteCount((ackTime.Sub(b.aggregationEpochStartTime)))
-	// Reset the current aggregation epoch as soon as the ack arrival rate is less
-	// than or equal to the max bandwidth.
-	if b.aggregationEpochBytes <= expectedAckedBytes {
-		// Reset to start measuring a new aggregation epoch.
-		b.aggregationEpochBytes = ackedBytes
-		b.aggregationEpochStartTime = ackTime
-		return 0
-	}
-	// Compute how many extra bytes were delivered vs max bandwidth.
-	// Include the bytes most recently acknowledged to account for stretch acks.
-	b.aggregationEpochBytes += ackedBytes
-	b.maxAckHeight.Update(int64(b.aggregationEpochBytes-expectedAckedBytes), b.roundTripCount)
-	return b.aggregationEpochBytes - expectedAckedBytes
-}
-
-func (b *bbrSender) UpdateGainCyclePhase(now time.Time, priorInFlight congestion.ByteCount, hasLosses bool) {
-	bytesInFlight := b.GetBytesInFlight()
-	// In most cases, the cycle is advanced after an RTT passes.
-	shouldAdvanceGainCycling := now.Sub(b.lastCycleStart) > b.GetMinRtt()
-
-	// If the pacing gain is above 1.0, the connection is trying to probe the
-	// bandwidth by increasing the number of bytes in flight to at least
-	// pacing_gain * BDP.  Make sure that it actually reaches the target, as long
-	// as there are no losses suggesting that the buffers are not able to hold
-	// that much.
-	if b.pacingGain > 1.0 && !hasLosses && priorInFlight < b.GetTargetCongestionWindow(b.pacingGain) {
-		shouldAdvanceGainCycling = false
-	}
-	// If pacing gain is below 1.0, the connection is trying to drain the extra
-	// queue which could have been incurred by probing prior to it.  If the number
-	// of bytes in flight falls down to the estimated BDP value earlier, conclude
-	// that the queue has been successfully drained and exit this cycle early.
-	if b.pacingGain < 1.0 && bytesInFlight <= b.GetTargetCongestionWindow(1.0) {
-		shouldAdvanceGainCycling = true
-	}
-
-	if shouldAdvanceGainCycling {
-		b.cycleCurrentOffset = (b.cycleCurrentOffset + 1) % GainCycleLength
-		b.lastCycleStart = now
-		// Stay in low gain mode until the target BDP is hit.
-		// Low gain mode will be exited immediately when the target BDP is achieved.
-		if b.drainToTarget && b.pacingGain < 1.0 && PacingGain[b.cycleCurrentOffset] == 1.0 &&
-			bytesInFlight > b.GetTargetCongestionWindow(1.0) {
-			return
-		}
-		b.pacingGain = PacingGain[b.cycleCurrentOffset]
-	}
-}
-
-func (b *bbrSender) GetTargetCongestionWindow(gain float64) congestion.ByteCount {
-	bdp := congestion.ByteCount(b.GetMinRtt()) * congestion.ByteCount(b.BandwidthEstimate())
-	congestionWindow := congestion.ByteCount(gain * float64(bdp))
-
-	// BDP estimate will be zero if no bandwidth samples are available yet.
-	if congestionWindow == 0 {
-		congestionWindow = congestion.ByteCount(gain * float64(b.initialCongestionWindow))
-	}
-
-	return maxByteCount(congestionWindow, b.minCongestionWindow())
-}
-
-func (b *bbrSender) CheckIfFullBandwidthReached() {
-	if b.lastSampleIsAppLimited {
-		return
-	}
-
-	target := Bandwidth(float64(b.bandwidthAtLastRound) * StartupGrowthTarget)
-	if b.BandwidthEstimate() >= target {
-		b.bandwidthAtLastRound = b.BandwidthEstimate()
-		b.roundsWithoutBandwidthGain = 0
-		if b.expireAckAggregationInStartup {
-			// Expire old excess delivery measurements now that bandwidth increased.
-			b.maxAckHeight.Reset(0, b.roundTripCount)
-		}
-		return
-	}
-	b.roundsWithoutBandwidthGain++
-	if b.roundsWithoutBandwidthGain >= b.numStartupRtts || (b.exitStartupOnLoss && b.InRecovery()) {
-		b.isAtFullBandwidth = true
-	}
-}
-
-func (b *bbrSender) MaybeExitStartupOrDrain(now time.Time) {
-	if b.mode == STARTUP && b.isAtFullBandwidth {
-		b.OnExitStartup(now)
-		b.mode = DRAIN
-		b.pacingGain = b.drainGain
-		b.congestionWindowGain = b.highCwndGain
-	}
-	if b.mode == DRAIN && b.GetBytesInFlight() <= b.GetTargetCongestionWindow(1) {
-		b.EnterProbeBandwidthMode(now)
-	}
-}
-
-func (b *bbrSender) EnterProbeBandwidthMode(now time.Time) {
-	b.mode = PROBE_BW
-	b.congestionWindowGain = b.congestionWindowGainConst
-
-	// Pick a random offset for the gain cycle out of {0, 2..7} range. 1 is
-	// excluded because in that case increased gain and decreased gain would not
-	// follow each other.
-	b.cycleCurrentOffset = rand.Int() % (GainCycleLength - 1)
-	if b.cycleCurrentOffset >= 1 {
-		b.cycleCurrentOffset += 1
-	}
-
-	b.lastCycleStart = now
-	b.pacingGain = PacingGain[b.cycleCurrentOffset]
-}
-
-func (b *bbrSender) MaybeEnterOrExitProbeRtt(now time.Time, isRoundStart, minRttExpired bool) {
-	if minRttExpired && !b.exitingQuiescence && b.mode != PROBE_RTT {
-		if b.InSlowStart() {
-			b.OnExitStartup(now)
-		}
-		b.mode = PROBE_RTT
-		b.pacingGain = 1.0
-		// Do not decide on the time to exit PROBE_RTT until the |bytes_in_flight|
-		// is at the target small value.
-		b.exitProbeRttAt = time.Time{}
-	}
-
-	if b.mode == PROBE_RTT {
-		b.sampler.OnAppLimited()
-		if b.exitProbeRttAt.IsZero() {
-			// If the window has reached the appropriate size, schedule exiting
-			// PROBE_RTT.  The CWND during PROBE_RTT is kMinimumCongestionWindow, but
-			// we allow an extra packet since QUIC checks CWND before sending a
-			// packet.
-			if b.GetBytesInFlight() < b.ProbeRttCongestionWindow()+b.maxDatagramSize {
-				b.exitProbeRttAt = now.Add(ProbeRttTime)
-				b.probeRttRoundPassed = false
-			}
-		} else {
-			if isRoundStart {
-				b.probeRttRoundPassed = true
-			}
-			if !now.Before(b.exitProbeRttAt) && b.probeRttRoundPassed {
-				b.minRttTimestamp = now
-				if !b.isAtFullBandwidth {
-					b.EnterStartupMode(now)
-				} else {
-					b.EnterProbeBandwidthMode(now)
-				}
-			}
-		}
-	}
-	b.exitingQuiescence = false
-}
-
-func (b *bbrSender) ProbeRttCongestionWindow() congestion.ByteCount {
-	if b.probeRttBasedOnBdp {
-		return b.GetTargetCongestionWindow(ModerateProbeRttMultiplier)
-	} else {
-		return b.minCongestionWindow()
-	}
-}
-
-func (b *bbrSender) EnterStartupMode(now time.Time) {
-	// if b.rttStats != nil {
-	// TODO: slow start.
-	// }
-	b.mode = STARTUP
-	b.pacingGain = b.highGain
-	b.congestionWindowGain = b.highCwndGain
-}
-
-func (b *bbrSender) OnExitStartup(now time.Time) {
-	if b.rttStats == nil {
-		return
-	}
-	// TODO: slow start.
-}
-
-func (b *bbrSender) CalculatePacingRate() {
-	if b.BandwidthEstimate() == 0 {
-		return
-	}
-
-	targetRate := Bandwidth(b.pacingGain * float64(b.BandwidthEstimate()))
-	if b.isAtFullBandwidth {
-		b.pacingRate = targetRate
-		return
-	}
-
-	// Pace at the rate of initial_window / RTT as soon as RTT measurements are
-	// available.
-	if b.pacingRate == 0 && b.rttStats.MinRTT() > 0 {
-		b.pacingRate = BandwidthFromDelta(b.initialCongestionWindow, b.rttStats.MinRTT())
-		return
-	}
-	// Slow the pacing rate in STARTUP once loss has ever been detected.
-	hasEverDetectedLoss := b.endRecoveryAt > 0
-	if b.slowerStartup && hasEverDetectedLoss && b.hasNoAppLimitedSample {
-		b.pacingRate = Bandwidth(StartupAfterLossGain * float64(b.BandwidthEstimate()))
-		return
-	}
-
-	// Slow the pacing rate in STARTUP by the bytes_lost / CWND.
-	if b.startupRateReductionMultiplier != 0 && hasEverDetectedLoss && b.hasNoAppLimitedSample {
-		b.pacingRate = Bandwidth((1.0 - (float64(b.startupBytesLost) * float64(b.startupRateReductionMultiplier) / float64(b.congestionWindow))) * float64(targetRate))
-		// Ensure the pacing rate doesn't drop below the startup growth target times
-		// the bandwidth estimate.
-		b.pacingRate = maxBandwidth(b.pacingRate, Bandwidth(StartupGrowthTarget*float64(b.BandwidthEstimate())))
-		return
-	}
-
-	// Do not decrease the pacing rate during startup.
-	b.pacingRate = maxBandwidth(b.pacingRate, targetRate)
-}
-
-func (b *bbrSender) CalculateCongestionWindow(ackedBytes, excessAcked congestion.ByteCount) {
-	if b.mode == PROBE_RTT {
-		return
-	}
-
-	targetWindow := b.GetTargetCongestionWindow(b.congestionWindowGain)
-	if b.isAtFullBandwidth {
-		// Add the max recently measured ack aggregation to CWND.
-		targetWindow += congestion.ByteCount(b.maxAckHeight.GetBest())
-	} else if b.enableAckAggregationDuringStartup {
-		// Add the most recent excess acked.  Because CWND never decreases in
-		// STARTUP, this will automatically create a very localized max filter.
-		targetWindow += excessAcked
-	}
-
-	// Instead of immediately setting the target CWND as the new one, BBR grows
-	// the CWND towards |target_window| by only increasing it |bytes_acked| at a
-	// time.
-	addBytesAcked := true || !b.InRecovery()
-	if b.isAtFullBandwidth {
-		b.congestionWindow = minByteCount(targetWindow, b.congestionWindow+ackedBytes)
-	} else if addBytesAcked && (b.congestionWindow < targetWindow || b.sampler.totalBytesAcked < b.initialCongestionWindow) {
-		// If the connection is not yet out of startup phase, do not decrease the
-		// window.
-		b.congestionWindow += ackedBytes
-	}
-
-	// Enforce the limits on the congestion window.
-	b.congestionWindow = maxByteCount(b.congestionWindow, b.minCongestionWindow())
-	b.congestionWindow = minByteCount(b.congestionWindow, b.maxCongestionWindow())
-}
-
-func (b *bbrSender) CalculateRecoveryWindow(ackedBytes, lostBytes congestion.ByteCount) {
-	if b.rateBasedStartup && b.mode == STARTUP {
-		return
-	}
-
-	if b.recoveryState == NOT_IN_RECOVERY {
-		return
-	}
-
-	// Set up the initial recovery window.
-	if b.recoveryWindow == 0 {
-		b.recoveryWindow = maxByteCount(b.GetBytesInFlight()+ackedBytes, b.minCongestionWindow())
-		return
-	}
-
-	// Remove losses from the recovery window, while accounting for a potential
-	// integer underflow.
-	if b.recoveryWindow >= lostBytes {
-		b.recoveryWindow -= lostBytes
-	} else {
-		b.recoveryWindow = congestion.ByteCount(b.maxDatagramSize)
-	}
-	// In CONSERVATION mode, just subtracting losses is sufficient.  In GROWTH,
-	// release additional |bytes_acked| to achieve a slow-start-like behavior.
-	if b.recoveryState == GROWTH {
-		b.recoveryWindow += ackedBytes
-	}
-	// Sanity checks.  Ensure that we always allow to send at least an MSS or
-	// |bytes_acked| in response, whichever is larger.
-	b.recoveryWindow = maxByteCount(b.recoveryWindow, b.GetBytesInFlight()+ackedBytes)
-	b.recoveryWindow = maxByteCount(b.recoveryWindow, b.minCongestionWindow())
-}
-
-var _ congestion.CongestionControl = (*bbrSender)(nil)
-
-func (b *bbrSender) GetMinRtt() time.Duration {
-	if b.minRtt > 0 {
-		return b.minRtt
-	} else {
-		return InitialRtt
-	}
-}
-
-func minRtt(a, b time.Duration) time.Duration {
-	if a < b {
-		return a
-	} else {
-		return b
-	}
-}
-
-func minBandwidth(a, b Bandwidth) Bandwidth {
-	if a < b {
-		return a
-	} else {
-		return b
-	}
-}
-
-func maxBandwidth(a, b Bandwidth) Bandwidth {
-	if a > b {
-		return a
-	} else {
-		return b
-	}
-}
-
-func maxByteCount(a, b congestion.ByteCount) congestion.ByteCount {
-	if a > b {
-		return a
-	} else {
-		return b
-	}
-}
-
-func minByteCount(a, b congestion.ByteCount) congestion.ByteCount {
-	if a < b {
-		return a
-	} else {
-		return b
-	}
-}
-
-var InfiniteRTT = time.Duration(math.MaxInt64)

+ 0 - 20
transport/tuic/congestion/clock.go

@@ -1,20 +0,0 @@
-package congestion
-
-import "time"
-
-// A Clock returns the current time
-type Clock interface {
-	Now() time.Time
-}
-
-// DefaultClock implements the Clock interface using the Go stdlib clock.
-type DefaultClock struct {
-	TimeFunc func() time.Time
-}
-
-var _ Clock = DefaultClock{}
-
-// Now gets the current time
-func (c DefaultClock) Now() time.Time {
-	return c.TimeFunc()
-}

+ 0 - 213
transport/tuic/congestion/cubic.go

@@ -1,213 +0,0 @@
-package congestion
-
-import (
-	"math"
-	"time"
-
-	"github.com/sagernet/quic-go/congestion"
-)
-
-// This cubic implementation is based on the one found in Chromiums's QUIC
-// implementation, in the files net/quic/congestion_control/cubic.{hh,cc}.
-
-// Constants based on TCP defaults.
-// The following constants are in 2^10 fractions of a second instead of ms to
-// allow a 10 shift right to divide.
-
-// 1024*1024^3 (first 1024 is from 0.100^3)
-// where 0.100 is 100 ms which is the scaling round trip time.
-const (
-	cubeScale                                      = 40
-	cubeCongestionWindowScale                      = 410
-	cubeFactor                congestion.ByteCount = 1 << cubeScale / cubeCongestionWindowScale / maxDatagramSize
-	// TODO: when re-enabling cubic, make sure to use the actual packet size here
-	maxDatagramSize = congestion.ByteCount(InitialPacketSizeIPv4)
-)
-
-const defaultNumConnections = 1
-
-// Default Cubic backoff factor
-const beta float32 = 0.7
-
-// Additional backoff factor when loss occurs in the concave part of the Cubic
-// curve. This additional backoff factor is expected to give up bandwidth to
-// new concurrent flows and speed up convergence.
-const betaLastMax float32 = 0.85
-
-// Cubic implements the cubic algorithm from TCP
-type Cubic struct {
-	clock Clock
-
-	// Number of connections to simulate.
-	numConnections int
-
-	// Time when this cycle started, after last loss event.
-	epoch time.Time
-
-	// Max congestion window used just before last loss event.
-	// Note: to improve fairness to other streams an additional back off is
-	// applied to this value if the new value is below our latest value.
-	lastMaxCongestionWindow congestion.ByteCount
-
-	// Number of acked bytes since the cycle started (epoch).
-	ackedBytesCount congestion.ByteCount
-
-	// TCP Reno equivalent congestion window in packets.
-	estimatedTCPcongestionWindow congestion.ByteCount
-
-	// Origin point of cubic function.
-	originPointCongestionWindow congestion.ByteCount
-
-	// Time to origin point of cubic function in 2^10 fractions of a second.
-	timeToOriginPoint uint32
-
-	// Last congestion window in packets computed by cubic function.
-	lastTargetCongestionWindow congestion.ByteCount
-}
-
-// NewCubic returns a new Cubic instance
-func NewCubic(clock Clock) *Cubic {
-	c := &Cubic{
-		clock:          clock,
-		numConnections: defaultNumConnections,
-	}
-	c.Reset()
-	return c
-}
-
-// Reset is called after a timeout to reset the cubic state
-func (c *Cubic) Reset() {
-	c.epoch = time.Time{}
-	c.lastMaxCongestionWindow = 0
-	c.ackedBytesCount = 0
-	c.estimatedTCPcongestionWindow = 0
-	c.originPointCongestionWindow = 0
-	c.timeToOriginPoint = 0
-	c.lastTargetCongestionWindow = 0
-}
-
-func (c *Cubic) alpha() float32 {
-	// TCPFriendly alpha is described in Section 3.3 of the CUBIC paper. Note that
-	// beta here is a cwnd multiplier, and is equal to 1-beta from the paper.
-	// We derive the equivalent alpha for an N-connection emulation as:
-	b := c.beta()
-	return 3 * float32(c.numConnections) * float32(c.numConnections) * (1 - b) / (1 + b)
-}
-
-func (c *Cubic) beta() float32 {
-	// kNConnectionBeta is the backoff factor after loss for our N-connection
-	// emulation, which emulates the effective backoff of an ensemble of N
-	// TCP-Reno connections on a single loss event. The effective multiplier is
-	// computed as:
-	return (float32(c.numConnections) - 1 + beta) / float32(c.numConnections)
-}
-
-func (c *Cubic) betaLastMax() float32 {
-	// betaLastMax is the additional backoff factor after loss for our
-	// N-connection emulation, which emulates the additional backoff of
-	// an ensemble of N TCP-Reno connections on a single loss event. The
-	// effective multiplier is computed as:
-	return (float32(c.numConnections) - 1 + betaLastMax) / float32(c.numConnections)
-}
-
-// OnApplicationLimited is called on ack arrival when sender is unable to use
-// the available congestion window. Resets Cubic state during quiescence.
-func (c *Cubic) OnApplicationLimited() {
-	// When sender is not using the available congestion window, the window does
-	// not grow. But to be RTT-independent, Cubic assumes that the sender has been
-	// using the entire window during the time since the beginning of the current
-	// "epoch" (the end of the last loss recovery period). Since
-	// application-limited periods break this assumption, we reset the epoch when
-	// in such a period. This reset effectively freezes congestion window growth
-	// through application-limited periods and allows Cubic growth to continue
-	// when the entire window is being used.
-	c.epoch = time.Time{}
-}
-
-// CongestionWindowAfterPacketLoss computes a new congestion window to use after
-// a loss event. Returns the new congestion window in packets. The new
-// congestion window is a multiplicative decrease of our current window.
-func (c *Cubic) CongestionWindowAfterPacketLoss(currentCongestionWindow congestion.ByteCount) congestion.ByteCount {
-	if currentCongestionWindow+maxDatagramSize < c.lastMaxCongestionWindow {
-		// We never reached the old max, so assume we are competing with another
-		// flow. Use our extra back off factor to allow the other flow to go up.
-		c.lastMaxCongestionWindow = congestion.ByteCount(c.betaLastMax() * float32(currentCongestionWindow))
-	} else {
-		c.lastMaxCongestionWindow = currentCongestionWindow
-	}
-	c.epoch = time.Time{} // Reset time.
-	return congestion.ByteCount(float32(currentCongestionWindow) * c.beta())
-}
-
-// CongestionWindowAfterAck computes a new congestion window to use after a received ACK.
-// Returns the new congestion window in packets. The new congestion window
-// follows a cubic function that depends on the time passed since last
-// packet loss.
-func (c *Cubic) CongestionWindowAfterAck(
-	ackedBytes congestion.ByteCount,
-	currentCongestionWindow congestion.ByteCount,
-	delayMin time.Duration,
-	eventTime time.Time,
-) congestion.ByteCount {
-	c.ackedBytesCount += ackedBytes
-
-	if c.epoch.IsZero() {
-		// First ACK after a loss event.
-		c.epoch = eventTime            // Start of epoch.
-		c.ackedBytesCount = ackedBytes // Reset count.
-		// Reset estimated_tcp_congestion_window_ to be in sync with cubic.
-		c.estimatedTCPcongestionWindow = currentCongestionWindow
-		if c.lastMaxCongestionWindow <= currentCongestionWindow {
-			c.timeToOriginPoint = 0
-			c.originPointCongestionWindow = currentCongestionWindow
-		} else {
-			c.timeToOriginPoint = uint32(math.Cbrt(float64(cubeFactor * (c.lastMaxCongestionWindow - currentCongestionWindow))))
-			c.originPointCongestionWindow = c.lastMaxCongestionWindow
-		}
-	}
-
-	// Change the time unit from microseconds to 2^10 fractions per second. Take
-	// the round trip time in account. This is done to allow us to use shift as a
-	// divide operator.
-	elapsedTime := int64(eventTime.Add(delayMin).Sub(c.epoch)/time.Microsecond) << 10 / (1000 * 1000)
-
-	// Right-shifts of negative, signed numbers have implementation-dependent
-	// behavior, so force the offset to be positive, as is done in the kernel.
-	offset := int64(c.timeToOriginPoint) - elapsedTime
-	if offset < 0 {
-		offset = -offset
-	}
-
-	deltaCongestionWindow := congestion.ByteCount(cubeCongestionWindowScale*offset*offset*offset) * maxDatagramSize >> cubeScale
-	var targetCongestionWindow congestion.ByteCount
-	if elapsedTime > int64(c.timeToOriginPoint) {
-		targetCongestionWindow = c.originPointCongestionWindow + deltaCongestionWindow
-	} else {
-		targetCongestionWindow = c.originPointCongestionWindow - deltaCongestionWindow
-	}
-	// Limit the CWND increase to half the acked bytes.
-	targetCongestionWindow = Min(targetCongestionWindow, currentCongestionWindow+c.ackedBytesCount/2)
-
-	// Increase the window by approximately Alpha * 1 MSS of bytes every
-	// time we ack an estimated tcp window of bytes.  For small
-	// congestion windows (less than 25), the formula below will
-	// increase slightly slower than linearly per estimated tcp window
-	// of bytes.
-	c.estimatedTCPcongestionWindow += congestion.ByteCount(float32(c.ackedBytesCount) * c.alpha() * float32(maxDatagramSize) / float32(c.estimatedTCPcongestionWindow))
-	c.ackedBytesCount = 0
-
-	// We have a new cubic congestion window.
-	c.lastTargetCongestionWindow = targetCongestionWindow
-
-	// Compute target congestion_window based on cubic target and estimated TCP
-	// congestion_window, use highest (fastest).
-	if targetCongestionWindow < c.estimatedTCPcongestionWindow {
-		targetCongestionWindow = c.estimatedTCPcongestionWindow
-	}
-	return targetCongestionWindow
-}
-
-// SetNumConnections sets the number of emulated connections
-func (c *Cubic) SetNumConnections(n int) {
-	c.numConnections = n
-}

+ 0 - 318
transport/tuic/congestion/cubic_sender.go

@@ -1,318 +0,0 @@
-package congestion
-
-import (
-	"fmt"
-	"time"
-
-	"github.com/sagernet/quic-go/congestion"
-	"github.com/sagernet/quic-go/logging"
-)
-
-const (
-	maxBurstPackets            = 3
-	renoBeta                   = 0.7 // Reno backoff factor.
-	minCongestionWindowPackets = 2
-	initialCongestionWindow    = 32
-)
-
-const (
-	InvalidPacketNumber        congestion.PacketNumber = -1
-	MaxCongestionWindowPackets                         = 20000
-	MaxByteCount                                       = congestion.ByteCount(1<<62 - 1)
-)
-
-type cubicSender struct {
-	hybridSlowStart HybridSlowStart
-	rttStats        congestion.RTTStatsProvider
-	cubic           *Cubic
-	pacer           *pacer
-	clock           Clock
-
-	reno bool
-
-	// Track the largest packet that has been sent.
-	largestSentPacketNumber congestion.PacketNumber
-
-	// Track the largest packet that has been acked.
-	largestAckedPacketNumber congestion.PacketNumber
-
-	// Track the largest packet number outstanding when a CWND cutback occurs.
-	largestSentAtLastCutback congestion.PacketNumber
-
-	// Whether the last loss event caused us to exit slowstart.
-	// Used for stats collection of slowstartPacketsLost
-	lastCutbackExitedSlowstart bool
-
-	// Congestion window in bytes.
-	congestionWindow congestion.ByteCount
-
-	// Slow start congestion window in bytes, aka ssthresh.
-	slowStartThreshold congestion.ByteCount
-
-	// ACK counter for the Reno implementation.
-	numAckedPackets uint64
-
-	initialCongestionWindow    congestion.ByteCount
-	initialMaxCongestionWindow congestion.ByteCount
-
-	maxDatagramSize congestion.ByteCount
-
-	lastState logging.CongestionState
-	tracer    logging.ConnectionTracer
-}
-
-var _ congestion.CongestionControl = &cubicSender{}
-
-// NewCubicSender makes a new cubic sender
-func NewCubicSender(
-	clock Clock,
-	initialMaxDatagramSize congestion.ByteCount,
-	reno bool,
-	tracer logging.ConnectionTracer,
-) *cubicSender {
-	return newCubicSender(
-		clock,
-		reno,
-		initialMaxDatagramSize,
-		initialCongestionWindow*initialMaxDatagramSize,
-		MaxCongestionWindowPackets*initialMaxDatagramSize,
-		tracer,
-	)
-}
-
-func newCubicSender(
-	clock Clock,
-	reno bool,
-	initialMaxDatagramSize,
-	initialCongestionWindow,
-	initialMaxCongestionWindow congestion.ByteCount,
-	tracer logging.ConnectionTracer,
-) *cubicSender {
-	c := &cubicSender{
-		largestSentPacketNumber:    InvalidPacketNumber,
-		largestAckedPacketNumber:   InvalidPacketNumber,
-		largestSentAtLastCutback:   InvalidPacketNumber,
-		initialCongestionWindow:    initialCongestionWindow,
-		initialMaxCongestionWindow: initialMaxCongestionWindow,
-		congestionWindow:           initialCongestionWindow,
-		slowStartThreshold:         MaxByteCount,
-		cubic:                      NewCubic(clock),
-		clock:                      clock,
-		reno:                       reno,
-		tracer:                     tracer,
-		maxDatagramSize:            initialMaxDatagramSize,
-	}
-	c.pacer = newPacer(c.BandwidthEstimate)
-	if c.tracer != nil {
-		c.lastState = logging.CongestionStateSlowStart
-		c.tracer.UpdatedCongestionState(logging.CongestionStateSlowStart)
-	}
-	return c
-}
-
-func (c *cubicSender) SetRTTStatsProvider(provider congestion.RTTStatsProvider) {
-	c.rttStats = provider
-}
-
-// TimeUntilSend returns when the next packet should be sent.
-func (c *cubicSender) TimeUntilSend(_ congestion.ByteCount) time.Time {
-	return c.pacer.TimeUntilSend()
-}
-
-func (c *cubicSender) HasPacingBudget(now time.Time) bool {
-	return c.pacer.Budget(now) >= c.maxDatagramSize
-}
-
-func (c *cubicSender) maxCongestionWindow() congestion.ByteCount {
-	return c.maxDatagramSize * MaxCongestionWindowPackets
-}
-
-func (c *cubicSender) minCongestionWindow() congestion.ByteCount {
-	return c.maxDatagramSize * minCongestionWindowPackets
-}
-
-func (c *cubicSender) OnPacketSent(
-	sentTime time.Time,
-	_ congestion.ByteCount,
-	packetNumber congestion.PacketNumber,
-	bytes congestion.ByteCount,
-	isRetransmittable bool,
-) {
-	c.pacer.SentPacket(sentTime, bytes)
-	if !isRetransmittable {
-		return
-	}
-	c.largestSentPacketNumber = packetNumber
-	c.hybridSlowStart.OnPacketSent(packetNumber)
-}
-
-func (c *cubicSender) CanSend(bytesInFlight congestion.ByteCount) bool {
-	return bytesInFlight < c.GetCongestionWindow()
-}
-
-func (c *cubicSender) InRecovery() bool {
-	return c.largestAckedPacketNumber != InvalidPacketNumber && c.largestAckedPacketNumber <= c.largestSentAtLastCutback
-}
-
-func (c *cubicSender) InSlowStart() bool {
-	return c.GetCongestionWindow() < c.slowStartThreshold
-}
-
-func (c *cubicSender) GetCongestionWindow() congestion.ByteCount {
-	return c.congestionWindow
-}
-
-func (c *cubicSender) MaybeExitSlowStart() {
-	if c.InSlowStart() &&
-		c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/c.maxDatagramSize) {
-		// exit slow start
-		c.slowStartThreshold = c.congestionWindow
-		c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance)
-	}
-}
-
-func (c *cubicSender) OnPacketAcked(
-	ackedPacketNumber congestion.PacketNumber,
-	ackedBytes congestion.ByteCount,
-	priorInFlight congestion.ByteCount,
-	eventTime time.Time,
-) {
-	c.largestAckedPacketNumber = Max(ackedPacketNumber, c.largestAckedPacketNumber)
-	if c.InRecovery() {
-		return
-	}
-	c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, priorInFlight, eventTime)
-	if c.InSlowStart() {
-		c.hybridSlowStart.OnPacketAcked(ackedPacketNumber)
-	}
-}
-
-func (c *cubicSender) OnPacketLost(packetNumber congestion.PacketNumber, lostBytes, priorInFlight congestion.ByteCount) {
-	// TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets
-	// already sent should be treated as a single loss event, since it's expected.
-	if packetNumber <= c.largestSentAtLastCutback {
-		return
-	}
-	c.lastCutbackExitedSlowstart = c.InSlowStart()
-	c.maybeTraceStateChange(logging.CongestionStateRecovery)
-
-	if c.reno {
-		c.congestionWindow = congestion.ByteCount(float64(c.congestionWindow) * renoBeta)
-	} else {
-		c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow)
-	}
-	if minCwnd := c.minCongestionWindow(); c.congestionWindow < minCwnd {
-		c.congestionWindow = minCwnd
-	}
-	c.slowStartThreshold = c.congestionWindow
-	c.largestSentAtLastCutback = c.largestSentPacketNumber
-	// reset packet count from congestion avoidance mode. We start
-	// counting again when we're out of recovery.
-	c.numAckedPackets = 0
-}
-
-// Called when we receive an ack. Normal TCP tracks how many packets one ack
-// represents, but quic has a separate ack for each packet.
-func (c *cubicSender) maybeIncreaseCwnd(
-	_ congestion.PacketNumber,
-	ackedBytes congestion.ByteCount,
-	priorInFlight congestion.ByteCount,
-	eventTime time.Time,
-) {
-	// Do not increase the congestion window unless the sender is close to using
-	// the current window.
-	if !c.isCwndLimited(priorInFlight) {
-		c.cubic.OnApplicationLimited()
-		c.maybeTraceStateChange(logging.CongestionStateApplicationLimited)
-		return
-	}
-	if c.congestionWindow >= c.maxCongestionWindow() {
-		return
-	}
-	if c.InSlowStart() {
-		// TCP slow start, exponential growth, increase by one for each ACK.
-		c.congestionWindow += c.maxDatagramSize
-		c.maybeTraceStateChange(logging.CongestionStateSlowStart)
-		return
-	}
-	// Congestion avoidance
-	c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance)
-	if c.reno {
-		// Classic Reno congestion avoidance.
-		c.numAckedPackets++
-		if c.numAckedPackets >= uint64(c.congestionWindow/c.maxDatagramSize) {
-			c.congestionWindow += c.maxDatagramSize
-			c.numAckedPackets = 0
-		}
-	} else {
-		c.congestionWindow = Min(c.maxCongestionWindow(), c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime))
-	}
-}
-
-func (c *cubicSender) isCwndLimited(bytesInFlight congestion.ByteCount) bool {
-	congestionWindow := c.GetCongestionWindow()
-	if bytesInFlight >= congestionWindow {
-		return true
-	}
-	availableBytes := congestionWindow - bytesInFlight
-	slowStartLimited := c.InSlowStart() && bytesInFlight > congestionWindow/2
-	return slowStartLimited || availableBytes <= maxBurstPackets*c.maxDatagramSize
-}
-
-// BandwidthEstimate returns the current bandwidth estimate
-func (c *cubicSender) BandwidthEstimate() Bandwidth {
-	if c.rttStats == nil {
-		return infBandwidth
-	}
-	srtt := c.rttStats.SmoothedRTT()
-	if srtt == 0 {
-		// If we haven't measured an rtt, the bandwidth estimate is unknown.
-		return infBandwidth
-	}
-	return BandwidthFromDelta(c.GetCongestionWindow(), srtt)
-}
-
-// OnRetransmissionTimeout is called on an retransmission timeout
-func (c *cubicSender) OnRetransmissionTimeout(packetsRetransmitted bool) {
-	c.largestSentAtLastCutback = InvalidPacketNumber
-	if !packetsRetransmitted {
-		return
-	}
-	c.hybridSlowStart.Restart()
-	c.cubic.Reset()
-	c.slowStartThreshold = c.congestionWindow / 2
-	c.congestionWindow = c.minCongestionWindow()
-}
-
-// OnConnectionMigration is called when the connection is migrated (?)
-func (c *cubicSender) OnConnectionMigration() {
-	c.hybridSlowStart.Restart()
-	c.largestSentPacketNumber = InvalidPacketNumber
-	c.largestAckedPacketNumber = InvalidPacketNumber
-	c.largestSentAtLastCutback = InvalidPacketNumber
-	c.lastCutbackExitedSlowstart = false
-	c.cubic.Reset()
-	c.numAckedPackets = 0
-	c.congestionWindow = c.initialCongestionWindow
-	c.slowStartThreshold = c.initialMaxCongestionWindow
-}
-
-func (c *cubicSender) maybeTraceStateChange(new logging.CongestionState) {
-	if c.tracer == nil || new == c.lastState {
-		return
-	}
-	c.tracer.UpdatedCongestionState(new)
-	c.lastState = new
-}
-
-func (c *cubicSender) SetMaxDatagramSize(s congestion.ByteCount) {
-	if s < c.maxDatagramSize {
-		panic(fmt.Sprintf("congestion BUG: decreased max datagram size from %d to %d", c.maxDatagramSize, s))
-	}
-	cwndIsMinCwnd := c.congestionWindow == c.minCongestionWindow()
-	c.maxDatagramSize = s
-	if cwndIsMinCwnd {
-		c.congestionWindow = c.minCongestionWindow()
-	}
-	c.pacer.SetMaxDatagramSize(s)
-}

+ 0 - 112
transport/tuic/congestion/hybrid_slow_start.go

@@ -1,112 +0,0 @@
-package congestion
-
-import (
-	"time"
-
-	"github.com/sagernet/quic-go/congestion"
-)
-
-// Note(pwestin): the magic clamping numbers come from the original code in
-// tcp_cubic.c.
-const hybridStartLowWindow = congestion.ByteCount(16)
-
-// Number of delay samples for detecting the increase of delay.
-const hybridStartMinSamples = uint32(8)
-
-// Exit slow start if the min rtt has increased by more than 1/8th.
-const hybridStartDelayFactorExp = 3 // 2^3 = 8
-// The original paper specifies 2 and 8ms, but those have changed over time.
-const (
-	hybridStartDelayMinThresholdUs = int64(4000)
-	hybridStartDelayMaxThresholdUs = int64(16000)
-)
-
-// HybridSlowStart implements the TCP hybrid slow start algorithm
-type HybridSlowStart struct {
-	endPacketNumber      congestion.PacketNumber
-	lastSentPacketNumber congestion.PacketNumber
-	started              bool
-	currentMinRTT        time.Duration
-	rttSampleCount       uint32
-	hystartFound         bool
-}
-
-// StartReceiveRound is called for the start of each receive round (burst) in the slow start phase.
-func (s *HybridSlowStart) StartReceiveRound(lastSent congestion.PacketNumber) {
-	s.endPacketNumber = lastSent
-	s.currentMinRTT = 0
-	s.rttSampleCount = 0
-	s.started = true
-}
-
-// IsEndOfRound returns true if this ack is the last packet number of our current slow start round.
-func (s *HybridSlowStart) IsEndOfRound(ack congestion.PacketNumber) bool {
-	return s.endPacketNumber < ack
-}
-
-// ShouldExitSlowStart should be called on every new ack frame, since a new
-// RTT measurement can be made then.
-// rtt: the RTT for this ack packet.
-// minRTT: is the lowest delay (RTT) we have seen during the session.
-// congestionWindow: the congestion window in packets.
-func (s *HybridSlowStart) ShouldExitSlowStart(latestRTT time.Duration, minRTT time.Duration, congestionWindow congestion.ByteCount) bool {
-	if !s.started {
-		// Time to start the hybrid slow start.
-		s.StartReceiveRound(s.lastSentPacketNumber)
-	}
-	if s.hystartFound {
-		return true
-	}
-	// Second detection parameter - delay increase detection.
-	// Compare the minimum delay (s.currentMinRTT) of the current
-	// burst of packets relative to the minimum delay during the session.
-	// Note: we only look at the first few(8) packets in each burst, since we
-	// only want to compare the lowest RTT of the burst relative to previous
-	// bursts.
-	s.rttSampleCount++
-	if s.rttSampleCount <= hybridStartMinSamples {
-		if s.currentMinRTT == 0 || s.currentMinRTT > latestRTT {
-			s.currentMinRTT = latestRTT
-		}
-	}
-	// We only need to check this once per round.
-	if s.rttSampleCount == hybridStartMinSamples {
-		// Divide minRTT by 8 to get a rtt increase threshold for exiting.
-		minRTTincreaseThresholdUs := int64(minRTT / time.Microsecond >> hybridStartDelayFactorExp)
-		// Ensure the rtt threshold is never less than 2ms or more than 16ms.
-		minRTTincreaseThresholdUs = Min(minRTTincreaseThresholdUs, hybridStartDelayMaxThresholdUs)
-		minRTTincreaseThreshold := time.Duration(Max(minRTTincreaseThresholdUs, hybridStartDelayMinThresholdUs)) * time.Microsecond
-
-		if s.currentMinRTT > (minRTT + minRTTincreaseThreshold) {
-			s.hystartFound = true
-		}
-	}
-	// Exit from slow start if the cwnd is greater than 16 and
-	// increasing delay is found.
-	return congestionWindow >= hybridStartLowWindow && s.hystartFound
-}
-
-// OnPacketSent is called when a packet was sent
-func (s *HybridSlowStart) OnPacketSent(packetNumber congestion.PacketNumber) {
-	s.lastSentPacketNumber = packetNumber
-}
-
-// OnPacketAcked gets invoked after ShouldExitSlowStart, so it's best to end
-// the round when the final packet of the burst is received and start it on
-// the next incoming ack.
-func (s *HybridSlowStart) OnPacketAcked(ackedPacketNumber congestion.PacketNumber) {
-	if s.IsEndOfRound(ackedPacketNumber) {
-		s.started = false
-	}
-}
-
-// Started returns true if started
-func (s *HybridSlowStart) Started() bool {
-	return s.started
-}
-
-// Restart the slow start phase
-func (s *HybridSlowStart) Restart() {
-	s.started = false
-	s.hystartFound = false
-}

+ 0 - 72
transport/tuic/congestion/minmax.go

@@ -1,72 +0,0 @@
-package congestion
-
-import (
-	"math"
-	"time"
-
-	"golang.org/x/exp/constraints"
-)
-
-// InfDuration is a duration of infinite length
-const InfDuration = time.Duration(math.MaxInt64)
-
-func Max[T constraints.Ordered](a, b T) T {
-	if a < b {
-		return b
-	}
-	return a
-}
-
-func Min[T constraints.Ordered](a, b T) T {
-	if a < b {
-		return a
-	}
-	return b
-}
-
-// MinNonZeroDuration return the minimum duration that's not zero.
-func MinNonZeroDuration(a, b time.Duration) time.Duration {
-	if a == 0 {
-		return b
-	}
-	if b == 0 {
-		return a
-	}
-	return Min(a, b)
-}
-
-// AbsDuration returns the absolute value of a time duration
-func AbsDuration(d time.Duration) time.Duration {
-	if d >= 0 {
-		return d
-	}
-	return -d
-}
-
-// MinTime returns the earlier time
-func MinTime(a, b time.Time) time.Time {
-	if a.After(b) {
-		return b
-	}
-	return a
-}
-
-// MinNonZeroTime returns the earlist time that is not time.Time{}
-// If both a and b are time.Time{}, it returns time.Time{}
-func MinNonZeroTime(a, b time.Time) time.Time {
-	if a.IsZero() {
-		return b
-	}
-	if b.IsZero() {
-		return a
-	}
-	return MinTime(a, b)
-}
-
-// MaxTime returns the later time
-func MaxTime(a, b time.Time) time.Time {
-	if a.After(b) {
-		return a
-	}
-	return b
-}

+ 0 - 81
transport/tuic/congestion/pacer.go

@@ -1,81 +0,0 @@
-package congestion
-
-import (
-	"math"
-	"time"
-
-	"github.com/sagernet/quic-go/congestion"
-)
-
-const (
-	initialMaxDatagramSize = congestion.ByteCount(1252)
-	MinPacingDelay         = time.Millisecond
-	TimerGranularity       = time.Millisecond
-	maxBurstSizePackets    = 10
-)
-
-// The pacer implements a token bucket pacing algorithm.
-type pacer struct {
-	budgetAtLastSent     congestion.ByteCount
-	maxDatagramSize      congestion.ByteCount
-	lastSentTime         time.Time
-	getAdjustedBandwidth func() uint64 // in bytes/s
-}
-
-func newPacer(getBandwidth func() Bandwidth) *pacer {
-	p := &pacer{
-		maxDatagramSize: initialMaxDatagramSize,
-		getAdjustedBandwidth: func() uint64 {
-			// Bandwidth is in bits/s. We need the value in bytes/s.
-			bw := uint64(getBandwidth() / BytesPerSecond)
-			// Use a slightly higher value than the actual measured bandwidth.
-			// RTT variations then won't result in under-utilization of the congestion window.
-			// Ultimately, this will  result in sending packets as acknowledgments are received rather than when timers fire,
-			// provided the congestion window is fully utilized and acknowledgments arrive at regular intervals.
-			return bw * 5 / 4
-		},
-	}
-	p.budgetAtLastSent = p.maxBurstSize()
-	return p
-}
-
-func (p *pacer) SentPacket(sendTime time.Time, size congestion.ByteCount) {
-	budget := p.Budget(sendTime)
-	if size > budget {
-		p.budgetAtLastSent = 0
-	} else {
-		p.budgetAtLastSent = budget - size
-	}
-	p.lastSentTime = sendTime
-}
-
-func (p *pacer) Budget(now time.Time) congestion.ByteCount {
-	if p.lastSentTime.IsZero() {
-		return p.maxBurstSize()
-	}
-	budget := p.budgetAtLastSent + (congestion.ByteCount(p.getAdjustedBandwidth())*congestion.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9
-	return Min(p.maxBurstSize(), budget)
-}
-
-func (p *pacer) maxBurstSize() congestion.ByteCount {
-	return Max(
-		congestion.ByteCount(uint64((MinPacingDelay+TimerGranularity).Nanoseconds())*p.getAdjustedBandwidth())/1e9,
-		maxBurstSizePackets*p.maxDatagramSize,
-	)
-}
-
-// TimeUntilSend returns when the next packet should be sent.
-// It returns the zero value of time.Time if a packet can be sent immediately.
-func (p *pacer) TimeUntilSend() time.Time {
-	if p.budgetAtLastSent >= p.maxDatagramSize {
-		return time.Time{}
-	}
-	return p.lastSentTime.Add(Max(
-		MinPacingDelay,
-		time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/float64(p.getAdjustedBandwidth())))*time.Nanosecond,
-	))
-}
-
-func (p *pacer) SetMaxDatagramSize(s congestion.ByteCount) {
-	p.maxDatagramSize = s
-}

+ 0 - 132
transport/tuic/congestion/windowed_filter.go

@@ -1,132 +0,0 @@
-package congestion
-
-// WindowedFilter Use the following to construct a windowed filter object of type T.
-// For example, a min filter using QuicTime as the time type:
-//
-//	WindowedFilter<T, MinFilter<T>, QuicTime, QuicTime::Delta> ObjectName;
-//
-// A max filter using 64-bit integers as the time type:
-//
-//	WindowedFilter<T, MaxFilter<T>, uint64_t, int64_t> ObjectName;
-//
-// Specifically, this template takes four arguments:
-//  1. T -- type of the measurement that is being filtered.
-//  2. Compare -- MinFilter<T> or MaxFilter<T>, depending on the type of filter
-//     desired.
-//  3. TimeT -- the type used to represent timestamps.
-//  4. TimeDeltaT -- the type used to represent continuous time intervals between
-//     two timestamps.  Has to be the type of (a - b) if both |a| and |b| are
-//     of type TimeT.
-type WindowedFilter struct {
-	// Time length of window.
-	windowLength int64
-	estimates    []Sample
-	comparator   func(int64, int64) bool
-}
-
-type Sample struct {
-	sample int64
-	time   int64
-}
-
-// Compares two values and returns true if the first is greater than or equal
-// to the second.
-func MaxFilter(a, b int64) bool {
-	return a >= b
-}
-
-// Compares two values and returns true if the first is less than or equal
-// to the second.
-func MinFilter(a, b int64) bool {
-	return a <= b
-}
-
-func NewWindowedFilter(windowLength int64, comparator func(int64, int64) bool) *WindowedFilter {
-	return &WindowedFilter{
-		windowLength: windowLength,
-		estimates:    make([]Sample, 3),
-		comparator:   comparator,
-	}
-}
-
-// Changes the window length.  Does not update any current samples.
-func (f *WindowedFilter) SetWindowLength(windowLength int64) {
-	f.windowLength = windowLength
-}
-
-func (f *WindowedFilter) GetBest() int64 {
-	return f.estimates[0].sample
-}
-
-func (f *WindowedFilter) GetSecondBest() int64 {
-	return f.estimates[1].sample
-}
-
-func (f *WindowedFilter) GetThirdBest() int64 {
-	return f.estimates[2].sample
-}
-
-func (f *WindowedFilter) Update(sample int64, time int64) {
-	if f.estimates[0].time == 0 || f.comparator(sample, f.estimates[0].sample) || (time-f.estimates[2].time) > f.windowLength {
-		f.Reset(sample, time)
-		return
-	}
-
-	if f.comparator(sample, f.estimates[1].sample) {
-		f.estimates[1].sample = sample
-		f.estimates[1].time = time
-		f.estimates[2].sample = sample
-		f.estimates[2].time = time
-	} else if f.comparator(sample, f.estimates[2].sample) {
-		f.estimates[2].sample = sample
-		f.estimates[2].time = time
-	}
-
-	// Expire and update estimates as necessary.
-	if time-f.estimates[0].time > f.windowLength {
-		// The best estimate hasn't been updated for an entire window, so promote
-		// second and third best estimates.
-		f.estimates[0].sample = f.estimates[1].sample
-		f.estimates[0].time = f.estimates[1].time
-		f.estimates[1].sample = f.estimates[2].sample
-		f.estimates[1].time = f.estimates[2].time
-		f.estimates[2].sample = sample
-		f.estimates[2].time = time
-		// Need to iterate one more time. Check if the new best estimate is
-		// outside the window as well, since it may also have been recorded a
-		// long time ago. Don't need to iterate once more since we cover that
-		// case at the beginning of the method.
-		if time-f.estimates[0].time > f.windowLength {
-			f.estimates[0].sample = f.estimates[1].sample
-			f.estimates[0].time = f.estimates[1].time
-			f.estimates[1].sample = f.estimates[2].sample
-			f.estimates[1].time = f.estimates[2].time
-		}
-		return
-	}
-	if f.estimates[1].sample == f.estimates[0].sample && time-f.estimates[1].time > f.windowLength>>2 {
-		// A quarter of the window has passed without a better sample, so the
-		// second-best estimate is taken from the second quarter of the window.
-		f.estimates[1].sample = sample
-		f.estimates[1].time = time
-		f.estimates[2].sample = sample
-		f.estimates[2].time = time
-		return
-	}
-
-	if f.estimates[2].sample == f.estimates[1].sample && time-f.estimates[2].time > f.windowLength>>1 {
-		// We've passed a half of the window without a better estimate, so take
-		// a third-best estimate from the second half of the window.
-		f.estimates[2].sample = sample
-		f.estimates[2].time = time
-	}
-}
-
-func (f *WindowedFilter) Reset(newSample int64, newTime int64) {
-	f.estimates[0].sample = newSample
-	f.estimates[0].time = newTime
-	f.estimates[1].sample = newSample
-	f.estimates[1].time = newTime
-	f.estimates[2].sample = newSample
-	f.estimates[2].time = newTime
-}

+ 0 - 532
transport/tuic/packet.go

@@ -1,532 +0,0 @@
-package tuic
-
-import (
-	"bytes"
-	"context"
-	"encoding/binary"
-	"errors"
-	"io"
-	"math"
-	"net"
-	"os"
-	"sync"
-	"time"
-
-	"github.com/sagernet/quic-go"
-	"github.com/sagernet/sing/common"
-	"github.com/sagernet/sing/common/atomic"
-	"github.com/sagernet/sing/common/buf"
-	"github.com/sagernet/sing/common/cache"
-	E "github.com/sagernet/sing/common/exceptions"
-	M "github.com/sagernet/sing/common/metadata"
-)
-
-var udpMessagePool = sync.Pool{
-	New: func() interface{} {
-		return new(udpMessage)
-	},
-}
-
-func allocMessage() *udpMessage {
-	message := udpMessagePool.Get().(*udpMessage)
-	message.referenced = true
-	return message
-}
-
-func releaseMessages(messages []*udpMessage) {
-	for _, message := range messages {
-		if message != nil {
-			message.release()
-		}
-	}
-}
-
-type udpMessage struct {
-	sessionID     uint16
-	packetID      uint16
-	fragmentTotal uint8
-	fragmentID    uint8
-	destination   M.Socksaddr
-	data          *buf.Buffer
-	referenced    bool
-}
-
-func (m *udpMessage) release() {
-	if !m.referenced {
-		return
-	}
-	*m = udpMessage{}
-	udpMessagePool.Put(m)
-}
-
-func (m *udpMessage) releaseMessage() {
-	m.data.Release()
-	m.release()
-}
-
-func (m *udpMessage) pack() *buf.Buffer {
-	buffer := buf.NewSize(m.headerSize() + m.data.Len())
-	common.Must(
-		buffer.WriteByte(Version),
-		buffer.WriteByte(CommandPacket),
-		binary.Write(buffer, binary.BigEndian, m.sessionID),
-		binary.Write(buffer, binary.BigEndian, m.packetID),
-		binary.Write(buffer, binary.BigEndian, m.fragmentTotal),
-		binary.Write(buffer, binary.BigEndian, m.fragmentID),
-		binary.Write(buffer, binary.BigEndian, uint16(m.data.Len())),
-		addressSerializer.WriteAddrPort(buffer, m.destination),
-		common.Error(buffer.Write(m.data.Bytes())),
-	)
-	return buffer
-}
-
-func (m *udpMessage) headerSize() int {
-	return 10 + addressSerializer.AddrPortLen(m.destination)
-}
-
-func fragUDPMessage(message *udpMessage, maxPacketSize int) []*udpMessage {
-	if message.data.Len() <= maxPacketSize {
-		return []*udpMessage{message}
-	}
-	var fragments []*udpMessage
-	originPacket := message.data.Bytes()
-	udpMTU := maxPacketSize - message.headerSize()
-	for remaining := len(originPacket); remaining > 0; remaining -= udpMTU {
-		fragment := allocMessage()
-		*fragment = *message
-		if remaining > udpMTU {
-			fragment.data = buf.As(originPacket[:udpMTU])
-			originPacket = originPacket[udpMTU:]
-		} else {
-			fragment.data = buf.As(originPacket)
-			originPacket = nil
-		}
-		fragments = append(fragments, fragment)
-	}
-	fragmentTotal := uint16(len(fragments))
-	for index, fragment := range fragments {
-		fragment.fragmentID = uint8(index)
-		fragment.fragmentTotal = uint8(fragmentTotal)
-		if index > 0 {
-			fragment.destination = M.Socksaddr{}
-		}
-	}
-	return fragments
-}
-
-type udpPacketConn struct {
-	ctx        context.Context
-	cancel     common.ContextCancelCauseFunc
-	sessionID  uint16
-	quicConn   quic.Connection
-	data       chan *udpMessage
-	udpStream  bool
-	udpMTU     int
-	udpMTUTime time.Time
-	packetId   atomic.Uint32
-	closeOnce  sync.Once
-	isServer   bool
-	defragger  *udpDefragger
-	onDestroy  func()
-}
-
-func newUDPPacketConn(ctx context.Context, quicConn quic.Connection, udpStream bool, isServer bool, onDestroy func()) *udpPacketConn {
-	ctx, cancel := common.ContextWithCancelCause(ctx)
-	return &udpPacketConn{
-		ctx:       ctx,
-		cancel:    cancel,
-		quicConn:  quicConn,
-		data:      make(chan *udpMessage, 64),
-		udpStream: udpStream,
-		isServer:  isServer,
-		defragger: newUDPDefragger(),
-		onDestroy: onDestroy,
-	}
-}
-
-func (c *udpPacketConn) ReadPacketThreadSafe() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
-	select {
-	case p := <-c.data:
-		buffer = p.data
-		destination = p.destination
-		p.release()
-		return
-	case <-c.ctx.Done():
-		return nil, M.Socksaddr{}, io.ErrClosedPipe
-	}
-}
-
-func (c *udpPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
-	select {
-	case p := <-c.data:
-		_, err = buffer.ReadOnceFrom(p.data)
-		destination = p.destination
-		p.releaseMessage()
-		return
-	case <-c.ctx.Done():
-		return M.Socksaddr{}, io.ErrClosedPipe
-	}
-}
-
-func (c *udpPacketConn) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) {
-	select {
-	case p := <-c.data:
-		_, err = newBuffer().ReadOnceFrom(p.data)
-		destination = p.destination
-		p.releaseMessage()
-		return
-	case <-c.ctx.Done():
-		return M.Socksaddr{}, io.ErrClosedPipe
-	}
-}
-
-func (c *udpPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
-	select {
-	case pkt := <-c.data:
-		n = copy(p, pkt.data.Bytes())
-		if pkt.destination.IsFqdn() {
-			addr = pkt.destination
-		} else {
-			addr = pkt.destination.UDPAddr()
-		}
-		pkt.releaseMessage()
-		return n, addr, nil
-	case <-c.ctx.Done():
-		return 0, nil, io.ErrClosedPipe
-	}
-}
-
-func (c *udpPacketConn) needFragment() bool {
-	nowTime := time.Now()
-	if c.udpMTU > 0 && nowTime.Sub(c.udpMTUTime) < 5*time.Second {
-		c.udpMTUTime = nowTime
-		return true
-	}
-	return false
-}
-
-func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
-	defer buffer.Release()
-	select {
-	case <-c.ctx.Done():
-		return net.ErrClosed
-	default:
-	}
-	if buffer.Len() > 0xffff {
-		return quic.ErrMessageTooLarge(0xffff)
-	}
-	if !destination.IsValid() {
-		return E.New("invalid destination address")
-	}
-	packetId := c.packetId.Add(1)
-	if packetId > math.MaxUint16 {
-		c.packetId.Store(0)
-		packetId = 0
-	}
-	message := allocMessage()
-	*message = udpMessage{
-		sessionID:     c.sessionID,
-		packetID:      uint16(packetId),
-		fragmentTotal: 1,
-		destination:   destination,
-		data:          buffer,
-	}
-	defer message.releaseMessage()
-	var err error
-	if !c.udpStream && c.needFragment() && buffer.Len() > c.udpMTU {
-		err = c.writePackets(fragUDPMessage(message, c.udpMTU))
-	} else {
-		err = c.writePacket(message)
-	}
-	if err == nil {
-		return nil
-	}
-	var tooLargeErr quic.ErrMessageTooLarge
-	if !errors.As(err, &tooLargeErr) {
-		return err
-	}
-	c.udpMTU = int(tooLargeErr)
-	c.udpMTUTime = time.Now()
-	return c.writePackets(fragUDPMessage(message, c.udpMTU))
-}
-
-func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
-	select {
-	case <-c.ctx.Done():
-		return 0, net.ErrClosed
-	default:
-	}
-	if len(p) > 0xffff {
-		return 0, quic.ErrMessageTooLarge(0xffff)
-	}
-	destination := M.SocksaddrFromNet(addr)
-	if !destination.IsValid() {
-		return 0, E.New("invalid destination address")
-	}
-	packetId := c.packetId.Add(1)
-	if packetId > math.MaxUint16 {
-		c.packetId.Store(0)
-		packetId = 0
-	}
-	message := allocMessage()
-	*message = udpMessage{
-		sessionID:     c.sessionID,
-		packetID:      uint16(packetId),
-		fragmentTotal: 1,
-		destination:   destination,
-		data:          buf.As(p),
-	}
-	if !c.udpStream && c.needFragment() && len(p) > c.udpMTU {
-		err = c.writePackets(fragUDPMessage(message, c.udpMTU))
-		if err == nil {
-			return len(p), nil
-		}
-	} else {
-		err = c.writePacket(message)
-	}
-	if err == nil {
-		return len(p), nil
-	}
-	var tooLargeErr quic.ErrMessageTooLarge
-	if !errors.As(err, &tooLargeErr) {
-		return
-	}
-	c.udpMTU = int(tooLargeErr)
-	c.udpMTUTime = time.Now()
-	err = c.writePackets(fragUDPMessage(message, c.udpMTU))
-	if err == nil {
-		return len(p), nil
-	}
-	return
-}
-
-func (c *udpPacketConn) inputPacket(message *udpMessage) {
-	if message.fragmentTotal <= 1 {
-		select {
-		case c.data <- message:
-		default:
-		}
-	} else {
-		newMessage := c.defragger.feed(message)
-		if newMessage != nil {
-			select {
-			case c.data <- newMessage:
-			default:
-			}
-		}
-	}
-}
-
-func (c *udpPacketConn) writePackets(messages []*udpMessage) error {
-	defer releaseMessages(messages)
-	for _, message := range messages {
-		err := c.writePacket(message)
-		if err != nil {
-			return err
-		}
-	}
-	return nil
-}
-
-func (c *udpPacketConn) writePacket(message *udpMessage) error {
-	if !c.udpStream {
-		buffer := message.pack()
-		err := c.quicConn.SendMessage(buffer.Bytes())
-		buffer.Release()
-		if err != nil {
-			return err
-		}
-	} else {
-		stream, err := c.quicConn.OpenUniStream()
-		if err != nil {
-			return err
-		}
-		buffer := message.pack()
-		_, err = stream.Write(buffer.Bytes())
-		buffer.Release()
-		stream.Close()
-		if err != nil {
-			return err
-		}
-	}
-	return nil
-}
-
-func (c *udpPacketConn) Close() error {
-	c.closeOnce.Do(func() {
-		c.closeWithError(os.ErrClosed)
-		c.onDestroy()
-	})
-	return nil
-}
-
-func (c *udpPacketConn) closeWithError(err error) {
-	c.cancel(err)
-	if !c.isServer {
-		buffer := buf.NewSize(4)
-		defer buffer.Release()
-		buffer.WriteByte(Version)
-		buffer.WriteByte(CommandDissociate)
-		binary.Write(buffer, binary.BigEndian, c.sessionID)
-		sendStream, openErr := c.quicConn.OpenUniStream()
-		if openErr != nil {
-			return
-		}
-		defer sendStream.Close()
-		sendStream.Write(buffer.Bytes())
-	}
-}
-
-func (c *udpPacketConn) LocalAddr() net.Addr {
-	return c.quicConn.LocalAddr()
-}
-
-func (c *udpPacketConn) SetDeadline(t time.Time) error {
-	return os.ErrInvalid
-}
-
-func (c *udpPacketConn) SetReadDeadline(t time.Time) error {
-	return os.ErrInvalid
-}
-
-func (c *udpPacketConn) SetWriteDeadline(t time.Time) error {
-	return os.ErrInvalid
-}
-
-type udpDefragger struct {
-	packetMap *cache.LruCache[uint16, *packetItem]
-}
-
-func newUDPDefragger() *udpDefragger {
-	return &udpDefragger{
-		packetMap: cache.New(
-			cache.WithAge[uint16, *packetItem](10),
-			cache.WithUpdateAgeOnGet[uint16, *packetItem](),
-			cache.WithEvict[uint16, *packetItem](func(key uint16, value *packetItem) {
-				releaseMessages(value.messages)
-			}),
-		),
-	}
-}
-
-type packetItem struct {
-	access   sync.Mutex
-	messages []*udpMessage
-	count    uint8
-}
-
-func (d *udpDefragger) feed(m *udpMessage) *udpMessage {
-	if m.fragmentTotal <= 1 {
-		return m
-	}
-	if m.fragmentID >= m.fragmentTotal {
-		return nil
-	}
-	item, _ := d.packetMap.LoadOrStore(m.packetID, newPacketItem)
-	item.access.Lock()
-	defer item.access.Unlock()
-	if int(m.fragmentTotal) != len(item.messages) {
-		releaseMessages(item.messages)
-		item.messages = make([]*udpMessage, m.fragmentTotal)
-		item.count = 1
-		item.messages[m.fragmentID] = m
-		return nil
-	}
-	if item.messages[m.fragmentID] != nil {
-		return nil
-	}
-	item.messages[m.fragmentID] = m
-	item.count++
-	if int(item.count) != len(item.messages) {
-		return nil
-	}
-	newMessage := allocMessage()
-	*newMessage = *item.messages[0]
-	var dataLength uint16
-	for _, message := range item.messages {
-		dataLength += uint16(message.data.Len())
-	}
-	if dataLength > 0 {
-		newMessage.data = buf.NewSize(int(dataLength))
-		for _, message := range item.messages {
-			common.Must1(newMessage.data.Write(message.data.Bytes()))
-			message.releaseMessage()
-		}
-		item.messages = nil
-		return newMessage
-	}
-	item.messages = nil
-	return nil
-}
-
-func newPacketItem() *packetItem {
-	return new(packetItem)
-}
-
-func readUDPMessage(message *udpMessage, reader io.Reader) error {
-	err := binary.Read(reader, binary.BigEndian, &message.sessionID)
-	if err != nil {
-		return err
-	}
-	err = binary.Read(reader, binary.BigEndian, &message.packetID)
-	if err != nil {
-		return err
-	}
-	err = binary.Read(reader, binary.BigEndian, &message.fragmentTotal)
-	if err != nil {
-		return err
-	}
-	err = binary.Read(reader, binary.BigEndian, &message.fragmentID)
-	if err != nil {
-		return err
-	}
-	var dataLength uint16
-	err = binary.Read(reader, binary.BigEndian, &dataLength)
-	if err != nil {
-		return err
-	}
-	message.destination, err = addressSerializer.ReadAddrPort(reader)
-	if err != nil {
-		return err
-	}
-	message.data = buf.NewSize(int(dataLength))
-	_, err = message.data.ReadFullFrom(reader, message.data.FreeLen())
-	if err != nil {
-		return err
-	}
-	return nil
-}
-
-func decodeUDPMessage(message *udpMessage, data []byte) error {
-	reader := bytes.NewReader(data)
-	err := binary.Read(reader, binary.BigEndian, &message.sessionID)
-	if err != nil {
-		return err
-	}
-	err = binary.Read(reader, binary.BigEndian, &message.packetID)
-	if err != nil {
-		return err
-	}
-	err = binary.Read(reader, binary.BigEndian, &message.fragmentTotal)
-	if err != nil {
-		return err
-	}
-	err = binary.Read(reader, binary.BigEndian, &message.fragmentID)
-	if err != nil {
-		return err
-	}
-	var dataLength uint16
-	err = binary.Read(reader, binary.BigEndian, &dataLength)
-	if err != nil {
-		return err
-	}
-	message.destination, err = addressSerializer.ReadAddrPort(reader)
-	if err != nil {
-		return err
-	}
-	if reader.Len() != int(dataLength) {
-		return io.ErrUnexpectedEOF
-	}
-	message.data = buf.As(data[len(data)-reader.Len():])
-	return nil
-}

+ 0 - 15
transport/tuic/protocol.go

@@ -1,15 +0,0 @@
-package tuic
-
-const (
-	Version = 5
-)
-
-const (
-	CommandAuthenticate = iota
-	CommandConnect
-	CommandPacket
-	CommandDissociate
-	CommandHeartbeat
-)
-
-const AuthenticateLen = 2 + 16 + 32

+ 0 - 437
transport/tuic/server.go

@@ -1,437 +0,0 @@
-//go:build with_quic
-
-package tuic
-
-import (
-	"bytes"
-	"context"
-	"encoding/binary"
-	"io"
-	"net"
-	"runtime"
-	"strings"
-	"sync"
-	"time"
-
-	"github.com/sagernet/quic-go"
-	"github.com/sagernet/sing-box/common/qtls"
-	"github.com/sagernet/sing-box/common/tls"
-	"github.com/sagernet/sing/common"
-	"github.com/sagernet/sing/common/auth"
-	"github.com/sagernet/sing/common/baderror"
-	"github.com/sagernet/sing/common/buf"
-	"github.com/sagernet/sing/common/bufio"
-	E "github.com/sagernet/sing/common/exceptions"
-	"github.com/sagernet/sing/common/logger"
-	M "github.com/sagernet/sing/common/metadata"
-	N "github.com/sagernet/sing/common/network"
-
-	"github.com/gofrs/uuid/v5"
-)
-
-type ServerOptions struct {
-	Context           context.Context
-	Logger            logger.Logger
-	TLSConfig         tls.ServerConfig
-	Users             []User
-	CongestionControl string
-	AuthTimeout       time.Duration
-	ZeroRTTHandshake  bool
-	Heartbeat         time.Duration
-	Handler           ServerHandler
-}
-
-type User struct {
-	Name     string
-	UUID     uuid.UUID
-	Password string
-}
-
-type ServerHandler interface {
-	N.TCPConnectionHandler
-	N.UDPConnectionHandler
-}
-
-type Server struct {
-	ctx               context.Context
-	logger            logger.Logger
-	tlsConfig         tls.ServerConfig
-	heartbeat         time.Duration
-	quicConfig        *quic.Config
-	userMap           map[uuid.UUID]User
-	congestionControl string
-	authTimeout       time.Duration
-	handler           ServerHandler
-
-	quicListener io.Closer
-}
-
-func NewServer(options ServerOptions) (*Server, error) {
-	if options.AuthTimeout == 0 {
-		options.AuthTimeout = 3 * time.Second
-	}
-	if options.Heartbeat == 0 {
-		options.Heartbeat = 10 * time.Second
-	}
-	quicConfig := &quic.Config{
-		DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"),
-		MaxDatagramFrameSize:    1400,
-		EnableDatagrams:         true,
-		Allow0RTT:               options.ZeroRTTHandshake,
-		MaxIncomingStreams:      1 << 60,
-		MaxIncomingUniStreams:   1 << 60,
-	}
-	switch options.CongestionControl {
-	case "":
-		options.CongestionControl = "cubic"
-	case "cubic", "new_reno", "bbr":
-	default:
-		return nil, E.New("unknown congestion control algorithm: ", options.CongestionControl)
-	}
-	if len(options.Users) == 0 {
-		return nil, E.New("missing users")
-	}
-	userMap := make(map[uuid.UUID]User)
-	for _, user := range options.Users {
-		userMap[user.UUID] = user
-	}
-	return &Server{
-		ctx:               options.Context,
-		logger:            options.Logger,
-		tlsConfig:         options.TLSConfig,
-		heartbeat:         options.Heartbeat,
-		quicConfig:        quicConfig,
-		userMap:           userMap,
-		congestionControl: options.CongestionControl,
-		authTimeout:       options.AuthTimeout,
-		handler:           options.Handler,
-	}, nil
-}
-
-func (s *Server) Start(conn net.PacketConn) error {
-	if !s.quicConfig.Allow0RTT {
-		listener, err := qtls.Listen(conn, s.tlsConfig, s.quicConfig)
-		if err != nil {
-			return err
-		}
-		s.quicListener = listener
-		go func() {
-			for {
-				connection, hErr := listener.Accept(s.ctx)
-				if hErr != nil {
-					if strings.Contains(hErr.Error(), "server closed") {
-						s.logger.Debug(E.Cause(hErr, "listener closed"))
-					} else {
-						s.logger.Error(E.Cause(hErr, "listener closed"))
-					}
-					return
-				}
-				go s.handleConnection(connection)
-			}
-		}()
-	} else {
-		listener, err := qtls.ListenEarly(conn, s.tlsConfig, s.quicConfig)
-		if err != nil {
-			return err
-		}
-		s.quicListener = listener
-		go func() {
-			for {
-				connection, hErr := listener.Accept(s.ctx)
-				if hErr != nil {
-					if strings.Contains(hErr.Error(), "server closed") {
-						s.logger.Debug(E.Cause(hErr, "listener closed"))
-					} else {
-						s.logger.Error(E.Cause(hErr, "listener closed"))
-					}
-					return
-				}
-				go s.handleConnection(connection)
-			}
-		}()
-	}
-	return nil
-}
-
-func (s *Server) Close() error {
-	return common.Close(
-		s.quicListener,
-	)
-}
-
-func (s *Server) handleConnection(connection quic.Connection) {
-	setCongestion(s.ctx, connection, s.congestionControl)
-	session := &serverSession{
-		Server:     s,
-		ctx:        s.ctx,
-		quicConn:   connection,
-		source:     M.SocksaddrFromNet(connection.RemoteAddr()),
-		connDone:   make(chan struct{}),
-		authDone:   make(chan struct{}),
-		udpConnMap: make(map[uint16]*udpPacketConn),
-	}
-	session.handle()
-}
-
-type serverSession struct {
-	*Server
-	ctx        context.Context
-	quicConn   quic.Connection
-	source     M.Socksaddr
-	connAccess sync.Mutex
-	connDone   chan struct{}
-	connErr    error
-	authDone   chan struct{}
-	authUser   *User
-	udpAccess  sync.RWMutex
-	udpConnMap map[uint16]*udpPacketConn
-}
-
-func (s *serverSession) handle() {
-	if s.ctx.Done() != nil {
-		go func() {
-			select {
-			case <-s.ctx.Done():
-				s.closeWithError(s.ctx.Err())
-			case <-s.connDone:
-			}
-		}()
-	}
-	go s.loopUniStreams()
-	go s.loopStreams()
-	go s.loopMessages()
-	go s.handleAuthTimeout()
-	go s.loopHeartbeats()
-}
-
-func (s *serverSession) loopUniStreams() {
-	for {
-		uniStream, err := s.quicConn.AcceptUniStream(s.ctx)
-		if err != nil {
-			return
-		}
-		go func() {
-			err = s.handleUniStream(uniStream)
-			if err != nil {
-				s.closeWithError(E.Cause(err, "handle uni stream"))
-			}
-		}()
-	}
-}
-
-func (s *serverSession) handleUniStream(stream quic.ReceiveStream) error {
-	defer stream.CancelRead(0)
-	buffer := buf.New()
-	defer buffer.Release()
-	_, err := buffer.ReadAtLeastFrom(stream, 2)
-	if err != nil {
-		return E.Cause(err, "read request")
-	}
-	version := buffer.Byte(0)
-	if version != Version {
-		return E.New("unknown version ", buffer.Byte(0))
-	}
-	command := buffer.Byte(1)
-	switch command {
-	case CommandAuthenticate:
-		select {
-		case <-s.authDone:
-			return E.New("authentication: multiple authentication requests")
-		default:
-		}
-		if buffer.Len() < AuthenticateLen {
-			_, err = buffer.ReadFullFrom(stream, AuthenticateLen-buffer.Len())
-			if err != nil {
-				return E.Cause(err, "authentication: read request")
-			}
-		}
-		userUUID := uuid.FromBytesOrNil(buffer.Range(2, 2+16))
-		user, loaded := s.userMap[userUUID]
-		if !loaded {
-			return E.New("authentication: unknown user ", userUUID)
-		}
-		handshakeState := s.quicConn.ConnectionState()
-		tuicToken, err := handshakeState.ExportKeyingMaterial(string(user.UUID[:]), []byte(user.Password), 32)
-		if err != nil {
-			return E.Cause(err, "authentication: export keying material")
-		}
-		if !bytes.Equal(tuicToken, buffer.Range(2+16, 2+16+32)) {
-			return E.New("authentication: token mismatch")
-		}
-		s.authUser = &user
-		close(s.authDone)
-		return nil
-	case CommandPacket:
-		select {
-		case <-s.connDone:
-			return s.connErr
-		case <-s.authDone:
-		}
-		message := allocMessage()
-		err = readUDPMessage(message, io.MultiReader(bytes.NewReader(buffer.From(2)), stream))
-		if err != nil {
-			message.release()
-			return err
-		}
-		s.handleUDPMessage(message, true)
-		return nil
-	case CommandDissociate:
-		select {
-		case <-s.connDone:
-			return s.connErr
-		case <-s.authDone:
-		}
-		if buffer.Len() > 4 {
-			return E.New("invalid dissociate message")
-		}
-		var sessionID uint16
-		err = binary.Read(io.MultiReader(bytes.NewReader(buffer.From(2)), stream), binary.BigEndian, &sessionID)
-		if err != nil {
-			return err
-		}
-		s.udpAccess.RLock()
-		udpConn, loaded := s.udpConnMap[sessionID]
-		s.udpAccess.RUnlock()
-		if loaded {
-			udpConn.closeWithError(E.New("remote closed"))
-			s.udpAccess.Lock()
-			delete(s.udpConnMap, sessionID)
-			s.udpAccess.Unlock()
-		}
-		return nil
-	default:
-		return E.New("unknown command ", command)
-	}
-}
-
-func (s *serverSession) handleAuthTimeout() {
-	select {
-	case <-s.connDone:
-	case <-s.authDone:
-	case <-time.After(s.authTimeout):
-		s.closeWithError(E.New("authentication timeout"))
-	}
-}
-
-func (s *serverSession) loopStreams() {
-	for {
-		stream, err := s.quicConn.AcceptStream(s.ctx)
-		if err != nil {
-			return
-		}
-		go func() {
-			err = s.handleStream(stream)
-			if err != nil {
-				stream.CancelRead(0)
-				stream.Close()
-				s.logger.Error(E.Cause(err, "handle stream request"))
-			}
-		}()
-	}
-}
-
-func (s *serverSession) handleStream(stream quic.Stream) error {
-	buffer := buf.NewSize(2 + M.MaxSocksaddrLength)
-	defer buffer.Release()
-	_, err := buffer.ReadAtLeastFrom(stream, 2)
-	if err != nil {
-		return E.Cause(err, "read request")
-	}
-	version, _ := buffer.ReadByte()
-	if version != Version {
-		return E.New("unknown version ", buffer.Byte(0))
-	}
-	command, _ := buffer.ReadByte()
-	if command != CommandConnect {
-		return E.New("unsupported stream command ", command)
-	}
-	destination, err := addressSerializer.ReadAddrPort(io.MultiReader(buffer, stream))
-	if err != nil {
-		return E.Cause(err, "read request destination")
-	}
-	select {
-	case <-s.connDone:
-		return s.connErr
-	case <-s.authDone:
-	}
-	var conn net.Conn = &serverConn{
-		Stream:      stream,
-		destination: destination,
-	}
-	if buffer.IsEmpty() {
-		buffer.Release()
-	} else {
-		conn = bufio.NewCachedConn(conn, buffer)
-	}
-	ctx := s.ctx
-	if s.authUser.Name != "" {
-		ctx = auth.ContextWithUser(s.ctx, s.authUser.Name)
-	}
-	_ = s.handler.NewConnection(ctx, conn, M.Metadata{
-		Source:      s.source,
-		Destination: destination,
-	})
-	return nil
-}
-
-func (s *serverSession) loopHeartbeats() {
-	ticker := time.NewTicker(s.heartbeat)
-	defer ticker.Stop()
-	for {
-		select {
-		case <-s.connDone:
-			return
-		case <-ticker.C:
-			err := s.quicConn.SendMessage([]byte{Version, CommandHeartbeat})
-			if err != nil {
-				s.closeWithError(E.Cause(err, "send heartbeat"))
-			}
-		}
-	}
-}
-
-func (s *serverSession) closeWithError(err error) {
-	s.connAccess.Lock()
-	defer s.connAccess.Unlock()
-	select {
-	case <-s.connDone:
-		return
-	default:
-		s.connErr = err
-		close(s.connDone)
-	}
-	if E.IsClosedOrCanceled(err) {
-		s.logger.Debug(E.Cause(err, "connection failed"))
-	} else {
-		s.logger.Error(E.Cause(err, "connection failed"))
-	}
-	_ = s.quicConn.CloseWithError(0, "")
-}
-
-type serverConn struct {
-	quic.Stream
-	destination M.Socksaddr
-}
-
-func (c *serverConn) Read(p []byte) (n int, err error) {
-	n, err = c.Stream.Read(p)
-	return n, baderror.WrapQUIC(err)
-}
-
-func (c *serverConn) Write(p []byte) (n int, err error) {
-	n, err = c.Stream.Write(p)
-	return n, baderror.WrapQUIC(err)
-}
-
-func (c *serverConn) LocalAddr() net.Addr {
-	return c.destination
-}
-
-func (c *serverConn) RemoteAddr() net.Addr {
-	return M.Socksaddr{}
-}
-
-func (c *serverConn) Close() error {
-	c.Stream.CancelRead(0)
-	return c.Stream.Close()
-}

+ 0 - 75
transport/tuic/server_packet.go

@@ -1,75 +0,0 @@
-//go:build with_quic
-
-package tuic
-
-import (
-	"github.com/sagernet/sing/common"
-	E "github.com/sagernet/sing/common/exceptions"
-	M "github.com/sagernet/sing/common/metadata"
-)
-
-func (s *serverSession) loopMessages() {
-	select {
-	case <-s.connDone:
-		return
-	case <-s.authDone:
-	}
-	for {
-		message, err := s.quicConn.ReceiveMessage(s.ctx)
-		if err != nil {
-			s.closeWithError(E.Cause(err, "receive message"))
-			return
-		}
-		hErr := s.handleMessage(message)
-		if hErr != nil {
-			s.closeWithError(E.Cause(hErr, "handle message"))
-			return
-		}
-	}
-}
-
-func (s *serverSession) handleMessage(data []byte) error {
-	if len(data) < 2 {
-		return E.New("invalid message")
-	}
-	if data[0] != Version {
-		return E.New("unknown version ", data[0])
-	}
-	switch data[1] {
-	case CommandPacket:
-		message := allocMessage()
-		err := decodeUDPMessage(message, data[2:])
-		if err != nil {
-			message.release()
-			return E.Cause(err, "decode UDP message")
-		}
-		s.handleUDPMessage(message, false)
-		return nil
-	case CommandHeartbeat:
-		return nil
-	default:
-		return E.New("unknown command ", data[0])
-	}
-}
-
-func (s *serverSession) handleUDPMessage(message *udpMessage, udpStream bool) {
-	s.udpAccess.RLock()
-	udpConn, loaded := s.udpConnMap[message.sessionID]
-	s.udpAccess.RUnlock()
-	if !loaded || common.Done(udpConn.ctx) {
-		udpConn = newUDPPacketConn(s.ctx, s.quicConn, udpStream, true, func() {
-			s.udpAccess.Lock()
-			delete(s.udpConnMap, message.sessionID)
-			s.udpAccess.Unlock()
-		})
-		udpConn.sessionID = message.sessionID
-		s.udpAccess.Lock()
-		s.udpConnMap[message.sessionID] = udpConn
-		s.udpAccess.Unlock()
-		go s.handler.NewPacketConnection(udpConn.ctx, udpConn, M.Metadata{
-			Source:      s.source,
-			Destination: message.destination,
-		})
-	}
-	udpConn.inputPacket(message)
-}

+ 1 - 1
transport/v2rayquic/client.go

@@ -9,11 +9,11 @@ import (
 
 	"github.com/sagernet/quic-go"
 	"github.com/sagernet/sing-box/adapter"
-	"github.com/sagernet/sing-box/common/qtls"
 	"github.com/sagernet/sing-box/common/tls"
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/option"
 	"github.com/sagernet/sing-box/transport/hysteria"
+	"github.com/sagernet/sing-quic"
 	"github.com/sagernet/sing/common"
 	"github.com/sagernet/sing/common/bufio"
 	M "github.com/sagernet/sing/common/metadata"

+ 2 - 2
transport/v2rayquic/server.go

@@ -9,11 +9,11 @@ import (
 
 	"github.com/sagernet/quic-go"
 	"github.com/sagernet/sing-box/adapter"
-	"github.com/sagernet/sing-box/common/qtls"
 	"github.com/sagernet/sing-box/common/tls"
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/option"
 	"github.com/sagernet/sing-box/transport/hysteria"
+	"github.com/sagernet/sing-quic"
 	"github.com/sagernet/sing/common"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
@@ -27,7 +27,7 @@ type Server struct {
 	quicConfig   *quic.Config
 	handler      adapter.V2RayServerTransportHandler
 	udpListener  net.PacketConn
-	quicListener qtls.QUICListener
+	quicListener qtls.Listener
 }
 
 func NewServer(ctx context.Context, options option.V2RayQUICOptions, tlsConfig tls.ServerConfig, handler adapter.V2RayServerTransportHandler) (adapter.V2RayServerTransport, error) {