nftables_runner.go 63 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. //go:build linux
  4. package linuxfw
  5. import (
  6. "encoding/binary"
  7. "encoding/hex"
  8. "errors"
  9. "fmt"
  10. "net"
  11. "net/netip"
  12. "reflect"
  13. "strings"
  14. "github.com/google/nftables"
  15. "github.com/google/nftables/expr"
  16. "golang.org/x/sys/unix"
  17. "tailscale.com/net/tsaddr"
  18. "tailscale.com/types/logger"
  19. "tailscale.com/types/ptr"
  20. )
  21. const (
  22. chainNameForward = "ts-forward"
  23. chainNameInput = "ts-input"
  24. chainNamePostrouting = "ts-postrouting"
  25. )
  26. // chainTypeRegular is an nftables chain that does not apply to a hook.
  27. const chainTypeRegular = ""
  28. type chainInfo struct {
  29. table *nftables.Table
  30. name string
  31. chainType nftables.ChainType
  32. chainHook *nftables.ChainHook
  33. chainPriority *nftables.ChainPriority
  34. chainPolicy *nftables.ChainPolicy
  35. }
  36. // nftable contains nat and filter tables for the given IP family (Proto).
  37. type nftable struct {
  38. Proto nftables.TableFamily // IPv4 or IPv6
  39. Filter *nftables.Table
  40. Nat *nftables.Table
  41. }
  42. // nftablesRunner implements a netfilterRunner using the netlink based nftables
  43. // library. As nftables allows for arbitrary tables and chains, there is a need
  44. // to follow conventions in order to integrate well with a surrounding
  45. // ecosystem. The rules installed by nftablesRunner have the following
  46. // properties:
  47. // - Install rules that intend to take precedence over rules installed by
  48. // other software. Tailscale provides packet filtering for tailnet traffic
  49. // inside the daemon based on the tailnet ACL rules.
  50. // - As nftables "accept" is not final, rules from high priority tables (low
  51. // numbers) will fall through to lower priority tables (high numbers). In
  52. // order to effectively be 'final', we install "jump" rules into conventional
  53. // tables and chains that will reach an accept verdict inside those tables.
  54. // - The table and chain conventions followed here are those used by
  55. // `iptables-nft` and `ufw`, so that those tools co-exist and do not
  56. // negatively affect Tailscale function.
  57. // - Be mindful that 1) all chains attached to a given hook (i.e the forward hook)
  58. // will be processed in priority order till either a rule in one of the chains issues a drop verdict
  59. // or there are no more chains for that hook
  60. // 2) processing of individual rules within a chain will stop once one of them issues a final verdict (accept, drop).
  61. // https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains
  62. type nftablesRunner struct {
  63. conn *nftables.Conn
  64. nft4 *nftable // IPv4 tables, never nil
  65. nft6 *nftable // IPv6 tables or nil if the system does not support IPv6
  66. v6Available bool // whether the host supports IPv6
  67. }
  68. func (n *nftablesRunner) ensurePreroutingChain(dst netip.Addr) (*nftables.Table, *nftables.Chain, error) {
  69. polAccept := nftables.ChainPolicyAccept
  70. table, err := n.getNFTByAddr(dst)
  71. if err != nil {
  72. return nil, nil, fmt.Errorf("error setting up nftables for IP family of %v: %w", dst, err)
  73. }
  74. nat, err := createTableIfNotExist(n.conn, table.Proto, "nat")
  75. if err != nil {
  76. return nil, nil, fmt.Errorf("error ensuring nat table: %w", err)
  77. }
  78. // ensure prerouting chain exists
  79. preroutingCh, err := getOrCreateChain(n.conn, chainInfo{
  80. table: nat,
  81. name: "PREROUTING",
  82. chainType: nftables.ChainTypeNAT,
  83. chainHook: nftables.ChainHookPrerouting,
  84. chainPriority: nftables.ChainPriorityNATDest,
  85. chainPolicy: &polAccept,
  86. })
  87. if err != nil {
  88. return nil, nil, fmt.Errorf("error ensuring prerouting chain: %w", err)
  89. }
  90. return nat, preroutingCh, nil
  91. }
  92. func (n *nftablesRunner) AddDNATRule(origDst netip.Addr, dst netip.Addr) error {
  93. nat, preroutingCh, err := n.ensurePreroutingChain(dst)
  94. if err != nil {
  95. return err
  96. }
  97. rule := dnatRuleForChain(nat, preroutingCh, origDst, dst, nil)
  98. n.conn.InsertRule(rule)
  99. return n.conn.Flush()
  100. }
  101. func dnatRuleForChain(t *nftables.Table, ch *nftables.Chain, origDst, dst netip.Addr, meta []byte) *nftables.Rule {
  102. var daddrOffset, fam, dadderLen uint32
  103. if origDst.Is4() {
  104. daddrOffset = 16
  105. dadderLen = 4
  106. fam = unix.NFPROTO_IPV4
  107. } else {
  108. daddrOffset = 24
  109. dadderLen = 16
  110. fam = unix.NFPROTO_IPV6
  111. }
  112. rule := &nftables.Rule{
  113. Table: t,
  114. Chain: ch,
  115. Exprs: []expr.Any{
  116. &expr.Payload{
  117. DestRegister: 1,
  118. Base: expr.PayloadBaseNetworkHeader,
  119. Offset: daddrOffset,
  120. Len: dadderLen,
  121. },
  122. &expr.Cmp{
  123. Op: expr.CmpOpEq,
  124. Register: 1,
  125. Data: origDst.AsSlice(),
  126. },
  127. &expr.Immediate{
  128. Register: 1,
  129. Data: dst.AsSlice(),
  130. },
  131. &expr.NAT{
  132. Type: expr.NATTypeDestNAT,
  133. Family: fam,
  134. RegAddrMin: 1,
  135. },
  136. },
  137. }
  138. if len(meta) > 0 {
  139. rule.UserData = meta
  140. }
  141. return rule
  142. }
  143. // DNATWithLoadBalancer currently just forwards all traffic destined for origDst
  144. // to the first IP address from the backend targets.
  145. // TODO (irbekrm): instead of doing this load balance traffic evenly to all
  146. // backend destinations.
  147. // https://github.com/tailscale/tailscale/commit/d37f2f508509c6c35ad724fd75a27685b90b575b#diff-a3bcbcd1ca198799f4f768dc56fea913e1945a6b3ec9dbec89325a84a19a85e7R148-R232
  148. func (n *nftablesRunner) DNATWithLoadBalancer(origDst netip.Addr, dsts []netip.Addr) error {
  149. return n.AddDNATRule(origDst, dsts[0])
  150. }
  151. func (n *nftablesRunner) DNATNonTailscaleTraffic(tunname string, dst netip.Addr) error {
  152. nat, preroutingCh, err := n.ensurePreroutingChain(dst)
  153. if err != nil {
  154. return err
  155. }
  156. var famConst uint32
  157. if dst.Is4() {
  158. famConst = unix.NFPROTO_IPV4
  159. } else {
  160. famConst = unix.NFPROTO_IPV6
  161. }
  162. dnatRule := &nftables.Rule{
  163. Table: nat,
  164. Chain: preroutingCh,
  165. Exprs: []expr.Any{
  166. &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
  167. &expr.Cmp{
  168. Op: expr.CmpOpNeq,
  169. Register: 1,
  170. Data: []byte(tunname),
  171. },
  172. &expr.Immediate{
  173. Register: 1,
  174. Data: dst.AsSlice(),
  175. },
  176. &expr.NAT{
  177. Type: expr.NATTypeDestNAT,
  178. Family: famConst,
  179. RegAddrMin: 1,
  180. },
  181. },
  182. }
  183. n.conn.InsertRule(dnatRule)
  184. return n.conn.Flush()
  185. }
  186. func (n *nftablesRunner) EnsureSNATForDst(src, dst netip.Addr) error {
  187. polAccept := nftables.ChainPolicyAccept
  188. table, err := n.getNFTByAddr(dst)
  189. if err != nil {
  190. return fmt.Errorf("error setting up nftables for IP family of %v: %w", dst, err)
  191. }
  192. nat, err := createTableIfNotExist(n.conn, table.Proto, "nat")
  193. if err != nil {
  194. return fmt.Errorf("error ensuring nat table exists: %w", err)
  195. }
  196. // ensure postrouting chain exists
  197. postRoutingCh, err := getOrCreateChain(n.conn, chainInfo{
  198. table: nat,
  199. name: "POSTROUTING",
  200. chainType: nftables.ChainTypeNAT,
  201. chainHook: nftables.ChainHookPostrouting,
  202. chainPriority: nftables.ChainPriorityNATSource,
  203. chainPolicy: &polAccept,
  204. })
  205. if err != nil {
  206. return fmt.Errorf("error ensuring postrouting chain: %w", err)
  207. }
  208. rules, err := n.conn.GetRules(nat, postRoutingCh)
  209. if err != nil {
  210. return fmt.Errorf("error listing rules: %w", err)
  211. }
  212. snatRulePrefixMatch := fmt.Sprintf("dst:%s,src:", dst.String())
  213. snatRuleFullMatch := fmt.Sprintf("%s%s", snatRulePrefixMatch, src.String())
  214. for _, rule := range rules {
  215. current := string(rule.UserData)
  216. if strings.HasPrefix(string(rule.UserData), snatRulePrefixMatch) {
  217. if strings.EqualFold(current, snatRuleFullMatch) {
  218. return nil // already exists, do nothing
  219. }
  220. if err := n.conn.DelRule(rule); err != nil {
  221. return fmt.Errorf("error deleting SNAT rule: %w", err)
  222. }
  223. }
  224. }
  225. rule := snatRule(nat, postRoutingCh, src, dst, []byte(snatRuleFullMatch))
  226. n.conn.AddRule(rule)
  227. return n.conn.Flush()
  228. }
  229. // ClampMSSToPMTU ensures that all packets with TCP flags (SYN, ACK, RST) set
  230. // being forwarded via the given interface (tun) have MSS set to <MTU of the
  231. // interface> - 40 (IP and TCP headers). This can be useful if this tailscale
  232. // instance is expected to run as a forwarding proxy, forwarding packets from an
  233. // endpoint with higher MTU in an environment where path MTU discovery is
  234. // expected to not work (such as the proxies created by the Tailscale Kubernetes
  235. // operator). ClamMSSToPMTU creates a new base-chain ts-clamp in the filter
  236. // table with accept policy and priority -150. In practice, this means that for
  237. // SYN packets the clamp rule in this chain will likely run first and accept the
  238. // packet. This is fine because 1) nftables run ALL chains with the same hook
  239. // type unless a rule in one of them drops the packet and 2) this chain does not
  240. // have functionality to drop the packet- so in practice a matching clamp rule
  241. // will always be followed by the custom tailscale filtering rules in the other
  242. // chains attached to the filter hook (FORWARD, ts-forward).
  243. // We do not want to place the clamping rule into FORWARD/ts-forward chains
  244. // because wgengine populates those chains with rules that contain accept
  245. // verdicts that would cause no further procesing within that chain. This
  246. // functionality is currently invoked from outside wgengine (containerboot), so
  247. // we don't want to race with wgengine for rule ordering within chains.
  248. func (n *nftablesRunner) ClampMSSToPMTU(tun string, addr netip.Addr) error {
  249. polAccept := nftables.ChainPolicyAccept
  250. table, err := n.getNFTByAddr(addr)
  251. if err != nil {
  252. return fmt.Errorf("error setting up nftables for IP family of %v: %w", addr, err)
  253. }
  254. filterTable, err := createTableIfNotExist(n.conn, table.Proto, "filter")
  255. if err != nil {
  256. return fmt.Errorf("error ensuring filter table: %w", err)
  257. }
  258. // ensure ts-clamp chain exists
  259. fwChain, err := getOrCreateChain(n.conn, chainInfo{
  260. table: filterTable,
  261. name: "ts-clamp",
  262. chainType: nftables.ChainTypeFilter,
  263. chainHook: nftables.ChainHookForward,
  264. chainPriority: nftables.ChainPriorityMangle,
  265. chainPolicy: &polAccept,
  266. })
  267. if err != nil {
  268. return fmt.Errorf("error ensuring forward chain: %w", err)
  269. }
  270. clampRule := &nftables.Rule{
  271. Table: filterTable,
  272. Chain: fwChain,
  273. Exprs: []expr.Any{
  274. &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
  275. &expr.Cmp{
  276. Op: expr.CmpOpEq,
  277. Register: 1,
  278. Data: []byte(tun),
  279. },
  280. &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
  281. &expr.Cmp{
  282. Op: expr.CmpOpEq,
  283. Register: 1,
  284. Data: []byte{unix.IPPROTO_TCP},
  285. },
  286. &expr.Payload{
  287. DestRegister: 1,
  288. Base: expr.PayloadBaseTransportHeader,
  289. Offset: 13,
  290. Len: 1,
  291. },
  292. &expr.Bitwise{
  293. DestRegister: 1,
  294. SourceRegister: 1,
  295. Len: 1,
  296. Mask: []byte{0x02},
  297. Xor: []byte{0x00},
  298. },
  299. &expr.Cmp{
  300. Op: expr.CmpOpNeq, // match any packet with a TCP flag set (SYN, ACK, RST)
  301. Register: 1,
  302. Data: []byte{0x00},
  303. },
  304. &expr.Rt{
  305. Register: 1,
  306. Key: expr.RtTCPMSS,
  307. },
  308. &expr.Byteorder{
  309. DestRegister: 1,
  310. SourceRegister: 1,
  311. Op: expr.ByteorderHton,
  312. Len: 2,
  313. Size: 2,
  314. },
  315. &expr.Exthdr{
  316. SourceRegister: 1,
  317. Type: 2,
  318. Offset: 2,
  319. Len: 2,
  320. Op: expr.ExthdrOpTcpopt,
  321. },
  322. },
  323. }
  324. n.conn.AddRule(clampRule)
  325. return n.conn.Flush()
  326. }
  327. // deleteTableIfExists deletes a nftables table via connection c if it exists
  328. // within the given family.
  329. func deleteTableIfExists(c *nftables.Conn, family nftables.TableFamily, name string) error {
  330. t, err := getTableIfExists(c, family, name)
  331. if err != nil {
  332. return fmt.Errorf("get table: %w", err)
  333. }
  334. if t == nil {
  335. // Table does not exist, so nothing to delete.
  336. return nil
  337. }
  338. c.DelTable(t)
  339. if err := c.Flush(); err != nil {
  340. if t, err = getTableIfExists(c, family, name); t == nil && err == nil {
  341. // Check if the table still exists. If it does not, then the error
  342. // is due to the table not existing, so we can ignore it. Maybe a
  343. // concurrent process deleted the table.
  344. return nil
  345. }
  346. return fmt.Errorf("del table: %w", err)
  347. }
  348. return nil
  349. }
  350. // getTableIfExists returns the table with the given name from the given family
  351. // if it exists. If none match, it returns (nil, nil).
  352. func getTableIfExists(c *nftables.Conn, family nftables.TableFamily, name string) (*nftables.Table, error) {
  353. tables, err := c.ListTables()
  354. if err != nil {
  355. return nil, fmt.Errorf("get tables: %w", err)
  356. }
  357. for _, table := range tables {
  358. if table.Name == name && table.Family == family {
  359. return table, nil
  360. }
  361. }
  362. return nil, nil
  363. }
  364. // createTableIfNotExist creates a nftables table via connection c if it does
  365. // not exist within the given family.
  366. func createTableIfNotExist(c *nftables.Conn, family nftables.TableFamily, name string) (*nftables.Table, error) {
  367. if t, err := getTableIfExists(c, family, name); err != nil {
  368. return nil, fmt.Errorf("get table: %w", err)
  369. } else if t != nil {
  370. return t, nil
  371. }
  372. t := c.AddTable(&nftables.Table{
  373. Family: family,
  374. Name: name,
  375. })
  376. if err := c.Flush(); err != nil {
  377. return nil, fmt.Errorf("add table: %w", err)
  378. }
  379. return t, nil
  380. }
  381. type errorChainNotFound struct {
  382. chainName string
  383. tableName string
  384. }
  385. func (e errorChainNotFound) Error() string {
  386. return fmt.Sprintf("chain %s not found in table %s", e.chainName, e.tableName)
  387. }
  388. // getChainFromTable returns the chain with the given name from the given table.
  389. // Note that a chain name is unique within a table.
  390. func getChainFromTable(c *nftables.Conn, table *nftables.Table, name string) (*nftables.Chain, error) {
  391. chains, err := c.ListChainsOfTableFamily(table.Family)
  392. if err != nil {
  393. return nil, fmt.Errorf("list chains: %w", err)
  394. }
  395. for _, chain := range chains {
  396. // Table family is already checked so table name is unique
  397. if chain.Table.Name == table.Name && chain.Name == name {
  398. return chain, nil
  399. }
  400. }
  401. return nil, errorChainNotFound{table.Name, name}
  402. }
  403. // isTSChain reports whether `name` begins with "ts-" (and is thus a
  404. // Tailscale-managed chain).
  405. func isTSChain(name string) bool {
  406. return strings.HasPrefix(name, "ts-")
  407. }
  408. // createChainIfNotExist creates a chain with the given name in the given table
  409. // if it does not exist.
  410. func createChainIfNotExist(c *nftables.Conn, cinfo chainInfo) error {
  411. _, err := getOrCreateChain(c, cinfo)
  412. return err
  413. }
  414. func getOrCreateChain(c *nftables.Conn, cinfo chainInfo) (*nftables.Chain, error) {
  415. chain, err := getChainFromTable(c, cinfo.table, cinfo.name)
  416. if err != nil && !errors.Is(err, errorChainNotFound{cinfo.table.Name, cinfo.name}) {
  417. return nil, fmt.Errorf("get chain: %w", err)
  418. } else if err == nil {
  419. // The chain already exists. If it is a TS chain, check the
  420. // type/hook/priority, but for "conventional chains" assume they're what
  421. // we expect (in case iptables-nft/ufw make minor behavior changes in
  422. // the future).
  423. if isTSChain(chain.Name) && (chain.Type != cinfo.chainType || *chain.Hooknum != *cinfo.chainHook || *chain.Priority != *cinfo.chainPriority) {
  424. return nil, fmt.Errorf("chain %s already exists with different type/hook/priority", cinfo.name)
  425. }
  426. return chain, nil
  427. }
  428. chain = c.AddChain(&nftables.Chain{
  429. Name: cinfo.name,
  430. Table: cinfo.table,
  431. Type: cinfo.chainType,
  432. Hooknum: cinfo.chainHook,
  433. Priority: cinfo.chainPriority,
  434. Policy: cinfo.chainPolicy,
  435. })
  436. if err := c.Flush(); err != nil {
  437. return nil, fmt.Errorf("add chain: %w", err)
  438. }
  439. return chain, nil
  440. }
  441. // NetfilterRunner abstracts helpers to run netfilter commands. It is
  442. // implemented by linuxfw.IPTablesRunner and linuxfw.NfTablesRunner.
  443. type NetfilterRunner interface {
  444. // AddLoopbackRule adds a rule to permit loopback traffic to addr. This rule
  445. // is added only if it does not already exist.
  446. AddLoopbackRule(addr netip.Addr) error
  447. // DelLoopbackRule removes the rule added by AddLoopbackRule.
  448. DelLoopbackRule(addr netip.Addr) error
  449. // AddHooks adds rules to conventional chains like "FORWARD", "INPUT" and
  450. // "POSTROUTING" to jump from those chains to tailscale chains.
  451. AddHooks() error
  452. // DelHooks deletes rules added by AddHooks.
  453. DelHooks(logf logger.Logf) error
  454. // AddChains creates custom Tailscale chains.
  455. AddChains() error
  456. // DelChains removes chains added by AddChains.
  457. DelChains() error
  458. // AddBase adds rules reused by different other rules.
  459. AddBase(tunname string) error
  460. // DelBase removes rules added by AddBase.
  461. DelBase() error
  462. // AddSNATRule adds the netfilter rule to SNAT incoming traffic over
  463. // the Tailscale interface destined for local subnets. An error is
  464. // returned if the rule already exists.
  465. AddSNATRule() error
  466. // DelSNATRule removes the rule added by AddSNATRule.
  467. DelSNATRule() error
  468. // AddStatefulRule adds a netfilter rule for stateful packet filtering
  469. // using conntrack.
  470. AddStatefulRule(tunname string) error
  471. // DelStatefulRule removes a netfilter rule for stateful packet filtering
  472. // using conntrack.
  473. DelStatefulRule(tunname string) error
  474. // HasIPV6 reports true if the system supports IPv6.
  475. HasIPV6() bool
  476. // HasIPV6NAT reports true if the system supports IPv6 NAT.
  477. HasIPV6NAT() bool
  478. // HasIPV6Filter reports true if the system supports IPv6 filter tables
  479. // This is only meaningful for iptables implementation, where hosts have
  480. // partial ipables support (i.e missing filter table). For nftables
  481. // implementation, this will default to the value of HasIPv6().
  482. HasIPV6Filter() bool
  483. // AddDNATRule adds a rule to the nat/PREROUTING chain to DNAT traffic
  484. // destined for the given original destination to the given new destination.
  485. // This is used to forward all traffic destined for the Tailscale interface
  486. // to the provided destination, as used in the Kubernetes ingress proxies.
  487. AddDNATRule(origDst, dst netip.Addr) error
  488. // DNATWithLoadBalancer adds a rule to the nat/PREROUTING chain to DNAT
  489. // traffic destined for the given original destination to the given new
  490. // destination(s) using round robin to load balance if more than one
  491. // destination is provided. This is used to forward all traffic destined
  492. // for the Tailscale interface to the provided destination(s), as used
  493. // in the Kubernetes ingress proxies.
  494. DNATWithLoadBalancer(origDst netip.Addr, dsts []netip.Addr) error
  495. // EnsureSNATForDst sets up firewall to mask the source for traffic destined for dst to src:
  496. // - creates a SNAT rule if it doesn't already exist
  497. // - deletes any pre-existing rules matching the destination
  498. // This is used to forward traffic destined for the local machine over
  499. // the Tailscale interface, as used in the Kubernetes egress proxies.
  500. EnsureSNATForDst(src, dst netip.Addr) error
  501. // DNATNonTailscaleTraffic adds a rule to the nat/PREROUTING chain to DNAT
  502. // all traffic inbound from any interface except exemptInterface to dst.
  503. // This is used to forward traffic destined for the local machine over
  504. // the Tailscale interface, as used in the Kubernetes egress proxies.
  505. DNATNonTailscaleTraffic(exemptInterface string, dst netip.Addr) error
  506. EnsurePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error
  507. DeletePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error
  508. EnsureDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error
  509. DeleteDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error
  510. DeleteSvc(svc, tun string, targetIPs []netip.Addr, pm []PortMap) error
  511. // ClampMSSToPMTU adds a rule to the mangle/FORWARD chain to clamp MSS for
  512. // traffic destined for the provided tun interface.
  513. ClampMSSToPMTU(tun string, addr netip.Addr) error
  514. // AddMagicsockPortRule adds a rule to the ts-input chain to accept
  515. // incoming traffic on the specified port, to allow magicsock to
  516. // communicate.
  517. AddMagicsockPortRule(port uint16, network string) error
  518. // DelMagicsockPortRule removes the rule created by AddMagicsockPortRule,
  519. // if it exists.
  520. DelMagicsockPortRule(port uint16, network string) error
  521. }
  522. // New creates a NetfilterRunner, auto-detecting whether to use
  523. // nftables or iptables.
  524. // As nftables is still experimental, iptables will be used unless
  525. // either the TS_DEBUG_FIREWALL_MODE environment variable, or the prefHint
  526. // parameter, is set to one of "nftables" or "auto".
  527. func New(logf logger.Logf, prefHint string) (NetfilterRunner, error) {
  528. mode := detectFirewallMode(logf, prefHint)
  529. switch mode {
  530. case FirewallModeIPTables:
  531. // Note that we don't simply return an newIPTablesRunner here because it
  532. // would return a `nil` iptablesRunner which is different from returning
  533. // a nil NetfilterRunner.
  534. ipr, err := newIPTablesRunner(logf)
  535. if err != nil {
  536. return nil, err
  537. }
  538. return ipr, nil
  539. case FirewallModeNfTables:
  540. // Note that we don't simply return an newNfTablesRunner here because it
  541. // would return a `nil` nftablesRunner which is different from returning
  542. // a nil NetfilterRunner.
  543. nfr, err := newNfTablesRunner(logf)
  544. if err != nil {
  545. return nil, err
  546. }
  547. return nfr, nil
  548. default:
  549. return nil, fmt.Errorf("unknown firewall mode %v", mode)
  550. }
  551. }
  552. // newNfTablesRunner creates a new nftablesRunner without guaranteeing
  553. // the existence of the tables and chains.
  554. func newNfTablesRunner(logf logger.Logf) (*nftablesRunner, error) {
  555. conn, err := nftables.New()
  556. if err != nil {
  557. return nil, fmt.Errorf("nftables connection: %w", err)
  558. }
  559. return newNfTablesRunnerWithConn(logf, conn), nil
  560. }
  561. func newNfTablesRunnerWithConn(logf logger.Logf, conn *nftables.Conn) *nftablesRunner {
  562. nft4 := &nftable{Proto: nftables.TableFamilyIPv4}
  563. v6err := CheckIPv6(logf)
  564. if v6err != nil {
  565. logf("disabling tunneled IPv6 due to system IPv6 config: %v", v6err)
  566. }
  567. supportsV6 := v6err == nil
  568. var nft6 *nftable
  569. if supportsV6 {
  570. nft6 = &nftable{Proto: nftables.TableFamilyIPv6}
  571. }
  572. logf("netfilter running in nftables mode, v6 = %v", supportsV6)
  573. // TODO(KevinLiang10): convert iptables rule to nftable rules if they exist in the iptables
  574. return &nftablesRunner{
  575. conn: conn,
  576. nft4: nft4,
  577. nft6: nft6,
  578. v6Available: supportsV6,
  579. }
  580. }
  581. // newLoadSaddrExpr creates a new nftables expression that loads the source
  582. // address of the packet into the given register.
  583. func newLoadSaddrExpr(proto nftables.TableFamily, destReg uint32) (expr.Any, error) {
  584. switch proto {
  585. case nftables.TableFamilyIPv4:
  586. return &expr.Payload{
  587. DestRegister: destReg,
  588. Base: expr.PayloadBaseNetworkHeader,
  589. Offset: 12,
  590. Len: 4,
  591. }, nil
  592. case nftables.TableFamilyIPv6:
  593. return &expr.Payload{
  594. DestRegister: destReg,
  595. Base: expr.PayloadBaseNetworkHeader,
  596. Offset: 8,
  597. Len: 16,
  598. }, nil
  599. default:
  600. return nil, fmt.Errorf("table family %v is neither IPv4 nor IPv6", proto)
  601. }
  602. }
  603. // newLoadDportExpr creates a new nftables express that loads the desination port
  604. // of a TCP/UDP packet into the given register.
  605. func newLoadDportExpr(destReg uint32) expr.Any {
  606. return &expr.Payload{
  607. DestRegister: destReg,
  608. Base: expr.PayloadBaseTransportHeader,
  609. Offset: 2,
  610. Len: 2,
  611. }
  612. }
  613. // HasIPV6 reports true if the system supports IPv6.
  614. func (n *nftablesRunner) HasIPV6() bool {
  615. return n.v6Available
  616. }
  617. // HasIPV6NAT returns true if the system supports IPv6.
  618. // Kernel support for nftables was added after support for IPv6
  619. // NAT, so no need for a separate IPv6 NAT support check like we do for iptables.
  620. // https://tldp.org/HOWTO/Linux+IPv6-HOWTO/ch18s04.html
  621. // https://wiki.nftables.org/wiki-nftables/index.php/Building_and_installing_nftables_from_sources
  622. func (n *nftablesRunner) HasIPV6NAT() bool {
  623. return n.v6Available
  624. }
  625. // HasIPV6Filter returns true if system supports IPv6. There are no known edge
  626. // cases where nftables running on a host that supports IPv6 would not support
  627. // filter table.
  628. func (n *nftablesRunner) HasIPV6Filter() bool {
  629. return n.v6Available
  630. }
  631. // findRule iterates through the rules to find the rule with matching expressions.
  632. func findRule(conn *nftables.Conn, rule *nftables.Rule) (*nftables.Rule, error) {
  633. rules, err := conn.GetRules(rule.Table, rule.Chain)
  634. if err != nil {
  635. return nil, fmt.Errorf("get nftables rules: %w", err)
  636. }
  637. if len(rules) == 0 {
  638. return nil, nil
  639. }
  640. ruleLoop:
  641. for _, r := range rules {
  642. if len(r.Exprs) != len(rule.Exprs) {
  643. continue
  644. }
  645. for i, e := range r.Exprs {
  646. // Skip counter expressions, as they will not match.
  647. if _, ok := e.(*expr.Counter); ok {
  648. continue
  649. }
  650. if !reflect.DeepEqual(e, rule.Exprs[i]) {
  651. continue ruleLoop
  652. }
  653. }
  654. return r, nil
  655. }
  656. return nil, nil
  657. }
  658. func createLoopbackRule(
  659. proto nftables.TableFamily,
  660. table *nftables.Table,
  661. chain *nftables.Chain,
  662. addr netip.Addr,
  663. ) (*nftables.Rule, error) {
  664. saddrExpr, err := newLoadSaddrExpr(proto, 1)
  665. if err != nil {
  666. return nil, fmt.Errorf("newLoadSaddrExpr: %w", err)
  667. }
  668. loopBackRule := &nftables.Rule{
  669. Table: table,
  670. Chain: chain,
  671. Exprs: []expr.Any{
  672. &expr.Meta{
  673. Key: expr.MetaKeyIIFNAME,
  674. Register: 1,
  675. },
  676. &expr.Cmp{
  677. Op: expr.CmpOpEq,
  678. Register: 1,
  679. Data: []byte("lo"),
  680. },
  681. saddrExpr,
  682. &expr.Cmp{
  683. Op: expr.CmpOpEq,
  684. Register: 1,
  685. Data: addr.AsSlice(),
  686. },
  687. &expr.Counter{},
  688. &expr.Verdict{
  689. Kind: expr.VerdictAccept,
  690. },
  691. },
  692. }
  693. return loopBackRule, nil
  694. }
  695. // insertLoopbackRule inserts the TS loop back rule into
  696. // the given chain as the first rule if it does not exist.
  697. func insertLoopbackRule(
  698. conn *nftables.Conn, proto nftables.TableFamily,
  699. table *nftables.Table, chain *nftables.Chain, addr netip.Addr) error {
  700. loopBackRule, err := createLoopbackRule(proto, table, chain, addr)
  701. if err != nil {
  702. return fmt.Errorf("create loopback rule: %w", err)
  703. }
  704. // If TestDial is set, we are running in test mode and we should not
  705. // find rule because header will mismatch.
  706. if conn.TestDial == nil {
  707. // Check if the rule already exists.
  708. rule, err := findRule(conn, loopBackRule)
  709. if err != nil {
  710. return fmt.Errorf("find rule: %w", err)
  711. }
  712. if rule != nil {
  713. // Rule already exists, no need to insert.
  714. return nil
  715. }
  716. }
  717. // This inserts the rule to the top of the chain
  718. _ = conn.InsertRule(loopBackRule)
  719. if err = conn.Flush(); err != nil {
  720. return fmt.Errorf("insert rule: %w", err)
  721. }
  722. return nil
  723. }
  724. // getNFTByAddr returns the nftables with correct IP family
  725. // that we will be using for the given address.
  726. func (n *nftablesRunner) getNFTByAddr(addr netip.Addr) (*nftable, error) {
  727. if addr.Is6() && !n.v6Available {
  728. return nil, fmt.Errorf("nftables for IPv6 are not available on this host")
  729. }
  730. if addr.Is6() {
  731. return n.nft6, nil
  732. }
  733. return n.nft4, nil
  734. }
  735. // AddLoopbackRule adds an nftables rule to permit loopback traffic to
  736. // a local Tailscale IP. This rule is added only if it does not already exist.
  737. func (n *nftablesRunner) AddLoopbackRule(addr netip.Addr) error {
  738. nf, err := n.getNFTByAddr(addr)
  739. if err != nil {
  740. return fmt.Errorf("error setting up nftables for IP family of %v: %w", addr, err)
  741. }
  742. inputChain, err := getChainFromTable(n.conn, nf.Filter, chainNameInput)
  743. if err != nil {
  744. return fmt.Errorf("get input chain: %w", err)
  745. }
  746. if err := insertLoopbackRule(n.conn, nf.Proto, nf.Filter, inputChain, addr); err != nil {
  747. return fmt.Errorf("add loopback rule: %w", err)
  748. }
  749. return nil
  750. }
  751. // DelLoopbackRule removes the nftables rule permitting loopback
  752. // traffic to a Tailscale IP.
  753. func (n *nftablesRunner) DelLoopbackRule(addr netip.Addr) error {
  754. nf, err := n.getNFTByAddr(addr)
  755. if err != nil {
  756. return fmt.Errorf("error setting up nftables for IP family of %v: %w", addr, err)
  757. }
  758. inputChain, err := getChainFromTable(n.conn, nf.Filter, chainNameInput)
  759. if err != nil {
  760. return fmt.Errorf("get input chain: %w", err)
  761. }
  762. loopBackRule, err := createLoopbackRule(nf.Proto, nf.Filter, inputChain, addr)
  763. if err != nil {
  764. return fmt.Errorf("create loopback rule: %w", err)
  765. }
  766. existingLoopBackRule, err := findRule(n.conn, loopBackRule)
  767. if err != nil {
  768. return fmt.Errorf("find loop back rule: %w", err)
  769. }
  770. if existingLoopBackRule == nil {
  771. // Rule does not exist, no need to delete.
  772. return nil
  773. }
  774. if err := n.conn.DelRule(existingLoopBackRule); err != nil {
  775. return fmt.Errorf("delete rule: %w", err)
  776. }
  777. return n.conn.Flush()
  778. }
  779. // getTables returns tables for IP families that this host was determined to
  780. // support (either IPv4 and IPv6 or just IPv4).
  781. func (n *nftablesRunner) getTables() []*nftable {
  782. if n.HasIPV6() {
  783. return []*nftable{n.nft4, n.nft6}
  784. }
  785. return []*nftable{n.nft4}
  786. }
  787. // AddChains creates custom Tailscale chains in netfilter via nftables
  788. // if the ts-chain doesn't already exist.
  789. func (n *nftablesRunner) AddChains() error {
  790. polAccept := nftables.ChainPolicyAccept
  791. for _, table := range n.getTables() {
  792. // Create the filter table if it doesn't exist, this table name is the same
  793. // as the name used by iptables-nft and ufw. We install rules into the
  794. // same conventional table so that `accept` verdicts from our jump
  795. // chains are conclusive.
  796. filter, err := createTableIfNotExist(n.conn, table.Proto, "filter")
  797. if err != nil {
  798. return fmt.Errorf("create table: %w", err)
  799. }
  800. table.Filter = filter
  801. // Adding the "conventional chains" that are used by iptables-nft and ufw.
  802. if err = createChainIfNotExist(n.conn, chainInfo{filter, "FORWARD", nftables.ChainTypeFilter, nftables.ChainHookForward, nftables.ChainPriorityFilter, &polAccept}); err != nil {
  803. return fmt.Errorf("create forward chain: %w", err)
  804. }
  805. if err = createChainIfNotExist(n.conn, chainInfo{filter, "INPUT", nftables.ChainTypeFilter, nftables.ChainHookInput, nftables.ChainPriorityFilter, &polAccept}); err != nil {
  806. return fmt.Errorf("create input chain: %w", err)
  807. }
  808. // Adding the tailscale chains that contain our rules.
  809. if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameForward, chainTypeRegular, nil, nil, nil}); err != nil {
  810. return fmt.Errorf("create forward chain: %w", err)
  811. }
  812. if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameInput, chainTypeRegular, nil, nil, nil}); err != nil {
  813. return fmt.Errorf("create input chain: %w", err)
  814. }
  815. // Create the nat table if it doesn't exist, this table name is the same
  816. // as the name used by iptables-nft and ufw. We install rules into the
  817. // same conventional table so that `accept` verdicts from our jump
  818. // chains are conclusive.
  819. nat, err := createTableIfNotExist(n.conn, table.Proto, "nat")
  820. if err != nil {
  821. return fmt.Errorf("create table: %w", err)
  822. }
  823. table.Nat = nat
  824. // Adding the "conventional chains" that are used by iptables-nft and ufw.
  825. if err = createChainIfNotExist(n.conn, chainInfo{nat, "POSTROUTING", nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATSource, &polAccept}); err != nil {
  826. return fmt.Errorf("create postrouting chain: %w", err)
  827. }
  828. // Adding the tailscale chain that contains our rules.
  829. if err = createChainIfNotExist(n.conn, chainInfo{nat, chainNamePostrouting, chainTypeRegular, nil, nil, nil}); err != nil {
  830. return fmt.Errorf("create postrouting chain: %w", err)
  831. }
  832. }
  833. return n.conn.Flush()
  834. }
  835. // These are dummy chains and tables we create to detect if nftables is
  836. // available. We create them, then delete them. If we can create and delete
  837. // them, then we can use nftables. If we can't, then we assume that we're
  838. // running on a system that doesn't support nftables. See
  839. // createDummyPostroutingChains.
  840. const (
  841. tsDummyChainName = "ts-test-postrouting"
  842. tsDummyTableName = "ts-test-nat"
  843. )
  844. // createDummyPostroutingChains creates dummy postrouting chains in netfilter
  845. // via netfilter via nftables, as a last resort measure to detect that nftables
  846. // can be used. It cleans up the dummy chains after creation.
  847. func (n *nftablesRunner) createDummyPostroutingChains() (retErr error) {
  848. polAccept := ptr.To(nftables.ChainPolicyAccept)
  849. for _, table := range n.getTables() {
  850. nat, err := createTableIfNotExist(n.conn, table.Proto, tsDummyTableName)
  851. if err != nil {
  852. return fmt.Errorf("create nat table: %w", err)
  853. }
  854. defer func(fm nftables.TableFamily) {
  855. if err := deleteTableIfExists(n.conn, fm, tsDummyTableName); err != nil && retErr == nil {
  856. retErr = fmt.Errorf("delete %q table: %w", tsDummyTableName, err)
  857. }
  858. }(table.Proto)
  859. table.Nat = nat
  860. if err = createChainIfNotExist(n.conn, chainInfo{nat, tsDummyChainName, nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATSource, polAccept}); err != nil {
  861. return fmt.Errorf("create %q chain: %w", tsDummyChainName, err)
  862. }
  863. if err := deleteChainIfExists(n.conn, nat, tsDummyChainName); err != nil {
  864. return fmt.Errorf("delete %q chain: %w", tsDummyChainName, err)
  865. }
  866. }
  867. return nil
  868. }
  869. // deleteChainIfExists deletes a chain if it exists.
  870. func deleteChainIfExists(c *nftables.Conn, table *nftables.Table, name string) error {
  871. chain, err := getChainFromTable(c, table, name)
  872. if err != nil && !errors.Is(err, errorChainNotFound{table.Name, name}) {
  873. return fmt.Errorf("get chain: %w", err)
  874. } else if err != nil {
  875. // If the chain doesn't exist, we don't need to delete it.
  876. return nil
  877. }
  878. c.FlushChain(chain)
  879. c.DelChain(chain)
  880. if err := c.Flush(); err != nil {
  881. return fmt.Errorf("flush and delete chain: %w", err)
  882. }
  883. return nil
  884. }
  885. // DelChains removes the custom Tailscale chains from netfilter via nftables.
  886. func (n *nftablesRunner) DelChains() error {
  887. for _, table := range n.getTables() {
  888. if err := deleteChainIfExists(n.conn, table.Filter, chainNameForward); err != nil {
  889. return fmt.Errorf("delete chain: %w", err)
  890. }
  891. if err := deleteChainIfExists(n.conn, table.Filter, chainNameInput); err != nil {
  892. return fmt.Errorf("delete chain: %w", err)
  893. }
  894. }
  895. if err := deleteChainIfExists(n.conn, n.nft4.Nat, chainNamePostrouting); err != nil {
  896. return fmt.Errorf("delete chain: %w", err)
  897. }
  898. if n.HasIPV6NAT() {
  899. if err := deleteChainIfExists(n.conn, n.nft6.Nat, chainNamePostrouting); err != nil {
  900. return fmt.Errorf("delete chain: %w", err)
  901. }
  902. }
  903. if err := n.conn.Flush(); err != nil {
  904. return fmt.Errorf("flush: %w", err)
  905. }
  906. return nil
  907. }
  908. // createHookRule creates a rule to jump from a hooked chain to a regular chain.
  909. func createHookRule(table *nftables.Table, fromChain *nftables.Chain, toChainName string) *nftables.Rule {
  910. exprs := []expr.Any{
  911. &expr.Counter{},
  912. &expr.Verdict{
  913. Kind: expr.VerdictJump,
  914. Chain: toChainName,
  915. },
  916. }
  917. rule := &nftables.Rule{
  918. Table: table,
  919. Chain: fromChain,
  920. Exprs: exprs,
  921. }
  922. return rule
  923. }
  924. // addHookRule adds a rule to jump from a hooked chain to a regular chain at top of the hooked chain.
  925. func addHookRule(conn *nftables.Conn, table *nftables.Table, fromChain *nftables.Chain, toChainName string) error {
  926. rule := createHookRule(table, fromChain, toChainName)
  927. _ = conn.InsertRule(rule)
  928. if err := conn.Flush(); err != nil {
  929. return fmt.Errorf("flush add rule: %w", err)
  930. }
  931. return nil
  932. }
  933. // AddHooks is adding rules to conventional chains like "FORWARD", "INPUT" and "POSTROUTING"
  934. // in tables and jump from those chains to tailscale chains.
  935. func (n *nftablesRunner) AddHooks() error {
  936. conn := n.conn
  937. for _, table := range n.getTables() {
  938. inputChain, err := getChainFromTable(conn, table.Filter, "INPUT")
  939. if err != nil {
  940. return fmt.Errorf("get INPUT chain: %w", err)
  941. }
  942. err = addHookRule(conn, table.Filter, inputChain, chainNameInput)
  943. if err != nil {
  944. return fmt.Errorf("Addhook: %w", err)
  945. }
  946. forwardChain, err := getChainFromTable(conn, table.Filter, "FORWARD")
  947. if err != nil {
  948. return fmt.Errorf("get FORWARD chain: %w", err)
  949. }
  950. err = addHookRule(conn, table.Filter, forwardChain, chainNameForward)
  951. if err != nil {
  952. return fmt.Errorf("Addhook: %w", err)
  953. }
  954. postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING")
  955. if err != nil {
  956. return fmt.Errorf("get INPUT chain: %w", err)
  957. }
  958. err = addHookRule(conn, table.Nat, postroutingChain, chainNamePostrouting)
  959. if err != nil {
  960. return fmt.Errorf("Addhook: %w", err)
  961. }
  962. }
  963. return nil
  964. }
  965. // delHookRule deletes a rule that jumps from a hooked chain to a regular chain.
  966. func delHookRule(conn *nftables.Conn, table *nftables.Table, fromChain *nftables.Chain, toChainName string) error {
  967. rule := createHookRule(table, fromChain, toChainName)
  968. existingRule, err := findRule(conn, rule)
  969. if err != nil {
  970. return fmt.Errorf("Failed to find hook rule: %w", err)
  971. }
  972. if existingRule == nil {
  973. return nil
  974. }
  975. _ = conn.DelRule(existingRule)
  976. if err := conn.Flush(); err != nil {
  977. return fmt.Errorf("flush del hook rule: %w", err)
  978. }
  979. return nil
  980. }
  981. // DelHooks is deleting the rules added to conventional chains to jump to tailscale chains.
  982. func (n *nftablesRunner) DelHooks(logf logger.Logf) error {
  983. conn := n.conn
  984. for _, table := range n.getTables() {
  985. inputChain, err := getChainFromTable(conn, table.Filter, "INPUT")
  986. if err != nil {
  987. return fmt.Errorf("get INPUT chain: %w", err)
  988. }
  989. err = delHookRule(conn, table.Filter, inputChain, chainNameInput)
  990. if err != nil {
  991. return fmt.Errorf("delhook: %w", err)
  992. }
  993. forwardChain, err := getChainFromTable(conn, table.Filter, "FORWARD")
  994. if err != nil {
  995. return fmt.Errorf("get FORWARD chain: %w", err)
  996. }
  997. err = delHookRule(conn, table.Filter, forwardChain, chainNameForward)
  998. if err != nil {
  999. return fmt.Errorf("delhook: %w", err)
  1000. }
  1001. postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING")
  1002. if err != nil {
  1003. return fmt.Errorf("get INPUT chain: %w", err)
  1004. }
  1005. err = delHookRule(conn, table.Nat, postroutingChain, chainNamePostrouting)
  1006. if err != nil {
  1007. return fmt.Errorf("delhook: %w", err)
  1008. }
  1009. }
  1010. return nil
  1011. }
  1012. // maskof returns the mask of the given prefix in big endian bytes.
  1013. func maskof(pfx netip.Prefix) []byte {
  1014. mask := make([]byte, 4)
  1015. binary.BigEndian.PutUint32(mask, ^(uint32(0xffff_ffff) >> pfx.Bits()))
  1016. return mask
  1017. }
  1018. // createRangeRule creates a rule that matches packets with source IP from the give
  1019. // range (like CGNAT range or ChromeOSVM range) and the interface is not the tunname,
  1020. // and makes the given decision. Only IPv4 is supported.
  1021. func createRangeRule(
  1022. table *nftables.Table, chain *nftables.Chain,
  1023. tunname string, rng netip.Prefix, decision expr.VerdictKind,
  1024. ) (*nftables.Rule, error) {
  1025. if rng.Addr().Is6() {
  1026. return nil, errors.New("IPv6 is not supported")
  1027. }
  1028. saddrExpr, err := newLoadSaddrExpr(nftables.TableFamilyIPv4, 1)
  1029. if err != nil {
  1030. return nil, fmt.Errorf("newLoadSaddrExpr: %w", err)
  1031. }
  1032. netip := rng.Addr().AsSlice()
  1033. mask := maskof(rng)
  1034. rule := &nftables.Rule{
  1035. Table: table,
  1036. Chain: chain,
  1037. Exprs: []expr.Any{
  1038. &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
  1039. &expr.Cmp{
  1040. Op: expr.CmpOpNeq,
  1041. Register: 1,
  1042. Data: []byte(tunname),
  1043. },
  1044. saddrExpr,
  1045. &expr.Bitwise{
  1046. SourceRegister: 1,
  1047. DestRegister: 1,
  1048. Len: 4,
  1049. Mask: mask,
  1050. Xor: []byte{0x00, 0x00, 0x00, 0x00},
  1051. },
  1052. &expr.Cmp{
  1053. Op: expr.CmpOpEq,
  1054. Register: 1,
  1055. Data: netip,
  1056. },
  1057. &expr.Counter{},
  1058. &expr.Verdict{
  1059. Kind: decision,
  1060. },
  1061. },
  1062. }
  1063. return rule, nil
  1064. }
  1065. // addReturnChromeOSVMRangeRule adds a rule to return if the source IP
  1066. // is in the ChromeOS VM range.
  1067. func addReturnChromeOSVMRangeRule(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
  1068. rule, err := createRangeRule(table, chain, tunname, tsaddr.ChromeOSVMRange(), expr.VerdictReturn)
  1069. if err != nil {
  1070. return fmt.Errorf("create rule: %w", err)
  1071. }
  1072. _ = c.AddRule(rule)
  1073. if err = c.Flush(); err != nil {
  1074. return fmt.Errorf("add rule: %w", err)
  1075. }
  1076. return nil
  1077. }
  1078. // addDropCGNATRangeRule adds a rule to drop if the source IP is in the
  1079. // CGNAT range.
  1080. func addDropCGNATRangeRule(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
  1081. rule, err := createRangeRule(table, chain, tunname, tsaddr.CGNATRange(), expr.VerdictDrop)
  1082. if err != nil {
  1083. return fmt.Errorf("create rule: %w", err)
  1084. }
  1085. _ = c.AddRule(rule)
  1086. if err = c.Flush(); err != nil {
  1087. return fmt.Errorf("add rule: %w", err)
  1088. }
  1089. return nil
  1090. }
  1091. // createSetSubnetRouteMarkRule creates a rule to set the subnet route
  1092. // mark if the packet is from the given interface.
  1093. func createSetSubnetRouteMarkRule(table *nftables.Table, chain *nftables.Chain, tunname string) (*nftables.Rule, error) {
  1094. hexTsFwmarkMaskNeg := getTailscaleFwmarkMaskNeg()
  1095. hexTSSubnetRouteMark := getTailscaleSubnetRouteMark()
  1096. rule := &nftables.Rule{
  1097. Table: table,
  1098. Chain: chain,
  1099. Exprs: []expr.Any{
  1100. &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
  1101. &expr.Cmp{
  1102. Op: expr.CmpOpEq,
  1103. Register: 1,
  1104. Data: []byte(tunname),
  1105. },
  1106. &expr.Counter{},
  1107. &expr.Meta{Key: expr.MetaKeyMARK, Register: 1},
  1108. &expr.Bitwise{
  1109. SourceRegister: 1,
  1110. DestRegister: 1,
  1111. Len: 4,
  1112. Mask: hexTsFwmarkMaskNeg,
  1113. Xor: hexTSSubnetRouteMark,
  1114. },
  1115. &expr.Meta{
  1116. Key: expr.MetaKeyMARK,
  1117. SourceRegister: true,
  1118. Register: 1,
  1119. },
  1120. },
  1121. }
  1122. return rule, nil
  1123. }
  1124. // addSetSubnetRouteMarkRule adds a rule to set the subnet route mark
  1125. // if the packet is from the given interface.
  1126. func addSetSubnetRouteMarkRule(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
  1127. rule, err := createSetSubnetRouteMarkRule(table, chain, tunname)
  1128. if err != nil {
  1129. return fmt.Errorf("create rule: %w", err)
  1130. }
  1131. _ = c.AddRule(rule)
  1132. if err := c.Flush(); err != nil {
  1133. return fmt.Errorf("add rule: %w", err)
  1134. }
  1135. return nil
  1136. }
  1137. // createDropOutgoingPacketFromCGNATRangeRuleWithTunname creates a rule to drop
  1138. // outgoing packets from the CGNAT range.
  1139. func createDropOutgoingPacketFromCGNATRangeRuleWithTunname(table *nftables.Table, chain *nftables.Chain, tunname string) (*nftables.Rule, error) {
  1140. _, ipNet, err := net.ParseCIDR(tsaddr.CGNATRange().String())
  1141. if err != nil {
  1142. return nil, fmt.Errorf("parse cidr: %v", err)
  1143. }
  1144. mask, err := hex.DecodeString(ipNet.Mask.String())
  1145. if err != nil {
  1146. return nil, fmt.Errorf("decode mask: %v", err)
  1147. }
  1148. netip := ipNet.IP.Mask(ipNet.Mask).To4()
  1149. saddrExpr, err := newLoadSaddrExpr(nftables.TableFamilyIPv4, 1)
  1150. if err != nil {
  1151. return nil, fmt.Errorf("newLoadSaddrExpr: %v", err)
  1152. }
  1153. rule := &nftables.Rule{
  1154. Table: table,
  1155. Chain: chain,
  1156. Exprs: []expr.Any{
  1157. &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
  1158. &expr.Cmp{
  1159. Op: expr.CmpOpEq,
  1160. Register: 1,
  1161. Data: []byte(tunname),
  1162. },
  1163. saddrExpr,
  1164. &expr.Bitwise{
  1165. SourceRegister: 1,
  1166. DestRegister: 1,
  1167. Len: 4,
  1168. Mask: mask,
  1169. Xor: []byte{0x00, 0x00, 0x00, 0x00},
  1170. },
  1171. &expr.Cmp{
  1172. Op: expr.CmpOpEq,
  1173. Register: 1,
  1174. Data: netip,
  1175. },
  1176. &expr.Counter{},
  1177. &expr.Verdict{
  1178. Kind: expr.VerdictDrop,
  1179. },
  1180. },
  1181. }
  1182. return rule, nil
  1183. }
  1184. // addDropOutgoingPacketFromCGNATRangeRuleWithTunname adds a rule to drop
  1185. // outgoing packets from the CGNAT range.
  1186. func addDropOutgoingPacketFromCGNATRangeRuleWithTunname(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
  1187. rule, err := createDropOutgoingPacketFromCGNATRangeRuleWithTunname(table, chain, tunname)
  1188. if err != nil {
  1189. return fmt.Errorf("create rule: %w", err)
  1190. }
  1191. _ = conn.AddRule(rule)
  1192. if err := conn.Flush(); err != nil {
  1193. return fmt.Errorf("add rule: %w", err)
  1194. }
  1195. return nil
  1196. }
  1197. // createAcceptOutgoingPacketRule creates a rule to accept outgoing packets
  1198. // from the given interface.
  1199. func createAcceptOutgoingPacketRule(table *nftables.Table, chain *nftables.Chain, tunname string) *nftables.Rule {
  1200. return &nftables.Rule{
  1201. Table: table,
  1202. Chain: chain,
  1203. Exprs: []expr.Any{
  1204. &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
  1205. &expr.Cmp{
  1206. Op: expr.CmpOpEq,
  1207. Register: 1,
  1208. Data: []byte(tunname),
  1209. },
  1210. &expr.Counter{},
  1211. &expr.Verdict{
  1212. Kind: expr.VerdictAccept,
  1213. },
  1214. },
  1215. }
  1216. }
  1217. // addAcceptOutgoingPacketRule adds a rule to accept outgoing packets
  1218. // from the given interface.
  1219. func addAcceptOutgoingPacketRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
  1220. rule := createAcceptOutgoingPacketRule(table, chain, tunname)
  1221. _ = conn.AddRule(rule)
  1222. if err := conn.Flush(); err != nil {
  1223. return fmt.Errorf("flush add rule: %w", err)
  1224. }
  1225. return nil
  1226. }
  1227. // createAcceptOnPortRule creates a rule to accept incoming packets to
  1228. // a given destination UDP port.
  1229. func createAcceptOnPortRule(table *nftables.Table, chain *nftables.Chain, port uint16) *nftables.Rule {
  1230. portBytes := make([]byte, 2)
  1231. binary.BigEndian.PutUint16(portBytes, port)
  1232. return &nftables.Rule{
  1233. Table: table,
  1234. Chain: chain,
  1235. Exprs: []expr.Any{
  1236. &expr.Meta{
  1237. Key: expr.MetaKeyL4PROTO,
  1238. Register: 1,
  1239. },
  1240. &expr.Cmp{
  1241. Op: expr.CmpOpEq,
  1242. Register: 1,
  1243. Data: []byte{unix.IPPROTO_UDP},
  1244. },
  1245. newLoadDportExpr(1),
  1246. &expr.Cmp{
  1247. Op: expr.CmpOpEq,
  1248. Register: 1,
  1249. Data: portBytes,
  1250. },
  1251. &expr.Counter{},
  1252. &expr.Verdict{
  1253. Kind: expr.VerdictAccept,
  1254. },
  1255. },
  1256. }
  1257. }
  1258. // addAcceptOnPortRule adds a rule to accept incoming packets to
  1259. // a given destination UDP port.
  1260. func addAcceptOnPortRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, port uint16) error {
  1261. rule := createAcceptOnPortRule(table, chain, port)
  1262. _ = conn.AddRule(rule)
  1263. if err := conn.Flush(); err != nil {
  1264. return fmt.Errorf("flush add rule: %w", err)
  1265. }
  1266. return nil
  1267. }
  1268. // addAcceptOnPortRule removes a rule to accept incoming packets to
  1269. // a given destination UDP port.
  1270. func removeAcceptOnPortRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, port uint16) error {
  1271. rule := createAcceptOnPortRule(table, chain, port)
  1272. rule, err := findRule(conn, rule)
  1273. if err != nil {
  1274. return fmt.Errorf("find rule: %v", err)
  1275. }
  1276. _ = conn.DelRule(rule)
  1277. if err := conn.Flush(); err != nil {
  1278. return fmt.Errorf("flush del rule: %w", err)
  1279. }
  1280. return nil
  1281. }
  1282. // AddMagicsockPortRule adds a rule to nftables to allow incoming traffic on
  1283. // the specified UDP port, so magicsock can accept incoming connections.
  1284. // network must be either "udp4" or "udp6" - this determines whether the rule
  1285. // is added for IPv4 or IPv6.
  1286. func (n *nftablesRunner) AddMagicsockPortRule(port uint16, network string) error {
  1287. var filterTable *nftables.Table
  1288. switch network {
  1289. case "udp4":
  1290. filterTable = n.nft4.Filter
  1291. case "udp6":
  1292. filterTable = n.nft6.Filter
  1293. default:
  1294. return fmt.Errorf("unsupported network %s", network)
  1295. }
  1296. inputChain, err := getChainFromTable(n.conn, filterTable, chainNameInput)
  1297. if err != nil {
  1298. return fmt.Errorf("get input chain: %v", err)
  1299. }
  1300. err = addAcceptOnPortRule(n.conn, filterTable, inputChain, port)
  1301. if err != nil {
  1302. return fmt.Errorf("add accept on port rule: %v", err)
  1303. }
  1304. return nil
  1305. }
  1306. // DelMagicsockPortRule removes a rule added by AddMagicsockPortRule to accept
  1307. // incoming traffic on a particular UDP port.
  1308. // network must be either "udp4" or "udp6" - this determines whether the rule
  1309. // is removed for IPv4 or IPv6.
  1310. func (n *nftablesRunner) DelMagicsockPortRule(port uint16, network string) error {
  1311. var filterTable *nftables.Table
  1312. switch network {
  1313. case "udp4":
  1314. filterTable = n.nft4.Filter
  1315. case "udp6":
  1316. filterTable = n.nft6.Filter
  1317. default:
  1318. return fmt.Errorf("unsupported network %s", network)
  1319. }
  1320. inputChain, err := getChainFromTable(n.conn, filterTable, chainNameInput)
  1321. if err != nil {
  1322. return fmt.Errorf("get input chain: %v", err)
  1323. }
  1324. err = removeAcceptOnPortRule(n.conn, filterTable, inputChain, port)
  1325. if err != nil {
  1326. return fmt.Errorf("add accept on port rule: %v", err)
  1327. }
  1328. return nil
  1329. }
  1330. // createAcceptIncomingPacketRule creates a rule to accept incoming packets to
  1331. // the given interface.
  1332. func createAcceptIncomingPacketRule(table *nftables.Table, chain *nftables.Chain, tunname string) *nftables.Rule {
  1333. return &nftables.Rule{
  1334. Table: table,
  1335. Chain: chain,
  1336. Exprs: []expr.Any{
  1337. &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
  1338. &expr.Cmp{
  1339. Op: expr.CmpOpEq,
  1340. Register: 1,
  1341. Data: []byte(tunname),
  1342. },
  1343. &expr.Counter{},
  1344. &expr.Verdict{
  1345. Kind: expr.VerdictAccept,
  1346. },
  1347. },
  1348. }
  1349. }
  1350. func addAcceptIncomingPacketRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
  1351. rule := createAcceptIncomingPacketRule(table, chain, tunname)
  1352. _ = conn.AddRule(rule)
  1353. if err := conn.Flush(); err != nil {
  1354. return fmt.Errorf("flush add rule: %w", err)
  1355. }
  1356. return nil
  1357. }
  1358. // AddBase adds some basic processing rules.
  1359. func (n *nftablesRunner) AddBase(tunname string) error {
  1360. if err := n.addBase4(tunname); err != nil {
  1361. return fmt.Errorf("add base v4: %w", err)
  1362. }
  1363. if n.HasIPV6() {
  1364. if err := n.addBase6(tunname); err != nil {
  1365. return fmt.Errorf("add base v6: %w", err)
  1366. }
  1367. }
  1368. return nil
  1369. }
  1370. // addBase4 adds some basic IPv4 processing rules.
  1371. func (n *nftablesRunner) addBase4(tunname string) error {
  1372. conn := n.conn
  1373. inputChain, err := getChainFromTable(conn, n.nft4.Filter, chainNameInput)
  1374. if err != nil {
  1375. return fmt.Errorf("get input chain v4: %v", err)
  1376. }
  1377. if err = addReturnChromeOSVMRangeRule(conn, n.nft4.Filter, inputChain, tunname); err != nil {
  1378. return fmt.Errorf("add return chromeos vm range rule v4: %w", err)
  1379. }
  1380. if err = addDropCGNATRangeRule(conn, n.nft4.Filter, inputChain, tunname); err != nil {
  1381. return fmt.Errorf("add drop cgnat range rule v4: %w", err)
  1382. }
  1383. if err = addAcceptIncomingPacketRule(conn, n.nft4.Filter, inputChain, tunname); err != nil {
  1384. return fmt.Errorf("add accept incoming packet rule v4: %w", err)
  1385. }
  1386. forwardChain, err := getChainFromTable(conn, n.nft4.Filter, chainNameForward)
  1387. if err != nil {
  1388. return fmt.Errorf("get forward chain v4: %v", err)
  1389. }
  1390. if err = addSetSubnetRouteMarkRule(conn, n.nft4.Filter, forwardChain, tunname); err != nil {
  1391. return fmt.Errorf("add set subnet route mark rule v4: %w", err)
  1392. }
  1393. if err = addMatchSubnetRouteMarkRule(conn, n.nft4.Filter, forwardChain, Accept); err != nil {
  1394. return fmt.Errorf("add match subnet route mark rule v4: %w", err)
  1395. }
  1396. if err = addDropOutgoingPacketFromCGNATRangeRuleWithTunname(conn, n.nft4.Filter, forwardChain, tunname); err != nil {
  1397. return fmt.Errorf("add drop outgoing packet from cgnat range rule v4: %w", err)
  1398. }
  1399. if err = addAcceptOutgoingPacketRule(conn, n.nft4.Filter, forwardChain, tunname); err != nil {
  1400. return fmt.Errorf("add accept outgoing packet rule v4: %w", err)
  1401. }
  1402. if err = conn.Flush(); err != nil {
  1403. return fmt.Errorf("flush base v4: %w", err)
  1404. }
  1405. return nil
  1406. }
  1407. // addBase6 adds some basic IPv6 processing rules.
  1408. func (n *nftablesRunner) addBase6(tunname string) error {
  1409. conn := n.conn
  1410. inputChain, err := getChainFromTable(conn, n.nft6.Filter, chainNameInput)
  1411. if err != nil {
  1412. return fmt.Errorf("get input chain v4: %v", err)
  1413. }
  1414. if err = addAcceptIncomingPacketRule(conn, n.nft6.Filter, inputChain, tunname); err != nil {
  1415. return fmt.Errorf("add accept incoming packet rule v6: %w", err)
  1416. }
  1417. forwardChain, err := getChainFromTable(conn, n.nft6.Filter, chainNameForward)
  1418. if err != nil {
  1419. return fmt.Errorf("get forward chain v6: %w", err)
  1420. }
  1421. if err = addSetSubnetRouteMarkRule(conn, n.nft6.Filter, forwardChain, tunname); err != nil {
  1422. return fmt.Errorf("add set subnet route mark rule v6: %w", err)
  1423. }
  1424. if err = addMatchSubnetRouteMarkRule(conn, n.nft6.Filter, forwardChain, Accept); err != nil {
  1425. return fmt.Errorf("add match subnet route mark rule v6: %w", err)
  1426. }
  1427. if err = addAcceptOutgoingPacketRule(conn, n.nft6.Filter, forwardChain, tunname); err != nil {
  1428. return fmt.Errorf("add accept outgoing packet rule v6: %w", err)
  1429. }
  1430. if err = conn.Flush(); err != nil {
  1431. return fmt.Errorf("flush base v6: %w", err)
  1432. }
  1433. return nil
  1434. }
  1435. // DelBase empties, but does not remove, custom Tailscale chains from
  1436. // netfilter via iptables.
  1437. func (n *nftablesRunner) DelBase() error {
  1438. conn := n.conn
  1439. for _, table := range n.getTables() {
  1440. inputChain, err := getChainFromTable(conn, table.Filter, chainNameInput)
  1441. if err != nil {
  1442. return fmt.Errorf("get input chain: %v", err)
  1443. }
  1444. conn.FlushChain(inputChain)
  1445. forwardChain, err := getChainFromTable(conn, table.Filter, chainNameForward)
  1446. if err != nil {
  1447. return fmt.Errorf("get forward chain: %v", err)
  1448. }
  1449. conn.FlushChain(forwardChain)
  1450. postrouteChain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
  1451. if err != nil {
  1452. return fmt.Errorf("get postrouting chain v4: %v", err)
  1453. }
  1454. conn.FlushChain(postrouteChain)
  1455. }
  1456. return conn.Flush()
  1457. }
  1458. // createMatchSubnetRouteMarkRule creates a rule that matches packets
  1459. // with the subnet route mark and takes the specified action.
  1460. func createMatchSubnetRouteMarkRule(table *nftables.Table, chain *nftables.Chain, action MatchDecision) (*nftables.Rule, error) {
  1461. hexTSFwmarkMask := getTailscaleFwmarkMask()
  1462. hexTSSubnetRouteMark := getTailscaleSubnetRouteMark()
  1463. var endAction expr.Any
  1464. endAction = &expr.Verdict{Kind: expr.VerdictAccept}
  1465. if action == Masq {
  1466. endAction = &expr.Masq{}
  1467. }
  1468. exprs := []expr.Any{
  1469. &expr.Meta{Key: expr.MetaKeyMARK, Register: 1},
  1470. &expr.Bitwise{
  1471. SourceRegister: 1,
  1472. DestRegister: 1,
  1473. Len: 4,
  1474. Mask: hexTSFwmarkMask,
  1475. Xor: []byte{0x00, 0x00, 0x00, 0x00},
  1476. },
  1477. &expr.Cmp{
  1478. Op: expr.CmpOpEq,
  1479. Register: 1,
  1480. Data: hexTSSubnetRouteMark,
  1481. },
  1482. &expr.Counter{},
  1483. endAction,
  1484. }
  1485. rule := &nftables.Rule{
  1486. Table: table,
  1487. Chain: chain,
  1488. Exprs: exprs,
  1489. }
  1490. return rule, nil
  1491. }
  1492. // addMatchSubnetRouteMarkRule adds a rule that matches packets with
  1493. // the subnet route mark and takes the specified action.
  1494. func addMatchSubnetRouteMarkRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, action MatchDecision) error {
  1495. rule, err := createMatchSubnetRouteMarkRule(table, chain, action)
  1496. if err != nil {
  1497. return fmt.Errorf("create match subnet route mark rule: %w", err)
  1498. }
  1499. _ = conn.AddRule(rule)
  1500. if err := conn.Flush(); err != nil {
  1501. return fmt.Errorf("flush add rule: %w", err)
  1502. }
  1503. return nil
  1504. }
  1505. // AddSNATRule adds a netfilter rule to SNAT traffic destined for
  1506. // local subnets.
  1507. func (n *nftablesRunner) AddSNATRule() error {
  1508. conn := n.conn
  1509. for _, table := range n.getTables() {
  1510. chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
  1511. if err != nil {
  1512. return fmt.Errorf("get postrouting chain v4: %w", err)
  1513. }
  1514. if err = addMatchSubnetRouteMarkRule(conn, table.Nat, chain, Masq); err != nil {
  1515. return fmt.Errorf("add match subnet route mark rule v4: %w", err)
  1516. }
  1517. }
  1518. if err := conn.Flush(); err != nil {
  1519. return fmt.Errorf("flush add SNAT rule: %w", err)
  1520. }
  1521. return nil
  1522. }
  1523. func delMatchSubnetRouteMarkMasqRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain) error {
  1524. rule, err := createMatchSubnetRouteMarkRule(table, chain, Masq)
  1525. if err != nil {
  1526. return fmt.Errorf("create match subnet route mark rule: %w", err)
  1527. }
  1528. SNATRule, err := findRule(conn, rule)
  1529. if err != nil {
  1530. return fmt.Errorf("find SNAT rule v4: %w", err)
  1531. }
  1532. if SNATRule != nil {
  1533. _ = conn.DelRule(SNATRule)
  1534. }
  1535. if err := conn.Flush(); err != nil {
  1536. return fmt.Errorf("flush del SNAT rule: %w", err)
  1537. }
  1538. return nil
  1539. }
  1540. // DelSNATRule removes the netfilter rule to SNAT traffic destined for
  1541. // local subnets. An error is returned if the rule does not exist.
  1542. func (n *nftablesRunner) DelSNATRule() error {
  1543. conn := n.conn
  1544. for _, table := range n.getTables() {
  1545. chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
  1546. if err != nil {
  1547. return fmt.Errorf("get postrouting chain: %w", err)
  1548. }
  1549. err = delMatchSubnetRouteMarkMasqRule(conn, table.Nat, chain)
  1550. if err != nil {
  1551. return err
  1552. }
  1553. }
  1554. return nil
  1555. }
  1556. func nativeUint32(v uint32) []byte {
  1557. b := make([]byte, 4)
  1558. binary.NativeEndian.PutUint32(b, v)
  1559. return b
  1560. }
  1561. func makeStatefulRuleExprs(tunname string) []expr.Any {
  1562. return []expr.Any{
  1563. // Check if the output interface is the Tailscale interface by
  1564. // first loding the OIFNAME into register 1 and comparing it
  1565. // against our tunname.
  1566. //
  1567. // 'cmp' implicitly breaks from a rule if a comparison fails,
  1568. // so if we continue past this rule we know that the packet is
  1569. // going to our TUN.
  1570. &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
  1571. &expr.Cmp{
  1572. Op: expr.CmpOpEq,
  1573. Register: 1,
  1574. Data: []byte(tunname),
  1575. },
  1576. // Store the conntrack state in register 1
  1577. &expr.Ct{
  1578. Register: 1,
  1579. Key: expr.CtKeySTATE,
  1580. },
  1581. // Mask the state in register 1 to "hide" the ESTABLISHED and
  1582. // RELATED bits (which are expected and fine); if there are any
  1583. // other bits, we want them to remain.
  1584. //
  1585. // This operation is, in the kernel:
  1586. // dst[i] = (src[i] & mask[i]) ^ xor[i]
  1587. //
  1588. // So, we can mask by setting the inverse of the bits we want
  1589. // to remove; i.e. ESTABLISHED = 0b00000010, RELATED =
  1590. // 0b00000100, so, if we assume an 8-bit state (in reality,
  1591. // it's 32-bit), we can mask with 0b11111001 to clear those
  1592. // bits and keep everything else (e.g. the INVALID bit which is
  1593. // 0b00000001).
  1594. //
  1595. // TODO(andrew-d): for now, let's also allow
  1596. // CtStateBitUNTRACKED, which is a state for packets that are not
  1597. // tracked (marked so explicitly with an iptables rule using
  1598. // --notrack); we should figure out if we want to allow this or not.
  1599. &expr.Bitwise{
  1600. SourceRegister: 1,
  1601. DestRegister: 1,
  1602. Len: 4,
  1603. Mask: nativeUint32(^(0 |
  1604. expr.CtStateBitESTABLISHED |
  1605. expr.CtStateBitRELATED |
  1606. expr.CtStateBitUNTRACKED)),
  1607. // Xor is unused but must be specified
  1608. Xor: nativeUint32(0),
  1609. },
  1610. // Compare against the expected state (0, i.e. no bits set
  1611. // other than maybe ESTABLISHED and RELATED). We want this
  1612. // comparison to fail if there are no bits set, so that this
  1613. // rule's evaluation stops and we don't fall through to the
  1614. // "Drop" verdict.
  1615. //
  1616. // For example, if the state is ESTABLISHED (and we want to
  1617. // break from this rule/accept this packet):
  1618. // state = ESTABLISHED
  1619. // register1 = 0b0 (since the bitwise operation cleared the ESTABLISHED bit)
  1620. //
  1621. // compare register1 (0b0) != 0: false
  1622. // -> comparison implicitly breaks
  1623. // -> continue to the next rule
  1624. //
  1625. // For example, if the state is NEW (and we want to continue to
  1626. // the next expression and thus drop this packet):
  1627. // state = NEW
  1628. // register1 = 0b1000
  1629. //
  1630. // compare register1 (0b1000) != 0: true
  1631. // -> comparison continues to next expr
  1632. &expr.Cmp{
  1633. Op: expr.CmpOpNeq,
  1634. Register: 1,
  1635. Data: []byte{0, 0, 0, 0},
  1636. },
  1637. // If we get here, we know that this packet is going to our TUN
  1638. // device, and has a conntrack state set other than ESTABLISHED
  1639. // or RELATED. We thus count and drop the packet.
  1640. &expr.Counter{},
  1641. &expr.Verdict{Kind: expr.VerdictDrop},
  1642. }
  1643. // TODO(andrew-d): iptables-nft writes a rule that dumps as:
  1644. //
  1645. // match name conntrack rev 3
  1646. //
  1647. // I think this is using expr.Match against the following struct
  1648. // (xt_conntrack_mtinfo3):
  1649. //
  1650. // https://github.com/torvalds/linux/blob/master/include/uapi/linux/netfilter/xt_conntrack.h#L64-L77
  1651. //
  1652. // We could probably do something similar here, but I'm not sure if
  1653. // there's any advantage. Below is an example Match statement if we
  1654. // decide to do that, based on dumping the rule that iptables-nft
  1655. // generates:
  1656. //
  1657. // _ = expr.Match{
  1658. // Name: "conntrack",
  1659. // Rev: 3,
  1660. // Info: &xt.ConntrackMtinfo3{
  1661. // ConntrackMtinfo2: xt.ConntrackMtinfo2{
  1662. // ConntrackMtinfoBase: xt.ConntrackMtinfoBase{
  1663. // MatchFlags: xt.ConntrackState,
  1664. // InvertFlags: xt.ConntrackState,
  1665. // },
  1666. // // Mask the state to remove ESTABLISHED and
  1667. // // RELATED before comparing.
  1668. // StateMask: expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED,
  1669. // },
  1670. // },
  1671. // }
  1672. }
  1673. // AddStatefulRule adds a netfilter rule for stateful packet filtering using
  1674. // conntrack.
  1675. func (n *nftablesRunner) AddStatefulRule(tunname string) error {
  1676. conn := n.conn
  1677. exprs := makeStatefulRuleExprs(tunname)
  1678. for _, table := range n.getTables() {
  1679. chain, err := getChainFromTable(conn, table.Filter, chainNameForward)
  1680. if err != nil {
  1681. return fmt.Errorf("get forward chain: %w", err)
  1682. }
  1683. // First, find the 'accept' rule that we want to insert our rule before.
  1684. acceptRule := createAcceptOutgoingPacketRule(table.Filter, chain, tunname)
  1685. rule, err := findRule(conn, acceptRule)
  1686. if err != nil {
  1687. return fmt.Errorf("find accept rule: %w", err)
  1688. }
  1689. conn.InsertRule(&nftables.Rule{
  1690. Table: table.Filter,
  1691. Chain: chain,
  1692. Exprs: exprs,
  1693. // Specifying Position in an Insert operation means to
  1694. // insert this rule before the specified rule.
  1695. Position: rule.Handle,
  1696. })
  1697. }
  1698. if err := conn.Flush(); err != nil {
  1699. return fmt.Errorf("flush add stateful rule: %w", err)
  1700. }
  1701. return nil
  1702. }
  1703. // DelStatefulRule removes the netfilter rule for stateful packet filtering
  1704. // using conntrack.
  1705. func (n *nftablesRunner) DelStatefulRule(tunname string) error {
  1706. conn := n.conn
  1707. exprs := makeStatefulRuleExprs(tunname)
  1708. for _, table := range n.getTables() {
  1709. chain, err := getChainFromTable(conn, table.Filter, chainNameForward)
  1710. if err != nil {
  1711. return fmt.Errorf("get forward chain: %w", err)
  1712. }
  1713. rule, err := findRule(conn, &nftables.Rule{
  1714. Table: table.Filter,
  1715. Chain: chain,
  1716. Exprs: exprs,
  1717. })
  1718. if err != nil {
  1719. return fmt.Errorf("find stateful rule: %w", err)
  1720. }
  1721. if rule != nil {
  1722. conn.DelRule(rule)
  1723. }
  1724. }
  1725. if err := conn.Flush(); err != nil {
  1726. return fmt.Errorf("flush del stateful rule: %w", err)
  1727. }
  1728. return nil
  1729. }
  1730. // cleanupChain removes a jump rule from hookChainName to tsChainName, and then
  1731. // the entire chain tsChainName. Errors are logged, but attempts to remove both
  1732. // the jump rule and chain continue even if one errors.
  1733. func cleanupChain(logf logger.Logf, conn *nftables.Conn, table *nftables.Table, hookChainName, tsChainName string) {
  1734. // remove the jump first, before removing the jump destination.
  1735. defaultChain, err := getChainFromTable(conn, table, hookChainName)
  1736. if err != nil && !errors.Is(err, errorChainNotFound{table.Name, hookChainName}) {
  1737. logf("cleanup: did not find default chain: %s", err)
  1738. }
  1739. if !errors.Is(err, errorChainNotFound{table.Name, hookChainName}) {
  1740. // delete hook in convention chain
  1741. _ = delHookRule(conn, table, defaultChain, tsChainName)
  1742. }
  1743. tsChain, err := getChainFromTable(conn, table, tsChainName)
  1744. if err != nil && !errors.Is(err, errorChainNotFound{table.Name, tsChainName}) {
  1745. logf("cleanup: did not find ts-chain: %s", err)
  1746. }
  1747. if tsChain != nil {
  1748. // flush and delete ts-chain
  1749. conn.FlushChain(tsChain)
  1750. conn.DelChain(tsChain)
  1751. err = conn.Flush()
  1752. logf("cleanup: delete and flush chain %s: %s", tsChainName, err)
  1753. }
  1754. }
  1755. // NfTablesCleanUp removes all Tailscale added nftables rules.
  1756. // Any errors that occur are logged to the provided logf.
  1757. func NfTablesCleanUp(logf logger.Logf) {
  1758. conn, err := nftables.New()
  1759. if err != nil {
  1760. logf("cleanup: nftables connection: %s", err)
  1761. }
  1762. tables, err := conn.ListTables() // both v4 and v6
  1763. if err != nil {
  1764. logf("cleanup: list tables: %s", err)
  1765. }
  1766. for _, table := range tables {
  1767. // These table names were used briefly in 1.48.0.
  1768. if table.Name == "ts-filter" || table.Name == "ts-nat" {
  1769. conn.DelTable(table)
  1770. if err := conn.Flush(); err != nil {
  1771. logf("cleanup: flush delete table %s: %s", table.Name, err)
  1772. }
  1773. }
  1774. if table.Name == "filter" {
  1775. cleanupChain(logf, conn, table, "INPUT", chainNameInput)
  1776. cleanupChain(logf, conn, table, "FORWARD", chainNameForward)
  1777. }
  1778. if table.Name == "nat" {
  1779. cleanupChain(logf, conn, table, "POSTROUTING", chainNamePostrouting)
  1780. }
  1781. }
  1782. }
  1783. func snatRule(t *nftables.Table, ch *nftables.Chain, src, dst netip.Addr, meta []byte) *nftables.Rule {
  1784. var daddrOffset, fam, daddrLen uint32
  1785. if dst.Is4() {
  1786. daddrOffset = 16
  1787. daddrLen = 4
  1788. fam = unix.NFPROTO_IPV4
  1789. } else {
  1790. daddrOffset = 24
  1791. daddrLen = 16
  1792. fam = unix.NFPROTO_IPV6
  1793. }
  1794. return &nftables.Rule{
  1795. Table: t,
  1796. Chain: ch,
  1797. Exprs: []expr.Any{
  1798. &expr.Payload{
  1799. DestRegister: 1,
  1800. Base: expr.PayloadBaseNetworkHeader,
  1801. Offset: daddrOffset,
  1802. Len: daddrLen,
  1803. },
  1804. &expr.Cmp{
  1805. Op: expr.CmpOpEq,
  1806. Register: 1,
  1807. Data: dst.AsSlice(),
  1808. },
  1809. &expr.Immediate{
  1810. Register: 1,
  1811. Data: src.AsSlice(),
  1812. },
  1813. &expr.NAT{
  1814. Type: expr.NATTypeSourceNAT,
  1815. Family: fam,
  1816. RegAddrMin: 1,
  1817. RegAddrMax: 1,
  1818. },
  1819. },
  1820. UserData: meta,
  1821. }
  1822. }