1
0
世界 2 жил өмнө
parent
commit
917420e79a

+ 3 - 0
constant/proxy.go

@@ -21,6 +21,7 @@ const (
 	TypeShadowTLS    = "shadowtls"
 	TypeShadowTLS    = "shadowtls"
 	TypeShadowsocksR = "shadowsocksr"
 	TypeShadowsocksR = "shadowsocksr"
 	TypeVLESS        = "vless"
 	TypeVLESS        = "vless"
+	TypeTUIC         = "tuic"
 )
 )
 
 
 const (
 const (
@@ -62,6 +63,8 @@ func ProxyDisplayName(proxyType string) string {
 		return "ShadowsocksR"
 		return "ShadowsocksR"
 	case TypeVLESS:
 	case TypeVLESS:
 		return "VLESS"
 		return "VLESS"
+	case TypeTUIC:
+		return "TUIC"
 	case TypeSelector:
 	case TypeSelector:
 		return "Selector"
 		return "Selector"
 	case TypeURLTest:
 	case TypeURLTest:

+ 2 - 0
inbound/builder.go

@@ -44,6 +44,8 @@ func New(ctx context.Context, router adapter.Router, logger log.ContextLogger, o
 		return NewShadowTLS(ctx, router, logger, options.Tag, options.ShadowTLSOptions)
 		return NewShadowTLS(ctx, router, logger, options.Tag, options.ShadowTLSOptions)
 	case C.TypeVLESS:
 	case C.TypeVLESS:
 		return NewVLESS(ctx, router, logger, options.Tag, options.VLESSOptions)
 		return NewVLESS(ctx, router, logger, options.Tag, options.VLESSOptions)
+	case C.TypeTUIC:
+		return NewTUIC(ctx, router, logger, options.Tag, options.TUICOptions)
 	default:
 	default:
 		return nil, E.New("unknown inbound type: ", options.Type)
 		return nil, E.New("unknown inbound type: ", options.Type)
 	}
 	}

+ 11 - 0
inbound/default.go

@@ -153,6 +153,17 @@ func (a *myInboundAdapter) createMetadata(conn net.Conn, metadata adapter.Inboun
 	return metadata
 	return metadata
 }
 }
 
 
+func (a *myInboundAdapter) createPacketMetadata(conn N.PacketConn, metadata adapter.InboundContext) adapter.InboundContext {
+	metadata.Inbound = a.tag
+	metadata.InboundType = a.protocol
+	metadata.InboundDetour = a.listenOptions.Detour
+	metadata.InboundOptions = a.listenOptions.InboundOptions
+	if !metadata.Destination.IsValid() {
+		metadata.Destination = M.SocksaddrFromNet(conn.LocalAddr()).Unwrap()
+	}
+	return metadata
+}
+
 func (a *myInboundAdapter) newError(err error) {
 func (a *myInboundAdapter) newError(err error) {
 	a.logger.Error(err)
 	a.logger.Error(err)
 }
 }

+ 114 - 0
inbound/tuic.go

@@ -0,0 +1,114 @@
+//go:build with_quic
+
+package inbound
+
+import (
+	"context"
+	"net"
+	"time"
+
+	"github.com/sagernet/sing-box/adapter"
+	"github.com/sagernet/sing-box/common/tls"
+	C "github.com/sagernet/sing-box/constant"
+	"github.com/sagernet/sing-box/log"
+	"github.com/sagernet/sing-box/option"
+	"github.com/sagernet/sing-box/transport/tuic"
+	"github.com/sagernet/sing/common"
+	"github.com/sagernet/sing/common/auth"
+	E "github.com/sagernet/sing/common/exceptions"
+	N "github.com/sagernet/sing/common/network"
+
+	"github.com/gofrs/uuid/v5"
+)
+
+var _ adapter.Inbound = (*TUIC)(nil)
+
+type TUIC struct {
+	myInboundAdapter
+	server    *tuic.Server
+	tlsConfig tls.ServerConfig
+}
+
+func NewTUIC(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TUICInboundOptions) (*TUIC, error) {
+	options.UDPFragmentDefault = true
+	if options.TLS == nil || !options.TLS.Enabled {
+		return nil, C.ErrTLSRequired
+	}
+	tlsConfig, err := tls.NewServer(ctx, router, logger, common.PtrValueOrDefault(options.TLS))
+	if err != nil {
+		return nil, err
+	}
+	rawConfig, err := tlsConfig.Config()
+	if err != nil {
+		return nil, err
+	}
+	var users []tuic.User
+	for index, user := range options.Users {
+		if user.UUID == "" {
+			return nil, E.New("missing uuid for user ", index)
+		}
+		userUUID, err := uuid.FromString(user.UUID)
+		if err != nil {
+			return nil, E.Cause(err, "invalid uuid for user ", index)
+		}
+		users = append(users, tuic.User{Name: user.Name, UUID: userUUID, Password: user.Password})
+	}
+	inbound := &TUIC{
+		myInboundAdapter: myInboundAdapter{
+			protocol:      C.TypeTUIC,
+			network:       []string{N.NetworkUDP},
+			ctx:           ctx,
+			router:        router,
+			logger:        logger,
+			tag:           tag,
+			listenOptions: options.ListenOptions,
+		},
+	}
+	server, err := tuic.NewServer(tuic.ServerOptions{
+		Context:           ctx,
+		Logger:            logger,
+		TLSConfig:         rawConfig,
+		Users:             users,
+		CongestionControl: options.CongestionControl,
+		AuthTimeout:       time.Duration(options.AuthTimeout),
+		ZeroRTTHandshake:  options.ZeroRTTHandshake,
+		Heartbeat:         time.Duration(options.Heartbeat),
+		Handler:           adapter.NewUpstreamHandler(adapter.InboundContext{}, inbound.newConnection, inbound.newPacketConnection, nil),
+	})
+	if err != nil {
+		return nil, err
+	}
+	inbound.server = server
+	return inbound, nil
+}
+
+func (h *TUIC) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
+	ctx = log.ContextWithNewID(ctx)
+	h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
+	metadata = h.createMetadata(conn, metadata)
+	metadata.User, _ = auth.UserFromContext[string](ctx)
+	return h.router.RouteConnection(ctx, conn, metadata)
+}
+
+func (h *TUIC) newPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
+	ctx = log.ContextWithNewID(ctx)
+	metadata = h.createPacketMetadata(conn, metadata)
+	metadata.User, _ = auth.UserFromContext[string](ctx)
+	h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination)
+	return h.router.RoutePacketConnection(ctx, conn, metadata)
+}
+
+func (h *TUIC) Start() error {
+	packetConn, err := h.myInboundAdapter.ListenUDP()
+	if err != nil {
+		return err
+	}
+	return h.server.Start(packetConn)
+}
+
+func (h *TUIC) Close() error {
+	return common.Close(
+		&h.myInboundAdapter,
+		common.PtrOrNil(h.server),
+	)
+}

+ 16 - 0
inbound/tuic_stub.go

@@ -0,0 +1,16 @@
+//go:build !with_quic
+
+package inbound
+
+import (
+	"context"
+
+	"github.com/sagernet/sing-box/adapter"
+	C "github.com/sagernet/sing-box/constant"
+	"github.com/sagernet/sing-box/log"
+	"github.com/sagernet/sing-box/option"
+)
+
+func NewTUIC(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TUICInboundOptions) (adapter.Inbound, error) {
+	return nil, C.ErrQUICNotIncluded
+}

+ 5 - 0
option/inbound.go

@@ -23,6 +23,7 @@ type _Inbound struct {
 	HysteriaOptions    HysteriaInboundOptions    `json:"-"`
 	HysteriaOptions    HysteriaInboundOptions    `json:"-"`
 	ShadowTLSOptions   ShadowTLSInboundOptions   `json:"-"`
 	ShadowTLSOptions   ShadowTLSInboundOptions   `json:"-"`
 	VLESSOptions       VLESSInboundOptions       `json:"-"`
 	VLESSOptions       VLESSInboundOptions       `json:"-"`
+	TUICOptions        TUICInboundOptions        `json:"-"`
 }
 }
 
 
 type Inbound _Inbound
 type Inbound _Inbound
@@ -58,6 +59,8 @@ func (h Inbound) MarshalJSON() ([]byte, error) {
 		v = h.ShadowTLSOptions
 		v = h.ShadowTLSOptions
 	case C.TypeVLESS:
 	case C.TypeVLESS:
 		v = h.VLESSOptions
 		v = h.VLESSOptions
+	case C.TypeTUIC:
+		v = h.TUICOptions
 	default:
 	default:
 		return nil, E.New("unknown inbound type: ", h.Type)
 		return nil, E.New("unknown inbound type: ", h.Type)
 	}
 	}
@@ -99,6 +102,8 @@ func (h *Inbound) UnmarshalJSON(bytes []byte) error {
 		v = &h.ShadowTLSOptions
 		v = &h.ShadowTLSOptions
 	case C.TypeVLESS:
 	case C.TypeVLESS:
 		v = &h.VLESSOptions
 		v = &h.VLESSOptions
+	case C.TypeTUIC:
+		v = &h.TUICOptions
 	default:
 	default:
 		return E.New("unknown inbound type: ", h.Type)
 		return E.New("unknown inbound type: ", h.Type)
 	}
 	}

+ 5 - 0
option/outbound.go

@@ -23,6 +23,7 @@ type _Outbound struct {
 	ShadowTLSOptions    ShadowTLSOutboundOptions    `json:"-"`
 	ShadowTLSOptions    ShadowTLSOutboundOptions    `json:"-"`
 	ShadowsocksROptions ShadowsocksROutboundOptions `json:"-"`
 	ShadowsocksROptions ShadowsocksROutboundOptions `json:"-"`
 	VLESSOptions        VLESSOutboundOptions        `json:"-"`
 	VLESSOptions        VLESSOutboundOptions        `json:"-"`
+	TUICOptions         TUICOutboundOptions         `json:"-"`
 	SelectorOptions     SelectorOutboundOptions     `json:"-"`
 	SelectorOptions     SelectorOutboundOptions     `json:"-"`
 	URLTestOptions      URLTestOutboundOptions      `json:"-"`
 	URLTestOptions      URLTestOutboundOptions      `json:"-"`
 }
 }
@@ -60,6 +61,8 @@ func (h Outbound) MarshalJSON() ([]byte, error) {
 		v = h.ShadowsocksROptions
 		v = h.ShadowsocksROptions
 	case C.TypeVLESS:
 	case C.TypeVLESS:
 		v = h.VLESSOptions
 		v = h.VLESSOptions
+	case C.TypeTUIC:
+		v = h.TUICOptions
 	case C.TypeSelector:
 	case C.TypeSelector:
 		v = h.SelectorOptions
 		v = h.SelectorOptions
 	case C.TypeURLTest:
 	case C.TypeURLTest:
@@ -105,6 +108,8 @@ func (h *Outbound) UnmarshalJSON(bytes []byte) error {
 		v = &h.ShadowsocksROptions
 		v = &h.ShadowsocksROptions
 	case C.TypeVLESS:
 	case C.TypeVLESS:
 		v = &h.VLESSOptions
 		v = &h.VLESSOptions
+	case C.TypeTUIC:
+		v = &h.TUICOptions
 	case C.TypeSelector:
 	case C.TypeSelector:
 		v = &h.SelectorOptions
 		v = &h.SelectorOptions
 	case C.TypeURLTest:
 	case C.TypeURLTest:

+ 30 - 0
option/tuic.go

@@ -0,0 +1,30 @@
+package option
+
+type TUICInboundOptions struct {
+	ListenOptions
+	Users             []TUICUser         `json:"users,omitempty"`
+	CongestionControl string             `json:"congestion_control,omitempty"`
+	AuthTimeout       Duration           `json:"auth_timeout,omitempty"`
+	ZeroRTTHandshake  bool               `json:"zero_rtt_handshake,omitempty"`
+	Heartbeat         Duration           `json:"heartbeat,omitempty"`
+	TLS               *InboundTLSOptions `json:"tls,omitempty"`
+}
+
+type TUICUser struct {
+	Name     string `json:"name,omitempty"`
+	UUID     string `json:"uuid,omitempty"`
+	Password string `json:"password,omitempty"`
+}
+
+type TUICOutboundOptions struct {
+	DialerOptions
+	ServerOptions
+	UUID              string              `json:"uuid,omitempty"`
+	Password          string              `json:"password,omitempty"`
+	CongestionControl string              `json:"congestion_control,omitempty"`
+	UDPRelayMode      string              `json:"udp_relay_mode,omitempty"`
+	ZeroRTTHandshake  bool                `json:"zero_rtt_handshake,omitempty"`
+	Heartbeat         Duration            `json:"heartbeat,omitempty"`
+	Network           NetworkList         `json:"network,omitempty"`
+	TLS               *OutboundTLSOptions `json:"tls,omitempty"`
+}

+ 2 - 0
outbound/builder.go

@@ -51,6 +51,8 @@ func New(ctx context.Context, router adapter.Router, logger log.ContextLogger, t
 		return NewShadowsocksR(ctx, router, logger, tag, options.ShadowsocksROptions)
 		return NewShadowsocksR(ctx, router, logger, tag, options.ShadowsocksROptions)
 	case C.TypeVLESS:
 	case C.TypeVLESS:
 		return NewVLESS(ctx, router, logger, tag, options.VLESSOptions)
 		return NewVLESS(ctx, router, logger, tag, options.VLESSOptions)
+	case C.TypeTUIC:
+		return NewTUIC(ctx, router, logger, tag, options.TUICOptions)
 	case C.TypeSelector:
 	case C.TypeSelector:
 		return NewSelector(router, logger, tag, options.SelectorOptions)
 		return NewSelector(router, logger, tag, options.SelectorOptions)
 	case C.TypeURLTest:
 	case C.TypeURLTest:

+ 123 - 0
outbound/tuic.go

@@ -0,0 +1,123 @@
+//go:build with_quic
+
+package outbound
+
+import (
+	"context"
+	"net"
+	"os"
+	"time"
+
+	"github.com/sagernet/sing-box/adapter"
+	"github.com/sagernet/sing-box/common/dialer"
+	"github.com/sagernet/sing-box/common/tls"
+	C "github.com/sagernet/sing-box/constant"
+	"github.com/sagernet/sing-box/log"
+	"github.com/sagernet/sing-box/option"
+	"github.com/sagernet/sing-box/transport/tuic"
+	"github.com/sagernet/sing/common"
+	"github.com/sagernet/sing/common/bufio"
+	E "github.com/sagernet/sing/common/exceptions"
+	M "github.com/sagernet/sing/common/metadata"
+	N "github.com/sagernet/sing/common/network"
+
+	"github.com/gofrs/uuid/v5"
+)
+
+var (
+	_ adapter.Outbound                = (*TUIC)(nil)
+	_ adapter.InterfaceUpdateListener = (*TUIC)(nil)
+)
+
+type TUIC struct {
+	myOutboundAdapter
+	client *tuic.Client
+}
+
+func NewTUIC(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TUICOutboundOptions) (*TUIC, error) {
+	options.UDPFragmentDefault = true
+	if options.TLS == nil || !options.TLS.Enabled {
+		return nil, C.ErrTLSRequired
+	}
+	abstractTLSConfig, err := tls.NewClient(router, options.Server, common.PtrValueOrDefault(options.TLS))
+	if err != nil {
+		return nil, err
+	}
+	tlsConfig, err := abstractTLSConfig.Config()
+	if err != nil {
+		return nil, err
+	}
+	userUUID, err := uuid.FromString(options.UUID)
+	if err != nil {
+		return nil, E.Cause(err, "invalid uuid")
+	}
+	var udpStream bool
+	switch options.UDPRelayMode {
+	case "native":
+	case "quic":
+		udpStream = true
+	}
+	client, err := tuic.NewClient(tuic.ClientOptions{
+		Context:           ctx,
+		Dialer:            dialer.New(router, options.DialerOptions),
+		ServerAddress:     options.ServerOptions.Build(),
+		TLSConfig:         tlsConfig,
+		UUID:              userUUID,
+		Password:          options.Password,
+		CongestionControl: options.CongestionControl,
+		UDPStream:         udpStream,
+		ZeroRTTHandshake:  options.ZeroRTTHandshake,
+		Heartbeat:         time.Duration(options.Heartbeat),
+	})
+	if err != nil {
+		return nil, err
+	}
+	return &TUIC{
+		myOutboundAdapter: myOutboundAdapter{
+			protocol:     C.TypeTUIC,
+			network:      options.Network.Build(),
+			router:       router,
+			logger:       logger,
+			tag:          tag,
+			dependencies: withDialerDependency(options.DialerOptions),
+		},
+		client: client,
+	}, nil
+}
+
+func (h *TUIC) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
+	switch N.NetworkName(network) {
+	case N.NetworkTCP:
+		h.logger.InfoContext(ctx, "outbound connection to ", destination)
+		return h.client.DialConn(ctx, destination)
+	case N.NetworkUDP:
+		conn, err := h.ListenPacket(ctx, destination)
+		if err != nil {
+			return nil, err
+		}
+		return bufio.NewBindPacketConn(conn, destination), nil
+	default:
+		return nil, E.New("unsupported network: ", network)
+	}
+}
+
+func (h *TUIC) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
+	h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
+	return h.client.ListenPacket(ctx)
+}
+
+func (h *TUIC) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
+	return NewConnection(ctx, h, conn, metadata)
+}
+
+func (h *TUIC) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
+	return NewPacketConnection(ctx, h, conn, metadata)
+}
+
+func (h *TUIC) InterfaceUpdated() {
+	_ = h.client.CloseWithError(E.New("network changed"))
+}
+
+func (h *TUIC) Close() error {
+	return h.client.CloseWithError(os.ErrClosed)
+}

+ 16 - 0
outbound/tuic_stub.go

@@ -0,0 +1,16 @@
+//go:build !with_quic
+
+package outbound
+
+import (
+	"context"
+
+	"github.com/sagernet/sing-box/adapter"
+	C "github.com/sagernet/sing-box/constant"
+	"github.com/sagernet/sing-box/log"
+	"github.com/sagernet/sing-box/option"
+)
+
+func NewTUIC(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TUICOutboundOptions) (adapter.Outbound, error) {
+	return nil, C.ErrQUICNotIncluded
+}

+ 4 - 0
test/clash_test.go

