|
|
@@ -20,6 +20,7 @@ import (
|
|
|
"golang.org/x/sys/unix"
|
|
|
"tailscale.com/net/tsaddr"
|
|
|
"tailscale.com/types/logger"
|
|
|
+ "tailscale.com/types/ptr"
|
|
|
)
|
|
|
|
|
|
const (
|
|
|
@@ -316,8 +317,33 @@ func (n *nftablesRunner) ClampMSSToPMTU(tun string, addr netip.Addr) error {
|
|
|
return n.conn.Flush()
|
|
|
}
|
|
|
|
|
|
-// createTableIfNotExist creates a nftables table via connection c if it does not exist within the given family.
|
|
|
-func createTableIfNotExist(c *nftables.Conn, family nftables.TableFamily, name string) (*nftables.Table, error) {
|
|
|
+// deleteTableIfExists deletes a nftables table via connection c if it exists
|
|
|
+// within the given family.
|
|
|
+func deleteTableIfExists(c *nftables.Conn, family nftables.TableFamily, name string) error {
|
|
|
+ t, err := getTableIfExists(c, family, name)
|
|
|
+ if err != nil {
|
|
|
+ return fmt.Errorf("get table: %w", err)
|
|
|
+ }
|
|
|
+ if t == nil {
|
|
|
+ // Table does not exist, so nothing to delete.
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ c.DelTable(t)
|
|
|
+ if err := c.Flush(); err != nil {
|
|
|
+ if t, err = getTableIfExists(c, family, name); t == nil && err == nil {
|
|
|
+ // Check if the table still exists. If it does not, then the error
|
|
|
+ // is due to the table not existing, so we can ignore it. Maybe a
|
|
|
+ // concurrent process deleted the table.
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ return fmt.Errorf("del table: %w", err)
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+// getTableIfExists returns the table with the given name from the given family
|
|
|
+// if it exists. If none match, it returns (nil, nil).
|
|
|
+func getTableIfExists(c *nftables.Conn, family nftables.TableFamily, name string) (*nftables.Table, error) {
|
|
|
tables, err := c.ListTables()
|
|
|
if err != nil {
|
|
|
return nil, fmt.Errorf("get tables: %w", err)
|
|
|
@@ -327,7 +353,17 @@ func createTableIfNotExist(c *nftables.Conn, family nftables.TableFamily, name s
|
|
|
return table, nil
|
|
|
}
|
|
|
}
|
|
|
+ return nil, nil
|
|
|
+}
|
|
|
|
|
|
+// createTableIfNotExist creates a nftables table via connection c if it does
|
|
|
+// not exist within the given family.
|
|
|
+func createTableIfNotExist(c *nftables.Conn, family nftables.TableFamily, name string) (*nftables.Table, error) {
|
|
|
+ if t, err := getTableIfExists(c, family, name); err != nil {
|
|
|
+ return nil, fmt.Errorf("get table: %w", err)
|
|
|
+ } else if t != nil {
|
|
|
+ return t, nil
|
|
|
+ }
|
|
|
t := c.AddTable(&nftables.Table{
|
|
|
Family: family,
|
|
|
Name: name,
|
|
|
@@ -365,24 +401,6 @@ func getChainFromTable(c *nftables.Conn, table *nftables.Table, name string) (*n
|
|
|
return nil, errorChainNotFound{table.Name, name}
|
|
|
}
|
|
|
|
|
|
-// getChainsFromTable returns all chains from the given table.
|
|
|
-func getChainsFromTable(c *nftables.Conn, table *nftables.Table) ([]*nftables.Chain, error) {
|
|
|
- chains, err := c.ListChainsOfTableFamily(table.Family)
|
|
|
- if err != nil {
|
|
|
- return nil, fmt.Errorf("list chains: %w", err)
|
|
|
- }
|
|
|
-
|
|
|
- var ret []*nftables.Chain
|
|
|
- for _, chain := range chains {
|
|
|
- // Table family is already checked so table name is unique
|
|
|
- if chain.Table.Name == table.Name {
|
|
|
- ret = append(ret, chain)
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- return ret, nil
|
|
|
-}
|
|
|
-
|
|
|
// isTSChain reports whether `name` begins with "ts-" (and is thus a
|
|
|
// Tailscale-managed chain).
|
|
|
func isTSChain(name string) bool {
|
|
|
@@ -804,6 +822,43 @@ func (n *nftablesRunner) AddChains() error {
|
|
|
return n.conn.Flush()
|
|
|
}
|
|
|
|
|
|
+// These are dummy chains and tables we create to detect if nftables is
|
|
|
+// available. We create them, then delete them. If we can create and delete
|
|
|
+// them, then we can use nftables. If we can't, then we assume that we're
|
|
|
+// running on a system that doesn't support nftables. See
|
|
|
+// createDummyPostroutingChains.
|
|
|
+const (
|
|
|
+ tsDummyChainName = "ts-test-postrouting"
|
|
|
+ tsDummyTableName = "ts-test-nat"
|
|
|
+)
|
|
|
+
|
|
|
+// createDummyPostroutingChains creates dummy postrouting chains in netfilter
|
|
|
+// via netfilter via nftables, as a last resort measure to detect that nftables
|
|
|
+// can be used. It cleans up the dummy chains after creation.
|
|
|
+func (n *nftablesRunner) createDummyPostroutingChains() (retErr error) {
|
|
|
+ polAccept := ptr.To(nftables.ChainPolicyAccept)
|
|
|
+ for _, table := range n.getNATTables() {
|
|
|
+ nat, err := createTableIfNotExist(n.conn, table.Proto, tsDummyTableName)
|
|
|
+ if err != nil {
|
|
|
+ return fmt.Errorf("create nat table: %w", err)
|
|
|
+ }
|
|
|
+ defer func(fm nftables.TableFamily) {
|
|
|
+ if err := deleteTableIfExists(n.conn, table.Proto, tsDummyTableName); err != nil && retErr == nil {
|
|
|
+ retErr = fmt.Errorf("delete %q table: %w", tsDummyTableName, err)
|
|
|
+ }
|
|
|
+ }(table.Proto)
|
|
|
+
|
|
|
+ table.Nat = nat
|
|
|
+ if err = createChainIfNotExist(n.conn, chainInfo{nat, tsDummyChainName, nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATSource, polAccept}); err != nil {
|
|
|
+ return fmt.Errorf("create %q chain: %w", tsDummyChainName, err)
|
|
|
+ }
|
|
|
+ if err := deleteChainIfExists(n.conn, nat, tsDummyChainName); err != nil {
|
|
|
+ return fmt.Errorf("delete %q chain: %w", tsDummyChainName, err)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
// deleteChainIfExists deletes a chain if it exists.
|
|
|
func deleteChainIfExists(c *nftables.Conn, table *nftables.Table, name string) error {
|
|
|
chain, err := getChainFromTable(c, table, name)
|