auditd_linux_test.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. //go:build linux && !android
  4. package tailssh
  5. import (
  6. "bufio"
  7. "bytes"
  8. "context"
  9. "encoding/binary"
  10. "fmt"
  11. "os"
  12. "os/exec"
  13. "strings"
  14. "syscall"
  15. "testing"
  16. "time"
  17. )
  18. // maybeWithSudo returns a command with context that may be prefixed with sudo if not running as root.
  19. func maybeWithSudo(ctx context.Context, name string, args ...string) *exec.Cmd {
  20. if os.Geteuid() == 0 {
  21. return exec.CommandContext(ctx, name, args...)
  22. }
  23. sudoArgs := append([]string{name}, args...)
  24. return exec.CommandContext(ctx, "sudo", sudoArgs...)
  25. }
  26. func TestBuildAuditNetlinkMessage(t *testing.T) {
  27. testCases := []struct {
  28. name string
  29. msgType uint16
  30. message string
  31. wantType uint16
  32. }{
  33. {
  34. name: "simple-message",
  35. msgType: auditUserLogin,
  36. message: "op=login acct=test",
  37. wantType: auditUserLogin,
  38. },
  39. {
  40. name: "message-with-quoted-fields",
  41. msgType: auditUserLogin,
  42. message: `op=login hostname="test-host" exe="/usr/bin/tailscaled" ts_user="[email protected]" ts_node="node.tail-scale.ts.net"`,
  43. wantType: auditUserLogin,
  44. },
  45. {
  46. name: "message-with-special-chars",
  47. msgType: auditUserLogin,
  48. message: `op=login hostname="host with spaces" ts_user="user [email protected]" ts_display_name="User \"Quote\" Name"`,
  49. wantType: auditUserLogin,
  50. },
  51. {
  52. name: "long-message-truncated",
  53. msgType: auditUserLogin,
  54. message: string(make([]byte, 2000)),
  55. wantType: auditUserLogin,
  56. },
  57. }
  58. for _, tc := range testCases {
  59. t.Run(tc.name, func(t *testing.T) {
  60. msg, err := buildAuditNetlinkMessage(tc.msgType, tc.message)
  61. if err != nil {
  62. t.Fatalf("buildAuditNetlinkMessage failed: %v", err)
  63. }
  64. if len(msg) < syscall.NLMSG_HDRLEN {
  65. t.Fatalf("message too short: got %d bytes, want at least %d", len(msg), syscall.NLMSG_HDRLEN)
  66. }
  67. var nlh syscall.NlMsghdr
  68. buf := bytes.NewReader(msg[:syscall.NLMSG_HDRLEN])
  69. if err := binary.Read(buf, binary.NativeEndian, &nlh); err != nil {
  70. t.Fatalf("failed to parse netlink header: %v", err)
  71. }
  72. if nlh.Type != tc.wantType {
  73. t.Errorf("message type: got %d, want %d", nlh.Type, tc.wantType)
  74. }
  75. if nlh.Flags != nlmFRequest {
  76. t.Errorf("flags: got 0x%x, want 0x%x", nlh.Flags, nlmFRequest)
  77. }
  78. if len(msg)%syscall.NLMSG_ALIGNTO != 0 {
  79. t.Errorf("message not aligned: len=%d, alignment=%d", len(msg), syscall.NLMSG_ALIGNTO)
  80. }
  81. payloadLen := int(nlh.Len) - syscall.NLMSG_HDRLEN
  82. if payloadLen < 0 {
  83. t.Fatalf("invalid payload length: %d", payloadLen)
  84. }
  85. payload := msg[syscall.NLMSG_HDRLEN : syscall.NLMSG_HDRLEN+payloadLen]
  86. expectedMsg := tc.message
  87. if len(expectedMsg) > maxAuditMessageLength {
  88. expectedMsg = expectedMsg[:maxAuditMessageLength]
  89. }
  90. if string(payload) != expectedMsg {
  91. t.Errorf("payload mismatch:\ngot: %q\nwant: %q", string(payload), expectedMsg)
  92. }
  93. expectedLen := syscall.NLMSG_HDRLEN + len(payload)
  94. if int(nlh.Len) != expectedLen {
  95. t.Errorf("length field: got %d, want %d", nlh.Len, expectedLen)
  96. }
  97. })
  98. }
  99. }
  100. func TestAuditIntegration(t *testing.T) {
  101. if !hasAuditWriteCap() {
  102. t.Skip("skipping: CAP_AUDIT_WRITE not in effective capability set")
  103. }
  104. if _, err := exec.LookPath("journalctl"); err != nil {
  105. t.Skip("skipping: journalctl not available")
  106. }
  107. ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
  108. defer cancel()
  109. checkCmd := maybeWithSudo(ctx, "journalctl", "--field", "_TRANSPORT")
  110. var out bytes.Buffer
  111. checkCmd.Stdout = &out
  112. if err := checkCmd.Run(); err != nil {
  113. t.Skipf("skipping: cannot query journalctl transports: %v", err)
  114. }
  115. if !strings.Contains(out.String(), "audit") {
  116. t.Skip("skipping: journald not configured for audit messages, try: systemctl enable systemd-journald-audit.socket && systemctl restart systemd-journald")
  117. }
  118. testID := fmt.Sprintf("tailscale-test-%d", time.Now().UnixNano())
  119. testMsg := fmt.Sprintf("op=test-audit test_id=%s res=success", testID)
  120. followCmd := maybeWithSudo(ctx, "journalctl", "-f", "_TRANSPORT=audit", "--no-pager")
  121. stdout, err := followCmd.StdoutPipe()
  122. if err != nil {
  123. t.Fatalf("failed to get stdout pipe: %v", err)
  124. }
  125. if err := followCmd.Start(); err != nil {
  126. t.Fatalf("failed to start journalctl: %v", err)
  127. }
  128. defer followCmd.Process.Kill()
  129. testLogf := func(format string, args ...any) {
  130. t.Logf(format, args...)
  131. }
  132. sendAuditMessage(testLogf, auditUserLogin, testMsg)
  133. bs := bufio.NewScanner(stdout)
  134. found := false
  135. for bs.Scan() {
  136. line := bs.Text()
  137. if strings.Contains(line, testID) {
  138. t.Logf("found audit log entry: %s", line)
  139. found = true
  140. break
  141. }
  142. }
  143. if err := bs.Err(); err != nil && ctx.Err() == nil {
  144. t.Fatalf("error reading journalctl output: %v", err)
  145. }
  146. if !found {
  147. if ctx.Err() == context.DeadlineExceeded {
  148. t.Errorf("timeout waiting for audit message with test_id=%s", testID)
  149. } else {
  150. t.Errorf("audit message with test_id=%s not found in journald audit log", testID)
  151. }
  152. }
  153. }