世界 4 hónapja
szülő
commit
d511698f3f
1 módosított fájl, 20 hozzáadás és 10 törlés
  1. 20 10
      common/dialer/tfo.go

+ 20 - 10
common/dialer/tfo.go

@@ -10,9 +10,7 @@ import (
 	"sync"
 	"time"
 
-	"github.com/sagernet/sing/common"
 	"github.com/sagernet/sing/common/bufio"
-	E "github.com/sagernet/sing/common/exceptions"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
 
@@ -26,7 +24,9 @@ type slowOpenConn struct {
 	destination M.Socksaddr
 	conn        net.Conn
 	create      chan struct{}
+	done        chan struct{}
 	access      sync.Mutex
+	closeOnce   sync.Once
 	err         error
 }
 
@@ -45,6 +45,7 @@ func DialSlowContext(dialer *tcpDialer, ctx context.Context, network string, des
 		network:     network,
 		destination: destination,
 		create:      make(chan struct{}),
+		done:        make(chan struct{}),
 	}, nil
 }
 
@@ -55,8 +56,8 @@ func (c *slowOpenConn) Read(b []byte) (n int, err error) {
 			if c.err != nil {
 				return 0, c.err
 			}
-		case <-c.ctx.Done():
-			return 0, c.ctx.Err()
+		case <-c.done:
+			return 0, os.ErrClosed
 		}
 	}
 	return c.conn.Read(b)
@@ -74,12 +75,15 @@ func (c *slowOpenConn) Write(b []byte) (n int, err error) {
 			return 0, c.err
 		}
 		return c.conn.Write(b)
+	case <-c.done:
+		return 0, os.ErrClosed
 	default:
 	}
-	c.conn, err = c.dialer.DialContext(c.ctx, c.network, c.destination.String(), b)
+	conn, err := c.dialer.DialContext(c.ctx, c.network, c.destination.String(), b)
 	if err != nil {
-		c.conn = nil
-		c.err = E.Cause(err, "dial tcp fast open")
+		c.err = err
+	} else {
+		c.conn = conn
 	}
 	n = len(b)
 	close(c.create)
@@ -87,7 +91,13 @@ func (c *slowOpenConn) Write(b []byte) (n int, err error) {
 }
 
 func (c *slowOpenConn) Close() error {
-	return common.Close(c.conn)
+	c.closeOnce.Do(func() {
+		close(c.done)
+		if c.conn != nil {
+			c.conn.Close()
+		}
+	})
+	return nil
 }
 
 func (c *slowOpenConn) LocalAddr() net.Addr {
@@ -152,8 +162,8 @@ func (c *slowOpenConn) WriteTo(w io.Writer) (n int64, err error) {
 			if c.err != nil {
 				return 0, c.err
 			}
-		case <-c.ctx.Done():
-			return 0, c.ctx.Err()
+		case <-c.done:
+			return 0, c.err
 		}
 	}
 	return bufio.Copy(w, c.conn)