Browse Source

ssh/tailssh: send audit messages on SSH login (Linux)

Send LOGIN audit messages to the kernel audit subsystem on Linux
when users successfully authenticate to Tailscale SSH. This provides
administrators with audit trail integration via auditd or journald,
recording details about both the Tailscale user (whois) and the
mapped local user account.

The implementation uses raw netlink sockets to send AUDIT_USER_LOGIN
messages to the kernel audit subsystem. It requires CAP_AUDIT_WRITE
capability, which is checked at runtime. If the capability is not
present, audit logging is silently skipped.

Audit messages are sent to the kernel (pid 0) and consumed by either
auditd (written to /var/log/audit/audit.log) or journald (available
via journalctl _TRANSPORT=audit), depending on system configuration.

Note: This may result in duplicate messages on a system where
auditd/journald audit logs are enabled and the system has and supports
`login -h`. Sadly Linux login code paths are still an inconsistent wild
west so we accept the potential duplication rather than trying to avoid
it.

Fixes #18332

Signed-off-by: James Tucker <[email protected]>
James Tucker 2 months ago
parent
commit
39a61888b8
3 changed files with 366 additions and 0 deletions
  1. 176 0
      ssh/tailssh/auditd_linux.go
  2. 180 0
      ssh/tailssh/auditd_linux_test.go
  3. 10 0
      ssh/tailssh/tailssh.go

+ 176 - 0
ssh/tailssh/auditd_linux.go

