Browse Source

net/ktimeout: add a package to set TCP user timeout

Setting a user timeout will be a more practical tuning knob for a number
of endpoints, this provides a way to set it.

Updates tailscale/corp#17587

Signed-off-by: James Tucker <[email protected]>
James Tucker 2 years ago
parent
commit
8fe504241d

+ 36 - 0
net/ktimeout/ktimeout.go

@@ -0,0 +1,36 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package ktimeout configures kernel TCP stack timeouts via the provided
+// control functions. Platform support varies; on unsupported platforms control
+// functions may be entirely no-ops.
+package ktimeout
+
+import (
+	"fmt"
+	"syscall"
+	"time"
+)
+
+// UserTimeout returns a control function that sets the TCP user timeout
+// (TCP_USER_TIMEOUT on linux). A user timeout specifies the maximum age of
+// unacknowledged data on the connection (either in buffer, or sent but not
+// acknowledged) before the connection is terminated. This timer has no effect
+// on limiting the lifetime of idle connections. This may be entirely local to
+// the network stack or may also apply RFC 5482 options to packets.
+func UserTimeout(timeout time.Duration) func(network, address string, c syscall.RawConn) error {
+	return func(network, address string, c syscall.RawConn) error {
+		switch network {
+		case "tcp", "tcp4", "tcp6":
+		default:
+			return fmt.Errorf("ktimeout.UserTimeout: unsupported network: %s", network)
+		}
+		var err error
+		if e := c.Control(func(fd uintptr) {
+			err = SetUserTimeout(fd, timeout)
+		}); e != nil {
+			return e
+		}
+		return err
+	}
+}

+ 15 - 0
net/ktimeout/ktimeout_default.go

@@ -0,0 +1,15 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !linux
+
+package ktimeout
+
+import (
+	"time"
+)
+
+// SetUserTimeout is a no-op on this platform.
+func SetUserTimeout(fd uintptr, timeout time.Duration) error {
+	return nil
+}

+ 15 - 0
net/ktimeout/ktimeout_linux.go

@@ -0,0 +1,15 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package ktimeout
+
+import (
+	"time"
+
+	"golang.org/x/sys/unix"
+)
+
+// SetUserTimeout sets the TCP_USER_TIMEOUT option on the given file descriptor.
+func SetUserTimeout(fd uintptr, timeout time.Duration) error {
+	return unix.SetsockoptInt(int(fd), unix.SOL_TCP, unix.TCP_USER_TIMEOUT, int(timeout/time.Millisecond))
+}

+ 46 - 0
net/ktimeout/ktimeout_linux_test.go

@@ -0,0 +1,46 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package ktimeout
+
+import (
+	"net"
+	"testing"
+	"time"
+
+	"golang.org/x/net/nettest"
+	"golang.org/x/sys/unix"
+	"tailscale.com/util/must"
+)
+
+func TestSetUserTimeout(t *testing.T) {
+	l := must.Get(nettest.NewLocalListener("tcp"))
+	defer l.Close()
+
+	var err error
+	if e := must.Get(l.(*net.TCPListener).SyscallConn()).Control(func(fd uintptr) {
+		err = SetUserTimeout(fd, 0)
+	}); e != nil {
+		t.Fatal(e)
+	}
+	if err != nil {
+		t.Fatal(err)
+	}
+	v := must.Get(unix.GetsockoptInt(int(must.Get(l.(*net.TCPListener).File()).Fd()), unix.SOL_TCP, unix.TCP_USER_TIMEOUT))
+	if v != 0 {
+		t.Errorf("TCP_USER_TIMEOUT: got %v; want 0", v)
+	}
+
+	if e := must.Get(l.(*net.TCPListener).SyscallConn()).Control(func(fd uintptr) {
+		err = SetUserTimeout(fd, 30*time.Second)
+	}); e != nil {
+		t.Fatal(e)
+	}
+	if err != nil {
+		t.Fatal(err)
+	}
+	v = must.Get(unix.GetsockoptInt(int(must.Get(l.(*net.TCPListener).File()).Fd()), unix.SOL_TCP, unix.TCP_USER_TIMEOUT))
+	if v != 30000 {
+		t.Errorf("TCP_USER_TIMEOUT: got %v; want 30000", v)
+	}
+}

+ 24 - 0
net/ktimeout/ktimeout_test.go

@@ -0,0 +1,24 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package ktimeout
+
+import (
+	"context"
+	"fmt"
+	"net"
+	"time"
+)
+
+func ExampleUserTimeout() {
+	lc := net.ListenConfig{
+		Control: UserTimeout(30 * time.Second),
+	}
+	l, err := lc.Listen(context.TODO(), "tcp", "127.0.0.1:0")
+	if err != nil {
+		fmt.Printf("error: %v", err)
+		return
+	}
+	l.Close()
+	// Output:
+}