浏览代码

Fix atomic pointer usages

世界 2 月之前
父节点
当前提交
fdc181106d
共有 7 个文件被更改,包括 99 次插入77 次删除
  1. 45 33
      common/dialer/tfo.go
  2. 1 1
      experimental/v2rayapi/stats.go
  3. 2 2
      go.mod
  4. 4 4
      go.sum
  5. 8 9
      transport/v2raygrpc/client.go
  6. 8 7
      transport/v2rayquic/client.go
  7. 31 21
      transport/v2raywebsocket/conn.go

+ 45 - 33
common/dialer/tfo.go

@@ -10,6 +10,8 @@ import (
 	"sync"
 	"time"
 
+	"github.com/sagernet/sing/common"
+	"github.com/sagernet/sing/common/atomic"
 	"github.com/sagernet/sing/common/bufio"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
@@ -22,7 +24,7 @@ type slowOpenConn struct {
 	ctx         context.Context
 	network     string
 	destination M.Socksaddr
-	conn        net.Conn
+	conn        atomic.Pointer[net.TCPConn]
 	create      chan struct{}
 	done        chan struct{}
 	access      sync.Mutex
@@ -50,22 +52,25 @@ func DialSlowContext(dialer *tcpDialer, ctx context.Context, network string, des
 }
 
 func (c *slowOpenConn) Read(b []byte) (n int, err error) {
-	if c.conn == nil {
-		select {
-		case <-c.create:
-			if c.err != nil {
-				return 0, c.err
-			}
-		case <-c.done:
-			return 0, os.ErrClosed
+	conn := c.conn.Load()
+	if conn != nil {
+		return conn.Read(b)
+	}
+	select {
+	case <-c.create:
+		if c.err != nil {
+			return 0, c.err
 		}
+		return c.conn.Load().Read(b)
+	case <-c.done:
+		return 0, os.ErrClosed
 	}
-	return c.conn.Read(b)
 }
 
 func (c *slowOpenConn) Write(b []byte) (n int, err error) {
-	if c.conn != nil {
-		return c.conn.Write(b)
+	tcpConn := c.conn.Load()
+	if tcpConn != nil {
+		return tcpConn.Write(b)
 	}
 	c.access.Lock()
 	defer c.access.Unlock()
@@ -74,7 +79,7 @@ func (c *slowOpenConn) Write(b []byte) (n int, err error) {
 		if c.err != nil {
 			return 0, c.err
 		}
-		return c.conn.Write(b)
+		return c.conn.Load().Write(b)
 	case <-c.done:
 		return 0, os.ErrClosed
 	default:
@@ -83,7 +88,7 @@ func (c *slowOpenConn) Write(b []byte) (n int, err error) {
 	if err != nil {
 		c.err = err
 	} else {
-		c.conn = conn
+		c.conn.Store(conn.(*net.TCPConn))
 	}
 	n = len(b)
 	close(c.create)
@@ -93,70 +98,77 @@ func (c *slowOpenConn) Write(b []byte) (n int, err error) {
 func (c *slowOpenConn) Close() error {
 	c.closeOnce.Do(func() {
 		close(c.done)
-		if c.conn != nil {
-			c.conn.Close()
+		conn := c.conn.Load()
+		if conn != nil {
+			conn.Close()
 		}
 	})
 	return nil
 }
 
 func (c *slowOpenConn) LocalAddr() net.Addr {
-	if c.conn == nil {
+	conn := c.conn.Load()
+	if conn == nil {
 		return M.Socksaddr{}
 	}
-	return c.conn.LocalAddr()
+	return conn.LocalAddr()
 }
 
 func (c *slowOpenConn) RemoteAddr() net.Addr {
-	if c.conn == nil {
+	conn := c.conn.Load()
+	if conn == nil {
 		return M.Socksaddr{}
 	}
-	return c.conn.RemoteAddr()
+	return conn.RemoteAddr()
 }
 
 func (c *slowOpenConn) SetDeadline(t time.Time) error {
-	if c.conn == nil {
+	conn := c.conn.Load()
+	if conn == nil {
 		return os.ErrInvalid
 	}
-	return c.conn.SetDeadline(t)
+	return conn.SetDeadline(t)
 }
 
 func (c *slowOpenConn) SetReadDeadline(t time.Time) error {
-	if c.conn == nil {
+	conn := c.conn.Load()
+	if conn == nil {
 		return os.ErrInvalid
 	}
-	return c.conn.SetReadDeadline(t)
+	return conn.SetReadDeadline(t)
 }
 
 func (c *slowOpenConn) SetWriteDeadline(t time.Time) error {
-	if c.conn == nil {
+	conn := c.conn.Load()
+	if conn == nil {
 		return os.ErrInvalid
 	}
-	return c.conn.SetWriteDeadline(t)
+	return conn.SetWriteDeadline(t)
 }
 
 func (c *slowOpenConn) Upstream() any {
-	return c.conn
+	return common.PtrOrNil(c.conn.Load())
 }
 
 func (c *slowOpenConn) ReaderReplaceable() bool {
-	return c.conn != nil
+	return c.conn.Load() != nil
 }
 
 func (c *slowOpenConn) WriterReplaceable() bool {
-	return c.conn != nil
+	return c.conn.Load() != nil
 }
 
 func (c *slowOpenConn) LazyHeadroom() bool {
-	return c.conn == nil
+	return c.conn.Load() == nil
 }
 
 func (c *slowOpenConn) NeedHandshake() bool {
-	return c.conn == nil
+	return c.conn.Load() == nil
 }
 
 func (c *slowOpenConn) WriteTo(w io.Writer) (n int64, err error) {
-	if c.conn == nil {
+	conn := c.conn.Load()
+	if conn == nil {
 		select {
 		case <-c.create:
 			if c.err != nil {
@@ -166,5 +178,5 @@ func (c *slowOpenConn) WriteTo(w io.Writer) (n int64, err error) {
 			return 0, c.err
 		}
 	}
-	return bufio.Copy(w, c.conn)
+	return bufio.Copy(w, c.conn.Load())
 }

+ 1 - 1
experimental/v2rayapi/stats.go

@@ -115,7 +115,7 @@ func (s *StatsService) RoutedPacketConnection(ctx context.Context, conn N.Packet
 		writeCounter = append(writeCounter, s.loadOrCreateCounter("user>>>"+user+">>>traffic>>>downlink"))
 	}
 	s.access.Unlock()
-	return bufio.NewInt64CounterPacketConn(conn, readCounter, writeCounter)
+	return bufio.NewInt64CounterPacketConn(conn, readCounter, nil, writeCounter, nil)
 }
 
 func (s *StatsService) GetStats(ctx context.Context, request *GetStatsRequest) (*GetStatsResponse, error) {

+ 2 - 2
go.mod

@@ -27,9 +27,9 @@ require (
 	github.com/sagernet/gomobile v0.1.8
 	github.com/sagernet/gvisor v0.0.0-20250325023245-7a9c0f5725fb
 	github.com/sagernet/quic-go v0.52.0-beta.1
-	github.com/sagernet/sing v0.7.5
+	github.com/sagernet/sing v0.7.6-0.20250815070458-d33ece7a184f
 	github.com/sagernet/sing-mux v0.3.3
-	github.com/sagernet/sing-quic v0.5.0-beta.3
+	github.com/sagernet/sing-quic v0.5.0
 	github.com/sagernet/sing-shadowsocks v0.2.8
 	github.com/sagernet/sing-shadowsocks2 v0.2.1
 	github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11

+ 4 - 4
go.sum

@@ -167,12 +167,12 @@ github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/l
 github.com/sagernet/quic-go v0.52.0-beta.1 h1:hWkojLg64zjV+MJOvJU/kOeWndm3tiEfBLx5foisszs=
 github.com/sagernet/quic-go v0.52.0-beta.1/go.mod h1:OV+V5kEBb8kJS7k29MzDu6oj9GyMc7HA07sE1tedxz4=
 github.com/sagernet/sing v0.6.9/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
-github.com/sagernet/sing v0.7.5 h1:gNMwZCLPqR+4e0g6dwi0sSsrvOmoMjpZgqxKsuJZatc=
-github.com/sagernet/sing v0.7.5/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
+github.com/sagernet/sing v0.7.6-0.20250815070458-d33ece7a184f h1:HIBo8l+tsS3wLwuI1E56uRTQw46QytXSUpZTP3vwG/U=
+github.com/sagernet/sing v0.7.6-0.20250815070458-d33ece7a184f/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
 github.com/sagernet/sing-mux v0.3.3 h1:YFgt9plMWzH994BMZLmyKL37PdIVaIilwP0Jg+EcLfw=
 github.com/sagernet/sing-mux v0.3.3/go.mod h1:pht8iFY4c9Xltj7rhVd208npkNaeCxzyXCgulDPLUDA=
-github.com/sagernet/sing-quic v0.5.0-beta.3 h1:X/acRNsqQNfDlmwE7SorHfaZiny5e67hqIzM/592ric=
-github.com/sagernet/sing-quic v0.5.0-beta.3/go.mod h1:SAv/qdeDN+75msGG5U5ZIwG+3Ua50jVIKNrRSY8pkx0=
+github.com/sagernet/sing-quic v0.5.0 h1:jNLIyVk24lFPvu8A4x+ZNEnZdI+Tg1rp7eCJ6v0Csak=
+github.com/sagernet/sing-quic v0.5.0/go.mod h1:SAv/qdeDN+75msGG5U5ZIwG+3Ua50jVIKNrRSY8pkx0=
 github.com/sagernet/sing-shadowsocks v0.2.8 h1:PURj5PRoAkqeHh2ZW205RWzN9E9RtKCVCzByXruQWfE=
 github.com/sagernet/sing-shadowsocks v0.2.8/go.mod h1:lo7TWEMDcN5/h5B8S0ew+r78ZODn6SwVaFhvB6H+PTI=
 github.com/sagernet/sing-shadowsocks2 v0.2.1 h1:dWV9OXCeFPuYGHb6IRqlSptVnSzOelnqqs2gQ2/Qioo=

+ 8 - 9
transport/v2raygrpc/client.go

@@ -4,6 +4,7 @@ import (
 	"context"
 	"net"
 	"sync"
+	"sync/atomic"
 	"time"
 
 	"github.com/sagernet/sing-box/adapter"
@@ -29,7 +30,7 @@ type Client struct {
 	serverAddr  string
 	serviceName string
 	dialOptions []grpc.DialOption
-	conn        *grpc.ClientConn
+	conn        atomic.Pointer[grpc.ClientConn]
 	connAccess  sync.Mutex
 }
 
@@ -74,13 +75,13 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
 }
 
 func (c *Client) connect() (*grpc.ClientConn, error) {
-	conn := c.conn
+	conn := c.conn.Load()
 	if conn != nil && conn.GetState() != connectivity.Shutdown {
 		return conn, nil
 	}
 	c.connAccess.Lock()
 	defer c.connAccess.Unlock()
-	conn = c.conn
+	conn = c.conn.Load()
 	if conn != nil && conn.GetState() != connectivity.Shutdown {
 		return conn, nil
 	}
@@ -89,7 +90,7 @@ func (c *Client) connect() (*grpc.ClientConn, error) {
 	if err != nil {
 		return nil, err
 	}
-	c.conn = conn
+	c.conn.Store(conn)
 	return conn, nil
 }
 
@@ -109,11 +110,9 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
 }
 
 func (c *Client) Close() error {
-	c.connAccess.Lock()
-	defer c.connAccess.Unlock()
-	if c.conn != nil {
-		c.conn.Close()
-		c.conn = nil
+	conn := c.conn.Swap(nil)
+	if conn != nil {
+		conn.Close()
 	}
 	return nil
 }

+ 8 - 7
transport/v2rayquic/client.go

@@ -15,6 +15,7 @@ import (
 	"github.com/sagernet/sing-box/option"
 	"github.com/sagernet/sing-quic"
 	"github.com/sagernet/sing/common"
+	"github.com/sagernet/sing/common/atomic"
 	"github.com/sagernet/sing/common/bufio"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
@@ -29,7 +30,7 @@ type Client struct {
 	tlsConfig  tls.Config
 	quicConfig *quic.Config
 	connAccess sync.Mutex
-	conn       quic.Connection
+	conn       atomic.TypedValue[quic.Connection]
 	rawConn    net.Conn
 }
 
@@ -50,13 +51,13 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
 }
 
 func (c *Client) offer() (quic.Connection, error) {
-	conn := c.conn
+	conn := c.conn.Load()
 	if conn != nil && !common.Done(conn.Context()) {
 		return conn, nil
 	}
 	c.connAccess.Lock()
 	defer c.connAccess.Unlock()
-	conn = c.conn
+	conn = c.conn.Load()
 	if conn != nil && !common.Done(conn.Context()) {
 		return conn, nil
 	}
@@ -78,7 +79,7 @@ func (c *Client) offerNew() (quic.Connection, error) {
 		packetConn.Close()
 		return nil, err
 	}
-	c.conn = quicConn
+	c.conn.Store(quicConn)
 	c.rawConn = udpConn
 	return quicConn, nil
 }
@@ -98,13 +99,13 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
 func (c *Client) Close() error {
 	c.connAccess.Lock()
 	defer c.connAccess.Unlock()
-	if c.conn != nil {
-		c.conn.CloseWithError(0, "")
+	conn := c.conn.Swap(nil)
+	if conn != nil {
+		conn.CloseWithError(0, "")
 	}
 	if c.rawConn != nil {
 		c.rawConn.Close()
 	}
-	c.conn = nil
 	c.rawConn = nil
 	return nil
 }

+ 31 - 21
transport/v2raywebsocket/conn.go

@@ -8,6 +8,7 @@ import (
 	"net"
 	"os"
 	"sync"
+	"sync/atomic"
 	"time"
 
 	C "github.com/sagernet/sing-box/constant"
@@ -135,20 +136,22 @@ func (c *WebsocketConn) Upstream() any {
 type EarlyWebsocketConn struct {
 	*Client
 	ctx    context.Context
-	conn   *WebsocketConn
+	conn   atomic.Pointer[WebsocketConn]
 	access sync.Mutex
 	create chan struct{}
 	err    error
 }
 
 func (c *EarlyWebsocketConn) Read(b []byte) (n int, err error) {
-	if c.conn == nil {
+	conn := c.conn.Load()
+	if conn == nil {
 		<-c.create
 		if c.err != nil {
 			return 0, c.err
 		}
+		conn = c.conn.Load()
 	}
-	return wrapWsError0(c.conn.Read(b))
+	return wrapWsError0(conn.Read(b))
 }
 
 func (c *EarlyWebsocketConn) writeRequest(content []byte) error {
@@ -187,21 +190,23 @@ func (c *EarlyWebsocketConn) writeRequest(content []byte) error {
 			return err
 		}
 	}
-	c.conn = conn
+	c.conn.Store(conn)
 	return nil
 }
 
 func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) {
-	if c.conn != nil {
-		return wrapWsError0(c.conn.Write(b))
+	conn := c.conn.Load()
+	if conn != nil {
+		return wrapWsError0(conn.Write(b))
 	}
 	c.access.Lock()
 	defer c.access.Unlock()
+	conn = c.conn.Load()
 	if c.err != nil {
 		return 0, c.err
 	}
-	if c.conn != nil {
-		return wrapWsError0(c.conn.Write(b))
+	if conn != nil {
+		return wrapWsError0(conn.Write(b))
 	}
 	err = c.writeRequest(b)
 	c.err = err
@@ -213,17 +218,19 @@ func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) {
 }
 
 func (c *EarlyWebsocketConn) WriteBuffer(buffer *buf.Buffer) error {
-	if c.conn != nil {
-		return wrapWsError(c.conn.WriteBuffer(buffer))
+	conn := c.conn.Load()
+	if conn != nil {
+		return wrapWsError(conn.WriteBuffer(buffer))
 	}
 	c.access.Lock()
 	defer c.access.Unlock()
-	if c.conn != nil {
-		return wrapWsError(c.conn.WriteBuffer(buffer))
-	}
 	if c.err != nil {
 		return c.err
 	}
+	conn = c.conn.Load()
+	if conn != nil {
+		return wrapWsError(conn.WriteBuffer(buffer))
+	}
 	err := c.writeRequest(buffer.Bytes())
 	c.err = err
 	close(c.create)
@@ -231,24 +238,27 @@ func (c *EarlyWebsocketConn) WriteBuffer(buffer *buf.Buffer) error {
 }
 
 func (c *EarlyWebsocketConn) Close() error {
-	if c.conn == nil {
+	conn := c.conn.Load()
+	if conn == nil {
 		return nil
 	}
-	return c.conn.Close()
+	return conn.Close()
 }
 
 func (c *EarlyWebsocketConn) LocalAddr() net.Addr {
-	if c.conn == nil {
+	conn := c.conn.Load()
+	if conn == nil {
 		return M.Socksaddr{}
 	}
-	return c.conn.LocalAddr()
+	return conn.LocalAddr()
 }
 
 func (c *EarlyWebsocketConn) RemoteAddr() net.Addr {
-	if c.conn == nil {
+	conn := c.conn.Load()
+	if conn == nil {
 		return M.Socksaddr{}
 	}
-	return c.conn.RemoteAddr()
+	return conn.RemoteAddr()
 }
 
 func (c *EarlyWebsocketConn) SetDeadline(t time.Time) error {
@@ -268,11 +278,11 @@ func (c *EarlyWebsocketConn) NeedAdditionalReadDeadline() bool {
 }
 
 func (c *EarlyWebsocketConn) Upstream() any {
-	return common.PtrOrNil(c.conn)
+	return common.PtrOrNil(c.conn.Load())
 }
 
 func (c *EarlyWebsocketConn) LazyHeadroom() bool {
-	return c.conn == nil
+	return c.conn.Load() == nil
 }
 
 func wrapWsError(err error) error {