Browse Source

First step of upcoming refactor for Xray-core: Add TimeoutWrapperReader; Use DispatchLink() in Tunnel/Socks/HTTP inbounds

https://github.com/XTLS/Xray-core/pull/5067#issuecomment-3236833240

Fixes https://github.com/XTLS/Xray-core/pull/4952#issuecomment-3229878125 for client's Xray-core
RPRX 3 months ago
parent
commit
56a45ad578
7 changed files with 93 additions and 206 deletions
  1. 5 3
      app/dispatcher/default.go
  2. 37 0
      common/buf/io.go
  3. 5 1
      common/mux/client.go
  4. 8 1
      common/mux/server.go
  5. 13 99
      proxy/dokodemo/dokodemo.go
  6. 14 47
      proxy/http/server.go
  7. 11 55
      proxy/socks/server.go

+ 5 - 3
app/dispatcher/default.go

@@ -29,7 +29,7 @@ var errSniffingTimeout = errors.New("timeout on sniffing")
 
 type cachedReader struct {
 	sync.Mutex
-	reader *pipe.Reader
+	reader buf.TimeoutReader // *pipe.Reader or *buf.TimeoutWrapperReader
 	cache  buf.MultiBuffer
 }
 
@@ -87,7 +87,9 @@ func (r *cachedReader) Interrupt() {
 		r.cache = buf.ReleaseMulti(r.cache)
 	}
 	r.Unlock()
-	r.reader.Interrupt()
+	if p, ok := r.reader.(*pipe.Reader); ok {
+		p.Interrupt()
+	}
 }
 
 // DefaultDispatcher is a default implementation of Dispatcher.
@@ -319,7 +321,7 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
 		d.routedDispatch(ctx, outbound, destination)
 	} else {
 		cReader := &cachedReader{
-			reader: outbound.Reader.(*pipe.Reader),
+			reader: outbound.Reader.(buf.TimeoutReader),
 		}
 		outbound.Reader = cReader
 		result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network)

+ 37 - 0
common/buf/io.go

@@ -24,9 +24,46 @@ var ErrReadTimeout = errors.New("IO timeout")
 
 // TimeoutReader is a reader that returns error if Read() operation takes longer than the given timeout.
 type TimeoutReader interface {
+	Reader
 	ReadMultiBufferTimeout(time.Duration) (MultiBuffer, error)
 }
 
+type TimeoutWrapperReader struct {
+	Reader
+	mb   MultiBuffer
+	err  error
+	done chan struct{}
+}
+
+func (r *TimeoutWrapperReader) ReadMultiBuffer() (MultiBuffer, error) {
+	if r.done != nil {
+		<-r.done
+		r.done = nil
+		return r.mb, r.err
+	}
+	r.mb = nil
+	r.err = nil
+	return r.Reader.ReadMultiBuffer()
+}
+
+func (r *TimeoutWrapperReader) ReadMultiBufferTimeout(duration time.Duration) (MultiBuffer, error) {
+	if r.done == nil {
+		r.done = make(chan struct{})
+		go func() {
+			r.mb, r.err = r.Reader.ReadMultiBuffer()
+			close(r.done)
+		}()
+	}
+	time.Sleep(duration)
+	select {
+	case <-r.done:
+		r.done = nil
+		return r.mb, r.err
+	default:
+		return nil, nil
+	}
+}
+
 // Writer extends io.Writer with MultiBuffer.
 type Writer interface {
 	// WriteMultiBuffer writes a MultiBuffer into underlying writer.

+ 5 - 1
common/mux/client.go

@@ -307,7 +307,11 @@ func (m *ClientWorker) Dispatch(ctx context.Context, link *transport.Link) bool
 	}
 	s.input = link.Reader
 	s.output = link.Writer
-	go fetchInput(ctx, s, m.link.Writer)
+	if _, ok := link.Reader.(*pipe.Reader); ok {
+		go fetchInput(ctx, s, m.link.Writer)
+	} else {
+		fetchInput(ctx, s, m.link.Writer)
+	}
 	return true
 }
 

+ 8 - 1
common/mux/server.go

@@ -87,7 +87,14 @@ func NewServerWorker(ctx context.Context, d routing.Dispatcher, link *transport.
 		link:           link,
 		sessionManager: NewSessionManager(),
 	}
