MathFunctions.cxx 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. #include <cmath>
  2. #include <format>
  3. #include <MathLogger.h>
  4. #ifdef TUTORIAL_USE_SSE2
  5. # include <emmintrin.h>
  6. #endif
  7. namespace {
  8. mathlogger::Logger Logger;
  9. #if defined(TUTORIAL_USE_GNU_BUILTIN)
  10. typedef double v2df __attribute__((vector_size(16)));
  11. double gnu_mysqrt(double x)
  12. {
  13. v2df root = __builtin_ia32_sqrtsd(v2df{ x, 0.0 });
  14. double result = root[0];
  15. Logger.Log(std::format("Computed sqrt of {} to be {} with GNU-builtins\n", x,
  16. result));
  17. return result;
  18. }
  19. #elif defined(TUTORIAL_USE_SSE2)
  20. double sse2_mysqrt(double x)
  21. {
  22. __m128d root = _mm_sqrt_sd(_mm_setzero_pd(), _mm_set_sd(x));
  23. double result = _mm_cvtsd_f64(root);
  24. Logger.Log(
  25. std::format("Computed sqrt of {} to be {} with SSE2\n", x, result));
  26. return result;
  27. }
  28. #endif
  29. // a hack square root calculation using simple operations
  30. double fallback_mysqrt(double x)
  31. {
  32. if (x <= 0) {
  33. return 0;
  34. }
  35. double result = x;
  36. // do ten iterations
  37. for (int i = 0; i < 10; ++i) {
  38. if (result <= 0) {
  39. result = 0.1;
  40. }
  41. double delta = x - (result * result);
  42. result = result + 0.5 * delta / result;
  43. Logger.Log(std::format("Computing sqrt of {} to be {}\n", x, result));
  44. }
  45. return result;
  46. }
  47. // TODO10: Replace this hardcoded sqrtTable with #include <SqrtTable.h>
  48. double sqrtTable[] = { 0, 1, 1, 2, 2, 2, 2, 3, 3, 3 };
  49. double table_sqrt(double x)
  50. {
  51. double result = sqrtTable[static_cast<int>(x)];
  52. // do ten iterations
  53. for (int i = 0; i < 10; ++i) {
  54. if (result <= 0) {
  55. result = 0.1;
  56. }
  57. double delta = x - (result * result);
  58. result = result + 0.5 * delta / result;
  59. }
  60. Logger.Log(
  61. std::format("Computed sqrt of {} to be {} with TableSqrt\n", x, result));
  62. return result;
  63. }
  64. double mysqrt(double x)
  65. {
  66. if (x >= 1 && x < 10) {
  67. return table_sqrt(x);
  68. }
  69. #if defined(TUTORIAL_USE_GNU_BUILTIN)
  70. return gnu_mysqrt(x);
  71. #elif defined(TUTORIAL_USE_SSE2)
  72. return sse2_mysqrt(x);
  73. #else
  74. return fallback_mysqrt(x);
  75. #endif
  76. }
  77. }
  78. namespace mathfunctions {
  79. double sqrt(double x)
  80. {
  81. #ifdef TUTORIAL_USE_STD_SQRT
  82. return std::sqrt(x);
  83. #else
  84. return mysqrt(x);
  85. #endif
  86. }
  87. }