Subnet.kt 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. package io.nekohasekai.sagernet.utils
  2. import io.nekohasekai.sagernet.ktx.parseNumericAddress
  3. import java.net.InetAddress
  4. import java.util.*
  5. class Subnet(val address: InetAddress, val prefixSize: Int) : Comparable<Subnet> {
  6. companion object {
  7. fun fromString(value: String, lengthCheck: Int = -1): Subnet? {
  8. val parts = value.split('/', limit = 2)
  9. val addr = parts[0].parseNumericAddress() ?: return null
  10. check(lengthCheck < 0 || addr.address.size == lengthCheck)
  11. return if (parts.size == 2) try {
  12. val prefixSize = parts[1].toInt()
  13. if (prefixSize < 0 || prefixSize > addr.address.size shl 3) null else Subnet(addr,
  14. prefixSize)
  15. } catch (_: NumberFormatException) {
  16. null
  17. } else Subnet(addr, addr.address.size shl 3)
  18. }
  19. }
  20. private val addressLength get() = address.address.size shl 3
  21. init {
  22. require(prefixSize in 0..addressLength) { "prefixSize $prefixSize not in 0..$addressLength" }
  23. }
  24. class Immutable(private val a: ByteArray, private val prefixSize: Int = 0) {
  25. companion object : Comparator<Immutable> {
  26. override fun compare(a: Immutable, b: Immutable): Int {
  27. check(a.a.size == b.a.size)
  28. for (i in a.a.indices) {
  29. val result = a.a[i].compareTo(b.a[i])
  30. if (result != 0) return result
  31. }
  32. return 0
  33. }
  34. }
  35. fun matches(b: Immutable) = matches(b.a)
  36. fun matches(b: ByteArray): Boolean {
  37. if (a.size != b.size) return false
  38. var i = 0
  39. while (i * 8 < prefixSize && i * 8 + 8 <= prefixSize) {
  40. if (a[i] != b[i]) return false
  41. ++i
  42. }
  43. return i * 8 == prefixSize || a[i] == (b[i].toInt() and -(1 shl i * 8 + 8 - prefixSize)).toByte()
  44. }
  45. }
  46. fun toImmutable() = Immutable(address.address.also {
  47. var i = prefixSize / 8
  48. if (prefixSize % 8 > 0) {
  49. it[i] = (it[i].toInt() and -(1 shl i * 8 + 8 - prefixSize)).toByte()
  50. ++i
  51. }
  52. while (i < it.size) it[i++] = 0
  53. }, prefixSize)
  54. override fun toString(): String =
  55. if (prefixSize == addressLength) address.hostAddress else address.hostAddress + '/' + prefixSize
  56. private fun Byte.unsigned() = toInt() and 0xFF
  57. override fun compareTo(other: Subnet): Int {
  58. val addrThis = address.address
  59. val addrThat = other.address.address
  60. var result =
  61. addrThis.size.compareTo(addrThat.size) // IPv4 address goes first
  62. if (result != 0) return result
  63. for (i in addrThis.indices) {
  64. result = addrThis[i].unsigned()
  65. .compareTo(addrThat[i].unsigned()) // undo sign extension of signed byte
  66. if (result != 0) return result
  67. }
  68. return prefixSize.compareTo(other.prefixSize)
  69. }
  70. override fun equals(other: Any?): Boolean {
  71. val that = other as? Subnet
  72. return address == that?.address && prefixSize == that.prefixSize
  73. }
  74. override fun hashCode(): Int = Objects.hash(address, prefixSize)
  75. }