-	go worker.run(ctx)
+	if inbound := session.InboundFromContext(ctx); inbound != nil {
+		inbound.CanSpliceCopy = 3
+	}
+	if _, ok := link.Reader.(*pipe.Reader); ok {
+		go worker.run(ctx)
+	} else {
+		worker.run(ctx)
+	}
 	return worker, nil
 }
 

+ 13 - 99
proxy/dokodemo/dokodemo.go

@@ -2,10 +2,8 @@ package dokodemo
 
 import (
 	"context"
-	"runtime"
 	"strconv"
 	"strings"
-	"sync/atomic"
 
 	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common/buf"
@@ -14,11 +12,10 @@ import (
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/protocol"
 	"github.com/xtls/xray-core/common/session"
-	"github.com/xtls/xray-core/common/signal"
-	"github.com/xtls/xray-core/common/task"
 	"github.com/xtls/xray-core/core"
 	"github.com/xtls/xray-core/features/policy"
 	"github.com/xtls/xray-core/features/routing"
+	"github.com/xtls/xray-core/transport"
 	"github.com/xtls/xray-core/transport/internet/stat"
 	"github.com/xtls/xray-core/transport/internet/tls"
 )
@@ -144,39 +141,11 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st
 	})
 	errors.LogInfo(ctx, "received request for ", conn.RemoteAddr())
 
-	plcy := d.policy()
-	ctx, cancel := context.WithCancel(ctx)
-	timer := signal.CancelAfterInactivity(ctx, cancel, plcy.Timeouts.ConnectionIdle)
-
-	if inbound != nil {
-		inbound.Timer = timer
-	}
-
-	ctx = policy.ContextWithBufferPolicy(ctx, plcy.Buffer)
-	link, err := dispatcher.Dispatch(ctx, dest)
-	if err != nil {
-		return errors.New("failed to dispatch request").Base(err)
-	}
-
-	requestCount := int32(1)
-	requestDone := func() error {
-		defer func() {
-			if atomic.AddInt32(&requestCount, -1) == 0 {
-				timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
-			}
-		}()
-
-		var reader buf.Reader
-		if dest.Network == net.Network_UDP {
-			reader = buf.NewPacketReader(conn)
-		} else {
-			reader = buf.NewReader(conn)
-		}
-		if err := buf.Copy(reader, link.Writer, buf.UpdateActivity(timer)); err != nil {
-			return errors.New("failed to transport request").Base(err)
-		}
-
-		return nil
+	var reader buf.Reader
+	if dest.Network == net.Network_TCP {
+		reader = buf.NewReader(conn)
+	} else {
+		reader = buf.NewPacketReader(conn)
 	}
 
 	var writer buf.Writer
@@ -208,72 +177,17 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st
 				return err
 			}
 			writer = NewPacketWriter(pConn, &dest, mark, back)
-			defer func() {
-				runtime.Gosched()
-				common.Interrupt(link.Reader) // maybe duplicated
-				runtime.Gosched()
-				writer.(*PacketWriter).Close() // close fake UDP conns
-			}()
-			/*
-				sockopt := &internet.SocketConfig{
-					Tproxy: internet.SocketConfig_TProxy,
-				}
-				if dest.Address.Family().IsIP() {
-					sockopt.BindAddress = dest.Address.IP()
-					sockopt.BindPort = uint32(dest.Port)
-				}
-				if d.sockopt != nil {
-					sockopt.Mark = d.sockopt.Mark
-				}
-				tConn, err := internet.DialSystem(ctx, net.DestinationFromAddr(conn.RemoteAddr()), sockopt)
-				if err != nil {
-					return err
-				}
-				defer tConn.Close()
-
-				writer = &buf.SequentialWriter{Writer: tConn}
-				tReader := buf.NewPacketReader(tConn)
-				requestCount++
-				tproxyRequest = func() error {
-					defer func() {
-						if atomic.AddInt32(&requestCount, -1) == 0 {
-							timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
-						}
-					}()
-					if err := buf.Copy(tReader, link.Writer, buf.UpdateActivity(timer)); err != nil {
-						return errors.New("failed to transport request (TPROXY conn)").Base(err)
-					}
-					return nil
-				}
-			*/
-		}
-	}
-
-	responseDone := func() error {
-		defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
-
-		if network == net.Network_UDP && destinationOverridden {
-			buf.Copy(link.Reader, writer) // respect upload's timeout
-			return nil
-		}
-
-		if err := buf.Copy(link.Reader, writer, buf.UpdateActivity(timer)); err != nil {
-			return errors.New("failed to transport response").Base(err)
+			defer writer.(*PacketWriter).Close() // close fake UDP conns
 		}
-		return nil
 	}
 
