| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180 |
- // Copyright (c) Tailscale Inc & AUTHORS
- // SPDX-License-Identifier: BSD-3-Clause
- //go:build linux && !android
- package tailssh
- import (
- "bufio"
- "bytes"
- "context"
- "encoding/binary"
- "fmt"
- "os"
- "os/exec"
- "strings"
- "syscall"
- "testing"
- "time"
- )
- // maybeWithSudo returns a command with context that may be prefixed with sudo if not running as root.
- func maybeWithSudo(ctx context.Context, name string, args ...string) *exec.Cmd {
- if os.Geteuid() == 0 {
- return exec.CommandContext(ctx, name, args...)
- }
- sudoArgs := append([]string{name}, args...)
- return exec.CommandContext(ctx, "sudo", sudoArgs...)
- }
- func TestBuildAuditNetlinkMessage(t *testing.T) {
- testCases := []struct {
- name string
- msgType uint16
- message string
- wantType uint16
- }{
- {
- name: "simple-message",
- msgType: auditUserLogin,
- message: "op=login acct=test",
- wantType: auditUserLogin,
- },
- {
- name: "message-with-quoted-fields",
- msgType: auditUserLogin,
- message: `op=login hostname="test-host" exe="/usr/bin/tailscaled" ts_user="[email protected]" ts_node="node.tail-scale.ts.net"`,
- wantType: auditUserLogin,
- },
- {
- name: "message-with-special-chars",
- msgType: auditUserLogin,
- message: `op=login hostname="host with spaces" ts_user="user [email protected]" ts_display_name="User \"Quote\" Name"`,
- wantType: auditUserLogin,
- },
- {
- name: "long-message-truncated",
- msgType: auditUserLogin,
- message: string(make([]byte, 2000)),
- wantType: auditUserLogin,
- },
- }
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- msg, err := buildAuditNetlinkMessage(tc.msgType, tc.message)
- if err != nil {
- t.Fatalf("buildAuditNetlinkMessage failed: %v", err)
- }
- if len(msg) < syscall.NLMSG_HDRLEN {
- t.Fatalf("message too short: got %d bytes, want at least %d", len(msg), syscall.NLMSG_HDRLEN)
- }
- var nlh syscall.NlMsghdr
- buf := bytes.NewReader(msg[:syscall.NLMSG_HDRLEN])
- if err := binary.Read(buf, binary.NativeEndian, &nlh); err != nil {
- t.Fatalf("failed to parse netlink header: %v", err)
- }
- if nlh.Type != tc.wantType {
- t.Errorf("message type: got %d, want %d", nlh.Type, tc.wantType)
- }
- if nlh.Flags != nlmFRequest {
- t.Errorf("flags: got 0x%x, want 0x%x", nlh.Flags, nlmFRequest)
- }
- if len(msg)%syscall.NLMSG_ALIGNTO != 0 {
- t.Errorf("message not aligned: len=%d, alignment=%d", len(msg), syscall.NLMSG_ALIGNTO)
- }
- payloadLen := int(nlh.Len) - syscall.NLMSG_HDRLEN
- if payloadLen < 0 {
- t.Fatalf("invalid payload length: %d", payloadLen)
- }
- payload := msg[syscall.NLMSG_HDRLEN : syscall.NLMSG_HDRLEN+payloadLen]
- expectedMsg := tc.message
- if len(expectedMsg) > maxAuditMessageLength {
- expectedMsg = expectedMsg[:maxAuditMessageLength]
- }
- if string(payload) != expectedMsg {
- t.Errorf("payload mismatch:\ngot: %q\nwant: %q", string(payload), expectedMsg)
- }
- expectedLen := syscall.NLMSG_HDRLEN + len(payload)
- if int(nlh.Len) != expectedLen {
- t.Errorf("length field: got %d, want %d", nlh.Len, expectedLen)
- }
- })
- }
- }
- func TestAuditIntegration(t *testing.T) {
- if !hasAuditWriteCap() {
- t.Skip("skipping: CAP_AUDIT_WRITE not in effective capability set")
- }
- if _, err := exec.LookPath("journalctl"); err != nil {
- t.Skip("skipping: journalctl not available")
- }
- ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
- defer cancel()
- checkCmd := maybeWithSudo(ctx, "journalctl", "--field", "_TRANSPORT")
- var out bytes.Buffer
- checkCmd.Stdout = &out
- if err := checkCmd.Run(); err != nil {
- t.Skipf("skipping: cannot query journalctl transports: %v", err)
- }
- if !strings.Contains(out.String(), "audit") {
- t.Skip("skipping: journald not configured for audit messages, try: systemctl enable systemd-journald-audit.socket && systemctl restart systemd-journald")
- }
- testID := fmt.Sprintf("tailscale-test-%d", time.Now().UnixNano())
- testMsg := fmt.Sprintf("op=test-audit test_id=%s res=success", testID)
- followCmd := maybeWithSudo(ctx, "journalctl", "-f", "_TRANSPORT=audit", "--no-pager")
- stdout, err := followCmd.StdoutPipe()
- if err != nil {
- t.Fatalf("failed to get stdout pipe: %v", err)
- }
- if err := followCmd.Start(); err != nil {
- t.Fatalf("failed to start journalctl: %v", err)
- }
- defer followCmd.Process.Kill()
- testLogf := func(format string, args ...any) {
- t.Logf(format, args...)
- }
- sendAuditMessage(testLogf, auditUserLogin, testMsg)
- bs := bufio.NewScanner(stdout)
- found := false
- for bs.Scan() {
- line := bs.Text()
- if strings.Contains(line, testID) {
- t.Logf("found audit log entry: %s", line)
- found = true
- break
- }
- }
- if err := bs.Err(); err != nil && ctx.Err() == nil {
- t.Fatalf("error reading journalctl output: %v", err)
- }
- if !found {
- if ctx.Err() == context.DeadlineExceeded {
- t.Errorf("timeout waiting for audit message with test_id=%s", testID)
- } else {
- t.Errorf("audit message with test_id=%s not found in journald audit log", testID)
- }
- }
- }
|