nftables.go 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. // TODO(#8502): add support for more architectures
  4. //go:build linux && (arm64 || amd64)
  5. package linuxfw
  6. import (
  7. "fmt"
  8. "sort"
  9. "strings"
  10. "github.com/google/nftables"
  11. "github.com/google/nftables/expr"
  12. "github.com/google/nftables/xt"
  13. "github.com/josharian/native"
  14. "golang.org/x/sys/unix"
  15. "tailscale.com/types/logger"
  16. "tailscale.com/util/cmpx"
  17. )
  18. // DebugNetfilter prints debug information about netfilter rules to the
  19. // provided log function.
  20. func DebugNetfilter(logf logger.Logf) error {
  21. conn, err := nftables.New()
  22. if err != nil {
  23. return err
  24. }
  25. chains, err := conn.ListChains()
  26. if err != nil {
  27. return fmt.Errorf("cannot list chains: %w", err)
  28. }
  29. if len(chains) == 0 {
  30. logf("netfilter: no chains")
  31. return nil
  32. }
  33. for _, chain := range chains {
  34. logf("netfilter: table=%s chain=%s", chain.Table.Name, chain.Name)
  35. rules, err := conn.GetRules(chain.Table, chain)
  36. if err != nil {
  37. continue
  38. }
  39. sort.Slice(rules, func(i, j int) bool {
  40. return rules[i].Position < rules[j].Position
  41. })
  42. for i, rule := range rules {
  43. logf("netfilter: rule[%d]: pos=%d flags=%d", i, rule.Position, rule.Flags)
  44. for _, ex := range rule.Exprs {
  45. switch v := ex.(type) {
  46. case *expr.Meta:
  47. key := cmpx.Or(metaKeyNames[v.Key], "UNKNOWN")
  48. logf("netfilter: Meta: key=%s source_register=%v register=%d", key, v.SourceRegister, v.Register)
  49. case *expr.Cmp:
  50. op := cmpx.Or(cmpOpNames[v.Op], "UNKNOWN")
  51. logf("netfilter: Cmp: op=%s register=%d data=%s", op, v.Register, formatMaybePrintable(v.Data))
  52. case *expr.Counter:
  53. // don't print
  54. case *expr.Verdict:
  55. kind := cmpx.Or(verdictNames[v.Kind], "UNKNOWN")
  56. logf("netfilter: Verdict: kind=%s data=%s", kind, v.Chain)
  57. case *expr.Target:
  58. logf("netfilter: Target: name=%s info=%s", v.Name, printTargetInfo(v.Name, v.Info))
  59. case *expr.Match:
  60. logf("netfilter: Match: name=%s info=%+v", v.Name, printMatchInfo(v.Name, v.Info))
  61. case *expr.Payload:
  62. logf("netfilter: Payload: op=%s src=%d dst=%d base=%s offset=%d len=%d",
  63. payloadOperationTypeNames[v.OperationType],
  64. v.SourceRegister, v.DestRegister,
  65. payloadBaseNames[v.Base],
  66. v.Offset, v.Len)
  67. // TODO(andrew): csum
  68. case *expr.Bitwise:
  69. var xor string
  70. for _, b := range v.Xor {
  71. if b != 0 {
  72. xor = fmt.Sprintf(" xor=%v", v.Xor)
  73. break
  74. }
  75. }
  76. logf("netfilter: Bitwise: src=%d dst=%d len=%d mask=%v%s",
  77. v.SourceRegister, v.DestRegister, v.Len, v.Mask, xor)
  78. default:
  79. logf("netfilter: unknown %T: %+v", v, v)
  80. }
  81. }
  82. }
  83. }
  84. return nil
  85. }
  86. // DetectNetfilter returns the number of nftables rules present in the system.
  87. func DetectNetfilter() (int, error) {
  88. conn, err := nftables.New()
  89. if err != nil {
  90. return 0, err
  91. }
  92. chains, err := conn.ListChains()
  93. if err != nil {
  94. return 0, fmt.Errorf("cannot list chains: %w", err)
  95. }
  96. var validRules int
  97. for _, chain := range chains {
  98. rules, err := conn.GetRules(chain.Table, chain)
  99. if err != nil {
  100. continue
  101. }
  102. validRules += len(rules)
  103. }
  104. return validRules, nil
  105. }
  106. func printMatchInfo(name string, info xt.InfoAny) string {
  107. var sb strings.Builder
  108. sb.WriteString(`{`)
  109. var handled bool = true
  110. switch v := info.(type) {
  111. // TODO(andrew): we should support these common types
  112. //case *xt.ConntrackMtinfo3:
  113. //case *xt.ConntrackMtinfo2:
  114. case *xt.Tcp:
  115. fmt.Fprintf(&sb, "Src:%s Dst:%s", formatPortRange(v.SrcPorts), formatPortRange(v.DstPorts))
  116. if v.Option != 0 {
  117. fmt.Fprintf(&sb, " Option:%d", v.Option)
  118. }
  119. if v.FlagsMask != 0 {
  120. fmt.Fprintf(&sb, " FlagsMask:%d", v.FlagsMask)
  121. }
  122. if v.FlagsCmp != 0 {
  123. fmt.Fprintf(&sb, " FlagsCmp:%d", v.FlagsCmp)
  124. }
  125. if v.InvFlags != 0 {
  126. fmt.Fprintf(&sb, " InvFlags:%d", v.InvFlags)
  127. }
  128. case *xt.Udp:
  129. fmt.Fprintf(&sb, "Src:%s Dst:%s", formatPortRange(v.SrcPorts), formatPortRange(v.DstPorts))
  130. if v.InvFlags != 0 {
  131. fmt.Fprintf(&sb, " InvFlags:%d", v.InvFlags)
  132. }
  133. case *xt.AddrType:
  134. var sprefix, dprefix string
  135. if v.InvertSource {
  136. sprefix = "!"
  137. }
  138. if v.InvertDest {
  139. dprefix = "!"
  140. }
  141. // TODO(andrew): translate source/dest
  142. fmt.Fprintf(&sb, "Source:%s%d Dest:%s%d", sprefix, v.Source, dprefix, v.Dest)
  143. case *xt.AddrTypeV1:
  144. // TODO(andrew): translate source/dest
  145. fmt.Fprintf(&sb, "Source:%d Dest:%d", v.Source, v.Dest)
  146. var flags []string
  147. for flag, name := range addrTypeFlagNames {
  148. if v.Flags&flag != 0 {
  149. flags = append(flags, name)
  150. }
  151. }
  152. if len(flags) > 0 {
  153. sort.Strings(flags)
  154. fmt.Fprintf(&sb, "Flags:%s", strings.Join(flags, ","))
  155. }
  156. default:
  157. handled = false
  158. }
  159. if handled {
  160. sb.WriteString(`}`)
  161. return sb.String()
  162. }
  163. unknown, ok := info.(*xt.Unknown)
  164. if !ok {
  165. return fmt.Sprintf("(%T)%+v", info, info)
  166. }
  167. data := []byte(*unknown)
  168. // Things where upstream has no type
  169. handled = true
  170. switch name {
  171. case "pkttype":
  172. if len(data) != 8 {
  173. handled = false
  174. break
  175. }
  176. pkttype := int(native.Endian.Uint32(data[0:4]))
  177. invert := int(native.Endian.Uint32(data[4:8]))
  178. var invertPrefix string
  179. if invert != 0 {
  180. invertPrefix = "!"
  181. }
  182. pkttypeName := packetTypeNames[pkttype]
  183. if pkttypeName != "" {
  184. fmt.Fprintf(&sb, "PktType:%s%s", invertPrefix, pkttypeName)
  185. } else {
  186. fmt.Fprintf(&sb, "PktType:%s%d", invertPrefix, pkttype)
  187. }
  188. default:
  189. handled = true
  190. }
  191. if !handled {
  192. return fmt.Sprintf("(%T)%+v", info, info)
  193. }
  194. sb.WriteString(`}`)
  195. return sb.String()
  196. }
  197. func printTargetInfo(name string, info xt.InfoAny) string {
  198. var sb strings.Builder
  199. sb.WriteString(`{`)
  200. unknown, ok := info.(*xt.Unknown)
  201. if !ok {
  202. return fmt.Sprintf("(%T)%+v", info, info)
  203. }
  204. data := []byte(*unknown)
  205. // Things where upstream has no type
  206. switch name {
  207. case "LOG":
  208. if len(data) != 32 {
  209. fmt.Fprintf(&sb, `Error:"bad size; want 32, got %d"`, len(data))
  210. break
  211. }
  212. level := data[0]
  213. logflags := data[1]
  214. prefix := unix.ByteSliceToString(data[2:])
  215. fmt.Fprintf(&sb, "Level:%d LogFlags:%d Prefix:%q", level, logflags, prefix)
  216. default:
  217. return fmt.Sprintf("(%T)%+v", info, info)
  218. }
  219. sb.WriteString(`}`)
  220. return sb.String()
  221. }