-	if err := task.Run(ctx,
-		task.OnSuccess(func() error { return task.Run(ctx, requestDone) }, task.Close(link.Writer)),
-		responseDone); err != nil {
-		runtime.Gosched()
-		common.Interrupt(link.Writer)
-		runtime.Gosched()
-		common.Interrupt(link.Reader)
-		return errors.New("connection ends").Base(err)
+	if err := dispatcher.DispatchLink(ctx, dest, &transport.Link{
+		Reader: &buf.TimeoutWrapperReader{Reader: reader},
+		Writer: writer},
+	); err != nil {
+		return errors.New("failed to dispatch request").Base(err)
 	}
-
-	return nil
+	return nil // Unlike Dispatch(), DispatchLink() will not return until the outbound finishes Process()
 }
 
 func NewPacketWriter(conn net.PacketConn, d *net.Destination, mark int, back *net.UDPAddr) buf.Writer {

+ 14 - 47
proxy/http/server.go

@@ -18,12 +18,12 @@ import (
 	"github.com/xtls/xray-core/common/protocol"
 	http_proto "github.com/xtls/xray-core/common/protocol/http"
 	"github.com/xtls/xray-core/common/session"
-	"github.com/xtls/xray-core/common/signal"
 	"github.com/xtls/xray-core/common/task"
 	"github.com/xtls/xray-core/core"
 	"github.com/xtls/xray-core/features/policy"
 	"github.com/xtls/xray-core/features/routing"
 	"github.com/xtls/xray-core/proxy"
+	"github.com/xtls/xray-core/transport"
 	"github.com/xtls/xray-core/transport/internet/stat"
 )
 
@@ -173,64 +173,31 @@ Start:
 	return err
 }
 