@@ -38,6 +38,8 @@ const (
 	ImageShadowsocksR          = "teddysun/shadowsocks-r:latest"
 	ImageShadowsocksR          = "teddysun/shadowsocks-r:latest"
 	ImageXRayCore              = "teddysun/xray:latest"
 	ImageXRayCore              = "teddysun/xray:latest"
 	ImageShadowsocksLegacy     = "mritd/shadowsocks:latest"
 	ImageShadowsocksLegacy     = "mritd/shadowsocks:latest"
+	ImageTUICServer            = ""
+	ImageTUICClient            = ""
 )
 )
 
 
 var allImages = []string{
 var allImages = []string{
@@ -53,6 +55,8 @@ var allImages = []string{
 	ImageShadowsocksR,
 	ImageShadowsocksR,
 	ImageXRayCore,
 	ImageXRayCore,
 	ImageShadowsocksLegacy,
 	ImageShadowsocksLegacy,
+	// ImageTUICServer,
+	// ImageTUICClient,
 }
 }
 
 
 var localIP = netip.MustParseAddr("127.0.0.1")
 var localIP = netip.MustParseAddr("127.0.0.1")

+ 14 - 0
test/config/tuic-client.json

@@ -0,0 +1,14 @@
+{
+  "relay": {
+    "server": "127.0.0.1:10000",
+    "uuid": "FE35D05B-8803-45C4-BAE6-723AD2CD5D3D",
+    "password": "tuic",
+    "certificates": [
+      "/etc/tuic/ca.pem"
+    ]
+  },
+  "local": {
+    "server": "127.0.0.1:10001"
+  },
+  "log_level": "debug"
+}

+ 9 - 0
test/config/tuic-server.json

@@ -0,0 +1,9 @@
+{
+    "server": "[::]:10000",
+    "users": {
+        "FE35D05B-8803-45C4-BAE6-723AD2CD5D3D": "tuic"
+    },
+    "certificate":  "/etc/tuic/cert.pem",
+    "private_key": "/etc/tuic/key.pem",
+    "log_level": "debug"
+}

+ 178 - 0
test/tuic_test.go

@@ -0,0 +1,178 @@
+package main
+
+import (
+	"net/netip"
+	"testing"
+
+	C "github.com/sagernet/sing-box/constant"
+	"github.com/sagernet/sing-box/option"
+
+	"github.com/gofrs/uuid/v5"
+)
+
+func TestTUICSelf(t *testing.T) {
+	t.Run("self", func(t *testing.T) {
+		testTUICSelf(t, false, false)
+	})
+	t.Run("self-udp-stream", func(t *testing.T) {
+		testTUICSelf(t, true, false)
+	})
+	t.Run("self-early", func(t *testing.T) {
+		testTUICSelf(t, false, true)
+	})
+}
+
+func testTUICSelf(t *testing.T, udpStream bool, zeroRTTHandshake bool) {
+	_, certPem, keyPem := createSelfSignedCertificate(t, "example.org")
+	var udpRelayMode string
+	if udpStream {
+		udpRelayMode = "quic"
+	}
+	startInstance(t, option.Options{
+		Inbounds: []option.Inbound{
+			{
+				Type: C.TypeMixed,
+				Tag:  "mixed-in",
+				MixedOptions: option.HTTPMixedInboundOptions{
+					ListenOptions: option.ListenOptions{
+						Listen:     option.NewListenAddress(netip.IPv4Unspecified()),
+						ListenPort: clientPort,
+					},
+				},
+			},
+			{
+				Type: C.TypeTUIC,
+				TUICOptions: option.TUICInboundOptions{
+					ListenOptions: option.ListenOptions{
+						Listen:     option.NewListenAddress(netip.IPv4Unspecified()),
+						ListenPort: serverPort,
+					},
+					Users: []option.TUICUser{{
+						UUID: uuid.Nil.String(),
+					}},
+					ZeroRTTHandshake: zeroRTTHandshake,
+					TLS: &option.InboundTLSOptions{
+						Enabled:         true,
+						ServerName:      "example.org",
+						CertificatePath: certPem,
+						KeyPath:         keyPem,
+					},
+				},
+			},
+		},
+		Outbounds: []option.Outbound{
+			{
+				Type: C.TypeDirect,
+			},
+			{
+				Type: C.TypeTUIC,
+				Tag:  "tuic-out",
+				TUICOptions: option.TUICOutboundOptions{
+					ServerOptions: option.ServerOptions{
+						Server:     "127.0.0.1",
+						ServerPort: serverPort,
+					},
+					UUID:             uuid.Nil.String(),
+					UDPRelayMode:     udpRelayMode,
+					ZeroRTTHandshake: zeroRTTHandshake,
+					TLS: &option.OutboundTLSOptions{
+						Enabled:         true,
+						ServerName:      "example.org",
+						CertificatePath: certPem,
+					},
+				},
+			},
+		},
+		Route: &option.RouteOptions{
+			Rules: []option.Rule{
+				{
+					DefaultOptions: option.DefaultRule{
+						Inbound:  []string{"mixed-in"},
+						Outbound: "tuic-out",
+					},
+				},
+			},
+		},
+	})
+	testSuit(t, clientPort, testPort)
+}
+
+func TestTUICInbound(t *testing.T) {
+	caPem, certPem, keyPem := createSelfSignedCertificate(t, "example.org")
+	startInstance(t, option.Options{
+		Inbounds: []option.Inbound{
+			{
+				Type: C.TypeTUIC,
+				TUICOptions: option.TUICInboundOptions{
+					ListenOptions: option.ListenOptions{
+						Listen:     option.NewListenAddress(netip.IPv4Unspecified()),
+						ListenPort: serverPort,
+					},
+					Users: []option.TUICUser{{
+						UUID:     "FE35D05B-8803-45C4-BAE6-723AD2CD5D3D",
+						Password: "tuic",
+					}},
+					TLS: &option.InboundTLSOptions{
+						Enabled:         true,
+						ServerName:      "example.org",
+						CertificatePath: certPem,
+						KeyPath:         keyPem,
+					},
+				},
+			},
+		},
+	})
+	startDockerContainer(t, DockerOptions{
+		Image: ImageTUICClient,
+		Ports: []uint16{serverPort, clientPort},
+		Bind: map[string]string{
+			"tuic-client.json": "/etc/tuic/config.json",
+			caPem:              "/etc/tuic/ca.pem",
+		},
+	})
+}
+
+func TestTUICOutbound(t *testing.T) {
+	_, certPem, keyPem := createSelfSignedCertificate(t, "example.org")
+	startDockerContainer(t, DockerOptions{
+		Image: ImageTUICServer,
+		Ports: []uint16{testPort},
+		Bind: map[string]string{
+			"tuic-server.json": "/etc/tuic/config.json",
+			certPem:            "/etc/tuic/cert.pem",
+			keyPem:             "/etc/tuic/key.pem",
+		},
+	})
+	startInstance(t, option.Options{
+		Inbounds: []option.Inbound{
+			{
+				Type: C.TypeMixed,
+				MixedOptions: option.HTTPMixedInboundOptions{
+					ListenOptions: option.ListenOptions{
+						Listen:     option.NewListenAddress(netip.IPv4Unspecified()),
+						ListenPort: clientPort,
+					},
+				},
+			},
+		},
+		Outbounds: []option.Outbound{
+			{
+				Type: C.TypeTUIC,
+				TUICOptions: option.TUICOutboundOptions{
+					ServerOptions: option.ServerOptions{
+						Server:     "127.0.0.1",
+						ServerPort: serverPort,
+					},
+					UUID:     "FE35D05B-8803-45C4-BAE6-723AD2CD5D3D",
+					Password: "tuic",
+					TLS: &option.OutboundTLSOptions{
+						Enabled:         true,
+						ServerName:      "example.org",
+						CertificatePath: certPem,
+					},
+				},
+			},
+		},
+	})
+	testSuit(t, clientPort, testPort)
+}

+ 10 - 0
transport/tuic/address.go

@@ -0,0 +1,10 @@
+package tuic
+
+import M "github.com/sagernet/sing/common/metadata"
+
+var addressSerializer = M.NewSerializer(
+	M.AddressFamilyByte(0x00, M.AddressFamilyFqdn),
+	M.AddressFamilyByte(0x01, M.AddressFamilyIPv4),
+	M.AddressFamilyByte(0x02, M.AddressFamilyIPv6),
+	M.AddressFamilyByte(0xff, M.AddressFamilyEmpty),
+)

+ 322 - 0
transport/tuic/client.go

@@ -0,0 +1,322 @@
+package tuic
+
+import (
+	"context"
+	"crypto/tls"
+	"io"
+	"net"
+	"os"
+	"runtime"
+	"sync"
+	"time"
+
+	"github.com/sagernet/quic-go"
+	"github.com/sagernet/sing-box/common/baderror"
+	"github.com/sagernet/sing/common"
+	"github.com/sagernet/sing/common/buf"
+	"github.com/sagernet/sing/common/bufio"
+	E "github.com/sagernet/sing/common/exceptions"
+	M "github.com/sagernet/sing/common/metadata"
+	N "github.com/sagernet/sing/common/network"
+
+	"github.com/gofrs/uuid/v5"
+)
+
+type ClientOptions struct {
+	Context           context.Context
+	Dialer            N.Dialer
+	ServerAddress     M.Socksaddr
+	TLSConfig         *tls.Config
+	UUID              uuid.UUID
+	Password          string
+	CongestionControl string
+	UDPStream         bool
+	ZeroRTTHandshake  bool
+	Heartbeat         time.Duration
+}
+
+type Client struct {
+	ctx               context.Context
+	dialer            N.Dialer
+	serverAddr        M.Socksaddr
+	tlsConfig         *tls.Config
+	quicConfig        *quic.Config
+	uuid              uuid.UUID
+	password          string
+	congestionControl string
+	udpStream         bool
+	zeroRTTHandshake  bool
+	heartbeat         time.Duration
+
+	connAccess sync.RWMutex
+	conn       *clientQUICConnection
+}
+
+func NewClient(options ClientOptions) (*Client, error) {
+	if options.Heartbeat == 0 {
+		options.Heartbeat = 10 * time.Second
+	}
+	quicConfig := &quic.Config{
+		DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"),
+		MaxDatagramFrameSize:    1400,
+		EnableDatagrams:         true,
+		MaxIncomingUniStreams:   1 << 60,
+	}
+	switch options.CongestionControl {
+	case "":
+		options.CongestionControl = "cubic"
+	case "cubic", "new_reno", "bbr":
+	default:
+		return nil, E.New("unknown congestion control algorithm: ", options.CongestionControl)
+	}
+	return &Client{
+		ctx:               options.Context,
+		dialer:            options.Dialer,
+		serverAddr:        options.ServerAddress,
+		tlsConfig:         options.TLSConfig,
+		quicConfig:        quicConfig,
+		uuid:              options.UUID,
+		password:          options.Password,
+		congestionControl: options.CongestionControl,
+		udpStream:         options.UDPStream,
+		zeroRTTHandshake:  options.ZeroRTTHandshake,
+		heartbeat:         options.Heartbeat,
+	}, nil
+}
+
+func (c *Client) offer(ctx context.Context) (*clientQUICConnection, error) {
+	conn := c.conn
+	if conn != nil && conn.active() {
+		return conn, nil
+	}
+	c.connAccess.Lock()
+	defer c.connAccess.Unlock()
+	conn = c.conn
+	if conn != nil && conn.active() {
+		return conn, nil
+	}
+	conn, err := c.offerNew(ctx)
+	if err != nil {
+		return nil, err
+	}
+	return conn, nil
+}
+
+func (c *Client) offerNew(ctx context.Context) (*clientQUICConnection, error) {
+	udpConn, err := c.dialer.DialContext(ctx, "udp", c.serverAddr)
+	if err != nil {
+		return nil, err
+	}
+	var quicConn quic.Connection
+	if c.zeroRTTHandshake {
+		quicConn, err = quic.DialEarly(ctx, bufio.NewUnbindPacketConn(udpConn), udpConn.RemoteAddr(), c.tlsConfig, c.quicConfig)
+	} else {
+		quicConn, err = quic.Dial(ctx, bufio.NewUnbindPacketConn(udpConn), udpConn.RemoteAddr(), c.tlsConfig, c.quicConfig)
+	}
+	if err != nil {
+		udpConn.Close()
+		return nil, E.Cause(err, "open connection")
+	}
+	setCongestion(c.ctx, quicConn, c.congestionControl)
+	conn := &clientQUICConnection{
+		quicConn:   quicConn,
+		rawConn:    udpConn,
+		connDone:   make(chan struct{}),
+		udpConnMap: make(map[uint16]*udpPacketConn),
+	}
+	go func() {
+		hErr := c.clientHandshake(quicConn)
+		if hErr != nil {
+			conn.closeWithError(hErr)
+		}
+	}()
+	if c.udpStream {
+		go c.loopUniStreams(conn)
+	}
+	go c.loopMessages(conn)
+	go c.loopHeartbeats(conn)
+	c.conn = conn
+	return conn, nil
+}
+
+func (c *Client) clientHandshake(conn quic.Connection) error {
+	authStream, err := conn.OpenUniStream()
+	if err != nil {
+		return err
+	}
+	defer authStream.Close()
+	handshakeState := conn.ConnectionState().TLS
+	tuicAuthToken, err := handshakeState.ExportKeyingMaterial(string(c.uuid[:]), []byte(c.password), 32)
+	if err != nil {
+		return err
+	}
+	authRequest := buf.NewSize(AuthenticateLen)
+	authRequest.WriteByte(Version)
+	authRequest.WriteByte(CommandAuthenticate)
+	authRequest.Write(c.uuid[:])
+	authRequest.Write(tuicAuthToken)
+	return common.Error(authStream.Write(authRequest.Bytes()))
+}
+
+func (c *Client) loopHeartbeats(conn *clientQUICConnection) {
+	ticker := time.NewTicker(c.heartbeat)
+	defer ticker.Stop()
+	for {
+		select {
+		case <-conn.connDone:
+			return
+		case <-ticker.C:
+			err := conn.quicConn.SendMessage([]byte{Version, CommandHeartbeat})
+			if err != nil {
+				conn.closeWithError(E.Cause(err, "send heartbeat"))
+			}
+		}
+	}
+}
+
+func (c *Client) DialConn(ctx context.Context, destination M.Socksaddr) (net.Conn, error) {
+	conn, err := c.offer(ctx)
+	if err != nil {
+		return nil, err
+	}
+	stream, err := conn.quicConn.OpenStream()
+	if err != nil {
+		return nil, err
+	}
+	return &clientConn{
+		parent:      conn,
+		stream:      stream,
+		destination: destination,
+	}, nil
+}
+
+func (c *Client) ListenPacket(ctx context.Context) (net.PacketConn, error) {
+	conn, err := c.offer(ctx)
+	if err != nil {
+		return nil, err
+	}
+	var sessionID uint16
+	clientPacketConn := newUDPPacketConn(ctx, conn.quicConn, c.udpStream, false, func() {
+		conn.udpAccess.Lock()
+		delete(conn.udpConnMap, sessionID)
+		conn.udpAccess.Unlock()
+	})
+	conn.udpAccess.Lock()
+	sessionID = conn.udpSessionID
+	conn.udpSessionID++
+	conn.udpConnMap[sessionID] = clientPacketConn
+	conn.udpAccess.Unlock()
+	clientPacketConn.sessionID = sessionID
+	return clientPacketConn, nil
+}
+
+func (c *Client) CloseWithError(err error) error {
+	conn := c.conn
+	if conn != nil {
+		conn.closeWithError(err)
+	}
+	return nil
+}
+
+type clientQUICConnection struct {
+	quicConn     quic.Connection
+	rawConn      io.Closer
+	closeOnce    sync.Once
+	connDone     chan struct{}
+	connErr      error
+	udpAccess    sync.RWMutex
+	udpConnMap   map[uint16]*udpPacketConn
+	udpSessionID uint16
+}
+
+func (c *clientQUICConnection) active() bool {
+	select {
+	case <-c.quicConn.Context().Done():
+		return false
+	default:
+	}
+	select {
+	case <-c.connDone:
+		return false
+	default:
+	}
+	return true
+}
+
+func (c *clientQUICConnection) closeWithError(err error) {
+	c.closeOnce.Do(func() {
+		c.connErr = err
+		close(c.connDone)
+		_ = c.quicConn.CloseWithError(0, "")
+		_ = c.rawConn.Close()
+	})
+}
+
+type clientConn struct {
+	parent         *clientQUICConnection
+	stream         quic.Stream
+	destination    M.Socksaddr
+	requestWritten bool
+}
+
+func (c *clientConn) Read(b []byte) (n int, err error) {
+	n, err = c.stream.Read(b)
+	return n, baderror.WrapQUIC(err)
+}
+
+func (c *clientConn) Write(b []byte) (n int, err error) {
+	if !c.requestWritten {
+		request := buf.NewSize(2 + addressSerializer.AddrPortLen(c.destination) + len(b))
+		request.WriteByte(Version)
+		request.WriteByte(CommandConnect)
+		addressSerializer.WriteAddrPort(request, c.destination)
+		request.Write(b)
+		_, err = c.stream.Write(request.Bytes())
+		if err != nil {
+			c.parent.closeWithError(E.Cause(err, "create new connection"))
+			return 0, baderror.WrapQUIC(err)
+		}
+		c.requestWritten = true
+		return len(b), nil
+	}
+	n, err = c.stream.Write(b)
+	return n, baderror.WrapQUIC(err)
+}
+
+func (c *clientConn) Close() error {
+	stream := c.stream
+	if stream == nil {
+		return nil
+	}
+	stream.CancelRead(0)
+	return stream.Close()
+}
+
+func (c *clientConn) LocalAddr() net.Addr {
+	return M.Socksaddr{}
+}
+
+func (c *clientConn) RemoteAddr() net.Addr {
+	return c.destination
+}
+
+func (c *clientConn) SetDeadline(t time.Time) error {
+	if c.stream == nil {
+		return os.ErrInvalid
+	}
+	return c.stream.SetDeadline(t)
+}
+
+func (c *clientConn) SetReadDeadline(t time.Time) error {
+	if c.stream == nil {
+		return os.ErrInvalid
+	}
+	return c.stream.SetReadDeadline(t)
+}
+
+func (c *clientConn) SetWriteDeadline(t time.Time) error {
+	if c.stream == nil {
+		return os.ErrInvalid
+	}
+	return c.stream.SetWriteDeadline(t)
+}

+ 110 - 0
transport/tuic/client_packet.go

@@ -0,0 +1,110 @@
+package tuic
+
+import (
+	"io"
+
+	"github.com/sagernet/quic-go"
+	"github.com/sagernet/sing/common/buf"
+	"github.com/sagernet/sing/common/bufio"
+	E "github.com/sagernet/sing/common/exceptions"
+)
+
+func (c *Client) loopMessages(conn *clientQUICConnection) {
+	for {
+		message, err := conn.quicConn.ReceiveMessage(c.ctx)
+		if err != nil {
+			conn.closeWithError(E.Cause(err, "receive message"))
+			return
+		}
+		go func() {
+			hErr := c.handleMessage(conn, message)
+			if hErr != nil {
+				conn.closeWithError(E.Cause(hErr, "handle message"))
+			}
+		}()
+	}
+}
+
+func (c *Client) handleMessage(conn *clientQUICConnection, data []byte) error {
+	if len(data) < 2 {
+		return E.New("invalid message")
+	}
+	if data[0] != Version {
+		return E.New("unknown version ", data[0])
+	}
+	switch data[1] {
+	case CommandPacket:
+		message := udpMessagePool.Get().(*udpMessage)
+		err := decodeUDPMessage(message, data[2:])
+		if err != nil {
+			message.release()
+			return E.Cause(err, "decode UDP message")
+		}
+		conn.handleUDPMessage(message)
+		return nil
+	case CommandHeartbeat:
+		return nil
+	default:
+		return E.New("unknown command ", data[0])
+	}
+}
+
+func (c *Client) loopUniStreams(conn *clientQUICConnection) {
+	for {
+		stream, err := conn.quicConn.AcceptUniStream(c.ctx)
+		if err != nil {
+			conn.closeWithError(E.Cause(err, "handle uni stream"))
+			return
+		}
+		go func() {
+			hErr := c.handleUniStream(conn, stream)
+			if hErr != nil {
+				conn.closeWithError(hErr)
+			}
+		}()
+	}
+}
+
+func (c *Client) handleUniStream(conn *clientQUICConnection, stream quic.ReceiveStream) error {
+	defer stream.CancelRead(0)
+	buffer := buf.NewPacket()
+	defer buffer.Release()
+	_, err := buffer.ReadAtLeastFrom(stream, 2)
+	if err != nil {
+		return err
+	}
+	version, _ := buffer.ReadByte()
+	if version != Version {
+		return E.New("unknown version ", version)
+	}
+	command, _ := buffer.ReadByte()
+	if command != CommandPacket {
+		return E.New("unknown command ", command)
+	}
+	reader := io.MultiReader(bufio.NewCachedReader(stream, buffer), stream)
+	message := udpMessagePool.Get().(*udpMessage)
+	err = readUDPMessage(message, reader)
+	if err != nil {
+		message.release()
+		return err
+	}
+	conn.handleUDPMessage(message)
+	return nil
+}
+
+func (c *clientQUICConnection) handleUDPMessage(message *udpMessage) {
+	c.udpAccess.RLock()
+	udpConn, loaded := c.udpConnMap[message.sessionID]
+	c.udpAccess.RUnlock()
+	if !loaded {
+		message.releaseMessage()
+		return
+	}
+	select {
+	case <-udpConn.ctx.Done():
+		message.releaseMessage()
+		return
+	default:
+	}
+	udpConn.inputPacket(message)
+}

+ 46 - 0
transport/tuic/congestion.go

@@ -0,0 +1,46 @@
+package tuic
+
+import (
+	"context"
+	"time"
+
+	"github.com/sagernet/quic-go"
+	"github.com/sagernet/sing-box/transport/tuic/congestion"
+	"github.com/sagernet/sing/common/ntp"
+)
+
+func setCongestion(ctx context.Context, connection quic.Connection, congestionName string) {
+	timeFunc := ntp.TimeFuncFromContext(ctx)
+	if timeFunc == nil {
+		timeFunc = time.Now
+	}
+	switch congestionName {
+	case "cubic":
+		connection.SetCongestionControl(
+			congestion.NewCubicSender(
+				congestion.DefaultClock{TimeFunc: timeFunc},
+				congestion.GetInitialPacketSize(connection.RemoteAddr()),
+				false,
+				nil,
+			),
+		)
+	case "new_reno":
+		connection.SetCongestionControl(
+			congestion.NewCubicSender(
+				congestion.DefaultClock{TimeFunc: timeFunc},
+				congestion.GetInitialPacketSize(connection.RemoteAddr()),
+				true,
+				nil,
+			),
+		)
+	case "bbr":
+		connection.SetCongestionControl(
+			congestion.NewBBRSender(
+				congestion.DefaultClock{},
+				congestion.GetInitialPacketSize(connection.RemoteAddr()),
+				congestion.InitialCongestionWindow*congestion.InitialMaxDatagramSize,
+				congestion.DefaultBBRMaxCongestionWindow*congestion.InitialMaxDatagramSize,
+			),
+		)
+	}
+}

+ 3 - 0
transport/tuic/congestion/README.md

@@ -0,0 +1,3 @@
+# congestion
+
+mod from https://github.com/MetaCubeX/Clash.Meta/tree/53f9e1ee7104473da2b4ff5da29965563084482d/transport/tuic/congestion

+ 25 - 0
transport/tuic/congestion/bandwidth.go

@@ -0,0 +1,25 @@
+package congestion
+
+import (
+	"math"
+	"time"
+
+	"github.com/sagernet/quic-go/congestion"
+)
+
+// Bandwidth of a connection
+type Bandwidth uint64
+
+const infBandwidth Bandwidth = math.MaxUint64
+
+const (
+	// BitsPerSecond is 1 bit per second
+	BitsPerSecond Bandwidth = 1
+	// BytesPerSecond is 1 byte per second
+	BytesPerSecond = 8 * BitsPerSecond
+)
+
+// BandwidthFromDelta calculates the bandwidth from a number of bytes and a time delta
+func BandwidthFromDelta(bytes congestion.ByteCount, delta time.Duration) Bandwidth {
+	return Bandwidth(bytes) * Bandwidth(time.Second) / Bandwidth(delta) * BytesPerSecond
+}

+ 374 - 0
transport/tuic/congestion/bandwidth_sampler.go

@@ -0,0 +1,374 @@
+package congestion
+
+import (
+	"math"
+	"time"
+
+	"github.com/sagernet/quic-go/congestion"
+)
+
+var InfiniteBandwidth = Bandwidth(math.MaxUint64)
+
+// SendTimeState is a subset of ConnectionStateOnSentPacket which is returned
+// to the caller when the packet is acked or lost.
+type SendTimeState struct {
+	// Whether other states in this object is valid.
+	isValid bool
+	// Whether the sender is app limited at the time the packet was sent.
+	// App limited bandwidth sample might be artificially low because the sender
+	// did not have enough data to send in order to saturate the link.
+	isAppLimited bool
+	// Total number of sent bytes at the time the packet was sent.
+	// Includes the packet itself.
+	totalBytesSent congestion.ByteCount
+	// Total number of acked bytes at the time the packet was sent.
+	totalBytesAcked congestion.ByteCount
+	// Total number of lost bytes at the time the packet was sent.
+	totalBytesLost congestion.ByteCount
+}
+
+// ConnectionStateOnSentPacket represents the information about a sent packet
+// and the state of the connection at the moment the packet was sent,
+// specifically the information about the most recently acknowledged packet at
+// that moment.
+type ConnectionStateOnSentPacket struct {
+	packetNumber congestion.PacketNumber
+	// Time at which the packet is sent.
+	sendTime time.Time
+	// Size of the packet.
+	size congestion.ByteCount
+	// The value of |totalBytesSentAtLastAckedPacket| at the time the
+	// packet was sent.
+	totalBytesSentAtLastAckedPacket congestion.ByteCount
+	// The value of |lastAckedPacketSentTime| at the time the packet was
+	// sent.
+	lastAckedPacketSentTime time.Time
+	// The value of |lastAckedPacketAckTime| at the time the packet was
+	// sent.
+	lastAckedPacketAckTime time.Time
+	// Send time states that are returned to the congestion controller when the
+	// packet is acked or lost.
+	sendTimeState SendTimeState
+}
+
+// BandwidthSample
+type BandwidthSample struct {
+	// The bandwidth at that particular sample. Zero if no valid bandwidth sample
+	// is available.
+	bandwidth Bandwidth
+	// The RTT measurement at this particular sample.  Zero if no RTT sample is
+	// available.  Does not correct for delayed ack time.
+	rtt time.Duration
+	// States captured when the packet was sent.
+	stateAtSend SendTimeState
+}
+
+func NewBandwidthSample() *BandwidthSample {
+	return &BandwidthSample{
+		// FIXME: the default value of original code is zero.
+		rtt: InfiniteRTT,
+	}
+}
+
+// BandwidthSampler keeps track of sent and acknowledged packets and outputs a
+// bandwidth sample for every packet acknowledged. The samples are taken for
+// individual packets, and are not filtered; the consumer has to filter the
+// bandwidth samples itself. In certain cases, the sampler will locally severely
+// underestimate the bandwidth, hence a maximum filter with a size of at least
+// one RTT is recommended.
+//
+// This class bases its samples on the slope of two curves: the number of bytes
+// sent over time, and the number of bytes acknowledged as received over time.
+// It produces a sample of both slopes for every packet that gets acknowledged,
+// based on a slope between two points on each of the corresponding curves. Note
+// that due to the packet loss, the number of bytes on each curve might get
+// further and further away from each other, meaning that it is not feasible to
+// compare byte values coming from different curves with each other.
+//
+// The obvious points for measuring slope sample are the ones corresponding to
+// the packet that was just acknowledged. Let us denote them as S_1 (point at
+// which the current packet was sent) and A_1 (point at which the current packet
+// was acknowledged). However, taking a slope requires two points on each line,
+// so estimating bandwidth requires picking a packet in the past with respect to
+// which the slope is measured.
+//
+// For that purpose, BandwidthSampler always keeps track of the most recently
+// acknowledged packet, and records it together with every outgoing packet.
+// When a packet gets acknowledged (A_1), it has not only information about when
+// it itself was sent (S_1), but also the information about the latest
+// acknowledged packet right before it was sent (S_0 and A_0).
+//
+// Based on that data, send and ack rate are estimated as:
+//
+//	send_rate = (bytes(S_1) - bytes(S_0)) / (time(S_1) - time(S_0))
+//	ack_rate = (bytes(A_1) - bytes(A_0)) / (time(A_1) - time(A_0))
+//
+// Here, the ack rate is intuitively the rate we want to treat as bandwidth.
+// However, in certain cases (e.g. ack compression) the ack rate at a point may
+// end up higher than the rate at which the data was originally sent, which is
+// not indicative of the real bandwidth. Hence, we use the send rate as an upper
+// bound, and the sample value is
+//
+//	rate_sample = min(send_rate, ack_rate)
+//
+// An important edge case handled by the sampler is tracking the app-limited
+// samples. There are multiple meaning of "app-limited" used interchangeably,
+// hence it is important to understand and to be able to distinguish between
+// them.
+//
+// Meaning 1: connection state. The connection is said to be app-limited when
+// there is no outstanding data to send. This means that certain bandwidth
+// samples in the future would not be an accurate indication of the link
+// capacity, and it is important to inform consumer about that. Whenever
+// connection becomes app-limited, the sampler is notified via OnAppLimited()
+// method.
+//
+// Meaning 2: a phase in the bandwidth sampler. As soon as the bandwidth
+// sampler becomes notified about the connection being app-limited, it enters
+// app-limited phase. In that phase, all *sent* packets are marked as
+// app-limited. Note that the connection itself does not have to be
+// app-limited during the app-limited phase, and in fact it will not be
+// (otherwise how would it send packets?). The boolean flag below indicates
+// whether the sampler is in that phase.
+//
+// Meaning 3: a flag on the sent packet and on the sample. If a sent packet is
+// sent during the app-limited phase, the resulting sample related to the
+// packet will be marked as app-limited.
+//
+// With the terminology issue out of the way, let us consider the question of
+// what kind of situation it addresses.
+//
+// Consider a scenario where we first send packets 1 to 20 at a regular
+// bandwidth, and then immediately run out of data. After a few seconds, we send
+// packets 21 to 60, and only receive ack for 21 between sending packets 40 and
+// 41. In this case, when we sample bandwidth for packets 21 to 40, the S_0/A_0
+// we use to compute the slope is going to be packet 20, a few seconds apart
+// from the current packet, hence the resulting estimate would be extremely low
+// and not indicative of anything. Only at packet 41 the S_0/A_0 will become 21,
+// meaning that the bandwidth sample would exclude the quiescence.
+//
+// Based on the analysis of that scenario, we implement the following rule: once
+// OnAppLimited() is called, all sent packets will produce app-limited samples
+// up until an ack for a packet that was sent after OnAppLimited() was called.
+// Note that while the scenario above is not the only scenario when the
+// connection is app-limited, the approach works in other cases too.
+type BandwidthSampler struct {
+	// The total number of congestion controlled bytes sent during the connection.
+	totalBytesSent congestion.ByteCount
+	// The total number of congestion controlled bytes which were acknowledged.
+	totalBytesAcked congestion.ByteCount
+	// The total number of congestion controlled bytes which were lost.
+	totalBytesLost congestion.ByteCount
+	// The value of |totalBytesSent| at the time the last acknowledged packet
+	// was sent. Valid only when |lastAckedPacketSentTime| is valid.
+	totalBytesSentAtLastAckedPacket congestion.ByteCount
+	// The time at which the last acknowledged packet was sent. Set to
+	// QuicTime::Zero() if no valid timestamp is available.
+	lastAckedPacketSentTime time.Time
+	// The time at which the most recent packet was acknowledged.
+	lastAckedPacketAckTime time.Time
+	// The most recently sent packet.
+	lastSendPacket congestion.PacketNumber
+	// Indicates whether the bandwidth sampler is currently in an app-limited
+	// phase.
+	isAppLimited bool
+	// The packet that will be acknowledged after this one will cause the sampler
+	// to exit the app-limited phase.
+	endOfAppLimitedPhase congestion.PacketNumber
+	// Record of the connection state at the point where each packet in flight was
+	// sent, indexed by the packet number.
+	connectionStats *ConnectionStates
+}
+
+func NewBandwidthSampler() *BandwidthSampler {
+	return &BandwidthSampler{
+		connectionStats: &ConnectionStates{
+			stats: make(map[congestion.PacketNumber]*ConnectionStateOnSentPacket),
+		},
+	}
+}
+
+// OnPacketSent Inputs the sent packet information into the sampler. Assumes that all
+// packets are sent in order. The information about the packet will not be
+// released from the sampler until it the packet is either acknowledged or
+// declared lost.
+func (s *BandwidthSampler) OnPacketSent(sentTime time.Time, lastSentPacket congestion.PacketNumber, sentBytes, bytesInFlight congestion.ByteCount, hasRetransmittableData bool) {
+	s.lastSendPacket = lastSentPacket
+
+	if !hasRetransmittableData {
+		return
+	}
+
+	s.totalBytesSent += sentBytes
+
+	// If there are no packets in flight, the time at which the new transmission
+	// opens can be treated as the A_0 point for the purpose of bandwidth
+	// sampling. This underestimates bandwidth to some extent, and produces some
+	// artificially low samples for most packets in flight, but it provides with
+	// samples at important points where we would not have them otherwise, most
+	// importantly at the beginning of the connection.
+	if bytesInFlight == 0 {
+		s.lastAckedPacketAckTime = sentTime
+		s.totalBytesSentAtLastAckedPacket = s.totalBytesSent
+
+		// In this situation ack compression is not a concern, set send rate to
+		// effectively infinite.
+		s.lastAckedPacketSentTime = sentTime
+	}
+
+	s.connectionStats.Insert(lastSentPacket, sentTime, sentBytes, s)
+}
+
+// OnPacketAcked Notifies the sampler that the |lastAckedPacket| is acknowledged. Returns a
+// bandwidth sample. If no bandwidth sample is available,
+// QuicBandwidth::Zero() is returned.
+func (s *BandwidthSampler) OnPacketAcked(ackTime time.Time, lastAckedPacket congestion.PacketNumber) *BandwidthSample {
+	sentPacketState := s.connectionStats.Get(lastAckedPacket)
+	if sentPacketState == nil {
+		return NewBandwidthSample()
+	}
+
+	sample := s.onPacketAckedInner(ackTime, lastAckedPacket, sentPacketState)
+	s.connectionStats.Remove(lastAckedPacket)
+
+	return sample
+}
+
+// onPacketAckedInner Handles the actual bandwidth calculations, whereas the outer method handles
+// retrieving and removing |sentPacket|.
+func (s *BandwidthSampler) onPacketAckedInner(ackTime time.Time, lastAckedPacket congestion.PacketNumber, sentPacket *ConnectionStateOnSentPacket) *BandwidthSample {
+	s.totalBytesAcked += sentPacket.size
+
+	s.totalBytesSentAtLastAckedPacket = sentPacket.sendTimeState.totalBytesSent
+	s.lastAckedPacketSentTime = sentPacket.sendTime
+	s.lastAckedPacketAckTime = ackTime
+
+	// Exit app-limited phase once a packet that was sent while the connection is
+	// not app-limited is acknowledged.
+	if s.isAppLimited && lastAckedPacket > s.endOfAppLimitedPhase {
+		s.isAppLimited = false
+	}
+
+	// There might have been no packets acknowledged at the moment when the
+	// current packet was sent. In that case, there is no bandwidth sample to
+	// make.
+	if sentPacket.lastAckedPacketSentTime.IsZero() {
+		return NewBandwidthSample()
+	}
+
+	// Infinite rate indicates that the sampler is supposed to discard the
+	// current send rate sample and use only the ack rate.
+	sendRate := InfiniteBandwidth
+	if sentPacket.sendTime.After(sentPacket.lastAckedPacketSentTime) {
+		sendRate = BandwidthFromDelta(sentPacket.sendTimeState.totalBytesSent-sentPacket.totalBytesSentAtLastAckedPacket, sentPacket.sendTime.Sub(sentPacket.lastAckedPacketSentTime))
+	}
+
+	// During the slope calculation, ensure that ack time of the current packet is
+	// always larger than the time of the previous packet, otherwise division by
+	// zero or integer underflow can occur.
+	if !ackTime.After(sentPacket.lastAckedPacketAckTime) {
+		// TODO(wub): Compare this code count before and after fixing clock jitter
+		// issue.
+		// if sentPacket.lastAckedPacketAckTime.Equal(sentPacket.sendTime) {
+		// This is the 1st packet after quiescense.
+		// QUIC_CODE_COUNT_N(quic_prev_ack_time_larger_than_current_ack_time, 1, 2);
+		// } else {
+		//   QUIC_CODE_COUNT_N(quic_prev_ack_time_larger_than_current_ack_time, 2, 2);
+		// }
+
+		return NewBandwidthSample()
+	}
+
+	ackRate := BandwidthFromDelta(s.totalBytesAcked-sentPacket.sendTimeState.totalBytesAcked,
+		ackTime.Sub(sentPacket.lastAckedPacketAckTime))
+
+	// Note: this sample does not account for delayed acknowledgement time.  This
+	// means that the RTT measurements here can be artificially high, especially
+	// on low bandwidth connections.
+	sample := &BandwidthSample{
+		bandwidth: minBandwidth(sendRate, ackRate),
+		rtt:       ackTime.Sub(sentPacket.sendTime),
+	}
+
+	SentPacketToSendTimeState(sentPacket, &sample.stateAtSend)
+	return sample
+}
+
+// OnPacketLost Informs the sampler that a packet is considered lost and it should no
+// longer keep track of it.
+func (s *BandwidthSampler) OnPacketLost(packetNumber congestion.PacketNumber) SendTimeState {
+	ok, sentPacket := s.connectionStats.Remove(packetNumber)
+	sendTimeState := SendTimeState{
+		isValid: ok,
+	}
+	if sentPacket != nil {
+		s.totalBytesLost += sentPacket.size
+		SentPacketToSendTimeState(sentPacket, &sendTimeState)
+	}
+
+	return sendTimeState
+}
+
+// OnAppLimited Informs the sampler that the connection is currently app-limited, causing
+// the sampler to enter the app-limited phase.  The phase will expire by
+// itself.
+func (s *BandwidthSampler) OnAppLimited() {
+	s.isAppLimited = true
+	s.endOfAppLimitedPhase = s.lastSendPacket
+}
+
+// SentPacketToSendTimeState Copy a subset of the (private) ConnectionStateOnSentPacket to the (public)
+// SendTimeState. Always set send_time_state->is_valid to true.
+func SentPacketToSendTimeState(sentPacket *ConnectionStateOnSentPacket, sendTimeState *SendTimeState) {
+	sendTimeState.isAppLimited = sentPacket.sendTimeState.isAppLimited
+	sendTimeState.totalBytesSent = sentPacket.sendTimeState.totalBytesSent
+	sendTimeState.totalBytesAcked = sentPacket.sendTimeState.totalBytesAcked
+	sendTimeState.totalBytesLost = sentPacket.sendTimeState.totalBytesLost
+	sendTimeState.isValid = true
+}
+
+// ConnectionStates Record of the connection state at the point where each packet in flight was
+// sent, indexed by the packet number.
+// FIXME: using LinkedList replace map to fast remove all the packets lower than the specified packet number.
+type ConnectionStates struct {
+	stats map[congestion.PacketNumber]*ConnectionStateOnSentPacket
+}
+
+func (s *ConnectionStates) Insert(packetNumber congestion.PacketNumber, sentTime time.Time, bytes congestion.ByteCount, sampler *BandwidthSampler) bool {
+	if _, ok := s.stats[packetNumber]; ok {
+		return false
+	}
+
+	s.stats[packetNumber] = NewConnectionStateOnSentPacket(packetNumber, sentTime, bytes, sampler)
+	return true
+}
+
+func (s *ConnectionStates) Get(packetNumber congestion.PacketNumber) *ConnectionStateOnSentPacket {
+	return s.stats[packetNumber]
+}
+
+func (s *ConnectionStates) Remove(packetNumber congestion.PacketNumber) (bool, *ConnectionStateOnSentPacket) {
+	state, ok := s.stats[packetNumber]
+	if ok {
+		delete(s.stats, packetNumber)
+	}
+	return ok, state
+}
+
+func NewConnectionStateOnSentPacket(packetNumber congestion.PacketNumber, sentTime time.Time, bytes congestion.ByteCount, sampler *BandwidthSampler) *ConnectionStateOnSentPacket {
+	return &ConnectionStateOnSentPacket{
+		packetNumber:                    packetNumber,
+		sendTime:                        sentTime,
+		size:                            bytes,
+		lastAckedPacketSentTime:         sampler.lastAckedPacketSentTime,
+		lastAckedPacketAckTime:          sampler.lastAckedPacketAckTime,
+		totalBytesSentAtLastAckedPacket: sampler.totalBytesSentAtLastAckedPacket,
+		sendTimeState: SendTimeState{
+			isValid:         true,
+			isAppLimited:    sampler.isAppLimited,
+			totalBytesSent:  sampler.totalBytesSent,
+			totalBytesAcked: sampler.totalBytesAcked,
+			totalBytesLost:  sampler.totalBytesLost,
+		},
+	}
+}

+ 1000 - 0
transport/tuic/congestion/bbr_sender.go

@@ -0,0 +1,1000 @@
+package congestion
+
+// src from https://quiche.googlesource.com/quiche.git/+/66dea072431f94095dfc3dd2743cb94ef365f7ef/quic/core/congestion_control/bbr_sender.cc
+
+import (
+	"fmt"
+	"math"
+	"math/rand"
+	"net"
+	"time"
+
+	"github.com/sagernet/quic-go/congestion"
+)
+
+const (
+	// InitialMaxDatagramSize is the default maximum packet size used in QUIC for congestion window computations in bytes.
+	InitialMaxDatagramSize        = 1252
+	InitialPacketSizeIPv4         = 1252
+	InitialPacketSizeIPv6         = 1232
+	InitialCongestionWindow       = 32
+	DefaultBBRMaxCongestionWindow = 10000
+)
+
+func GetInitialPacketSize(addr net.Addr) congestion.ByteCount {
+	maxSize := congestion.ByteCount(1200)
+	// If this is not a UDP address, we don't know anything about the MTU.
+	// Use the minimum size of an Initial packet as the max packet size.
+	if udpAddr, ok := addr.(*net.UDPAddr); ok {
+		if udpAddr.IP.To4() != nil {
+			maxSize = InitialPacketSizeIPv4
+		} else {
+			maxSize = InitialPacketSizeIPv6
+		}
+	}
+	return congestion.ByteCount(maxSize)
+}
+
+var (
+
+	// Default initial rtt used before any samples are received.
+	InitialRtt = 100 * time.Millisecond
+
+	// The gain used for the STARTUP, equal to  4*ln(2).
+	DefaultHighGain = 2.77
+
+	// The gain used in STARTUP after loss has been detected.
+	// 1.5 is enough to allow for 25% exogenous loss and still observe a 25% growth
+	// in measured bandwidth.
+	StartupAfterLossGain = 1.5
+
+	// The cycle of gains used during the PROBE_BW stage.
+	PacingGain = []float64{1.25, 0.75, 1, 1, 1, 1, 1, 1}
+
+	// The length of the gain cycle.
+	GainCycleLength = len(PacingGain)
+
+	// The size of the bandwidth filter window, in round-trips.
+	BandwidthWindowSize = GainCycleLength + 2
+
+	// The time after which the current min_rtt value expires.
+	MinRttExpiry = 10 * time.Second
+
+	// The minimum time the connection can spend in PROBE_RTT mode.
+	ProbeRttTime = time.Millisecond * 200
+
+	// If the bandwidth does not increase by the factor of |kStartupGrowthTarget|
+	// within |kRoundTripsWithoutGrowthBeforeExitingStartup| rounds, the connection
+	// will exit the STARTUP mode.
+	StartupGrowthTarget                         = 1.25
+	RoundTripsWithoutGrowthBeforeExitingStartup = int64(3)
+
+	// Coefficient of target congestion window to use when basing PROBE_RTT on BDP.
+	ModerateProbeRttMultiplier = 0.75
+
+	// Coefficient to determine if a new RTT is sufficiently similar to min_rtt that
+	// we don't need to enter PROBE_RTT.
+	SimilarMinRttThreshold = 1.125
+
+	// Congestion window gain for QUIC BBR during PROBE_BW phase.
+	DefaultCongestionWindowGainConst = 2.0
+)
+
+type bbrMode int
+
+const (
+	// Startup phase of the connection.
+	STARTUP = iota
+	// After achieving the highest possible bandwidth during the startup, lower
+	// the pacing rate in order to drain the queue.
+	DRAIN
+	// Cruising mode.
+	PROBE_BW
+	// Temporarily slow down sending in order to empty the buffer and measure
+	// the real minimum RTT.
+	PROBE_RTT
+)
+
+type bbrRecoveryState int
+
+const (
+	// Do not limit.
+	NOT_IN_RECOVERY = iota
+
+	// Allow an extra outstanding byte for each byte acknowledged.
+	CONSERVATION
+
+	// Allow two extra outstanding bytes for each byte acknowledged (slow
+	// start).
+	GROWTH
+)
+
+type bbrSender struct {
+	mode          bbrMode
+	clock         Clock
+	rttStats      congestion.RTTStatsProvider
+	bytesInFlight congestion.ByteCount
+	// return total bytes of unacked packets.
+	// GetBytesInFlight func() congestion.ByteCount
+	// Bandwidth sampler provides BBR with the bandwidth measurements at
+	// individual points.
+	sampler *BandwidthSampler
+	// The number of the round trips that have occurred during the connection.
+	roundTripCount int64
+	// The packet number of the most recently sent packet.
+	lastSendPacket congestion.PacketNumber
+	// Acknowledgement of any packet after |current_round_trip_end_| will cause
+	// the round trip counter to advance.
+	currentRoundTripEnd congestion.PacketNumber
+	// The filter that tracks the maximum bandwidth over the multiple recent
+	// round-trips.
+	maxBandwidth *WindowedFilter
+	// Tracks the maximum number of bytes acked faster than the sending rate.
+	maxAckHeight *WindowedFilter
+	// The time this aggregation started and the number of bytes acked during it.
+	aggregationEpochStartTime time.Time
+	aggregationEpochBytes     congestion.ByteCount
+	// Minimum RTT estimate.  Automatically expires within 10 seconds (and
+	// triggers PROBE_RTT mode) if no new value is sampled during that period.
+	minRtt time.Duration
+	// The time at which the current value of |min_rtt_| was assigned.
+	minRttTimestamp time.Time
+	// The maximum allowed number of bytes in flight.
+	congestionWindow congestion.ByteCount
+	// The initial value of the |congestion_window_|.
+	initialCongestionWindow congestion.ByteCount
+	// The largest value the |congestion_window_| can achieve.
+	initialMaxCongestionWindow congestion.ByteCount
+	// The smallest value the |congestion_window_| can achieve.
+	// minCongestionWindow congestion.ByteCount
+	// The pacing gain applied during the STARTUP phase.
+	highGain float64
+	// The CWND gain applied during the STARTUP phase.
+	highCwndGain float64
+	// The pacing gain applied during the DRAIN phase.
+	drainGain float64
+	// The current pacing rate of the connection.
+	pacingRate Bandwidth
+	// The gain currently applied to the pacing rate.
+	pacingGain float64
+	// The gain currently applied to the congestion window.
+	congestionWindowGain float64
+	// The gain used for the congestion window during PROBE_BW.  Latched from
+	// quic_bbr_cwnd_gain flag.
+	congestionWindowGainConst float64
+	// The number of RTTs to stay in STARTUP mode.  Defaults to 3.
+	numStartupRtts int64
+	// If true, exit startup if 1RTT has passed with no bandwidth increase and
+	// the connection is in recovery.
+	exitStartupOnLoss bool
+	// Number of round-trips in PROBE_BW mode, used for determining the current
+	// pacing gain cycle.
+	cycleCurrentOffset int
+	// The time at which the last pacing gain cycle was started.
+	lastCycleStart time.Time
+	// Indicates whether the connection has reached the full bandwidth mode.
+	isAtFullBandwidth bool
+	// Number of rounds during which there was no significant bandwidth increase.
+	roundsWithoutBandwidthGain int64
+	// The bandwidth compared to which the increase is measured.
+	bandwidthAtLastRound Bandwidth
+	// Set to true upon exiting quiescence.
+	exitingQuiescence bool
+	// Time at which PROBE_RTT has to be exited.  Setting it to zero indicates
+	// that the time is yet unknown as the number of packets in flight has not
+	// reached the required value.
+	exitProbeRttAt time.Time
+	// Indicates whether a round-trip has passed since PROBE_RTT became active.
+	probeRttRoundPassed bool
+	// Indicates whether the most recent bandwidth sample was marked as
+	// app-limited.
+	lastSampleIsAppLimited bool
+	// Indicates whether any non app-limited samples have been recorded.
+	hasNoAppLimitedSample bool
+	// Indicates app-limited calls should be ignored as long as there's
+	// enough data inflight to see more bandwidth when necessary.
+	flexibleAppLimited bool
+	// Current state of recovery.
+	recoveryState bbrRecoveryState
+	// Receiving acknowledgement of a packet after |end_recovery_at_| will cause
+	// BBR to exit the recovery mode.  A value above zero indicates at least one
+	// loss has been detected, so it must not be set back to zero.
+	endRecoveryAt congestion.PacketNumber
+	// A window used to limit the number of bytes in flight during loss recovery.
+	recoveryWindow congestion.ByteCount
+	// If true, consider all samples in recovery app-limited.
+	isAppLimitedRecovery bool
+	// When true, pace at 1.5x and disable packet conservation in STARTUP.
+	slowerStartup bool
+	// When true, disables packet conservation in STARTUP.
+	rateBasedStartup bool
+	// When non-zero, decreases the rate in STARTUP by the total number of bytes
+	// lost in STARTUP divided by CWND.
+	startupRateReductionMultiplier int64
+	// Sum of bytes lost in STARTUP.
+	startupBytesLost congestion.ByteCount
+	// When true, add the most recent ack aggregation measurement during STARTUP.
+	enableAckAggregationDuringStartup bool
+	// When true, expire the windowed ack aggregation values in STARTUP when
+	// bandwidth increases more than 25%.
+	expireAckAggregationInStartup bool
+	// If true, will not exit low gain mode until bytes_in_flight drops below BDP
+	// or it's time for high gain mode.
+	drainToTarget bool
+	// If true, use a CWND of 0.75*BDP during probe_rtt instead of 4 packets.
+	probeRttBasedOnBdp bool
+	// If true, skip probe_rtt and update the timestamp of the existing min_rtt to
+	// now if min_rtt over the last cycle is within 12.5% of the current min_rtt.
+	// Even if the min_rtt is 12.5% too low, the 25% gain cycling and 2x CWND gain
+	// should overcome an overly small min_rtt.
+	probeRttSkippedIfSimilarRtt bool
+	// If true, disable PROBE_RTT entirely as long as the connection was recently
+	// app limited.
+	probeRttDisabledIfAppLimited bool
+	appLimitedSinceLastProbeRtt  bool
+	minRttSinceLastProbeRtt      time.Duration
+	// Latched value of --quic_always_get_bw_sample_when_acked.
+	alwaysGetBwSampleWhenAcked bool
+
+	pacer *pacer
+
+	maxDatagramSize congestion.ByteCount
+}
+
+func NewBBRSender(
+	clock Clock,
+	initialMaxDatagramSize,
+	initialCongestionWindow,
+	initialMaxCongestionWindow congestion.ByteCount,
+) *bbrSender {
+	b := &bbrSender{
+		mode:                      STARTUP,
+		clock:                     clock,
+		sampler:                   NewBandwidthSampler(),
+		maxBandwidth:              NewWindowedFilter(int64(BandwidthWindowSize), MaxFilter),
+		maxAckHeight:              NewWindowedFilter(int64(BandwidthWindowSize), MaxFilter),
+		congestionWindow:          initialCongestionWindow,
+		initialCongestionWindow:   initialCongestionWindow,
+		highGain:                  DefaultHighGain,
+		highCwndGain:              DefaultHighGain,
+		drainGain:                 1.0 / DefaultHighGain,
+		pacingGain:                1.0,
+		congestionWindowGain:      1.0,
+		congestionWindowGainConst: DefaultCongestionWindowGainConst,
+		numStartupRtts:            RoundTripsWithoutGrowthBeforeExitingStartup,
+		recoveryState:             NOT_IN_RECOVERY,
+		recoveryWindow:            initialMaxCongestionWindow,
+		minRttSinceLastProbeRtt:   InfiniteRTT,
+		maxDatagramSize:           initialMaxDatagramSize,
+	}
+	b.pacer = newPacer(b.BandwidthEstimate)
+	return b
+}
+
+func (b *bbrSender) maxCongestionWindow() congestion.ByteCount {
+	return b.maxDatagramSize * DefaultBBRMaxCongestionWindow
+}
+
+func (b *bbrSender) minCongestionWindow() congestion.ByteCount {
+	return b.maxDatagramSize * b.initialCongestionWindow
+}
+
+func (b *bbrSender) SetRTTStatsProvider(provider congestion.RTTStatsProvider) {
+	b.rttStats = provider
+}
+
+func (b *bbrSender) GetBytesInFlight() congestion.ByteCount {
+	return b.bytesInFlight
+}
+
+// TimeUntilSend returns when the next packet should be sent.
+func (b *bbrSender) TimeUntilSend(bytesInFlight congestion.ByteCount) time.Time {
+	b.bytesInFlight = bytesInFlight
+	return b.pacer.TimeUntilSend()
+}
+
+func (b *bbrSender) HasPacingBudget(now time.Time) bool {
+	return b.pacer.Budget(now) >= b.maxDatagramSize
+}
+
+func (b *bbrSender) SetMaxDatagramSize(s congestion.ByteCount) {
+	if s < b.maxDatagramSize {
+		panic(fmt.Sprintf("congestion BUG: decreased max datagram size from %d to %d", b.maxDatagramSize, s))
+	}
+	cwndIsMinCwnd := b.congestionWindow == b.minCongestionWindow()
+	b.maxDatagramSize = s
+	if cwndIsMinCwnd {
+		b.congestionWindow = b.minCongestionWindow()
+	}
+	b.pacer.SetMaxDatagramSize(s)
+}
+
+func (b *bbrSender) OnPacketSent(sentTime time.Time, bytesInFlight congestion.ByteCount, packetNumber congestion.PacketNumber, bytes congestion.ByteCount, isRetransmittable bool) {
+	b.pacer.SentPacket(sentTime, bytes)
+	b.lastSendPacket = packetNumber
+
+	b.bytesInFlight = bytesInFlight
+	if bytesInFlight == 0 && b.sampler.isAppLimited {
+		b.exitingQuiescence = true
+	}
+
+	if b.aggregationEpochStartTime.IsZero() {
+		b.aggregationEpochStartTime = sentTime
+	}
+
+	b.sampler.OnPacketSent(sentTime, packetNumber, bytes, bytesInFlight, isRetransmittable)
+}
+
+func (b *bbrSender) CanSend(bytesInFlight congestion.ByteCount) bool {
+	b.bytesInFlight = bytesInFlight
+	return bytesInFlight < b.GetCongestionWindow()
+}
+
+func (b *bbrSender) GetCongestionWindow() congestion.ByteCount {
+	if b.mode == PROBE_RTT {
+		return b.ProbeRttCongestionWindow()
+	}
+
+	if b.InRecovery() && !(b.rateBasedStartup && b.mode == STARTUP) {
+		return minByteCount(b.congestionWindow, b.recoveryWindow)
+	}
+
+	return b.congestionWindow
+}
+
+func (b *bbrSender) MaybeExitSlowStart() {
+}
+
+func (b *bbrSender) OnPacketAcked(number congestion.PacketNumber, ackedBytes congestion.ByteCount, priorInFlight congestion.ByteCount, eventTime time.Time) {
+	totalBytesAckedBefore := b.sampler.totalBytesAcked
+	isRoundStart, minRttExpired := false, false
+	lastAckedPacket := number
+
+	isRoundStart = b.UpdateRoundTripCounter(lastAckedPacket)
+	minRttExpired = b.UpdateBandwidthAndMinRtt(eventTime, number, ackedBytes)
+	b.UpdateRecoveryState(false, isRoundStart)
+	bytesAcked := b.sampler.totalBytesAcked - totalBytesAckedBefore
+	excessAcked := b.UpdateAckAggregationBytes(eventTime, bytesAcked)
+
+	// Handle logic specific to STARTUP and DRAIN modes.
+	if isRoundStart && !b.isAtFullBandwidth {
+		b.CheckIfFullBandwidthReached()
+	}
+	b.MaybeExitStartupOrDrain(eventTime)
+
+	// Handle logic specific to PROBE_RTT.
+	b.MaybeEnterOrExitProbeRtt(eventTime, isRoundStart, minRttExpired)
+
+	// After the model is updated, recalculate the pacing rate and congestion
+	// window.
+	b.CalculatePacingRate()
+	b.CalculateCongestionWindow(bytesAcked, excessAcked)
+	b.CalculateRecoveryWindow(bytesAcked, congestion.ByteCount(0))
+}
+
+func (b *bbrSender) OnPacketLost(number congestion.PacketNumber, lostBytes congestion.ByteCount, priorInFlight congestion.ByteCount) {
+	eventTime := time.Now()
+	totalBytesAckedBefore := b.sampler.totalBytesAcked
+	isRoundStart, minRttExpired := false, false
+
+	b.DiscardLostPackets(number, lostBytes)
+
+	// Input the new data into the BBR model of the connection.
+	var excessAcked congestion.ByteCount
+
+	// Handle logic specific to PROBE_BW mode.
+	if b.mode == PROBE_BW {
+		b.UpdateGainCyclePhase(time.Now(), priorInFlight, true)
+	}
+
+	// Handle logic specific to STARTUP and DRAIN modes.
+	b.MaybeExitStartupOrDrain(eventTime)
+
+	// Handle logic specific to PROBE_RTT.
+	b.MaybeEnterOrExitProbeRtt(eventTime, isRoundStart, minRttExpired)
+
+	// Calculate number of packets acked and lost.
+	bytesAcked := b.sampler.totalBytesAcked - totalBytesAckedBefore
+	bytesLost := lostBytes
+
+	// After the model is updated, recalculate the pacing rate and congestion
+	// window.
+	b.CalculatePacingRate()
+	b.CalculateCongestionWindow(bytesAcked, excessAcked)
+	b.CalculateRecoveryWindow(bytesAcked, bytesLost)
+}
+
+//func (b *bbrSender) OnCongestionEvent(priorInFlight congestion.ByteCount, eventTime time.Time, ackedPackets, lostPackets []*congestion.Packet) {
+//	totalBytesAckedBefore := b.sampler.totalBytesAcked
+//	isRoundStart, minRttExpired := false, false
+//
+//	if lostPackets != nil {
+//		b.DiscardLostPackets(lostPackets)
+//	}
+//
+//	// Input the new data into the BBR model of the connection.
+//	var excessAcked congestion.ByteCount
+//	if len(ackedPackets) > 0 {
+//		lastAckedPacket := ackedPackets[len(ackedPackets)-1].PacketNumber
+//		isRoundStart = b.UpdateRoundTripCounter(lastAckedPacket)
+//		minRttExpired = b.UpdateBandwidthAndMinRtt(eventTime, ackedPackets)
+//		b.UpdateRecoveryState(lastAckedPacket, len(lostPackets) > 0, isRoundStart)
+//		bytesAcked := b.sampler.totalBytesAcked - totalBytesAckedBefore
+//		excessAcked = b.UpdateAckAggregationBytes(eventTime, bytesAcked)
+//	}
+//
+//	// Handle logic specific to PROBE_BW mode.
+//	if b.mode == PROBE_BW {
+//		b.UpdateGainCyclePhase(eventTime, priorInFlight, len(lostPackets) > 0)
+//	}
+//
+//	// Handle logic specific to STARTUP and DRAIN modes.
+//	if isRoundStart && !b.isAtFullBandwidth {
+//		b.CheckIfFullBandwidthReached()
+//	}
+//	b.MaybeExitStartupOrDrain(eventTime)
+//
+//	// Handle logic specific to PROBE_RTT.
+//	b.MaybeEnterOrExitProbeRtt(eventTime, isRoundStart, minRttExpired)
+//
+//	// Calculate number of packets acked and lost.
+//	bytesAcked := b.sampler.totalBytesAcked - totalBytesAckedBefore
+//	bytesLost := congestion.ByteCount(0)
+//	for _, packet := range lostPackets {
+//		bytesLost += packet.Length
+//	}
+//
+//	// After the model is updated, recalculate the pacing rate and congestion
+//	// window.
+//	b.CalculatePacingRate()
+//	b.CalculateCongestionWindow(bytesAcked, excessAcked)
+//	b.CalculateRecoveryWindow(bytesAcked, bytesLost)
+//}
+
+//func (b *bbrSender) SetNumEmulatedConnections(n int) {
+//
+//}
+
+func (b *bbrSender) OnRetransmissionTimeout(packetsRetransmitted bool) {
+}
+
+//func (b *bbrSender) OnConnectionMigration() {
+//
+//}
+
+//// Experiments
+//func (b *bbrSender) SetSlowStartLargeReduction(enabled bool) {
+//
+//}
+
+//func (b *bbrSender) BandwidthEstimate() Bandwidth {
+//	return Bandwidth(b.maxBandwidth.GetBest())
+//}
+
+// BandwidthEstimate returns the current bandwidth estimate
+func (b *bbrSender) BandwidthEstimate() Bandwidth {
+	if b.rttStats == nil {
+		return infBandwidth
+	}
+	srtt := b.rttStats.SmoothedRTT()
+	if srtt == 0 {
+		// If we haven't measured an rtt, the bandwidth estimate is unknown.
+		return infBandwidth
+	}
+	return BandwidthFromDelta(b.GetCongestionWindow(), srtt)
+}
+
+//func (b *bbrSender) HybridSlowStart() *HybridSlowStart {
+//	return nil
+//}
+
+//func (b *bbrSender) SlowstartThreshold() congestion.ByteCount {
+//	return 0
+//}
+
+//func (b *bbrSender) RenoBeta() float32 {
+//	return 0.0
+//}
+
+func (b *bbrSender) InRecovery() bool {
+	return b.recoveryState != NOT_IN_RECOVERY
+}
+
+func (b *bbrSender) InSlowStart() bool {
+	return b.mode == STARTUP
+}
+
+//func (b *bbrSender) ShouldSendProbingPacket() bool {
+//	if b.pacingGain <= 1 {
+//		return false
+//	}
+//	// TODO(b/77975811): If the pipe is highly under-utilized, consider not
+//	// sending a probing transmission, because the extra bandwidth is not needed.
+//	// If flexible_app_limited is enabled, check if the pipe is sufficiently full.
+//	if b.flexibleAppLimited {
+//		return !b.IsPipeSufficientlyFull()
+//	} else {
+//		return true
+//	}
+//}
+
+//func (b *bbrSender) IsPipeSufficientlyFull() bool {
+//	// See if we need more bytes in flight to see more bandwidth.
+//	if b.mode == STARTUP {
+//		// STARTUP exits if it doesn't observe a 25% bandwidth increase, so the CWND
+//		// must be more than 25% above the target.
+//		return b.GetBytesInFlight() >= b.GetTargetCongestionWindow(1.5)
+//	}
+//	if b.pacingGain > 1 {
+//		// Super-unity PROBE_BW doesn't exit until 1.25 * BDP is achieved.
+//		return b.GetBytesInFlight() >= b.GetTargetCongestionWindow(b.pacingGain)
+//	}
+//	// If bytes_in_flight are above the target congestion window, it should be
+//	// possible to observe the same or more bandwidth if it's available.
+//	return b.GetBytesInFlight() >= b.GetTargetCongestionWindow(1.1)
+//}
+
+//func (b *bbrSender) SetFromConfig() {
+//	// TODO: not impl.
+//}
+
+func (b *bbrSender) UpdateRoundTripCounter(lastAckedPacket congestion.PacketNumber) bool {
+	if b.currentRoundTripEnd == 0 || lastAckedPacket > b.currentRoundTripEnd {
+		b.currentRoundTripEnd = lastAckedPacket
+		b.roundTripCount++
+		// if b.rttStats != nil && b.InSlowStart() {
+		// TODO: ++stats_->slowstart_num_rtts;
+		// }
+		return true
+	}
+	return false
+}
+
+func (b *bbrSender) UpdateBandwidthAndMinRtt(now time.Time, number congestion.PacketNumber, ackedBytes congestion.ByteCount) bool {
+	sampleMinRtt := InfiniteRTT
+
+	if !b.alwaysGetBwSampleWhenAcked && ackedBytes == 0 {
+		// Skip acked packets with 0 in flight bytes when updating bandwidth.
+		return false
+	}
+	bandwidthSample := b.sampler.OnPacketAcked(now, number)
+	if b.alwaysGetBwSampleWhenAcked && !bandwidthSample.stateAtSend.isValid {
+		// From the sampler's perspective, the packet has never been sent, or the
+		// packet has been acked or marked as lost previously.
+		return false
+	}
+	b.lastSampleIsAppLimited = bandwidthSample.stateAtSend.isAppLimited
+	//     has_non_app_limited_sample_ |=
+	//        !bandwidth_sample.state_at_send.is_app_limited;
+	if !bandwidthSample.stateAtSend.isAppLimited {
+		b.hasNoAppLimitedSample = true
+	}
+	if bandwidthSample.rtt > 0 {
+		sampleMinRtt = minRtt(sampleMinRtt, bandwidthSample.rtt)
+	}
+	if !bandwidthSample.stateAtSend.isAppLimited || bandwidthSample.bandwidth > b.BandwidthEstimate() {
+		b.maxBandwidth.Update(int64(bandwidthSample.bandwidth), b.roundTripCount)
+	}
+
+	// If none of the RTT samples are valid, return immediately.
+	if sampleMinRtt == InfiniteRTT {
+		return false
+	}
+
+	b.minRttSinceLastProbeRtt = minRtt(b.minRttSinceLastProbeRtt, sampleMinRtt)
+	// Do not expire min_rtt if none was ever available.
+	minRttExpired := b.minRtt > 0 && (now.After(b.minRttTimestamp.Add(MinRttExpiry)))
+	if minRttExpired || sampleMinRtt < b.minRtt || b.minRtt == 0 {
+		if minRttExpired && b.ShouldExtendMinRttExpiry() {
+			minRttExpired = false
+		} else {
+			b.minRtt = sampleMinRtt
+		}
+		b.minRttTimestamp = now
+		// Reset since_last_probe_rtt fields.
+		b.minRttSinceLastProbeRtt = InfiniteRTT
+		b.appLimitedSinceLastProbeRtt = false
+	}
+
+	return minRttExpired
+}
+
+func (b *bbrSender) ShouldExtendMinRttExpiry() bool {
+	if b.probeRttDisabledIfAppLimited && b.appLimitedSinceLastProbeRtt {
+		// Extend the current min_rtt if we've been app limited recently.
+		return true
+	}
+
+	minRttIncreasedSinceLastProbe := b.minRttSinceLastProbeRtt > time.Duration(float64(b.minRtt)*SimilarMinRttThreshold)
+	if b.probeRttSkippedIfSimilarRtt && b.appLimitedSinceLastProbeRtt && !minRttIncreasedSinceLastProbe {
+		// Extend the current min_rtt if we've been app limited recently and an rtt
+		// has been measured in that time that's less than 12.5% more than the
+		// current min_rtt.
+		return true
+	}
+
+	return false
+}
+
+func (b *bbrSender) DiscardLostPackets(number congestion.PacketNumber, lostBytes congestion.ByteCount) {
+	b.sampler.OnPacketLost(number)
+	if b.mode == STARTUP {
+		// if b.rttStats != nil {
+		// TODO: slow start.
+		// }
+		if b.startupRateReductionMultiplier != 0 {
+			b.startupBytesLost += lostBytes
+		}
+	}
+}
+
+func (b *bbrSender) UpdateRecoveryState(hasLosses, isRoundStart bool) {
+	// Exit recovery when there are no losses for a round.
+	if !hasLosses {
+		b.endRecoveryAt = b.lastSendPacket
+	}
+	switch b.recoveryState {
+	case NOT_IN_RECOVERY:
+		// Enter conservation on the first loss.
+		if hasLosses {
+			b.recoveryState = CONSERVATION
+			// This will cause the |recovery_window_| to be set to the correct
+			// value in CalculateRecoveryWindow().
+			b.recoveryWindow = 0
+			// Since the conservation phase is meant to be lasting for a whole
+			// round, extend the current round as if it were started right now.
+			b.currentRoundTripEnd = b.lastSendPacket
+			if false && b.lastSampleIsAppLimited {
+				b.isAppLimitedRecovery = true
+			}
+		}
+	case CONSERVATION:
+		if isRoundStart {
+			b.recoveryState = GROWTH
+		}
+		fallthrough
+	case GROWTH:
+		// Exit recovery if appropriate.
+		if !hasLosses && b.lastSendPacket > b.endRecoveryAt {
+			b.recoveryState = NOT_IN_RECOVERY
+			b.isAppLimitedRecovery = false
+		}
+	}
+
+	if b.recoveryState != NOT_IN_RECOVERY && b.isAppLimitedRecovery {
+		b.sampler.OnAppLimited()
+	}
+}
+
+func (b *bbrSender) UpdateAckAggregationBytes(ackTime time.Time, ackedBytes congestion.ByteCount) congestion.ByteCount {
+	// Compute how many bytes are expected to be delivered, assuming max bandwidth
+	// is correct.
+	expectedAckedBytes := congestion.ByteCount(b.maxBandwidth.GetBest()) *
+		congestion.ByteCount((ackTime.Sub(b.aggregationEpochStartTime)))
+	// Reset the current aggregation epoch as soon as the ack arrival rate is less
+	// than or equal to the max bandwidth.
+	if b.aggregationEpochBytes <= expectedAckedBytes {
+		// Reset to start measuring a new aggregation epoch.
+		b.aggregationEpochBytes = ackedBytes
+		b.aggregationEpochStartTime = ackTime
+		return 0
+	}
+	// Compute how many extra bytes were delivered vs max bandwidth.
+	// Include the bytes most recently acknowledged to account for stretch acks.
+	b.aggregationEpochBytes += ackedBytes
+	b.maxAckHeight.Update(int64(b.aggregationEpochBytes-expectedAckedBytes), b.roundTripCount)
+	return b.aggregationEpochBytes - expectedAckedBytes
+}
+
+func (b *bbrSender) UpdateGainCyclePhase(now time.Time, priorInFlight congestion.ByteCount, hasLosses bool) {
+	bytesInFlight := b.GetBytesInFlight()
+	// In most cases, the cycle is advanced after an RTT passes.
+	shouldAdvanceGainCycling := now.Sub(b.lastCycleStart) > b.GetMinRtt()
+
+	// If the pacing gain is above 1.0, the connection is trying to probe the
+	// bandwidth by increasing the number of bytes in flight to at least
+	// pacing_gain * BDP.  Make sure that it actually reaches the target, as long
+	// as there are no losses suggesting that the buffers are not able to hold
+	// that much.
+	if b.pacingGain > 1.0 && !hasLosses && priorInFlight < b.GetTargetCongestionWindow(b.pacingGain) {
+		shouldAdvanceGainCycling = false
+	}
+	// If pacing gain is below 1.0, the connection is trying to drain the extra
+	// queue which could have been incurred by probing prior to it.  If the number
+	// of bytes in flight falls down to the estimated BDP value earlier, conclude
+	// that the queue has been successfully drained and exit this cycle early.
+	if b.pacingGain < 1.0 && bytesInFlight <= b.GetTargetCongestionWindow(1.0) {
+		shouldAdvanceGainCycling = true
+	}
+
+	if shouldAdvanceGainCycling {
+		b.cycleCurrentOffset = (b.cycleCurrentOffset + 1) % GainCycleLength
+		b.lastCycleStart = now
+		// Stay in low gain mode until the target BDP is hit.
+		// Low gain mode will be exited immediately when the target BDP is achieved.
+		if b.drainToTarget && b.pacingGain < 1.0 && PacingGain[b.cycleCurrentOffset] == 1.0 &&
+			bytesInFlight > b.GetTargetCongestionWindow(1.0) {
+			return
+		}
+		b.pacingGain = PacingGain[b.cycleCurrentOffset]
+	}
+}
+
+func (b *bbrSender) GetTargetCongestionWindow(gain float64) congestion.ByteCount {
+	bdp := congestion.ByteCount(b.GetMinRtt()) * congestion.ByteCount(b.BandwidthEstimate())
+	congestionWindow := congestion.ByteCount(gain * float64(bdp))
+
+	// BDP estimate will be zero if no bandwidth samples are available yet.
+	if congestionWindow == 0 {
+		congestionWindow = congestion.ByteCount(gain * float64(b.initialCongestionWindow))
+	}
+
+	return maxByteCount(congestionWindow, b.minCongestionWindow())
+}
+
+func (b *bbrSender) CheckIfFullBandwidthReached() {
+	if b.lastSampleIsAppLimited {
+		return
+	}
+
+	target := Bandwidth(float64(b.bandwidthAtLastRound) * StartupGrowthTarget)
+	if b.BandwidthEstimate() >= target {
+		b.bandwidthAtLastRound = b.BandwidthEstimate()
+		b.roundsWithoutBandwidthGain = 0
+		if b.expireAckAggregationInStartup {
+			// Expire old excess delivery measurements now that bandwidth increased.
+			b.maxAckHeight.Reset(0, b.roundTripCount)
+		}
+		return
+	}
+	b.roundsWithoutBandwidthGain++
+	if b.roundsWithoutBandwidthGain >= b.numStartupRtts || (b.exitStartupOnLoss && b.InRecovery()) {
+		b.isAtFullBandwidth = true
+	}
+}
+
+func (b *bbrSender) MaybeExitStartupOrDrain(now time.Time) {
+	if b.mode == STARTUP && b.isAtFullBandwidth {
+		b.OnExitStartup(now)
+		b.mode = DRAIN
+		b.pacingGain = b.drainGain
+		b.congestionWindowGain = b.highCwndGain
+	}
+	if b.mode == DRAIN && b.GetBytesInFlight() <= b.GetTargetCongestionWindow(1) {
+		b.EnterProbeBandwidthMode(now)
+	}
+}
+
+func (b *bbrSender) EnterProbeBandwidthMode(now time.Time) {
+	b.mode = PROBE_BW
+	b.congestionWindowGain = b.congestionWindowGainConst
+
+	// Pick a random offset for the gain cycle out of {0, 2..7} range. 1 is
+	// excluded because in that case increased gain and decreased gain would not
+	// follow each other.
+	b.cycleCurrentOffset = rand.Int() % (GainCycleLength - 1)
+	if b.cycleCurrentOffset >= 1 {
+		b.cycleCurrentOffset += 1
+	}
+
+	b.lastCycleStart = now
+	b.pacingGain = PacingGain[b.cycleCurrentOffset]
+}
+
+func (b *bbrSender) MaybeEnterOrExitProbeRtt(now time.Time, isRoundStart, minRttExpired bool) {
+	if minRttExpired && !b.exitingQuiescence && b.mode != PROBE_RTT {
+		if b.InSlowStart() {
+			b.OnExitStartup(now)
+		}
+		b.mode = PROBE_RTT
+		b.pacingGain = 1.0
+		// Do not decide on the time to exit PROBE_RTT until the |bytes_in_flight|
+		// is at the target small value.
+		b.exitProbeRttAt = time.Time{}
+	}
+
+	if b.mode == PROBE_RTT {
+		b.sampler.OnAppLimited()
+		if b.exitProbeRttAt.IsZero() {
+			// If the window has reached the appropriate size, schedule exiting
+			// PROBE_RTT.  The CWND during PROBE_RTT is kMinimumCongestionWindow, but
+			// we allow an extra packet since QUIC checks CWND before sending a
+			// packet.
+			if b.GetBytesInFlight() < b.ProbeRttCongestionWindow()+b.maxDatagramSize {
+				b.exitProbeRttAt = now.Add(ProbeRttTime)
+				b.probeRttRoundPassed = false
+			}
+		} else {
+			if isRoundStart {
+				b.probeRttRoundPassed = true
+			}
+			if !now.Before(b.exitProbeRttAt) && b.probeRttRoundPassed {
+				b.minRttTimestamp = now
+				if !b.isAtFullBandwidth {
+					b.EnterStartupMode(now)
+				} else {
+					b.EnterProbeBandwidthMode(now)
+				}
+			}
+		}
+	}
+	b.exitingQuiescence = false
+}
+
+func (b *bbrSender) ProbeRttCongestionWindow() congestion.ByteCount {
+	if b.probeRttBasedOnBdp {
+		return b.GetTargetCongestionWindow(ModerateProbeRttMultiplier)
+	} else {
+		return b.minCongestionWindow()
+	}
+}
+
+func (b *bbrSender) EnterStartupMode(now time.Time) {
+	// if b.rttStats != nil {
+	// TODO: slow start.
+	// }
+	b.mode = STARTUP
+	b.pacingGain = b.highGain
+	b.congestionWindowGain = b.highCwndGain
+}
+
+func (b *bbrSender) OnExitStartup(now time.Time) {
+	if b.rttStats == nil {
+		return
+	}
+	// TODO: slow start.
+}
+
+func (b *bbrSender) CalculatePacingRate() {
+	if b.BandwidthEstimate() == 0 {
+		return
+	}
+
+	targetRate := Bandwidth(b.pacingGain * float64(b.BandwidthEstimate()))
+	if b.isAtFullBandwidth {
+		b.pacingRate = targetRate
+		return
+	}
+
+	// Pace at the rate of initial_window / RTT as soon as RTT measurements are
+	// available.
+	if b.pacingRate == 0 && b.rttStats.MinRTT() > 0 {
+		b.pacingRate = BandwidthFromDelta(b.initialCongestionWindow, b.rttStats.MinRTT())
+		return
+	}
+	// Slow the pacing rate in STARTUP once loss has ever been detected.
+	hasEverDetectedLoss := b.endRecoveryAt > 0
+	if b.slowerStartup && hasEverDetectedLoss && b.hasNoAppLimitedSample {
+		b.pacingRate = Bandwidth(StartupAfterLossGain * float64(b.BandwidthEstimate()))
+		return
+	}
+
+	// Slow the pacing rate in STARTUP by the bytes_lost / CWND.
+	if b.startupRateReductionMultiplier != 0 && hasEverDetectedLoss && b.hasNoAppLimitedSample {
+		b.pacingRate = Bandwidth((1.0 - (float64(b.startupBytesLost) * float64(b.startupRateReductionMultiplier) / float64(b.congestionWindow))) * float64(targetRate))
+		// Ensure the pacing rate doesn't drop below the startup growth target times
+		// the bandwidth estimate.
+		b.pacingRate = maxBandwidth(b.pacingRate, Bandwidth(StartupGrowthTarget*float64(b.BandwidthEstimate())))
+		return
+	}
+
+	// Do not decrease the pacing rate during startup.
+	b.pacingRate = maxBandwidth(b.pacingRate, targetRate)
+}
+
+func (b *bbrSender) CalculateCongestionWindow(ackedBytes, excessAcked congestion.ByteCount) {
+	if b.mode == PROBE_RTT {
+		return
+	}
+
+	targetWindow := b.GetTargetCongestionWindow(b.congestionWindowGain)
+	if b.isAtFullBandwidth {
+		// Add the max recently measured ack aggregation to CWND.
+		targetWindow += congestion.ByteCount(b.maxAckHeight.GetBest())
+	} else if b.enableAckAggregationDuringStartup {
+		// Add the most recent excess acked.  Because CWND never decreases in
+		// STARTUP, this will automatically create a very localized max filter.
+		targetWindow += excessAcked
+	}
+
+	// Instead of immediately setting the target CWND as the new one, BBR grows
+	// the CWND towards |target_window| by only increasing it |bytes_acked| at a
+	// time.
+	addBytesAcked := true || !b.InRecovery()
+	if b.isAtFullBandwidth {
+		b.congestionWindow = minByteCount(targetWindow, b.congestionWindow+ackedBytes)
+	} else if addBytesAcked && (b.congestionWindow < targetWindow || b.sampler.totalBytesAcked < b.initialCongestionWindow) {
+		// If the connection is not yet out of startup phase, do not decrease the
+		// window.
+		b.congestionWindow += ackedBytes
+	}
+
+	// Enforce the limits on the congestion window.
+	b.congestionWindow = maxByteCount(b.congestionWindow, b.minCongestionWindow())
+	b.congestionWindow = minByteCount(b.congestionWindow, b.maxCongestionWindow())
+}
+
+func (b *bbrSender) CalculateRecoveryWindow(ackedBytes, lostBytes congestion.ByteCount) {
+	if b.rateBasedStartup && b.mode == STARTUP {
+		return
+	}
+
+	if b.recoveryState == NOT_IN_RECOVERY {
+		return
+	}
+
+	// Set up the initial recovery window.
+	if b.recoveryWindow == 0 {
+		b.recoveryWindow = maxByteCount(b.GetBytesInFlight()+ackedBytes, b.minCongestionWindow())
+		return
+	}
+
+	// Remove losses from the recovery window, while accounting for a potential
+	// integer underflow.
+	if b.recoveryWindow >= lostBytes {
+		b.recoveryWindow -= lostBytes
+	} else {
+		b.recoveryWindow = congestion.ByteCount(b.maxDatagramSize)
+	}
+	// In CONSERVATION mode, just subtracting losses is sufficient.  In GROWTH,
+	// release additional |bytes_acked| to achieve a slow-start-like behavior.
+	if b.recoveryState == GROWTH {
+		b.recoveryWindow += ackedBytes
+	}
+	// Sanity checks.  Ensure that we always allow to send at least an MSS or
+	// |bytes_acked| in response, whichever is larger.
+	b.recoveryWindow = maxByteCount(b.recoveryWindow, b.GetBytesInFlight()+ackedBytes)
+	b.recoveryWindow = maxByteCount(b.recoveryWindow, b.minCongestionWindow())
+}
+
+var _ congestion.CongestionControl = (*bbrSender)(nil)
+
+func (b *bbrSender) GetMinRtt() time.Duration {
+	if b.minRtt > 0 {
+		return b.minRtt
+	} else {
+		return InitialRtt
+	}
+}
+
+func minRtt(a, b time.Duration) time.Duration {
+	if a < b {
+		return a
+	} else {
+		return b
+	}
+}
+
+func minBandwidth(a, b Bandwidth) Bandwidth {
+	if a < b {
+		return a
+	} else {
+		return b
+	}
+}
+
+func maxBandwidth(a, b Bandwidth) Bandwidth {
+	if a > b {
+		return a
+	} else {
+		return b
+	}
+}
+
+func maxByteCount(a, b congestion.ByteCount) congestion.ByteCount {
+	if a > b {
+		return a
+	} else {
+		return b
+	}
+}
+
+func minByteCount(a, b congestion.ByteCount) congestion.ByteCount {
+	if a < b {
+		return a
+	} else {
+		return b
+	}
+}
+
+var InfiniteRTT = time.Duration(math.MaxInt64)

+ 20 - 0
transport/tuic/congestion/clock.go

@@ -0,0 +1,20 @@
+package congestion
+
+import "time"
+
+// A Clock returns the current time
+type Clock interface {
+	Now() time.Time
+}
+
+// DefaultClock implements the Clock interface using the Go stdlib clock.
+type DefaultClock struct {
+	TimeFunc func() time.Time
+}
+
+var _ Clock = DefaultClock{}
+
+// Now gets the current time
+func (c DefaultClock) Now() time.Time {
+	return c.TimeFunc()
+}

+ 213 - 0
transport/tuic/congestion/cubic.go

@@ -0,0 +1,213 @@
+package congestion
+
+import (
+	"math"
+	"time"
+
+	"github.com/sagernet/quic-go/congestion"
+)
+
+// This cubic implementation is based on the one found in Chromiums's QUIC
+// implementation, in the files net/quic/congestion_control/cubic.{hh,cc}.
+
+// Constants based on TCP defaults.
+// The following constants are in 2^10 fractions of a second instead of ms to
+// allow a 10 shift right to divide.
+
+// 1024*1024^3 (first 1024 is from 0.100^3)
+// where 0.100 is 100 ms which is the scaling round trip time.
+const (
+	cubeScale                                      = 40
+	cubeCongestionWindowScale                      = 410
+	cubeFactor                congestion.ByteCount = 1 << cubeScale / cubeCongestionWindowScale / maxDatagramSize
+	// TODO: when re-enabling cubic, make sure to use the actual packet size here
+	maxDatagramSize = congestion.ByteCount(InitialPacketSizeIPv4)
+)
+
+const defaultNumConnections = 1
+
+// Default Cubic backoff factor
+const beta float32 = 0.7
+
+// Additional backoff factor when loss occurs in the concave part of the Cubic
+// curve. This additional backoff factor is expected to give up bandwidth to
+// new concurrent flows and speed up convergence.
+const betaLastMax float32 = 0.85
+
+// Cubic implements the cubic algorithm from TCP
+type Cubic struct {
+	clock Clock
+
+	// Number of connections to simulate.
+	numConnections int
+
+	// Time when this cycle started, after last loss event.
+	epoch time.Time
+
+	// Max congestion window used just before last loss event.
+	// Note: to improve fairness to other streams an additional back off is
+	// applied to this value if the new value is below our latest value.
+	lastMaxCongestionWindow congestion.ByteCount
+
+	// Number of acked bytes since the cycle started (epoch).
+	ackedBytesCount congestion.ByteCount
+
+	// TCP Reno equivalent congestion window in packets.
+	estimatedTCPcongestionWindow congestion.ByteCount
+
+	// Origin point of cubic function.
+	originPointCongestionWindow congestion.ByteCount
+
+	// Time to origin point of cubic function in 2^10 fractions of a second.
+	timeToOriginPoint uint32
+
+	// Last congestion window in packets computed by cubic function.
+	lastTargetCongestionWindow congestion.ByteCount
+}
+
+// NewCubic returns a new Cubic instance
+func NewCubic(clock Clock) *Cubic {
+	c := &Cubic{
+		clock:          clock,
+		numConnections: defaultNumConnections,
+	}
+	c.Reset()
+	return c
+}
+
+// Reset is called after a timeout to reset the cubic state
+func (c *Cubic) Reset() {
+	c.epoch = time.Time{}
+	c.lastMaxCongestionWindow = 0
+	c.ackedBytesCount = 0
+	c.estimatedTCPcongestionWindow = 0
+	c.originPointCongestionWindow = 0
+	c.timeToOriginPoint = 0
+	c.lastTargetCongestionWindow = 0
+}
+
+func (c *Cubic) alpha() float32 {
+	// TCPFriendly alpha is described in Section 3.3 of the CUBIC paper. Note that
+	// beta here is a cwnd multiplier, and is equal to 1-beta from the paper.
+	// We derive the equivalent alpha for an N-connection emulation as:
+	b := c.beta()
+	return 3 * float32(c.numConnections) * float32(c.numConnections) * (1 - b) / (1 + b)
+}
+
+func (c *Cubic) beta() float32 {
+	// kNConnectionBeta is the backoff factor after loss for our N-connection
+	// emulation, which emulates the effective backoff of an ensemble of N
+	// TCP-Reno connections on a single loss event. The effective multiplier is
+	// computed as:
+	return (float32(c.numConnections) - 1 + beta) / float32(c.numConnections)
+}
+
+func (c *Cubic) betaLastMax() float32 {
+	// betaLastMax is the additional backoff factor after loss for our
+	// N-connection emulation, which emulates the additional backoff of
+	// an ensemble of N TCP-Reno connections on a single loss event. The
+	// effective multiplier is computed as:
+	return (float32(c.numConnections) - 1 + betaLastMax) / float32(c.numConnections)
+}
+
+// OnApplicationLimited is called on ack arrival when sender is unable to use
+// the available congestion window. Resets Cubic state during quiescence.
+func (c *Cubic) OnApplicationLimited() {
+	// When sender is not using the available congestion window, the window does
+	// not grow. But to be RTT-independent, Cubic assumes that the sender has been
+	// using the entire window during the time since the beginning of the current
+	// "epoch" (the end of the last loss recovery period). Since
+	// application-limited periods break this assumption, we reset the epoch when
+	// in such a period. This reset effectively freezes congestion window growth
+	// through application-limited periods and allows Cubic growth to continue
+	// when the entire window is being used.
+	c.epoch = time.Time{}
+}
+
+// CongestionWindowAfterPacketLoss computes a new congestion window to use after
+// a loss event. Returns the new congestion window in packets. The new
+// congestion window is a multiplicative decrease of our current window.
+func (c *Cubic) CongestionWindowAfterPacketLoss(currentCongestionWindow congestion.ByteCount) congestion.ByteCount {
+	if currentCongestionWindow+maxDatagramSize < c.lastMaxCongestionWindow {
+		// We never reached the old max, so assume we are competing with another
+		// flow. Use our extra back off factor to allow the other flow to go up.
+		c.lastMaxCongestionWindow = congestion.ByteCount(c.betaLastMax() * float32(currentCongestionWindow))
+	} else {
+		c.lastMaxCongestionWindow = currentCongestionWindow
+	}
+	c.epoch = time.Time{} // Reset time.
+	return congestion.ByteCount(float32(currentCongestionWindow) * c.beta())
+}
+
+// CongestionWindowAfterAck computes a new congestion window to use after a received ACK.
+// Returns the new congestion window in packets. The new congestion window
+// follows a cubic function that depends on the time passed since last
+// packet loss.
+func (c *Cubic) CongestionWindowAfterAck(
+	ackedBytes congestion.ByteCount,
+	currentCongestionWindow congestion.ByteCount,
+	delayMin time.Duration,
+	eventTime time.Time,
+) congestion.ByteCount {
+	c.ackedBytesCount += ackedBytes
+
+	if c.epoch.IsZero() {
+		// First ACK after a loss event.
+		c.epoch = eventTime            // Start of epoch.
+		c.ackedBytesCount = ackedBytes // Reset count.
+		// Reset estimated_tcp_congestion_window_ to be in sync with cubic.
+		c.estimatedTCPcongestionWindow = currentCongestionWindow
+		if c.lastMaxCongestionWindow <= currentCongestionWindow {
+			c.timeToOriginPoint = 0
+			c.originPointCongestionWindow = currentCongestionWindow
+		} else {
+			c.timeToOriginPoint = uint32(math.Cbrt(float64(cubeFactor * (c.lastMaxCongestionWindow - currentCongestionWindow))))
+			c.originPointCongestionWindow = c.lastMaxCongestionWindow
+		}
+	}
+
+	// Change the time unit from microseconds to 2^10 fractions per second. Take
+	// the round trip time in account. This is done to allow us to use shift as a
+	// divide operator.
+	elapsedTime := int64(eventTime.Add(delayMin).Sub(c.epoch)/time.Microsecond) << 10 / (1000 * 1000)
+
+	// Right-shifts of negative, signed numbers have implementation-dependent
+	// behavior, so force the offset to be positive, as is done in the kernel.
+	offset := int64(c.timeToOriginPoint) - elapsedTime
+	if offset < 0 {
+		offset = -offset
+	}
+
+	deltaCongestionWindow := congestion.ByteCount(cubeCongestionWindowScale*offset*offset*offset) * maxDatagramSize >> cubeScale
+	var targetCongestionWindow congestion.ByteCount
+	if elapsedTime > int64(c.timeToOriginPoint) {
+		targetCongestionWindow = c.originPointCongestionWindow + deltaCongestionWindow
+	} else {
+		targetCongestionWindow = c.originPointCongestionWindow - deltaCongestionWindow
+	}
+	// Limit the CWND increase to half the acked bytes.
+	targetCongestionWindow = Min(targetCongestionWindow, currentCongestionWindow+c.ackedBytesCount/2)
+
+	// Increase the window by approximately Alpha * 1 MSS of bytes every
+	// time we ack an estimated tcp window of bytes.  For small
+	// congestion windows (less than 25), the formula below will
+	// increase slightly slower than linearly per estimated tcp window
+	// of bytes.
+	c.estimatedTCPcongestionWindow += congestion.ByteCount(float32(c.ackedBytesCount) * c.alpha() * float32(maxDatagramSize) / float32(c.estimatedTCPcongestionWindow))
+	c.ackedBytesCount = 0
+
+	// We have a new cubic congestion window.
+	c.lastTargetCongestionWindow = targetCongestionWindow
+
+	// Compute target congestion_window based on cubic target and estimated TCP
+	// congestion_window, use highest (fastest).
+	if targetCongestionWindow < c.estimatedTCPcongestionWindow {
+		targetCongestionWindow = c.estimatedTCPcongestionWindow
+	}
+	return targetCongestionWindow
+}
+
+// SetNumConnections sets the number of emulated connections
+func (c *Cubic) SetNumConnections(n int) {
+	c.numConnections = n
+}

+ 318 - 0
transport/tuic/congestion/cubic_sender.go

@@ -0,0 +1,318 @@
+package congestion
+
+import (
+	"fmt"
+	"time"
+
+	"github.com/sagernet/quic-go/congestion"
+	"github.com/sagernet/quic-go/logging"
+)
+
+const (
+	maxBurstPackets            = 3
+	renoBeta                   = 0.7 // Reno backoff factor.
+	minCongestionWindowPackets = 2
+	initialCongestionWindow    = 32
+)
+
+const (
+	InvalidPacketNumber        congestion.PacketNumber = -1
+	MaxCongestionWindowPackets                         = 20000
+	MaxByteCount                                       = congestion.ByteCount(1<<62 - 1)
+)
+
+type cubicSender struct {
+	hybridSlowStart HybridSlowStart
+	rttStats        congestion.RTTStatsProvider
+	cubic           *Cubic
+	pacer           *pacer
+	clock           Clock
+
+	reno bool
+
+	// Track the largest packet that has been sent.
+	largestSentPacketNumber congestion.PacketNumber
+
+	// Track the largest packet that has been acked.
+	largestAckedPacketNumber congestion.PacketNumber
+
+	// Track the largest packet number outstanding when a CWND cutback occurs.
+	largestSentAtLastCutback congestion.PacketNumber
+
+	// Whether the last loss event caused us to exit slowstart.
+	// Used for stats collection of slowstartPacketsLost
+	lastCutbackExitedSlowstart bool
+
+	// Congestion window in bytes.
+	congestionWindow congestion.ByteCount
+
+	// Slow start congestion window in bytes, aka ssthresh.
+	slowStartThreshold congestion.ByteCount
+
+	// ACK counter for the Reno implementation.
+	numAckedPackets uint64
+
+	initialCongestionWindow    congestion.ByteCount
+	initialMaxCongestionWindow congestion.ByteCount
+
+	maxDatagramSize congestion.ByteCount
+
+	lastState logging.CongestionState
+	tracer    logging.ConnectionTracer
+}
+
+var _ congestion.CongestionControl = &cubicSender{}
+
+// NewCubicSender makes a new cubic sender
+func NewCubicSender(
+	clock Clock,
+	initialMaxDatagramSize congestion.ByteCount,
+	reno bool,
+	tracer logging.ConnectionTracer,
+) *cubicSender {
+	return newCubicSender(
+		clock,
+		reno,
+		initialMaxDatagramSize,
+		initialCongestionWindow*initialMaxDatagramSize,
+		MaxCongestionWindowPackets*initialMaxDatagramSize,
+		tracer,
+	)
+}
+
+func newCubicSender(
+	clock Clock,
+	reno bool,
+	initialMaxDatagramSize,
+	initialCongestionWindow,
+	initialMaxCongestionWindow congestion.ByteCount,
+	tracer logging.ConnectionTracer,
+) *cubicSender {
+	c := &cubicSender{
+		largestSentPacketNumber:    InvalidPacketNumber,
+		largestAckedPacketNumber:   InvalidPacketNumber,
+		largestSentAtLastCutback:   InvalidPacketNumber,
+		initialCongestionWindow:    initialCongestionWindow,
+		initialMaxCongestionWindow: initialMaxCongestionWindow,
+		congestionWindow:           initialCongestionWindow,
+		slowStartThreshold:         MaxByteCount,
+		cubic:                      NewCubic(clock),
+		clock:                      clock,
+		reno:                       reno,
+		tracer:                     tracer,
+		maxDatagramSize:            initialMaxDatagramSize,
+	}
+	c.pacer = newPacer(c.BandwidthEstimate)
+	if c.tracer != nil {
+		c.lastState = logging.CongestionStateSlowStart
+		c.tracer.UpdatedCongestionState(logging.CongestionStateSlowStart)
+	}
+	return c
+}
+
+func (c *cubicSender) SetRTTStatsProvider(provider congestion.RTTStatsProvider) {
+	c.rttStats = provider
+}
+
+// TimeUntilSend returns when the next packet should be sent.
+func (c *cubicSender) TimeUntilSend(_ congestion.ByteCount) time.Time {
+	return c.pacer.TimeUntilSend()
+}
+
+func (c *cubicSender) HasPacingBudget(now time.Time) bool {
+	return c.pacer.Budget(now) >= c.maxDatagramSize
+}
+
+func (c *cubicSender) maxCongestionWindow() congestion.ByteCount {
+	return c.maxDatagramSize * MaxCongestionWindowPackets
+}
+
+func (c *cubicSender) minCongestionWindow() congestion.ByteCount {
+	return c.maxDatagramSize * minCongestionWindowPackets
+}
+
+func (c *cubicSender) OnPacketSent(
+	sentTime time.Time,
+	_ congestion.ByteCount,
+	packetNumber congestion.PacketNumber,
+	bytes congestion.ByteCount,
+	isRetransmittable bool,
+) {
+	c.pacer.SentPacket(sentTime, bytes)
+	if !isRetransmittable {
+		return
+	}
+	c.largestSentPacketNumber = packetNumber
+	c.hybridSlowStart.OnPacketSent(packetNumber)
+}
+
+func (c *cubicSender) CanSend(bytesInFlight congestion.ByteCount) bool {
+	return bytesInFlight < c.GetCongestionWindow()
+}
+
+func (c *cubicSender) InRecovery() bool {
+	return c.largestAckedPacketNumber != InvalidPacketNumber && c.largestAckedPacketNumber <= c.largestSentAtLastCutback
+}
+
+func (c *cubicSender) InSlowStart() bool {
+	return c.GetCongestionWindow() < c.slowStartThreshold
+}
+
+func (c *cubicSender) GetCongestionWindow() congestion.ByteCount {
+	return c.congestionWindow
+}
+
+func (c *cubicSender) MaybeExitSlowStart() {
+	if c.InSlowStart() &&
+		c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/c.maxDatagramSize) {
+		// exit slow start
+		c.slowStartThreshold = c.congestionWindow
+		c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance)
+	}
+}
+
+func (c *cubicSender) OnPacketAcked(
+	ackedPacketNumber congestion.PacketNumber,
+	ackedBytes congestion.ByteCount,
+	priorInFlight congestion.ByteCount,
+	eventTime time.Time,
+) {
+	c.largestAckedPacketNumber = Max(ackedPacketNumber, c.largestAckedPacketNumber)
+	if c.InRecovery() {
+		return
+	}
+	c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, priorInFlight, eventTime)
+	if c.InSlowStart() {
+		c.hybridSlowStart.OnPacketAcked(ackedPacketNumber)
+	}
+}
+
+func (c *cubicSender) OnPacketLost(packetNumber congestion.PacketNumber, lostBytes, priorInFlight congestion.ByteCount) {
+	// TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets
+	// already sent should be treated as a single loss event, since it's expected.
+	if packetNumber <= c.largestSentAtLastCutback {
+		return
+	}
+	c.lastCutbackExitedSlowstart = c.InSlowStart()
+	c.maybeTraceStateChange(logging.CongestionStateRecovery)
+
+	if c.reno {
+		c.congestionWindow = congestion.ByteCount(float64(c.congestionWindow) * renoBeta)
+	} else {
+		c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow)
+	}
+	if minCwnd := c.minCongestionWindow(); c.congestionWindow < minCwnd {
+		c.congestionWindow = minCwnd
+	}
+	c.slowStartThreshold = c.congestionWindow
+	c.largestSentAtLastCutback = c.largestSentPacketNumber
+	// reset packet count from congestion avoidance mode. We start
+	// counting again when we're out of recovery.
+	c.numAckedPackets = 0
+}
+
+// Called when we receive an ack. Normal TCP tracks how many packets one ack
+// represents, but quic has a separate ack for each packet.
+func (c *cubicSender) maybeIncreaseCwnd(
+	_ congestion.PacketNumber,
+	ackedBytes congestion.ByteCount,
+	priorInFlight congestion.ByteCount,
+	eventTime time.Time,
+) {
+	// Do not increase the congestion window unless the sender is close to using
+	// the current window.
+	if !c.isCwndLimited(priorInFlight) {
+		c.cubic.OnApplicationLimited()
+		c.maybeTraceStateChange(logging.CongestionStateApplicationLimited)
+		return
+	}
+	if c.congestionWindow >= c.maxCongestionWindow() {
+		return
+	}
+	if c.InSlowStart() {
+		// TCP slow start, exponential growth, increase by one for each ACK.
+		c.congestionWindow += c.maxDatagramSize
+		c.maybeTraceStateChange(logging.CongestionStateSlowStart)
+		return
+	}
+	// Congestion avoidance
+	c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance)
+	if c.reno {
+		// Classic Reno congestion avoidance.
+		c.numAckedPackets++
+		if c.numAckedPackets >= uint64(c.congestionWindow/c.maxDatagramSize) {
+			c.congestionWindow += c.maxDatagramSize
+			c.numAckedPackets = 0
+		}
+	} else {
+		c.congestionWindow = Min(c.maxCongestionWindow(), c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime))
+	}
+}
+
+func (c *cubicSender) isCwndLimited(bytesInFlight congestion.ByteCount) bool {
+	congestionWindow := c.GetCongestionWindow()
+	if bytesInFlight >= congestionWindow {
+		return true
+	}
+	availableBytes := congestionWindow - bytesInFlight
+	slowStartLimited := c.InSlowStart() && bytesInFlight > congestionWindow/2
+	return slowStartLimited || availableBytes <= maxBurstPackets*c.maxDatagramSize
+}
+
+// BandwidthEstimate returns the current bandwidth estimate
+func (c *cubicSender) BandwidthEstimate() Bandwidth {
+	if c.rttStats == nil {
+		return infBandwidth
+	}
+	srtt := c.rttStats.SmoothedRTT()
+	if srtt == 0 {
+		// If we haven't measured an rtt, the bandwidth estimate is unknown.
+		return infBandwidth
+	}
+	return BandwidthFromDelta(c.GetCongestionWindow(), srtt)
+}
+
+// OnRetransmissionTimeout is called on an retransmission timeout
+func (c *cubicSender) OnRetransmissionTimeout(packetsRetransmitted bool) {
+	c.largestSentAtLastCutback = InvalidPacketNumber
+	if !packetsRetransmitted {
+		return
+	}
+	c.hybridSlowStart.Restart()
+	c.cubic.Reset()
+	c.slowStartThreshold = c.congestionWindow / 2
+	c.congestionWindow = c.minCongestionWindow()
+}
+
+// OnConnectionMigration is called when the connection is migrated (?)
+func (c *cubicSender) OnConnectionMigration() {
+	c.hybridSlowStart.Restart()
+	c.largestSentPacketNumber = InvalidPacketNumber
+	c.largestAckedPacketNumber = InvalidPacketNumber
+	c.largestSentAtLastCutback = InvalidPacketNumber
+	c.lastCutbackExitedSlowstart = false
+	c.cubic.Reset()
+	c.numAckedPackets = 0
+	c.congestionWindow = c.initialCongestionWindow
+	c.slowStartThreshold = c.initialMaxCongestionWindow
+}
+
+func (c *cubicSender) maybeTraceStateChange(new logging.CongestionState) {
+	if c.tracer == nil || new == c.lastState {
+		return
+	}
+	c.tracer.UpdatedCongestionState(new)
+	c.lastState = new
+}
+
+func (c *cubicSender) SetMaxDatagramSize(s congestion.ByteCount) {
+	if s < c.maxDatagramSize {
+		panic(fmt.Sprintf("congestion BUG: decreased max datagram size from %d to %d", c.maxDatagramSize, s))
+	}
+	cwndIsMinCwnd := c.congestionWindow == c.minCongestionWindow()
+	c.maxDatagramSize = s
+	if cwndIsMinCwnd {
+		c.congestionWindow = c.minCongestionWindow()
+	}
+	c.pacer.SetMaxDatagramSize(s)
+}

