Salsa20.hpp 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. /*
  2. * Based on public domain code available at: http://cr.yp.to/snuffle.html
  3. *
  4. * This therefore is public domain.
  5. */
  6. #ifndef ZT_SALSA20_HPP
  7. #define ZT_SALSA20_HPP
  8. #include <stdio.h>
  9. #include <stdint.h>
  10. #include <stdlib.h>
  11. #include <string.h>
  12. #include "Constants.hpp"
  13. #include "Utils.hpp"
  14. #if (!defined(ZT_SALSA20_SSE)) && (defined(__SSE2__) || defined(__WINDOWS__))
  15. #define ZT_SALSA20_SSE 1
  16. #endif
  17. #ifdef ZT_SALSA20_SSE
  18. #include <emmintrin.h>
  19. #endif // ZT_SALSA20_SSE
  20. namespace ZeroTier {
  21. /**
  22. * Salsa20 stream cipher
  23. */
  24. class Salsa20
  25. {
  26. public:
  27. inline Salsa20() {}
  28. inline ~Salsa20() { Utils::burn(&_state,sizeof(_state)); }
  29. /**
  30. * XOR d with s
  31. *
  32. * This is done efficiently using e.g. SSE if available. It's used when
  33. * alternative Salsa20 implementations are used in Packet and is here
  34. * since this is where all the SSE stuff is already included.
  35. *
  36. * @param d Destination to XOR
  37. * @param s Source bytes to XOR with destination
  38. * @param len Length of s and d
  39. */
  40. static inline void memxor(uint8_t *d,const uint8_t *s,unsigned int len)
  41. {
  42. #ifdef ZT_SALSA20_SSE
  43. while (len >= 128) {
  44. __m128i s0 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(s));
  45. __m128i s1 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(s + 16));
  46. __m128i s2 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(s + 32));
  47. __m128i s3 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(s + 48));
  48. __m128i s4 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(s + 64));
  49. __m128i s5 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(s + 80));
  50. __m128i s6 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(s + 96));
  51. __m128i s7 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(s + 112));
  52. __m128i d0 = _mm_loadu_si128(reinterpret_cast<__m128i *>(d));
  53. __m128i d1 = _mm_loadu_si128(reinterpret_cast<__m128i *>(d + 16));
  54. __m128i d2 = _mm_loadu_si128(reinterpret_cast<__m128i *>(d + 32));
  55. __m128i d3 = _mm_loadu_si128(reinterpret_cast<__m128i *>(d + 48));
  56. __m128i d4 = _mm_loadu_si128(reinterpret_cast<__m128i *>(d + 64));
  57. __m128i d5 = _mm_loadu_si128(reinterpret_cast<__m128i *>(d + 80));
  58. __m128i d6 = _mm_loadu_si128(reinterpret_cast<__m128i *>(d + 96));
  59. __m128i d7 = _mm_loadu_si128(reinterpret_cast<__m128i *>(d + 112));
  60. d0 = _mm_xor_si128(d0,s0);
  61. d1 = _mm_xor_si128(d1,s1);
  62. d2 = _mm_xor_si128(d2,s2);
  63. d3 = _mm_xor_si128(d3,s3);
  64. d4 = _mm_xor_si128(d4,s4);
  65. d5 = _mm_xor_si128(d5,s5);
  66. d6 = _mm_xor_si128(d6,s6);
  67. d7 = _mm_xor_si128(d7,s7);
  68. _mm_storeu_si128(reinterpret_cast<__m128i *>(d),d0);
  69. _mm_storeu_si128(reinterpret_cast<__m128i *>(d + 16),d1);
  70. _mm_storeu_si128(reinterpret_cast<__m128i *>(d + 32),d2);
  71. _mm_storeu_si128(reinterpret_cast<__m128i *>(d + 48),d3);
  72. _mm_storeu_si128(reinterpret_cast<__m128i *>(d + 64),d4);
  73. _mm_storeu_si128(reinterpret_cast<__m128i *>(d + 80),d5);
  74. _mm_storeu_si128(reinterpret_cast<__m128i *>(d + 96),d6);
  75. _mm_storeu_si128(reinterpret_cast<__m128i *>(d + 112),d7);
  76. s += 128;
  77. d += 128;
  78. len -= 128;
  79. }
  80. while (len >= 16) {
  81. _mm_storeu_si128(reinterpret_cast<__m128i *>(d),_mm_xor_si128(_mm_loadu_si128(reinterpret_cast<__m128i *>(d)),_mm_loadu_si128(reinterpret_cast<const __m128i *>(s))));
  82. s += 16;
  83. d += 16;
  84. len -= 16;
  85. }
  86. #else
  87. #ifndef ZT_NO_TYPE_PUNNING
  88. while (len >= 16) {
  89. (*reinterpret_cast<uint64_t *>(d)) ^= (*reinterpret_cast<const uint64_t *>(s));
  90. s += 8;
  91. d += 8;
  92. (*reinterpret_cast<uint64_t *>(d)) ^= (*reinterpret_cast<const uint64_t *>(s));
  93. s += 8;
  94. d += 8;
  95. len -= 16;
  96. }
  97. #endif
  98. #endif
  99. while (len) {
  100. --len;
  101. *(d++) ^= *(s++);
  102. }
  103. }
  104. /**
  105. * @param key 256-bit (32 byte) key
  106. * @param iv 64-bit initialization vector
  107. */
  108. inline Salsa20(const void *key,const void *iv) { init(key,iv); }
  109. /**
  110. * Initialize cipher
  111. *
  112. * @param key Key bits
  113. * @param iv 64-bit initialization vector
  114. */
  115. void init(const void *key,const void *iv);
  116. /**
  117. * Encrypt/decrypt data using Salsa20/12
  118. *
  119. * @param in Input data
  120. * @param out Output buffer
  121. * @param bytes Length of data
  122. */
  123. void crypt12(const void *in,void *out,unsigned int bytes);
  124. /**
  125. * Encrypt/decrypt data using Salsa20/20
  126. *
  127. * @param in Input data
  128. * @param out Output buffer
  129. * @param bytes Length of data
  130. */
  131. void crypt20(const void *in,void *out,unsigned int bytes);
  132. private:
  133. union {
  134. #ifdef ZT_SALSA20_SSE
  135. __m128i v[4];
  136. #endif // ZT_SALSA20_SSE
  137. uint32_t i[16];
  138. } _state;
  139. };
  140. } // namespace ZeroTier
  141. #endif