@@ -0,0 +1,176 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build linux && !android
+
+package tailssh
+
+import (
+	"bytes"
+	"encoding/binary"
+	"fmt"
+	"os"
+	"syscall"
+
+	"golang.org/x/sys/unix"
+	"tailscale.com/types/logger"
+)
+
+const (
+	auditUserLogin = 1112 // audit message type for user login (from linux/audit.h)
+	netlinkAudit   = 9    // AF_NETLINK protocol number for audit (from linux/netlink.h)
+	nlmFRequest    = 0x01 // netlink message flag: request (from linux/netlink.h)
+
+	// maxAuditMessageLength is the maximum length of an audit message payload.
+	// This is derived from MAX_AUDIT_MESSAGE_LENGTH (8970) in the Linux kernel
+	// (linux/audit.h), minus overhead for the netlink header and safety margin.
+	maxAuditMessageLength = 8192
+)
+
+// hasAuditWriteCap checks if the process has CAP_AUDIT_WRITE in its effective capability set.
+func hasAuditWriteCap() bool {
+	var hdr unix.CapUserHeader
+	var data [2]unix.CapUserData
+
+	hdr.Version = unix.LINUX_CAPABILITY_VERSION_3
+	hdr.Pid = int32(os.Getpid())
+
+	if err := unix.Capget(&hdr, &data[0]); err != nil {
+		return false
+	}
+
+	const capBit = uint32(1 << (unix.CAP_AUDIT_WRITE % 32))
+	const capIdx = unix.CAP_AUDIT_WRITE / 32
+	return (data[capIdx].Effective & capBit) != 0
+}
+
+// buildAuditNetlinkMessage constructs a netlink audit message.
+// This is separated from sendAuditMessage to allow testing the message format
+// without requiring CAP_AUDIT_WRITE or a netlink socket.
+func buildAuditNetlinkMessage(msgType uint16, message string) ([]byte, error) {
+	msgBytes := []byte(message)
+	if len(msgBytes) > maxAuditMessageLength {
+		msgBytes = msgBytes[:maxAuditMessageLength]
+	}
+	msgLen := len(msgBytes)
+
+	totalLen := syscall.NLMSG_HDRLEN + msgLen
+	alignedLen := (totalLen + syscall.NLMSG_ALIGNTO - 1) & ^(syscall.NLMSG_ALIGNTO - 1)
+
+	nlh := syscall.NlMsghdr{
+		Len:   uint32(totalLen),
+		Type:  msgType,
+		Flags: nlmFRequest,
+		Seq:   1,
+		Pid:   uint32(os.Getpid()),
+	}
+
+	buf := bytes.NewBuffer(make([]byte, 0, alignedLen))
+	if err := binary.Write(buf, binary.NativeEndian, nlh); err != nil {
+		return nil, err
+	}
+	buf.Write(msgBytes)
+
+	for buf.Len() < alignedLen {
+		buf.WriteByte(0)
+	}
+
+	return buf.Bytes(), nil
+}
+
+// sendAuditMessage sends a message to the audit subsystem using raw netlink.
+// It logs errors but does not return them.
+func sendAuditMessage(logf logger.Logf, msgType uint16, message string) {
+	if !hasAuditWriteCap() {
+		return
+	}
+
+	fd, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_RAW, netlinkAudit)
+	if err != nil {
+		logf("auditd: failed to create netlink socket: %v", err)
+		return
+	}
+	defer syscall.Close(fd)
+
+	bindAddr := &syscall.SockaddrNetlink{
+		Family: syscall.AF_NETLINK,
+		Pid:    uint32(os.Getpid()),
+		Groups: 0,
+	}
+
+	if err := syscall.Bind(fd, bindAddr); err != nil {
+		logf("auditd: failed to bind netlink socket: %v", err)
+		return
+	}
+
+	kernelAddr := &syscall.SockaddrNetlink{
+		Family: syscall.AF_NETLINK,
+		Pid:    0,
+		Groups: 0,
+	}
+
+	msgBytes, err := buildAuditNetlinkMessage(msgType, message)
+	if err != nil {
+		logf("auditd: failed to build audit message: %v", err)
+		return
+	}
+
+	if err := syscall.Sendto(fd, msgBytes, 0, kernelAddr); err != nil {
+		logf("auditd: failed to send audit message: %v", err)
+		return
+	}
+}
+
+// logSSHLogin logs an SSH login event to auditd with whois information.
+func logSSHLogin(logf logger.Logf, c *conn) {
+	if c == nil || c.info == nil || c.localUser == nil {
+		return
+	}
+
+	exePath := c.srv.tailscaledPath
+	if exePath == "" {
+		exePath = "tailscaled"
+	}
+
+	srcIP := c.info.src.Addr().String()
+	srcPort := c.info.src.Port()
+	dstIP := c.info.dst.Addr().String()
+	dstPort := c.info.dst.Port()
+
+	tailscaleUser := c.info.uprof.LoginName
+	tailscaleUserID := c.info.uprof.ID
+	tailscaleDisplayName := c.info.uprof.DisplayName
+	nodeName := c.info.node.Name()
+	nodeID := c.info.node.ID()
+
+	localUser := c.localUser.Username
+	localUID := c.localUser.Uid
+	localGID := c.localUser.Gid
+
+	hostname, err := os.Hostname()
+	if err != nil {
+		hostname = "unknown"
+	}
+
+	// use principally the same format as ssh / PAM, which come from the audit userspace, i.e.
+	// https://github.com/linux-audit/audit-userspace/blob/b6f8c208435038df113a9795e3e202720aee6b70/lib/audit_logging.c#L515
+	msg := fmt.Sprintf(
+		"op=login acct=%s uid=%s gid=%s "+
+			"src=%s src_port=%d dst=%s dst_port=%d "+
+			"hostname=%q exe=%q terminal=ssh res=success "+
+			"ts_user=%q ts_user_id=%d ts_display_name=%q ts_node=%q ts_node_id=%d",
+		localUser, localUID, localGID,
+		srcIP, srcPort, dstIP, dstPort,
+		hostname, exePath,
+		tailscaleUser, tailscaleUserID, tailscaleDisplayName, nodeName, nodeID,
+	)
+
+	sendAuditMessage(logf, auditUserLogin, msg)
+
+	logf("audit: SSH login: user=%s uid=%s from=%s ts_user=%s node=%s",
+		localUser, localUID, srcIP, tailscaleUser, nodeName)
+}
+
+func init() {
+	hookSSHLoginSuccess.Set(logSSHLogin)
+}

+ 180 - 0
ssh/tailssh/auditd_linux_test.go