+ 112 - 0
transport/tuic/congestion/hybrid_slow_start.go

@@ -0,0 +1,112 @@
+package congestion
+
+import (
+	"time"
+
+	"github.com/sagernet/quic-go/congestion"
+)
+
+// Note(pwestin): the magic clamping numbers come from the original code in
+// tcp_cubic.c.
+const hybridStartLowWindow = congestion.ByteCount(16)
+
+// Number of delay samples for detecting the increase of delay.
+const hybridStartMinSamples = uint32(8)
+
+// Exit slow start if the min rtt has increased by more than 1/8th.
+const hybridStartDelayFactorExp = 3 // 2^3 = 8
+// The original paper specifies 2 and 8ms, but those have changed over time.
+const (
+	hybridStartDelayMinThresholdUs = int64(4000)
+	hybridStartDelayMaxThresholdUs = int64(16000)
+)
+
+// HybridSlowStart implements the TCP hybrid slow start algorithm
+type HybridSlowStart struct {
+	endPacketNumber      congestion.PacketNumber
+	lastSentPacketNumber congestion.PacketNumber
+	started              bool
+	currentMinRTT        time.Duration
+	rttSampleCount       uint32
+	hystartFound         bool
+}
+
+// StartReceiveRound is called for the start of each receive round (burst) in the slow start phase.
+func (s *HybridSlowStart) StartReceiveRound(lastSent congestion.PacketNumber) {
+	s.endPacketNumber = lastSent
+	s.currentMinRTT = 0
+	s.rttSampleCount = 0
+	s.started = true
+}
+
+// IsEndOfRound returns true if this ack is the last packet number of our current slow start round.
+func (s *HybridSlowStart) IsEndOfRound(ack congestion.PacketNumber) bool {
+	return s.endPacketNumber < ack
+}
+
+// ShouldExitSlowStart should be called on every new ack frame, since a new
+// RTT measurement can be made then.
+// rtt: the RTT for this ack packet.
+// minRTT: is the lowest delay (RTT) we have seen during the session.
+// congestionWindow: the congestion window in packets.
+func (s *HybridSlowStart) ShouldExitSlowStart(latestRTT time.Duration, minRTT time.Duration, congestionWindow congestion.ByteCount) bool {
+	if !s.started {
+		// Time to start the hybrid slow start.
+		s.StartReceiveRound(s.lastSentPacketNumber)
+	}
+	if s.hystartFound {
+		return true
+	}
+	// Second detection parameter - delay increase detection.
+	// Compare the minimum delay (s.currentMinRTT) of the current
+	// burst of packets relative to the minimum delay during the session.
+	// Note: we only look at the first few(8) packets in each burst, since we
+	// only want to compare the lowest RTT of the burst relative to previous
+	// bursts.
+	s.rttSampleCount++
+	if s.rttSampleCount <= hybridStartMinSamples {
+		if s.currentMinRTT == 0 || s.currentMinRTT > latestRTT {
+			s.currentMinRTT = latestRTT
+		}
+	}
+	// We only need to check this once per round.
+	if s.rttSampleCount == hybridStartMinSamples {
+		// Divide minRTT by 8 to get a rtt increase threshold for exiting.
+		minRTTincreaseThresholdUs := int64(minRTT / time.Microsecond >> hybridStartDelayFactorExp)
+		// Ensure the rtt threshold is never less than 2ms or more than 16ms.
+		minRTTincreaseThresholdUs = Min(minRTTincreaseThresholdUs, hybridStartDelayMaxThresholdUs)
+		minRTTincreaseThreshold := time.Duration(Max(minRTTincreaseThresholdUs, hybridStartDelayMinThresholdUs)) * time.Microsecond
+
+		if s.currentMinRTT > (minRTT + minRTTincreaseThreshold) {
+			s.hystartFound = true
+		}
+	}
+	// Exit from slow start if the cwnd is greater than 16 and
+	// increasing delay is found.
+	return congestionWindow >= hybridStartLowWindow && s.hystartFound
+}
+
+// OnPacketSent is called when a packet was sent
+func (s *HybridSlowStart) OnPacketSent(packetNumber congestion.PacketNumber) {
+	s.lastSentPacketNumber = packetNumber
+}
+
+// OnPacketAcked gets invoked after ShouldExitSlowStart, so it's best to end
+// the round when the final packet of the burst is received and start it on
+// the next incoming ack.
+func (s *HybridSlowStart) OnPacketAcked(ackedPacketNumber congestion.PacketNumber) {
+	if s.IsEndOfRound(ackedPacketNumber) {
+		s.started = false
+	}
+}
+
+// Started returns true if started
+func (s *HybridSlowStart) Started() bool {
+	return s.started
+}
+
+// Restart the slow start phase
+func (s *HybridSlowStart) Restart() {
+	s.started = false
+	s.hystartFound = false
+}

