Browse Source

Fix: TFO AsIs bug (#452)

risetechlab 4 years ago
parent
commit
b63049f404

+ 3 - 6
infra/conf/transport_internet.go

@@ -478,22 +478,19 @@ type SocketConfig struct {
 
 // Build implements Buildable.
 func (c *SocketConfig) Build() (*internet.SocketConfig, error) {
-	tfo := int32(-1)
+	tfo := int32(0) // don't invoke setsockopt() for TFO
 	if c.TFO != nil {
 		switch v := c.TFO.(type) {
 		case bool:
 			if v {
 				tfo = 256
 			} else {
-				tfo = 0
+				tfo = -1 // TFO need to be disabled
 			}
 		case float64:
-			if v < 0 {
-				return nil, newError("tcpFastOpen: only boolean and non-negative integer value is acceptable")
-			}
 			tfo = int32(math.Min(v, math.MaxInt32))
 		default:
-			return nil, newError("tcpFastOpen: only boolean and non-negative integer value is acceptable")
+			return nil, newError("tcpFastOpen: only boolean and integer value is acceptable")
 		}
 	}
 	var tproxy internet.SocketConfig_TProxyMode

+ 101 - 14
infra/conf/transport_test.go

@@ -31,6 +31,13 @@ func TestSocketConfig(t *testing.T) {
 		}
 	}
 
+	// test "tcpFastOpen": true, queue length 256 is expected. other parameters are tested here too
+	expectedOutput := &internet.SocketConfig{
+		Mark:           1,
+		Tfo:            256,
+		DomainStrategy: internet.DomainStrategy_USE_IP,
+		DialerProxy:    "tag",
+	}
 	runMultiTestCase(t, []TestCase{
 		{
 			Input: `{
@@ -40,38 +47,118 @@ func TestSocketConfig(t *testing.T) {
 				"dialerProxy": "tag"
 			}`,
 			Parser: createParser(),
-			Output: &internet.SocketConfig{
-				Mark:           1,
-				Tfo:            256,
-				DomainStrategy: internet.DomainStrategy_USE_IP,
-				DialerProxy:    "tag",
-			},
+			Output: expectedOutput,
 		},
 	})
+	if expectedOutput.ParseTFOValue() != 256 {
+		t.Fatalf("unexpected parsed TFO value, which should be 256")
+	}
+
+	// test "tcpFastOpen": false, disabled TFO is expected
+	expectedOutput = &internet.SocketConfig{
+		Mark: 0,
+		Tfo:  -1,
+	}
 	runMultiTestCase(t, []TestCase{
 		{
 			Input: `{
 				"tcpFastOpen": false
 			}`,
 			Parser: createParser(),
-			Output: &internet.SocketConfig{
-				Mark: 0,
-				Tfo:  0,
-			},
+			Output: expectedOutput,
 		},
 	})
+	if expectedOutput.ParseTFOValue() != 0 {
+		t.Fatalf("unexpected parsed TFO value, which should be 0")
+	}
+
+	// test "tcpFastOpen": 65535, queue length 65535 is expected
+	expectedOutput = &internet.SocketConfig{
+		Mark: 0,
+		Tfo:  65535,
+	}
 	runMultiTestCase(t, []TestCase{
 		{
 			Input: `{
 				"tcpFastOpen": 65535
 			}`,
 			Parser: createParser(),
-			Output: &internet.SocketConfig{
-				Mark: 0,
-				Tfo:  65535,
-			},
+			Output: expectedOutput,
 		},
 	})
+	if expectedOutput.ParseTFOValue() != 65535 {
+		t.Fatalf("unexpected parsed TFO value, which should be 65535")
+	}
+
+	// test "tcpFastOpen": -65535, disable TFO is expected
+	expectedOutput = &internet.SocketConfig{
+		Mark: 0,
+		Tfo:  -65535,
+	}
+	runMultiTestCase(t, []TestCase{
+		{
+			Input: `{
+				"tcpFastOpen": -65535
+			}`,
+			Parser: createParser(),
+			Output: expectedOutput,
+		},
+	})
+	if expectedOutput.ParseTFOValue() != 0 {
+		t.Fatalf("unexpected parsed TFO value, which should be 0")
+	}
+
+	// test "tcpFastOpen": 0, no operation is expected
+	expectedOutput = &internet.SocketConfig{
+		Mark: 0,
+		Tfo:  0,
+	}
+	runMultiTestCase(t, []TestCase{
+		{
+			Input: `{
+				"tcpFastOpen": 0
+			}`,
+			Parser: createParser(),
+			Output: expectedOutput,
+		},
+	})
+	if expectedOutput.ParseTFOValue() != -1 {
+		t.Fatalf("unexpected parsed TFO value, which should be -1")
+	}
+
+	// test omit "tcpFastOpen", no operation is expected
+	expectedOutput = &internet.SocketConfig{
+		Mark: 0,
+		Tfo:  0,
+	}
+	runMultiTestCase(t, []TestCase{
+		{
+			Input:  `{}`,
+			Parser: createParser(),
+			Output: expectedOutput,
+		},
+	})
+	if expectedOutput.ParseTFOValue() != -1 {
+		t.Fatalf("unexpected parsed TFO value, which should be -1")
+	}
+
+	// test "tcpFastOpen": null, no operation is expected
+	expectedOutput = &internet.SocketConfig{
+		Mark: 0,
+		Tfo:  0,
+	}
+	runMultiTestCase(t, []TestCase{
+		{
+			Input: `{
+				"tcpFastOpen": null
+			}`,
+			Parser: createParser(),
+			Output: expectedOutput,
+		},
+	})
+	if expectedOutput.ParseTFOValue() != -1 {
+		t.Fatalf("unexpected parsed TFO value, which should be -1")
+	}
 }
 
 func TestTransportConfig(t *testing.T) {

+ 11 - 0
transport/internet/sockopt.go

@@ -17,3 +17,14 @@ func isUDPSocket(network string) bool {
 		return false
 	}
 }
+
+func (v *SocketConfig) ParseTFOValue() int {
+	if v.Tfo == 0 {
+		return -1
+	}
+	tfo := int(v.Tfo)
+	if tfo < 0 {
+		tfo = 0
+	}
+	return tfo
+}

+ 4 - 4
transport/internet/sockopt_darwin.go

@@ -15,12 +15,12 @@ const (
 
 func applyOutboundSocketOptions(network string, address string, fd uintptr, config *SocketConfig) error {
 	if isTCPSocket(network) {
-		tfo := config.Tfo
+		tfo := config.ParseTFOValue()
 		if tfo > 0 {
 			tfo = TCP_FASTOPEN_CLIENT
 		}
 		if tfo >= 0 {
-			if err := syscall.SetsockoptInt(int(fd), syscall.IPPROTO_TCP, TCP_FASTOPEN, int(tfo)); err != nil {
+			if err := syscall.SetsockoptInt(int(fd), syscall.IPPROTO_TCP, TCP_FASTOPEN, tfo); err != nil {
 				return err
 			}
 		}
@@ -31,12 +31,12 @@ func applyOutboundSocketOptions(network string, address string, fd uintptr, conf
 
 func applyInboundSocketOptions(network string, fd uintptr, config *SocketConfig) error {
 	if isTCPSocket(network) {
-		tfo := config.Tfo
+		tfo := config.ParseTFOValue()
 		if tfo > 0 {
 			tfo = TCP_FASTOPEN_SERVER
 		}
 		if tfo >= 0 {
-			if err := syscall.SetsockoptInt(int(fd), syscall.IPPROTO_TCP, TCP_FASTOPEN, int(tfo)); err != nil {
+			if err := syscall.SetsockoptInt(int(fd), syscall.IPPROTO_TCP, TCP_FASTOPEN, tfo); err != nil {
 				return err
 			}
 		}

+ 5 - 4
transport/internet/sockopt_freebsd.go

@@ -130,7 +130,7 @@ func applyOutboundSocketOptions(network string, address string, fd uintptr, conf
 	}
 
 	if isTCPSocket(network) {
-		tfo := int(config.Tfo)
+		tfo := config.ParseTFOValue()
 		if tfo > 0 {
 			tfo = 1
 		}
@@ -163,9 +163,10 @@ func applyInboundSocketOptions(network string, fd uintptr, config *SocketConfig)
 		}
 	}
 	if isTCPSocket(network) {
-		if config.Tfo >= 0 {
-			if err := syscall.SetsockoptInt(int(fd), syscall.IPPROTO_TCP, unix.TCP_FASTOPEN, int(config.Tfo)); err != nil {
-				return newError("failed to set TCP_FASTOPEN=", config.Tfo).Base(err)
+		tfo := config.ParseTFOValue()
+		if tfo >= 0 {
+			if err := syscall.SetsockoptInt(int(fd), syscall.IPPROTO_TCP, unix.TCP_FASTOPEN, tfo); err != nil {
+				return newError("failed to set TCP_FASTOPEN=", tfo).Base(err)
 			}
 		}
 	}

+ 5 - 4
transport/internet/sockopt_linux.go

@@ -48,7 +48,7 @@ func applyOutboundSocketOptions(network string, address string, fd uintptr, conf
 	}
 
 	if isTCPSocket(network) {
-		tfo := int(config.Tfo)
+		tfo := config.ParseTFOValue()
 		if tfo > 0 {
 			tfo = 1
 		}
@@ -75,9 +75,10 @@ func applyInboundSocketOptions(network string, fd uintptr, config *SocketConfig)
 		}
 	}
 	if isTCPSocket(network) {
-		if config.Tfo >= 0 {
-			if err := syscall.SetsockoptInt(int(fd), syscall.SOL_TCP, TCP_FASTOPEN, int(config.Tfo)); err != nil {
-				return newError("failed to set TCP_FASTOPEN=", config.Tfo).Base(err)
+		tfo := config.ParseTFOValue()
+		if tfo >= 0 {
+			if err := syscall.SetsockoptInt(int(fd), syscall.SOL_TCP, TCP_FASTOPEN, tfo); err != nil {
+				return newError("failed to set TCP_FASTOPEN=", tfo).Base(err)
 			}
 		}
 	}

+ 4 - 4
transport/internet/sockopt_windows.go

@@ -8,12 +8,12 @@ const (
 	TCP_FASTOPEN = 15
 )
 
-func setTFO(fd syscall.Handle, tfo int32) error {
+func setTFO(fd syscall.Handle, tfo int) error {
 	if tfo > 0 {
 		tfo = 1
 	}
 	if tfo >= 0 {
-		if err := syscall.SetsockoptInt(fd, syscall.IPPROTO_TCP, TCP_FASTOPEN, int(tfo)); err != nil {
+		if err := syscall.SetsockoptInt(fd, syscall.IPPROTO_TCP, TCP_FASTOPEN, tfo); err != nil {
 			return err
 		}
 	}
@@ -22,7 +22,7 @@ func setTFO(fd syscall.Handle, tfo int32) error {
 
 func applyOutboundSocketOptions(network string, address string, fd uintptr, config *SocketConfig) error {
 	if isTCPSocket(network) {
-		if err := setTFO(syscall.Handle(fd), config.Tfo); err != nil {
+		if err := setTFO(syscall.Handle(fd), config.ParseTFOValue()); err != nil {
 			return err
 		}
 
@@ -33,7 +33,7 @@ func applyOutboundSocketOptions(network string, address string, fd uintptr, conf
 
 func applyInboundSocketOptions(network string, fd uintptr, config *SocketConfig) error {
 	if isTCPSocket(network) {
-		if err := setTFO(syscall.Handle(fd), config.Tfo); err != nil {
+		if err := setTFO(syscall.Handle(fd), config.ParseTFOValue()); err != nil {
 			return err
 		}
 	}