123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160 |
- //go:build go1.20
- package dialer
- import (
- "context"
- "io"
- "net"
- "os"
- "sync"
- "time"
- "github.com/sagernet/sing/common"
- "github.com/sagernet/sing/common/bufio"
- E "github.com/sagernet/sing/common/exceptions"
- M "github.com/sagernet/sing/common/metadata"
- N "github.com/sagernet/sing/common/network"
- "github.com/metacubex/tfo-go"
- )
- type slowOpenConn struct {
- dialer *tfo.Dialer
- ctx context.Context
- network string
- destination M.Socksaddr
- conn net.Conn
- create chan struct{}
- access sync.Mutex
- err error
- }
- func DialSlowContext(dialer *tcpDialer, ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
- if dialer.DisableTFO || N.NetworkName(network) != N.NetworkTCP {
- switch N.NetworkName(network) {
- case N.NetworkTCP, N.NetworkUDP:
- return dialer.Dialer.DialContext(ctx, network, destination.String())
- default:
- return dialer.Dialer.DialContext(ctx, network, destination.AddrString())
- }
- }
- return &slowOpenConn{
- dialer: dialer,
- ctx: ctx,
- network: network,
- destination: destination,
- create: make(chan struct{}),
- }, nil
- }
- func (c *slowOpenConn) Read(b []byte) (n int, err error) {
- if c.conn == nil {
- select {
- case <-c.create:
- if c.err != nil {
- return 0, c.err
- }
- case <-c.ctx.Done():
- return 0, c.ctx.Err()
- }
- }
- return c.conn.Read(b)
- }
- func (c *slowOpenConn) Write(b []byte) (n int, err error) {
- if c.conn != nil {
- return c.conn.Write(b)
- }
- c.access.Lock()
- defer c.access.Unlock()
- select {
- case <-c.create:
- if c.err != nil {
- return 0, c.err
- }
- return c.conn.Write(b)
- default:
- }
- c.conn, err = c.dialer.DialContext(c.ctx, c.network, c.destination.String(), b)
- if err != nil {
- c.conn = nil
- c.err = E.Cause(err, "dial tcp fast open")
- }
- n = len(b)
- close(c.create)
- return
- }
- func (c *slowOpenConn) Close() error {
- return common.Close(c.conn)
- }
- func (c *slowOpenConn) LocalAddr() net.Addr {
- if c.conn == nil {
- return M.Socksaddr{}
- }
- return c.conn.LocalAddr()
- }
- func (c *slowOpenConn) RemoteAddr() net.Addr {
- if c.conn == nil {
- return M.Socksaddr{}
- }
- return c.conn.RemoteAddr()
- }
- func (c *slowOpenConn) SetDeadline(t time.Time) error {
- if c.conn == nil {
- return os.ErrInvalid
- }
- return c.conn.SetDeadline(t)
- }
- func (c *slowOpenConn) SetReadDeadline(t time.Time) error {
- if c.conn == nil {
- return os.ErrInvalid
- }
- return c.conn.SetReadDeadline(t)
- }
- func (c *slowOpenConn) SetWriteDeadline(t time.Time) error {
- if c.conn == nil {
- return os.ErrInvalid
- }
- return c.conn.SetWriteDeadline(t)
- }
- func (c *slowOpenConn) Upstream() any {
- return c.conn
- }
- func (c *slowOpenConn) ReaderReplaceable() bool {
- return c.conn != nil
- }
- func (c *slowOpenConn) WriterReplaceable() bool {
- return c.conn != nil
- }
- func (c *slowOpenConn) LazyHeadroom() bool {
- return c.conn == nil
- }
- func (c *slowOpenConn) NeedHandshake() bool {
- return c.conn == nil
- }
- func (c *slowOpenConn) WriteTo(w io.Writer) (n int64, err error) {
- if c.conn == nil {
- select {
- case <-c.create:
- if c.err != nil {
- return 0, c.err
- }
- case <-c.ctx.Done():
- return 0, c.ctx.Err()
- }
- }
- return bufio.Copy(w, c.conn)
- }
|