+ 72 - 0
transport/tuic/congestion/minmax.go

@@ -0,0 +1,72 @@
+package congestion
+
+import (
+	"math"
+	"time"
+
+	"golang.org/x/exp/constraints"
+)
+
+// InfDuration is a duration of infinite length
+const InfDuration = time.Duration(math.MaxInt64)
+
+func Max[T constraints.Ordered](a, b T) T {
+	if a < b {
+		return b
+	}
+	return a
+}
+
+func Min[T constraints.Ordered](a, b T) T {
+	if a < b {
+		return a
+	}
+	return b
+}
+
+// MinNonZeroDuration return the minimum duration that's not zero.
+func MinNonZeroDuration(a, b time.Duration) time.Duration {
+	if a == 0 {
+		return b
+	}
+	if b == 0 {
+		return a
+	}
+	return Min(a, b)
+}
+
+// AbsDuration returns the absolute value of a time duration
+func AbsDuration(d time.Duration) time.Duration {
+	if d >= 0 {
+		return d
+	}
+	return -d
+}
+
+// MinTime returns the earlier time
+func MinTime(a, b time.Time) time.Time {
+	if a.After(b) {
+		return b
+	}
+	return a
+}
+
+// MinNonZeroTime returns the earlist time that is not time.Time{}
+// If both a and b are time.Time{}, it returns time.Time{}
+func MinNonZeroTime(a, b time.Time) time.Time {
+	if a.IsZero() {
+		return b
+	}
+	if b.IsZero() {
+		return a
+	}
+	return MinTime(a, b)
+}
+
+// MaxTime returns the later time
+func MaxTime(a, b time.Time) time.Time {
+	if a.After(b) {
+		return a
+	}
+	return b
+}