@@ -0,0 +1,180 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build linux && !android
+
+package tailssh
+
+import (
+	"bufio"
+	"bytes"
+	"context"
+	"encoding/binary"
+	"fmt"
+	"os"
+	"os/exec"
+	"strings"
+	"syscall"
+	"testing"
+	"time"
+)
+
+// maybeWithSudo returns a command with context that may be prefixed with sudo if not running as root.
+func maybeWithSudo(ctx context.Context, name string, args ...string) *exec.Cmd {
+	if os.Geteuid() == 0 {
+		return exec.CommandContext(ctx, name, args...)
+	}
+	sudoArgs := append([]string{name}, args...)
+	return exec.CommandContext(ctx, "sudo", sudoArgs...)
+}
+
+func TestBuildAuditNetlinkMessage(t *testing.T) {
+	testCases := []struct {
+		name     string
+		msgType  uint16
+		message  string
+		wantType uint16
+	}{
+		{
+			name:     "simple-message",
+			msgType:  auditUserLogin,
+			message:  "op=login acct=test",
+			wantType: auditUserLogin,
+		},
+		{
+			name:     "message-with-quoted-fields",
+			msgType:  auditUserLogin,
+			message:  `op=login hostname="test-host" exe="/usr/bin/tailscaled" ts_user="[email protected]" ts_node="node.tail-scale.ts.net"`,
+			wantType: auditUserLogin,
+		},
+		{
+			name:     "message-with-special-chars",
+			msgType:  auditUserLogin,
+			message:  `op=login hostname="host with spaces" ts_user="user [email protected]" ts_display_name="User \"Quote\" Name"`,
+			wantType: auditUserLogin,
+		},
+		{
+			name:     "long-message-truncated",
+			msgType:  auditUserLogin,
+			message:  string(make([]byte, 2000)),
+			wantType: auditUserLogin,
+		},
+	}
+
+	for _, tc := range testCases {
+		t.Run(tc.name, func(t *testing.T) {
+			msg, err := buildAuditNetlinkMessage(tc.msgType, tc.message)
+			if err != nil {
+				t.Fatalf("buildAuditNetlinkMessage failed: %v", err)
+			}
+
+			if len(msg) < syscall.NLMSG_HDRLEN {
+				t.Fatalf("message too short: got %d bytes, want at least %d", len(msg), syscall.NLMSG_HDRLEN)
+			}
+
+			var nlh syscall.NlMsghdr
+			buf := bytes.NewReader(msg[:syscall.NLMSG_HDRLEN])
+			if err := binary.Read(buf, binary.NativeEndian, &nlh); err != nil {
+				t.Fatalf("failed to parse netlink header: %v", err)
+			}
+
+			if nlh.Type != tc.wantType {
+				t.Errorf("message type: got %d, want %d", nlh.Type, tc.wantType)
+			}
+
+			if nlh.Flags != nlmFRequest {
+				t.Errorf("flags: got 0x%x, want 0x%x", nlh.Flags, nlmFRequest)
+			}
+
+			if len(msg)%syscall.NLMSG_ALIGNTO != 0 {
+				t.Errorf("message not aligned: len=%d, alignment=%d", len(msg), syscall.NLMSG_ALIGNTO)
+			}
+
+			payloadLen := int(nlh.Len) - syscall.NLMSG_HDRLEN
+			if payloadLen < 0 {
+				t.Fatalf("invalid payload length: %d", payloadLen)
+			}
+
+			payload := msg[syscall.NLMSG_HDRLEN : syscall.NLMSG_HDRLEN+payloadLen]
+
+			expectedMsg := tc.message
+			if len(expectedMsg) > maxAuditMessageLength {
+				expectedMsg = expectedMsg[:maxAuditMessageLength]
+			}
+			if string(payload) != expectedMsg {
+				t.Errorf("payload mismatch:\ngot:  %q\nwant: %q", string(payload), expectedMsg)
+			}
+
+			expectedLen := syscall.NLMSG_HDRLEN + len(payload)
+			if int(nlh.Len) != expectedLen {
+				t.Errorf("length field: got %d, want %d", nlh.Len, expectedLen)
+			}
+		})
+	}
+}
+
+func TestAuditIntegration(t *testing.T) {
+	if !hasAuditWriteCap() {
+		t.Skip("skipping: CAP_AUDIT_WRITE not in effective capability set")
+	}
+
+	if _, err := exec.LookPath("journalctl"); err != nil {
+		t.Skip("skipping: journalctl not available")
+	}
+
+	ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
+	defer cancel()
+
+	checkCmd := maybeWithSudo(ctx, "journalctl", "--field", "_TRANSPORT")
+	var out bytes.Buffer
+	checkCmd.Stdout = &out
+	if err := checkCmd.Run(); err != nil {
+		t.Skipf("skipping: cannot query journalctl transports: %v", err)
+	}
+	if !strings.Contains(out.String(), "audit") {
+		t.Skip("skipping: journald not configured for audit messages, try: systemctl enable systemd-journald-audit.socket && systemctl restart systemd-journald")
+	}
+
+	testID := fmt.Sprintf("tailscale-test-%d", time.Now().UnixNano())
+	testMsg := fmt.Sprintf("op=test-audit test_id=%s res=success", testID)
+
+	followCmd := maybeWithSudo(ctx, "journalctl", "-f", "_TRANSPORT=audit", "--no-pager")
+
+	stdout, err := followCmd.StdoutPipe()
+	if err != nil {
+		t.Fatalf("failed to get stdout pipe: %v", err)
+	}
+
+	if err := followCmd.Start(); err != nil {
+		t.Fatalf("failed to start journalctl: %v", err)
+	}
+	defer followCmd.Process.Kill()
+
+	testLogf := func(format string, args ...any) {
+		t.Logf(format, args...)
+	}
+	sendAuditMessage(testLogf, auditUserLogin, testMsg)
+
+	bs := bufio.NewScanner(stdout)
+	found := false
+	for bs.Scan() {
+		line := bs.Text()
+		if strings.Contains(line, testID) {
+			t.Logf("found audit log entry: %s", line)
+			found = true
+			break
+		}
+	}
+
+	if err := bs.Err(); err != nil && ctx.Err() == nil {
+		t.Fatalf("error reading journalctl output: %v", err)
+	}
+
+	if !found {
+		if ctx.Err() == context.DeadlineExceeded {
+			t.Errorf("timeout waiting for audit message with test_id=%s", testID)
+		} else {
+			t.Errorf("audit message with test_id=%s not found in journald audit log", testID)
+		}
+	}
+}

