Browse Source

Dokodemo TPROXY: Interrupt UDP download if upload timeouts

https://github.com/XTLS/Xray-core/issues/4194#issuecomment-2564156340

Fixes https://github.com/XTLS/Xray-core/issues/4194
RPRX 9 months ago
parent
commit
a8559a1b46
1 changed files with 18 additions and 9 deletions
  1. 18 9
      proxy/dokodemo/dokodemo.go

+ 18 - 9
proxy/dokodemo/dokodemo.go

@@ -2,6 +2,7 @@ package dokodemo
 
 import (
 	"context"
+	"runtime"
 	"sync/atomic"
 
 	"github.com/xtls/xray-core/common"
@@ -147,10 +148,6 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st
 		return nil
 	}
 
-	tproxyRequest := func() error {
-		return nil
-	}
-
 	var writer buf.Writer
 	if network == net.Network_TCP {
 		writer = buf.NewWriter(conn)
@@ -180,7 +177,12 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st
 				return err
 			}
 			writer = NewPacketWriter(pConn, &dest, mark, back)
-			defer writer.(*PacketWriter).Close()
+			defer func() {
+				runtime.Gosched()
+				common.Interrupt(link.Reader) // maybe duplicated
+				runtime.Gosched()
+				writer.(*PacketWriter).Close() // close fake UDP conns
+			}()
 			/*
 				sockopt := &internet.SocketConfig{
 					Tproxy: internet.SocketConfig_TProxy,
@@ -219,17 +221,24 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st
 	responseDone := func() error {
 		defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
 
+		if network == net.Network_UDP && destinationOverridden {
+			buf.Copy(link.Reader, writer) // respect upload's timeout
+			return nil
+		}
+
 		if err := buf.Copy(link.Reader, writer, buf.UpdateActivity(timer)); err != nil {
 			return errors.New("failed to transport response").Base(err)
 		}
 		return nil
 	}
 
-	if err := task.Run(ctx, task.OnSuccess(func() error {
-		return task.Run(ctx, requestDone, tproxyRequest)
-	}, task.Close(link.Writer)), responseDone); err != nil {
-		common.Interrupt(link.Reader)
+	if err := task.Run(ctx,
+		task.OnSuccess(func() error { return task.Run(ctx, requestDone) }, task.Close(link.Writer)),
+		responseDone); err != nil {
+		runtime.Gosched()
 		common.Interrupt(link.Writer)
+		runtime.Gosched()
+		common.Interrupt(link.Reader)
 		return errors.New("connection ends").Base(err)
 	}