+ 81 - 0
transport/tuic/congestion/pacer.go

@@ -0,0 +1,81 @@
+package congestion
+
+import (
+	"math"
+	"time"
+
+	"github.com/sagernet/quic-go/congestion"
+)
+
+const (
+	initialMaxDatagramSize = congestion.ByteCount(1252)
+	MinPacingDelay         = time.Millisecond
+	TimerGranularity       = time.Millisecond
+	maxBurstSizePackets    = 10
+)
+
+// The pacer implements a token bucket pacing algorithm.
+type pacer struct {
+	budgetAtLastSent     congestion.ByteCount
+	maxDatagramSize      congestion.ByteCount
+	lastSentTime         time.Time
+	getAdjustedBandwidth func() uint64 // in bytes/s
+}
+
+func newPacer(getBandwidth func() Bandwidth) *pacer {
+	p := &pacer{
+		maxDatagramSize: initialMaxDatagramSize,
+		getAdjustedBandwidth: func() uint64 {
+			// Bandwidth is in bits/s. We need the value in bytes/s.
+			bw := uint64(getBandwidth() / BytesPerSecond)
+			// Use a slightly higher value than the actual measured bandwidth.
+			// RTT variations then won't result in under-utilization of the congestion window.
+			// Ultimately, this will  result in sending packets as acknowledgments are received rather than when timers fire,
+			// provided the congestion window is fully utilized and acknowledgments arrive at regular intervals.
+			return bw * 5 / 4
+		},
+	}
+	p.budgetAtLastSent = p.maxBurstSize()
+	return p
+}
+
+func (p *pacer) SentPacket(sendTime time.Time, size congestion.ByteCount) {
+	budget := p.Budget(sendTime)
+	if size > budget {
+		p.budgetAtLastSent = 0
+	} else {
+		p.budgetAtLastSent = budget - size
+	}
+	p.lastSentTime = sendTime
+}
+
+func (p *pacer) Budget(now time.Time) congestion.ByteCount {
+	if p.lastSentTime.IsZero() {
+		return p.maxBurstSize()
+	}
+	budget := p.budgetAtLastSent + (congestion.ByteCount(p.getAdjustedBandwidth())*congestion.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9
+	return Min(p.maxBurstSize(), budget)
+}
+
+func (p *pacer) maxBurstSize() congestion.ByteCount {
+	return Max(
+		congestion.ByteCount(uint64((MinPacingDelay+TimerGranularity).Nanoseconds())*p.getAdjustedBandwidth())/1e9,
+		maxBurstSizePackets*p.maxDatagramSize,
+	)
+}
+
+// TimeUntilSend returns when the next packet should be sent.
+// It returns the zero value of time.Time if a packet can be sent immediately.
+func (p *pacer) TimeUntilSend() time.Time {
+	if p.budgetAtLastSent >= p.maxDatagramSize {
+		return time.Time{}
+	}
+	return p.lastSentTime.Add(Max(
+		MinPacingDelay,
+		time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/float64(p.getAdjustedBandwidth())))*time.Nanosecond,
+	))
+}
+
+func (p *pacer) SetMaxDatagramSize(s congestion.ByteCount) {
+	p.maxDatagramSize = s
+}

