allowedips.go 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. /* SPDX-License-Identifier: MIT
  2. *
  3. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
  4. */
  5. package device
  6. import (
  7. "container/list"
  8. "encoding/binary"
  9. "errors"
  10. "math/bits"
  11. "net"
  12. "net/netip"
  13. "sync"
  14. "unsafe"
  15. )
  16. type parentIndirection struct {
  17. parentBit **trieEntry
  18. parentBitType uint8
  19. }
  20. type trieEntry struct {
  21. peer *Peer
  22. child [2]*trieEntry
  23. parent parentIndirection
  24. cidr uint8
  25. bitAtByte uint8
  26. bitAtShift uint8
  27. bits []byte
  28. perPeerElem *list.Element
  29. }
  30. func commonBits(ip1, ip2 []byte) uint8 {
  31. size := len(ip1)
  32. if size == net.IPv4len {
  33. a := binary.BigEndian.Uint32(ip1)
  34. b := binary.BigEndian.Uint32(ip2)
  35. x := a ^ b
  36. return uint8(bits.LeadingZeros32(x))
  37. } else if size == net.IPv6len {
  38. a := binary.BigEndian.Uint64(ip1)
  39. b := binary.BigEndian.Uint64(ip2)
  40. x := a ^ b
  41. if x != 0 {
  42. return uint8(bits.LeadingZeros64(x))
  43. }
  44. a = binary.BigEndian.Uint64(ip1[8:])
  45. b = binary.BigEndian.Uint64(ip2[8:])
  46. x = a ^ b
  47. return 64 + uint8(bits.LeadingZeros64(x))
  48. } else {
  49. panic("Wrong size bit string")
  50. }
  51. }
  52. func (node *trieEntry) addToPeerEntries() {
  53. node.perPeerElem = node.peer.trieEntries.PushBack(node)
  54. }
  55. func (node *trieEntry) removeFromPeerEntries() {
  56. if node.perPeerElem != nil {
  57. node.peer.trieEntries.Remove(node.perPeerElem)
  58. node.perPeerElem = nil
  59. }
  60. }
  61. func (node *trieEntry) choose(ip []byte) byte {
  62. return (ip[node.bitAtByte] >> node.bitAtShift) & 1
  63. }
  64. func (node *trieEntry) maskSelf() {
  65. mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
  66. for i := 0; i < len(mask); i++ {
  67. node.bits[i] &= mask[i]
  68. }
  69. }
  70. func (node *trieEntry) zeroizePointers() {
  71. // Make the garbage collector's life slightly easier
  72. node.peer = nil
  73. node.child[0] = nil
  74. node.child[1] = nil
  75. node.parent.parentBit = nil
  76. }
  77. func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) {
  78. for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
  79. parent = node
  80. if parent.cidr == cidr {
  81. exact = true
  82. return
  83. }
  84. bit := node.choose(ip)
  85. node = node.child[bit]
  86. }
  87. return
  88. }
  89. func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
  90. if *trie.parentBit == nil {
  91. node := &trieEntry{
  92. peer: peer,
  93. parent: trie,
  94. bits: ip,
  95. cidr: cidr,
  96. bitAtByte: cidr / 8,
  97. bitAtShift: 7 - (cidr % 8),
  98. }
  99. node.maskSelf()
  100. node.addToPeerEntries()
  101. *trie.parentBit = node
  102. return
  103. }
  104. node, exact := (*trie.parentBit).nodePlacement(ip, cidr)
  105. if exact {
  106. node.removeFromPeerEntries()
  107. node.peer = peer
  108. node.addToPeerEntries()
  109. return
  110. }
  111. newNode := &trieEntry{
  112. peer: peer,
  113. bits: ip,
  114. cidr: cidr,
  115. bitAtByte: cidr / 8,
  116. bitAtShift: 7 - (cidr % 8),
  117. }
  118. newNode.maskSelf()
  119. newNode.addToPeerEntries()
  120. var down *trieEntry
  121. if node == nil {
  122. down = *trie.parentBit
  123. } else {
  124. bit := node.choose(ip)
  125. down = node.child[bit]
  126. if down == nil {
  127. newNode.parent = parentIndirection{&node.child[bit], bit}
  128. node.child[bit] = newNode
  129. return
  130. }
  131. }
  132. common := commonBits(down.bits, ip)
  133. if common < cidr {
  134. cidr = common
  135. }
  136. parent := node
  137. if newNode.cidr == cidr {
  138. bit := newNode.choose(down.bits)
  139. down.parent = parentIndirection{&newNode.child[bit], bit}
  140. newNode.child[bit] = down
  141. if parent == nil {
  142. newNode.parent = trie
  143. *trie.parentBit = newNode
  144. } else {
  145. bit := parent.choose(newNode.bits)
  146. newNode.parent = parentIndirection{&parent.child[bit], bit}
  147. parent.child[bit] = newNode
  148. }
  149. return
  150. }
  151. node = &trieEntry{
  152. bits: append([]byte{}, newNode.bits...),
  153. cidr: cidr,
  154. bitAtByte: cidr / 8,
  155. bitAtShift: 7 - (cidr % 8),
  156. }
  157. node.maskSelf()
  158. bit := node.choose(down.bits)
  159. down.parent = parentIndirection{&node.child[bit], bit}
  160. node.child[bit] = down
  161. bit = node.choose(newNode.bits)
  162. newNode.parent = parentIndirection{&node.child[bit], bit}
  163. node.child[bit] = newNode
  164. if parent == nil {
  165. node.parent = trie
  166. *trie.parentBit = node
  167. } else {
  168. bit := parent.choose(node.bits)
  169. node.parent = parentIndirection{&parent.child[bit], bit}
  170. parent.child[bit] = node
  171. }
  172. }
  173. func (node *trieEntry) lookup(ip []byte) *Peer {
  174. var found *Peer
  175. size := uint8(len(ip))
  176. for node != nil && commonBits(node.bits, ip) >= node.cidr {
  177. if node.peer != nil {
  178. found = node.peer
  179. }
  180. if node.bitAtByte == size {
  181. break
  182. }
  183. bit := node.choose(ip)
  184. node = node.child[bit]
  185. }
  186. return found
  187. }
  188. type AllowedIPs struct {
  189. IPv4 *trieEntry
  190. IPv6 *trieEntry
  191. mutex sync.RWMutex
  192. }
  193. func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) {
  194. table.mutex.RLock()
  195. defer table.mutex.RUnlock()
  196. for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
  197. node := elem.Value.(*trieEntry)
  198. a, _ := netip.AddrFromSlice(node.bits)
  199. if !cb(netip.PrefixFrom(a, int(node.cidr))) {
  200. return
  201. }
  202. }
  203. }
  204. func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
  205. table.mutex.Lock()
  206. defer table.mutex.Unlock()
  207. var next *list.Element
  208. for elem := peer.trieEntries.Front(); elem != nil; elem = next {
  209. next = elem.Next()
  210. node := elem.Value.(*trieEntry)
  211. node.removeFromPeerEntries()
  212. node.peer = nil
  213. if node.child[0] != nil && node.child[1] != nil {
  214. continue
  215. }
  216. bit := 0
  217. if node.child[0] == nil {
  218. bit = 1
  219. }
  220. child := node.child[bit]
  221. if child != nil {
  222. child.parent = node.parent
  223. }
  224. *node.parent.parentBit = child
  225. if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
  226. node.zeroizePointers()
  227. continue
  228. }
  229. parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
  230. if parent.peer != nil {
  231. node.zeroizePointers()
  232. continue
  233. }
  234. child = parent.child[node.parent.parentBitType^1]
  235. if child != nil {
  236. child.parent = parent.parent
  237. }
  238. *parent.parent.parentBit = child
  239. node.zeroizePointers()
  240. parent.zeroizePointers()
  241. }
  242. }
  243. func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
  244. table.mutex.Lock()
  245. defer table.mutex.Unlock()
  246. if prefix.Addr().Is6() {
  247. ip := prefix.Addr().As16()
  248. parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
  249. } else if prefix.Addr().Is4() {
  250. ip := prefix.Addr().As4()
  251. parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
  252. } else {
  253. panic(errors.New("inserting unknown address type"))
  254. }
  255. }
  256. func (table *AllowedIPs) Lookup(ip []byte) *Peer {
  257. table.mutex.RLock()
  258. defer table.mutex.RUnlock()
  259. switch len(ip) {
  260. case net.IPv6len:
  261. return table.IPv6.lookup(ip)
  262. case net.IPv4len:
  263. return table.IPv4.lookup(ip)
  264. default:
  265. panic(errors.New("looking up unknown address type"))
  266. }
  267. }