+ 10 - 0
ssh/tailssh/tailssh.go

@@ -31,6 +31,7 @@ import (
 
 	gossh "golang.org/x/crypto/ssh"
 	"tailscale.com/envknob"
+	"tailscale.com/feature"
 	"tailscale.com/ipn/ipnlocal"
 	"tailscale.com/net/tsaddr"
 	"tailscale.com/net/tsdial"
@@ -56,6 +57,10 @@ var (
 	// authentication methods that may proceed), which results in the SSH
 	// server immediately disconnecting the client.
 	errTerminal = &gossh.PartialSuccessError{}
+
+	// hookSSHLoginSuccess is called after successful SSH authentication.
+	// It is set by platform-specific code (e.g., auditd_linux.go).
+	hookSSHLoginSuccess feature.Hook[func(logf logger.Logf, c *conn)]
 )
 
 const (
@@ -647,6 +652,11 @@ func (c *conn) handleSessionPostSSHAuth(s ssh.Session) {
 	ss := c.newSSHSession(s)
 	ss.logf("handling new SSH connection from %v (%v) to ssh-user %q", c.info.uprof.LoginName, c.info.src.Addr(), c.localUser.Username)
 	ss.logf("access granted to %v as ssh-user %q", c.info.uprof.LoginName, c.localUser.Username)
+
+	if f, ok := hookSSHLoginSuccess.GetOk(); ok {
+		f(c.srv.logf, c)
+	}
+
 	ss.run()
 }