+ 132 - 0
transport/tuic/congestion/windowed_filter.go

@@ -0,0 +1,132 @@
+package congestion
+
+// WindowedFilter Use the following to construct a windowed filter object of type T.
+// For example, a min filter using QuicTime as the time type:
+//
+//	WindowedFilter<T, MinFilter<T>, QuicTime, QuicTime::Delta> ObjectName;
+//
+// A max filter using 64-bit integers as the time type:
+//
+//	WindowedFilter<T, MaxFilter<T>, uint64_t, int64_t> ObjectName;
+//
+// Specifically, this template takes four arguments:
+//  1. T -- type of the measurement that is being filtered.
+//  2. Compare -- MinFilter<T> or MaxFilter<T>, depending on the type of filter
+//     desired.
+//  3. TimeT -- the type used to represent timestamps.
+//  4. TimeDeltaT -- the type used to represent continuous time intervals between
+//     two timestamps.  Has to be the type of (a - b) if both |a| and |b| are
+//     of type TimeT.
+type WindowedFilter struct {
+	// Time length of window.
+	windowLength int64
+	estimates    []Sample
+	comparator   func(int64, int64) bool
+}
+
+type Sample struct {
+	sample int64
+	time   int64
+}
+
+// Compares two values and returns true if the first is greater than or equal
+// to the second.
+func MaxFilter(a, b int64) bool {
+	return a >= b
+}
+
+// Compares two values and returns true if the first is less than or equal
+// to the second.
+func MinFilter(a, b int64) bool {
+	return a <= b
+}
+
+func NewWindowedFilter(windowLength int64, comparator func(int64, int64) bool) *WindowedFilter {
+	return &WindowedFilter{
+		windowLength: windowLength,
+		estimates:    make([]Sample, 3),
+		comparator:   comparator,
+	}
+}
+
+// Changes the window length.  Does not update any current samples.
+func (f *WindowedFilter) SetWindowLength(windowLength int64) {
+	f.windowLength = windowLength
+}
+
+func (f *WindowedFilter) GetBest() int64 {
+	return f.estimates[0].sample
+}
+
+func (f *WindowedFilter) GetSecondBest() int64 {
+	return f.estimates[1].sample
+}
+
+func (f *WindowedFilter) GetThirdBest() int64 {
+	return f.estimates[2].sample
+}
+
+func (f *WindowedFilter) Update(sample int64, time int64) {
+	if f.estimates[0].time == 0 || f.comparator(sample, f.estimates[0].sample) || (time-f.estimates[2].time) > f.windowLength {
+		f.Reset(sample, time)
+		return
+	}
+
+	if f.comparator(sample, f.estimates[1].sample) {
+		f.estimates[1].sample = sample
+		f.estimates[1].time = time
+		f.estimates[2].sample = sample
+		f.estimates[2].time = time
+	} else if f.comparator(sample, f.estimates[2].sample) {
+		f.estimates[2].sample = sample
+		f.estimates[2].time = time
+	}
+
+	// Expire and update estimates as necessary.
+	if time-f.estimates[0].time > f.windowLength {
+		// The best estimate hasn't been updated for an entire window, so promote
+		// second and third best estimates.
+		f.estimates[0].sample = f.estimates[1].sample
+		f.estimates[0].time = f.estimates[1].time
+		f.estimates[1].sample = f.estimates[2].sample
+		f.estimates[1].time = f.estimates[2].time
+		f.estimates[2].sample = sample
+		f.estimates[2].time = time
+		// Need to iterate one more time. Check if the new best estimate is
+		// outside the window as well, since it may also have been recorded a
+		// long time ago. Don't need to iterate once more since we cover that
+		// case at the beginning of the method.
+		if time-f.estimates[0].time > f.windowLength {
+			f.estimates[0].sample = f.estimates[1].sample
+			f.estimates[0].time = f.estimates[1].time
+			f.estimates[1].sample = f.estimates[2].sample
+			f.estimates[1].time = f.estimates[2].time
+		}
+		return
+	}
+	if f.estimates[1].sample == f.estimates[0].sample && time-f.estimates[1].time > f.windowLength>>2 {
+		// A quarter of the window has passed without a better sample, so the
+		// second-best estimate is taken from the second quarter of the window.
+		f.estimates[1].sample = sample
+		f.estimates[1].time = time
+		f.estimates[2].sample = sample
+		f.estimates[2].time = time
+		return
+	}
+
+	if f.estimates[2].sample == f.estimates[1].sample && time-f.estimates[2].time > f.windowLength>>1 {
+		// We've passed a half of the window without a better estimate, so take
+		// a third-best estimate from the second half of the window.
+		f.estimates[2].sample = sample
+		f.estimates[2].time = time
+	}
+}
+
+func (f *WindowedFilter) Reset(newSample int64, newTime int64) {
+	f.estimates[0].sample = newSample
+	f.estimates[0].time = newTime
+	f.estimates[1].sample = newSample
+	f.estimates[1].time = newTime
+	f.estimates[2].sample = newSample
+	f.estimates[2].time = newTime
+}