-func (s *Server) handleConnect(ctx context.Context, _ *http.Request, reader *bufio.Reader, conn stat.Connection, dest net.Destination, dispatcher routing.Dispatcher, inbound *session.Inbound) error {
+func (s *Server) handleConnect(ctx context.Context, _ *http.Request, buffer *bufio.Reader, conn stat.Connection, dest net.Destination, dispatcher routing.Dispatcher, inbound *session.Inbound) error {
 	_, err := conn.Write([]byte("HTTP/1.1 200 Connection established\r\n\r\n"))
 	if err != nil {
 		return errors.New("failed to write back OK response").Base(err)
 	}
 
-	plcy := s.policy()
-	ctx, cancel := context.WithCancel(ctx)
-	timer := signal.CancelAfterInactivity(ctx, cancel, plcy.Timeouts.ConnectionIdle)
-
-	if inbound != nil {
-		inbound.Timer = timer
-	}
-
-	ctx = policy.ContextWithBufferPolicy(ctx, plcy.Buffer)
-	link, err := dispatcher.Dispatch(ctx, dest)
-	if err != nil {
-		return err
-	}
-
-	if reader.Buffered() > 0 {
-		payload, err := buf.ReadFrom(io.LimitReader(reader, int64(reader.Buffered())))
+	reader := buf.NewReader(conn)
+	if buffer.Buffered() > 0 {
+		payload, err := buf.ReadFrom(io.LimitReader(buffer, int64(buffer.Buffered())))
 		if err != nil {
 			return err
 		}
-		if err := link.Writer.WriteMultiBuffer(payload); err != nil {
-			return err
-		}
-		reader = nil
-	}
-
-	requestDone := func() error {
-		defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
-
-		return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
+		reader = &buf.BufferedReader{Reader: reader, Buffer: payload}
+		buffer = nil
 	}
 
-	responseDone := func() error {
-		if inbound.CanSpliceCopy == 2 {
-			inbound.CanSpliceCopy = 1
-		}
-		defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
-
-		v2writer := buf.NewWriter(conn)
-		if err := buf.Copy(link.Reader, v2writer, buf.UpdateActivity(timer)); err != nil {
-			return err
-		}
-
-		return nil
+	if inbound.CanSpliceCopy == 2 {
+		inbound.CanSpliceCopy = 1
 	}
-
-	closeWriter := task.OnSuccess(requestDone, task.Close(link.Writer))
-	if err := task.Run(ctx, closeWriter, responseDone); err != nil {
-		common.Interrupt(link.Reader)
-		common.Interrupt(link.Writer)
-		return errors.New("connection ends").Base(err)
+	if err := dispatcher.DispatchLink(ctx, dest, &transport.Link{
+		Reader: &buf.TimeoutWrapperReader{Reader: reader},
+		Writer: buf.NewWriter(conn)},
+	); err != nil {
+		return errors.New("failed to dispatch request").Base(err)
 	}
-
 	return nil
 }
 

+ 11 - 55
proxy/socks/server.go

@@ -14,13 +14,12 @@ import (
 	"github.com/xtls/xray-core/common/protocol"
 	udp_proto "github.com/xtls/xray-core/common/protocol/udp"
 	"github.com/xtls/xray-core/common/session"
-	"github.com/xtls/xray-core/common/signal"
-	"github.com/xtls/xray-core/common/task"
 	"github.com/xtls/xray-core/core"
 	"github.com/xtls/xray-core/features/policy"
 	"github.com/xtls/xray-core/features/routing"
 	"github.com/xtls/xray-core/proxy"
 	"github.com/xtls/xray-core/proxy/http"
+	"github.com/xtls/xray-core/transport"
 	"github.com/xtls/xray-core/transport/internet/stat"
 	"github.com/xtls/xray-core/transport/internet/udp"
 )
@@ -158,8 +157,16 @@ func (s *Server) processTCP(ctx context.Context, conn stat.Connection, dispatche
 				Reason: "",
 			})
 		}
-
-		return s.transport(ctx, reader, conn, dest, dispatcher, inbound)
+		if inbound.CanSpliceCopy == 2 {
+			inbound.CanSpliceCopy = 1
+		}
+		if err := dispatcher.DispatchLink(ctx, dest, &transport.Link{
+			Reader: &buf.TimeoutWrapperReader{Reader: reader},
+			Writer: buf.NewWriter(conn)},
+		); err != nil {
+			return errors.New("failed to dispatch request").Base(err)
+		}
+		return nil
 	}
 
 	if request.Command == protocol.RequestCommandUDP {
@@ -178,54 +185,6 @@ func (*Server) handleUDP(c io.Reader) error {
 	return common.Error2(io.Copy(buf.DiscardBytes, c))
 }
 
-func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writer, dest net.Destination, dispatcher routing.Dispatcher, inbound *session.Inbound) error {
-	ctx, cancel := context.WithCancel(ctx)
-	timer := signal.CancelAfterInactivity(ctx, cancel, s.policy().Timeouts.ConnectionIdle)
-
-	if inbound != nil {
-		inbound.Timer = timer
-	}
-
-	plcy := s.policy()
-	ctx = policy.ContextWithBufferPolicy(ctx, plcy.Buffer)
-	link, err := dispatcher.Dispatch(ctx, dest)
-	if err != nil {
-		return err
-	}
-
-	requestDone := func() error {
-		defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
-		if err := buf.Copy(buf.NewReader(reader), link.Writer, buf.UpdateActivity(timer)); err != nil {
-			return errors.New("failed to transport all TCP request").Base(err)
-		}
-
-		return nil
-	}
-
-	responseDone := func() error {
-		if inbound.CanSpliceCopy == 2 {
-			inbound.CanSpliceCopy = 1
-		}
-		defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
-
-		v2writer := buf.NewWriter(writer)
-		if err := buf.Copy(link.Reader, v2writer, buf.UpdateActivity(timer)); err != nil {
-			return errors.New("failed to transport all TCP response").Base(err)
-		}
-
-		return nil
-	}
-
-	requestDonePost := task.OnSuccess(requestDone, task.Close(link.Writer))
-	if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
-		common.Interrupt(link.Reader)
-		common.Interrupt(link.Writer)
-		return errors.New("connection ends").Base(err)
-	}
-
-	return nil
-}
-
 func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dispatcher routing.Dispatcher) error {
 	if s.udpFilter != nil && !s.udpFilter.Check(conn.RemoteAddr()) {
 		errors.LogDebug(ctx, "Unauthorized UDP access from ", conn.RemoteAddr().String())
@@ -265,9 +224,6 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis
 	if inbound != nil && inbound.Source.IsValid() {
 		errors.LogInfo(ctx, "client UDP connection from ", inbound.Source)
 	}
-	if inbound.CanSpliceCopy == 2 {
-		inbound.CanSpliceCopy = 1
-	}
 
 	var dest *net.Destination