+ 497 - 0
transport/tuic/packet.go

@@ -0,0 +1,497 @@
+package tuic
+
+import (
+	"bytes"
+	"context"
+	"encoding/binary"
+	"errors"
+	"io"
+	"math"
+	"net"
+	"os"
+	"sync"
+	"time"
+
+	"github.com/sagernet/quic-go"
+	"github.com/sagernet/sing/common"
+	"github.com/sagernet/sing/common/atomic"
+	"github.com/sagernet/sing/common/buf"
+	"github.com/sagernet/sing/common/cache"
+	M "github.com/sagernet/sing/common/metadata"
+)
+
+var udpMessagePool = sync.Pool{
+	New: func() interface{} {
+		return new(udpMessage)
+	},
+}
+
+func releaseMessages(messages []*udpMessage) {
+	for _, message := range messages {
+		if message != nil {
+			*message = udpMessage{}
+			udpMessagePool.Put(message)
+		}
+	}
+}
+
+type udpMessage struct {
+	sessionID     uint16
+	packetID      uint16
+	fragmentTotal uint8
+	fragmentID    uint8
+	destination   M.Socksaddr
+	dataLength    uint16
+	data          *buf.Buffer
+}
+
+func (m *udpMessage) release() {
+	*m = udpMessage{}
+	udpMessagePool.Put(m)
+}
+
+func (m *udpMessage) releaseMessage() {
+	m.data.Release()
+	m.release()
+}
+
+func (m *udpMessage) pack() *buf.Buffer {
+	buffer := buf.NewSize(m.headerSize() + m.data.Len())
+	common.Must(
+		buffer.WriteByte(Version),
+		buffer.WriteByte(CommandPacket),
+		binary.Write(buffer, binary.BigEndian, m.sessionID),
+		binary.Write(buffer, binary.BigEndian, m.packetID),
+		binary.Write(buffer, binary.BigEndian, m.fragmentTotal),
+		binary.Write(buffer, binary.BigEndian, m.fragmentID),
+		binary.Write(buffer, binary.BigEndian, uint16(m.data.Len())),
+		addressSerializer.WriteAddrPort(buffer, m.destination),
+		common.Error(buffer.Write(m.data.Bytes())),
+	)
+	return buffer
+}
+
+func (m *udpMessage) headerSize() int {
+	return 2 + 10 + addressSerializer.AddrPortLen(m.destination)
+}
+
+func fragUDPMessage(message *udpMessage, maxPacketSize int) []*udpMessage {
+	if message.data.Len() <= maxPacketSize {
+		return []*udpMessage{message}
+	}
+	var fragments []*udpMessage
+	originPacket := message.data.Bytes()
+	udpMTU := maxPacketSize - message.headerSize()
+	for remaining := len(originPacket); remaining > 0; remaining -= udpMTU {
+		fragment := udpMessagePool.Get().(*udpMessage)
+		*fragment = *message
+		if remaining > udpMTU {
+			fragment.data = buf.As(originPacket[:udpMTU])
+			originPacket = originPacket[udpMTU:]
+		} else {
+			fragment.data = buf.As(originPacket)
+			originPacket = nil
+		}
+		fragments = append(fragments, fragment)
+	}
+	fragmentTotal := uint16(len(fragments))
+	for index, fragment := range fragments {
+		fragment.fragmentID = uint8(index)
+		fragment.fragmentTotal = uint8(fragmentTotal)
+		if index > 0 {
+			fragment.destination = M.Socksaddr{}
+		}
+	}
+	return fragments
+}
+
+type udpPacketConn struct {
+	ctx       context.Context
+	cancel    common.ContextCancelCauseFunc
+	sessionID uint16
+	quicConn  quic.Connection
+	data      chan *udpMessage
+	udpStream bool
+	udpMTU    int
+	packetId  atomic.Uint32
+	closeOnce sync.Once
+	isServer  bool
+	defragger *udpDefragger
+	onDestroy func()
+}
+
+func newUDPPacketConn(ctx context.Context, quicConn quic.Connection, udpStream bool, isServer bool, onDestroy func()) *udpPacketConn {
+	ctx, cancel := common.ContextWithCancelCause(ctx)
+	return &udpPacketConn{
+		ctx:       ctx,
+		cancel:    cancel,
+		quicConn:  quicConn,
+		data:      make(chan *udpMessage, 64),
+		udpStream: udpStream,
+		isServer:  isServer,
+		defragger: newUDPDefragger(),
+		onDestroy: onDestroy,
+	}
+}
+
+func (c *udpPacketConn) ReadPacketThreadSafe() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
+	select {
+	case p := <-c.data:
+		buffer = p.data
+		destination = p.destination
+		p.release()
+		return
+	case <-c.ctx.Done():
+		return nil, M.Socksaddr{}, io.ErrClosedPipe
+	}
+}
+
+func (c *udpPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
+	select {
+	case p := <-c.data:
+		_, err = buffer.ReadOnceFrom(p.data)
+		destination = p.destination
+		p.releaseMessage()
+		return
+	case <-c.ctx.Done():
+		return M.Socksaddr{}, io.ErrClosedPipe
+	}
+}
+
+func (c *udpPacketConn) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) {
+	select {
+	case p := <-c.data:
+		_, err = newBuffer().ReadOnceFrom(p.data)
+		destination = p.destination
+		p.releaseMessage()
+		return
+	case <-c.ctx.Done():
+		return M.Socksaddr{}, io.ErrClosedPipe
+	}
+}
+
+func (c *udpPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
+	select {
+	case pkt := <-c.data:
+		n = copy(p, pkt.data.Bytes())
+		if pkt.destination.IsFqdn() {
+			addr = pkt.destination
+		} else {
+			addr = pkt.destination.UDPAddr()
+		}
+		pkt.releaseMessage()
+		return n, addr, nil
+	case <-c.ctx.Done():
+		return 0, nil, io.ErrClosedPipe
+	}
+}
+
+func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
+	defer buffer.Release()
+	select {
+	case <-c.ctx.Done():
+		return net.ErrClosed
+	default:
+	}
+	if buffer.Len() > 0xffff {
+		return quic.ErrMessageTooLarge(0xffff)
+	}
+	packetId := c.packetId.Add(1)
+	if packetId > math.MaxUint16 {
+		c.packetId.Store(0)
+		packetId = 0
+	}
+	message := udpMessagePool.Get().(*udpMessage)
+	*message = udpMessage{
+		sessionID:     c.sessionID,
+		packetID:      uint16(packetId),
+		fragmentTotal: 1,
+		destination:   destination,
+		data:          buffer,
+	}
+	defer message.releaseMessage()
+	var err error
+	if !c.udpStream && c.udpMTU > 0 && buffer.Len() > c.udpMTU {
+		err = c.writePackets(fragUDPMessage(message, c.udpMTU))
+	} else {
+		err = c.writePacket(message)
+	}
+	if err == nil {
+		return nil
+	}
+	var tooLargeErr quic.ErrMessageTooLarge
+	if !errors.As(err, &tooLargeErr) {
+		return err
+	}
+	c.udpMTU = int(tooLargeErr)
+	return c.writePackets(fragUDPMessage(message, c.udpMTU))
+}
+
+func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
+	select {
+	case <-c.ctx.Done():
+		return 0, net.ErrClosed
+	default:
+	}
+	if len(p) > 0xffff {
+		return 0, quic.ErrMessageTooLarge(0xffff)
+	}
+	packetId := c.packetId.Add(1)
+	if packetId > math.MaxUint16 {
+		c.packetId.Store(0)
+		packetId = 0
+	}
+	message := udpMessagePool.Get().(*udpMessage)
+	*message = udpMessage{
+		sessionID:     c.sessionID,
+		packetID:      uint16(packetId),
+		fragmentTotal: 1,
+		destination:   M.SocksaddrFromNet(addr),
+		data:          buf.As(p),
+	}
+	if c.udpMTU > 0 && len(p) > c.udpMTU {
+		err = c.writePackets(fragUDPMessage(message, c.udpMTU))
+		if err == nil {
+			return len(p), nil
+		}
+	} else {
+		err = c.writePacket(message)
+	}
+	if err == nil {
+		return len(p), nil
+	}
+	var tooLargeErr quic.ErrMessageTooLarge
+	if !errors.As(err, &tooLargeErr) {
+		return
+	}
+	c.udpMTU = int(tooLargeErr)
+	err = c.writePackets(fragUDPMessage(message, c.udpMTU))
+	if err == nil {
+		return len(p), nil
+	}
+	return
+}
+
+func (c *udpPacketConn) inputPacket(message *udpMessage) {
+	if message.fragmentTotal <= 1 {
+		select {
+		case c.data <- message:
+		default:
+		}
+	} else {
+		newMessage := c.defragger.feed(message)
+		if newMessage != nil {
+			select {
+			case c.data <- newMessage:
+			default:
+			}
+		}
+	}
+}
+
+func (c *udpPacketConn) writePackets(messages []*udpMessage) error {
+	defer releaseMessages(messages)
+	for _, message := range messages {
+		err := c.writePacket(message)
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+func (c *udpPacketConn) writePacket(message *udpMessage) error {
+	if !c.udpStream {
+		buffer := message.pack()
+		err := c.quicConn.SendMessage(buffer.Bytes())
+		buffer.Release()
+		if err != nil {
+			return err
+		}
+	} else {
+		stream, err := c.quicConn.OpenUniStream()
+		if err != nil {
+			return err
+		}
+		buffer := message.pack()
+		_, err = stream.Write(buffer.Bytes())
+		buffer.Release()
+		stream.Close()
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+func (c *udpPacketConn) Close() error {
+	c.closeOnce.Do(func() {
+		c.closeWithError(os.ErrClosed)
+		c.onDestroy()
+	})
+	return nil
+}
+
+func (c *udpPacketConn) closeWithError(err error) {
+	c.cancel(err)
+	if !c.isServer {
+		buffer := buf.NewSize(4)
+		defer buffer.Release()
+		buffer.WriteByte(Version)
+		buffer.WriteByte(CommandDissociate)
+		binary.Write(buffer, binary.BigEndian, c.sessionID)
+		sendStream, openErr := c.quicConn.OpenUniStream()
+		if openErr != nil {
+			return
+		}
+		defer sendStream.Close()
+		sendStream.Write(buffer.Bytes())
+	}
+}
+
+func (c *udpPacketConn) LocalAddr() net.Addr {
+	return c.quicConn.LocalAddr()
+}
+
+func (c *udpPacketConn) SetDeadline(t time.Time) error {
+	return os.ErrInvalid
+}
+
+func (c *udpPacketConn) SetReadDeadline(t time.Time) error {
+	return os.ErrInvalid
+}
+
+func (c *udpPacketConn) SetWriteDeadline(t time.Time) error {
+	return os.ErrInvalid
+}
+
+type udpDefragger struct {
+	packetMap *cache.LruCache[uint16, *packetItem]
+}
+
+func newUDPDefragger() *udpDefragger {
+	return &udpDefragger{
+		packetMap: cache.New(
+			cache.WithAge[uint16, *packetItem](10),
+			cache.WithUpdateAgeOnGet[uint16, *packetItem](),
+			cache.WithEvict[uint16, *packetItem](func(key uint16, value *packetItem) {
+				releaseMessages(value.messages)
+			}),
+		),
+	}
+}
+
+type packetItem struct {
+	access   sync.Mutex
+	messages []*udpMessage
+	count    uint8
+}
+
+func (d *udpDefragger) feed(m *udpMessage) *udpMessage {
+	if m.fragmentTotal <= 1 {
+		return m
+	}
+	if m.fragmentID >= m.fragmentTotal {
+		return nil
+	}
+	item, _ := d.packetMap.LoadOrStore(m.packetID, newPacketItem)
+	item.access.Lock()
+	defer item.access.Unlock()
+	if int(m.fragmentTotal) != len(item.messages) {
+		releaseMessages(item.messages)
+		item.messages = make([]*udpMessage, m.fragmentTotal)
+		item.count = 1
+		item.messages[m.fragmentID] = m
+		return nil
+	}
+	if item.messages[m.fragmentID] != nil {
+		return nil
+	}
+	item.messages[m.fragmentID] = m
+	item.count++
+	if int(item.count) != len(item.messages) {
+		return nil
+	}
+	newMessage := udpMessagePool.Get().(*udpMessage)
+	*newMessage = *item.messages[0]
+	if m.dataLength > 0 {
+		newMessage.data = buf.NewSize(int(m.dataLength))
+		for _, message := range item.messages {
+			newMessage.data.Write(message.data.Bytes())
+			message.releaseMessage()
+		}
+		item.messages = nil
+		return newMessage
+	}
+	return nil
+}
+
+func newPacketItem() *packetItem {
+	return new(packetItem)
+}
+
+func readUDPMessage(message *udpMessage, reader io.Reader) error {
+	err := binary.Read(reader, binary.BigEndian, &message.sessionID)
+	if err != nil {
+		return err
+	}
+	err = binary.Read(reader, binary.BigEndian, &message.packetID)
+	if err != nil {
+		return err
+	}
+	err = binary.Read(reader, binary.BigEndian, &message.fragmentTotal)
+	if err != nil {
+		return err
+	}
+	err = binary.Read(reader, binary.BigEndian, &message.fragmentID)
+	if err != nil {
+		return err
+	}
+	err = binary.Read(reader, binary.BigEndian, &message.dataLength)
+	if err != nil {
+		return err
+	}
+	message.destination, err = addressSerializer.ReadAddrPort(reader)
+	if err != nil {
+		return err
+	}
+	message.data = buf.NewSize(int(message.dataLength))
+	_, err = message.data.ReadFullFrom(reader, message.data.FreeLen())
+	if err != nil {
+		return err
+	}
+	return nil
+}
+
+func decodeUDPMessage(message *udpMessage, data []byte) error {
+	reader := bytes.NewReader(data)
+	err := binary.Read(reader, binary.BigEndian, &message.sessionID)
+	if err != nil {
+		return err
+	}
+	err = binary.Read(reader, binary.BigEndian, &message.packetID)
+	if err != nil {
+		return err
+	}
+	err = binary.Read(reader, binary.BigEndian, &message.fragmentTotal)
+	if err != nil {
+		return err
+	}
+	err = binary.Read(reader, binary.BigEndian, &message.fragmentID)
+	if err != nil {
+		return err
+	}
+	err = binary.Read(reader, binary.BigEndian, &message.dataLength)
+	if err != nil {
+		return err
+	}
+	message.destination, err = addressSerializer.ReadAddrPort(reader)
+	if err != nil {
+		return err
+	}
+	if reader.Len() != int(message.dataLength) {
+		return io.ErrUnexpectedEOF
+	}
+	message.data = buf.As(data[len(data)-reader.Len():])
+	return nil
+}

+ 15 - 0
transport/tuic/protocol.go

@@ -0,0 +1,15 @@
+package tuic
+
+const (
+	Version = 5
+)
+
+const (
+	CommandAuthenticate = iota
+	CommandConnect
+	CommandPacket
+	CommandDissociate
+	CommandHeartbeat
+)
+
+const AuthenticateLen = 2 + 16 + 32

+ 434 - 0
transport/tuic/server.go

@@ -0,0 +1,434 @@
+package tuic
+
+import (
+	"bytes"
+	"context"
+	"crypto/tls"
+	"encoding/binary"
+	"io"
+	"net"
+	"runtime"
+	"strings"
+	"sync"
+	"time"
+
+	"github.com/sagernet/quic-go"
+	"github.com/sagernet/sing-box/common/baderror"
+	"github.com/sagernet/sing/common"
+	"github.com/sagernet/sing/common/auth"
+	"github.com/sagernet/sing/common/buf"
+	"github.com/sagernet/sing/common/bufio"
+	E "github.com/sagernet/sing/common/exceptions"
+	"github.com/sagernet/sing/common/logger"
+	M "github.com/sagernet/sing/common/metadata"
+	N "github.com/sagernet/sing/common/network"
+
+	"github.com/gofrs/uuid/v5"
+)
+
+type ServerOptions struct {
+	Context           context.Context
+	Logger            logger.Logger
+	TLSConfig         *tls.Config
+	Users             []User
+	CongestionControl string
+	AuthTimeout       time.Duration
+	ZeroRTTHandshake  bool
+	Heartbeat         time.Duration
+	Handler           ServerHandler
+}
+
+type User struct {
+	Name     string
+	UUID     uuid.UUID
+	Password string
+}
+
+type ServerHandler interface {
+	N.TCPConnectionHandler
+	N.UDPConnectionHandler
+}
+
+type Server struct {
+	ctx               context.Context
+	logger            logger.Logger
+	tlsConfig         *tls.Config
+	heartbeat         time.Duration
+	quicConfig        *quic.Config
+	userMap           map[uuid.UUID]User
+	congestionControl string
+	authTimeout       time.Duration
+	handler           ServerHandler
+
+	quicListener io.Closer
+}
+
+func NewServer(options ServerOptions) (*Server, error) {
+	if options.AuthTimeout == 0 {
+		options.AuthTimeout = 3 * time.Second
+	}
+	if options.Heartbeat == 0 {
+		options.Heartbeat = 10 * time.Second
+	}
+	quicConfig := &quic.Config{
+		DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"),
+		MaxDatagramFrameSize:    1400,
+		EnableDatagrams:         true,
+		Allow0RTT:               options.ZeroRTTHandshake,
+		MaxIncomingStreams:      1 << 60,
+		MaxIncomingUniStreams:   1 << 60,
+	}
+	switch options.CongestionControl {
+	case "":
+		options.CongestionControl = "cubic"
+	case "cubic", "new_reno", "bbr":
+	default:
+		return nil, E.New("unknown congestion control algorithm: ", options.CongestionControl)
+	}
+	if len(options.Users) == 0 {
+		return nil, E.New("missing users")
+	}
+	userMap := make(map[uuid.UUID]User)
+	for _, user := range options.Users {
+		userMap[user.UUID] = user
+	}
+	return &Server{
+		ctx:               options.Context,
+		logger:            options.Logger,
+		tlsConfig:         options.TLSConfig,
+		heartbeat:         options.Heartbeat,
+		quicConfig:        quicConfig,
+		userMap:           userMap,
+		congestionControl: options.CongestionControl,
+		authTimeout:       options.AuthTimeout,
+		handler:           options.Handler,
+	}, nil
+}
+
+func (s *Server) Start(conn net.PacketConn) error {
+	if !s.quicConfig.Allow0RTT {
+		listener, err := quic.Listen(conn, s.tlsConfig, s.quicConfig)
+		if err != nil {
+			return err
+		}
+		s.quicListener = listener
+		go func() {
+			for {
+				connection, hErr := listener.Accept(s.ctx)
+				if hErr != nil {
+					if strings.Contains(hErr.Error(), "server closed") {
+						s.logger.Debug(E.Cause(hErr, "listener closed"))
+					} else {
+						s.logger.Error(E.Cause(hErr, "listener closed"))
+					}
+					return
+				}
+				go s.handleConnection(connection)
+			}
+		}()
+	} else {
+		listener, err := quic.ListenEarly(conn, s.tlsConfig, s.quicConfig)
+		if err != nil {
+			return err
+		}
+		s.quicListener = listener
+		go func() {
+			for {
+				connection, hErr := listener.Accept(s.ctx)
+				if hErr != nil {
+					if strings.Contains(hErr.Error(), "server closed") {
+						s.logger.Debug(E.Cause(hErr, "listener closed"))
+					} else {
+						s.logger.Error(E.Cause(hErr, "listener closed"))
+					}
+					return
+				}
+				go s.handleConnection(connection)
+			}
+		}()
+	}
+	return nil
+}
+
+func (s *Server) Close() error {
+	return common.Close(
+		s.quicListener,
+	)
+}
+
+func (s *Server) handleConnection(connection quic.Connection) {
+	setCongestion(s.ctx, connection, s.congestionControl)
+	session := &serverSession{
+		Server:     s,
+		ctx:        s.ctx,
+		quicConn:   connection,
+		source:     M.SocksaddrFromNet(connection.RemoteAddr()),
+		connDone:   make(chan struct{}),
+		authDone:   make(chan struct{}),
+		udpConnMap: make(map[uint16]*udpPacketConn),
+	}
+	session.handle()
+}
+
+type serverSession struct {
+	*Server
+	ctx        context.Context
+	quicConn   quic.Connection
+	source     M.Socksaddr
+	connAccess sync.Mutex
+	connDone   chan struct{}
+	connErr    error
+	authDone   chan struct{}
+	authUser   *User
+	udpAccess  sync.RWMutex
+	udpConnMap map[uint16]*udpPacketConn
+}
+
+func (s *serverSession) handle() {
+	if s.ctx.Done() != nil {
+		go func() {
+			select {
+			case <-s.ctx.Done():
+				s.closeWithError(s.ctx.Err())
+			case <-s.connDone:
+			}
+		}()
+	}
+	go s.loopUniStreams()
+	go s.loopStreams()
+	go s.loopMessages()
+	go s.handleAuthTimeout()
+	go s.loopHeartbeats()
+}
+
+func (s *serverSession) loopUniStreams() {
+	for {
+		uniStream, err := s.quicConn.AcceptUniStream(s.ctx)
+		if err != nil {
+			return
+		}
+		go func() {
+			err = s.handleUniStream(uniStream)
+			if err != nil {
+				s.closeWithError(E.Cause(err, "handle uni stream"))
+			}
+		}()
+	}
+}
+
+func (s *serverSession) handleUniStream(stream quic.ReceiveStream) error {
+	defer stream.CancelRead(0)
+	buffer := buf.New()
+	defer buffer.Release()
+	_, err := buffer.ReadAtLeastFrom(stream, 2)
+	if err != nil {
+		return E.Cause(err, "read request")
+	}
+	version := buffer.Byte(0)
+	if version != Version {
+		return E.New("unknown version ", buffer.Byte(0))
+	}
+	command := buffer.Byte(1)
+	switch command {
+	case CommandAuthenticate:
+		select {
+		case <-s.authDone:
+			return E.New("authentication: multiple authentication requests")
+		default:
+		}
+		if buffer.Len() < AuthenticateLen {
+			_, err = buffer.ReadFullFrom(stream, AuthenticateLen-buffer.Len())
+			if err != nil {
+				return E.Cause(err, "authentication: read request")
+			}
+		}
+		userUUID := uuid.FromBytesOrNil(buffer.Range(2, 2+16))
+		user, loaded := s.userMap[userUUID]
+		if !loaded {
+			return E.New("authentication: unknown user ", userUUID)
+		}
+		handshakeState := s.quicConn.ConnectionState().TLS
+		tuicToken, err := handshakeState.ExportKeyingMaterial(string(user.UUID[:]), []byte(user.Password), 32)
+		if err != nil {
+			return E.Cause(err, "authentication: export keying material")
+		}
+		if !bytes.Equal(tuicToken, buffer.Range(2+16, 2+16+32)) {
+			return E.New("authentication: token mismatch")
+		}
+		s.authUser = &user
+		close(s.authDone)
+		return nil
+	case CommandPacket:
+		select {
+		case <-s.connDone:
+			return s.connErr
+		case <-s.authDone:
+		}
+		message := udpMessagePool.Get().(*udpMessage)
+		err = readUDPMessage(message, io.MultiReader(bytes.NewReader(buffer.From(2)), stream))
+		if err != nil {
+			message.release()
+			return err
+		}
+		s.handleUDPMessage(message, true)
+		return nil
+	case CommandDissociate:
+		select {
+		case <-s.connDone:
+			return s.connErr
+		case <-s.authDone:
+		}
+		if buffer.Len() > 4 {
+			return E.New("invalid dissociate message")
+		}
+		var sessionID uint16
+		err = binary.Read(io.MultiReader(bytes.NewReader(buffer.From(2)), stream), binary.BigEndian, &sessionID)
+		if err != nil {
+			return err
+		}
+		s.udpAccess.RLock()
+		udpConn, loaded := s.udpConnMap[sessionID]
+		s.udpAccess.RUnlock()
+		if loaded {
+			udpConn.closeWithError(E.New("remote closed"))
+			s.udpAccess.Lock()
+			delete(s.udpConnMap, sessionID)
+			s.udpAccess.Unlock()
+		}
+		return nil
+	default:
+		return E.New("unknown command ", command)
+	}
+}
+
+func (s *serverSession) handleAuthTimeout() {
+	select {
+	case <-s.connDone:
+	case <-s.authDone:
+	case <-time.After(s.authTimeout):
+		s.closeWithError(E.New("authentication timeout"))
+	}
+}
+
+func (s *serverSession) loopStreams() {
+	for {
+		stream, err := s.quicConn.AcceptStream(s.ctx)
+		if err != nil {
+			return
+		}
+		go func() {
+			err = s.handleStream(stream)
+			if err != nil {
+				stream.CancelRead(0)
+				stream.Close()
+				s.logger.Error(E.Cause(err, "handle stream request"))
+			}
+		}()
+	}
+}
+
+func (s *serverSession) handleStream(stream quic.Stream) error {
+	buffer := buf.NewSize(2 + M.MaxSocksaddrLength)
+	defer buffer.Release()
+	_, err := buffer.ReadAtLeastFrom(stream, 2)
+	if err != nil {
+		return E.Cause(err, "read request")
+	}
+	version, _ := buffer.ReadByte()
+	if version != Version {
+		return E.New("unknown version ", buffer.Byte(0))
+	}
+	command, _ := buffer.ReadByte()
+	if command != CommandConnect {
+		return E.New("unsupported stream command ", command)
+	}
+	destination, err := addressSerializer.ReadAddrPort(io.MultiReader(buffer, stream))
+	if err != nil {
+		return E.Cause(err, "read request destination")
+	}
+	select {
+	case <-s.connDone:
+		return s.connErr
+	case <-s.authDone:
+	}
+	var conn net.Conn = &serverConn{
+		Stream:      stream,
+		destination: destination,
+	}
+	if buffer.IsEmpty() {
+		buffer.Release()
+	} else {
+		conn = bufio.NewCachedConn(conn, buffer)
+	}
+	ctx := s.ctx
+	if s.authUser.Name != "" {
+		ctx = auth.ContextWithUser(s.ctx, s.authUser.Name)
+	}
+	_ = s.handler.NewConnection(ctx, conn, M.Metadata{
+		Source:      s.source,
+		Destination: destination,
+	})
+	return nil
+}
+
+func (s *serverSession) loopHeartbeats() {
+	ticker := time.NewTicker(s.heartbeat)
+	defer ticker.Stop()
+	for {
+		select {
+		case <-s.connDone:
+			return
+		case <-ticker.C:
+			err := s.quicConn.SendMessage([]byte{Version, CommandHeartbeat})
+			if err != nil {
+				s.closeWithError(E.Cause(err, "send heartbeat"))
+			}
+		}
+	}
+}
+
+func (s *serverSession) closeWithError(err error) {
+	s.connAccess.Lock()
+	defer s.connAccess.Unlock()
+	select {
+	case <-s.connDone:
+		return
+	default:
+		s.connErr = err
+		close(s.connDone)
+	}
+	if E.IsClosedOrCanceled(err) {
+		s.logger.Debug(E.Cause(err, "connection failed"))
+	} else {
+		s.logger.Error(E.Cause(err, "connection failed"))
+	}
+	_ = s.quicConn.CloseWithError(0, "")
+}
+
+type serverConn struct {
+	quic.Stream
+	destination M.Socksaddr
+}
+
+func (c *serverConn) Read(p []byte) (n int, err error) {
+	n, err = c.Stream.Read(p)
+	return n, baderror.WrapQUIC(err)
+}
+
+func (c *serverConn) Write(p []byte) (n int, err error) {
+	n, err = c.Stream.Write(p)
+	return n, baderror.WrapQUIC(err)
+}
+
+func (c *serverConn) LocalAddr() net.Addr {
+	return c.destination
+}
+
+func (c *serverConn) RemoteAddr() net.Addr {
+	return M.Socksaddr{}
+}
+
+func (c *serverConn) Close() error {
+	c.Stream.CancelRead(0)
+	return c.Stream.Close()
+}

+ 73 - 0
transport/tuic/server_packet.go

@@ -0,0 +1,73 @@
+package tuic
+
+import (
+	"github.com/sagernet/sing/common"
+	E "github.com/sagernet/sing/common/exceptions"
+	M "github.com/sagernet/sing/common/metadata"
+)
+
+func (s *serverSession) loopMessages() {
+	select {
+	case <-s.connDone:
+		return
+	case <-s.authDone:
+	}
+	for {
+		message, err := s.quicConn.ReceiveMessage(s.ctx)
+		if err != nil {
+			s.closeWithError(E.Cause(err, "receive message"))
+			return
+		}
+		hErr := s.handleMessage(message)
+		if hErr != nil {
+			s.closeWithError(E.Cause(hErr, "handle message"))
+			return
+		}
+	}
+}
+
+func (s *serverSession) handleMessage(data []byte) error {
+	if len(data) < 2 {
+		return E.New("invalid message")
+	}
+	if data[0] != Version {
+		return E.New("unknown version ", data[0])
+	}
+	switch data[1] {
+	case CommandPacket:
+		message := udpMessagePool.Get().(*udpMessage)
+		err := decodeUDPMessage(message, data[2:])
+		if err != nil {
+			message.release()
+			return E.Cause(err, "decode UDP message")
+		}
+		s.handleUDPMessage(message, false)
+		return nil
+	case CommandHeartbeat:
+		return nil
+	default:
+		return E.New("unknown command ", data[0])
+	}
+}
+
+func (s *serverSession) handleUDPMessage(message *udpMessage, udpStream bool) {
+	s.udpAccess.RLock()
+	udpConn, loaded := s.udpConnMap[message.sessionID]
+	s.udpAccess.RUnlock()
+	if !loaded || common.Done(udpConn.ctx) {
+		udpConn = newUDPPacketConn(s.ctx, s.quicConn, udpStream, true, func() {
+			s.udpAccess.Lock()
+			delete(s.udpConnMap, message.sessionID)
+			s.udpAccess.Unlock()
+		})
+		udpConn.sessionID = message.sessionID
+		s.udpAccess.Lock()
+		s.udpConnMap[message.sessionID] = udpConn
+		s.udpAccess.Unlock()
+		go s.handler.NewPacketConnection(udpConn.ctx, udpConn, M.Metadata{
+			Source:      s.source,
+			Destination: message.destination,
+		})
+	}
+	udpConn.inputPacket(message)
+}