safe_math.h 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. /*
  2. * Copyright 2021-2022 The OpenSSL Project Authors. All Rights Reserved.
  3. *
  4. * Licensed under the Apache License 2.0 (the "License"). You may not use
  5. * this file except in compliance with the License. You can obtain a copy
  6. * in the file LICENSE in the source distribution or at
  7. * https://www.openssl.org/source/license.html
  8. */
  9. #ifndef OSSL_INTERNAL_SAFE_MATH_H
  10. # define OSSL_INTERNAL_SAFE_MATH_H
  11. # pragma once
  12. # include <openssl/e_os2.h> /* For 'ossl_inline' */
  13. # ifndef OPENSSL_NO_BUILTIN_OVERFLOW_CHECKING
  14. # ifdef __has_builtin
  15. # define has(func) __has_builtin(func)
  16. # elif defined(__GNUC__)
  17. # if __GNUC__ > 5
  18. # define has(func) 1
  19. # endif
  20. # endif
  21. # endif /* OPENSSL_NO_BUILTIN_OVERFLOW_CHECKING */
  22. # ifndef has
  23. # define has(func) 0
  24. # endif
  25. /*
  26. * Safe addition helpers
  27. */
  28. # if has(__builtin_add_overflow)
  29. # define OSSL_SAFE_MATH_ADDS(type_name, type, min, max) \
  30. static ossl_inline ossl_unused type safe_add_ ## type_name(type a, \
  31. type b, \
  32. int *err) \
  33. { \
  34. type r; \
  35. \
  36. if (!__builtin_add_overflow(a, b, &r)) \
  37. return r; \
  38. *err |= 1; \
  39. return a < 0 ? min : max; \
  40. }
  41. # define OSSL_SAFE_MATH_ADDU(type_name, type, max) \
  42. static ossl_inline ossl_unused type safe_add_ ## type_name(type a, \
  43. type b, \
  44. int *err) \
  45. { \
  46. type r; \
  47. \
  48. if (!__builtin_add_overflow(a, b, &r)) \
  49. return r; \
  50. *err |= 1; \
  51. return a + b; \
  52. }
  53. # else /* has(__builtin_add_overflow) */
  54. # define OSSL_SAFE_MATH_ADDS(type_name, type, min, max) \
  55. static ossl_inline ossl_unused type safe_add_ ## type_name(type a, \
  56. type b, \
  57. int *err) \
  58. { \
  59. if ((a < 0) ^ (b < 0) \
  60. || (a > 0 && b <= max - a) \
  61. || (a < 0 && b >= min - a) \
  62. || a == 0) \
  63. return a + b; \
  64. *err |= 1; \
  65. return a < 0 ? min : max; \
  66. }
  67. # define OSSL_SAFE_MATH_ADDU(type_name, type, max) \
  68. static ossl_inline ossl_unused type safe_add_ ## type_name(type a, \
  69. type b, \
  70. int *err) \
  71. { \
  72. if (b > max - a) \
  73. *err |= 1; \
  74. return a + b; \
  75. }
  76. # endif /* has(__builtin_add_overflow) */
  77. /*
  78. * Safe subtraction helpers
  79. */
  80. # if has(__builtin_sub_overflow)
  81. # define OSSL_SAFE_MATH_SUBS(type_name, type, min, max) \
  82. static ossl_inline ossl_unused type safe_sub_ ## type_name(type a, \
  83. type b, \
  84. int *err) \
  85. { \
  86. type r; \
  87. \
  88. if (!__builtin_sub_overflow(a, b, &r)) \
  89. return r; \
  90. *err |= 1; \
  91. return a < 0 ? min : max; \
  92. }
  93. # else /* has(__builtin_sub_overflow) */
  94. # define OSSL_SAFE_MATH_SUBS(type_name, type, min, max) \
  95. static ossl_inline ossl_unused type safe_sub_ ## type_name(type a, \
  96. type b, \
  97. int *err) \
  98. { \
  99. if (!((a < 0) ^ (b < 0)) \
  100. || (b > 0 && a >= min + b) \
  101. || (b < 0 && a <= max + b) \
  102. || b == 0) \
  103. return a - b; \
  104. *err |= 1; \
  105. return a < 0 ? min : max; \
  106. }
  107. # endif /* has(__builtin_sub_overflow) */
  108. # define OSSL_SAFE_MATH_SUBU(type_name, type) \
  109. static ossl_inline ossl_unused type safe_sub_ ## type_name(type a, \
  110. type b, \
  111. int *err) \
  112. { \
  113. if (b > a) \
  114. *err |= 1; \
  115. return a - b; \
  116. }
  117. /*
  118. * Safe multiplication helpers
  119. */
  120. # if has(__builtin_mul_overflow)
  121. # define OSSL_SAFE_MATH_MULS(type_name, type, min, max) \
  122. static ossl_inline ossl_unused type safe_mul_ ## type_name(type a, \
  123. type b, \
  124. int *err) \
  125. { \
  126. type r; \
  127. \
  128. if (!__builtin_mul_overflow(a, b, &r)) \
  129. return r; \
  130. *err |= 1; \
  131. return (a < 0) ^ (b < 0) ? min : max; \
  132. }
  133. # define OSSL_SAFE_MATH_MULU(type_name, type, max) \
  134. static ossl_inline ossl_unused type safe_mul_ ## type_name(type a, \
  135. type b, \
  136. int *err) \
  137. { \
  138. type r; \
  139. \
  140. if (!__builtin_mul_overflow(a, b, &r)) \
  141. return r; \
  142. *err |= 1; \
  143. return a * b; \
  144. }
  145. # else /* has(__builtin_mul_overflow) */
  146. # define OSSL_SAFE_MATH_MULS(type_name, type, min, max) \
  147. static ossl_inline ossl_unused type safe_mul_ ## type_name(type a, \
  148. type b, \
  149. int *err) \
  150. { \
  151. if (a == 0 || b == 0) \
  152. return 0; \
  153. if (a == 1) \
  154. return b; \
  155. if (b == 1) \
  156. return a; \
  157. if (a != min && b != min) { \
  158. const type x = a < 0 ? -a : a; \
  159. const type y = b < 0 ? -b : b; \
  160. \
  161. if (x <= max / y) \
  162. return a * b; \
  163. } \
  164. *err |= 1; \
  165. return (a < 0) ^ (b < 0) ? min : max; \
  166. }
  167. # define OSSL_SAFE_MATH_MULU(type_name, type, max) \
  168. static ossl_inline ossl_unused type safe_mul_ ## type_name(type a, \
  169. type b, \
  170. int *err) \
  171. { \
  172. if (b != 0 && a > max / b) \
  173. *err |= 1; \
  174. return a * b; \
  175. }
  176. # endif /* has(__builtin_mul_overflow) */
  177. /*
  178. * Safe division helpers
  179. */
  180. # define OSSL_SAFE_MATH_DIVS(type_name, type, min, max) \
  181. static ossl_inline ossl_unused type safe_div_ ## type_name(type a, \
  182. type b, \
  183. int *err) \
  184. { \
  185. if (b == 0) { \
  186. *err |= 1; \
  187. return a < 0 ? min : max; \
  188. } \
  189. if (b == -1 && a == min) { \
  190. *err |= 1; \
  191. return max; \
  192. } \
  193. return a / b; \
  194. }
  195. # define OSSL_SAFE_MATH_DIVU(type_name, type, max) \
  196. static ossl_inline ossl_unused type safe_div_ ## type_name(type a, \
  197. type b, \
  198. int *err) \
  199. { \
  200. if (b != 0) \
  201. return a / b; \
  202. *err |= 1; \
  203. return max; \
  204. }
  205. /*
  206. * Safe modulus helpers
  207. */
  208. # define OSSL_SAFE_MATH_MODS(type_name, type, min, max) \
  209. static ossl_inline ossl_unused type safe_mod_ ## type_name(type a, \
  210. type b, \
  211. int *err) \
  212. { \
  213. if (b == 0) { \
  214. *err |= 1; \
  215. return 0; \
  216. } \
  217. if (b == -1 && a == min) { \
  218. *err |= 1; \
  219. return max; \
  220. } \
  221. return a % b; \
  222. }
  223. # define OSSL_SAFE_MATH_MODU(type_name, type) \
  224. static ossl_inline ossl_unused type safe_mod_ ## type_name(type a, \
  225. type b, \
  226. int *err) \
  227. { \
  228. if (b != 0) \
  229. return a % b; \
  230. *err |= 1; \
  231. return 0; \
  232. }
  233. /*
  234. * Safe negation helpers
  235. */
  236. # define OSSL_SAFE_MATH_NEGS(type_name, type, min) \
  237. static ossl_inline ossl_unused type safe_neg_ ## type_name(type a, \
  238. int *err) \
  239. { \
  240. if (a != min) \
  241. return -a; \
  242. *err |= 1; \
  243. return min; \
  244. }
  245. # define OSSL_SAFE_MATH_NEGU(type_name, type) \
  246. static ossl_inline ossl_unused type safe_neg_ ## type_name(type a, \
  247. int *err) \
  248. { \
  249. if (a == 0) \
  250. return a; \
  251. *err |= 1; \
  252. return 1 + ~a; \
  253. }
  254. /*
  255. * Safe absolute value helpers
  256. */
  257. # define OSSL_SAFE_MATH_ABSS(type_name, type, min) \
  258. static ossl_inline ossl_unused type safe_abs_ ## type_name(type a, \
  259. int *err) \
  260. { \
  261. if (a != min) \
  262. return a < 0 ? -a : a; \
  263. *err |= 1; \
  264. return min; \
  265. }
  266. # define OSSL_SAFE_MATH_ABSU(type_name, type) \
  267. static ossl_inline ossl_unused type safe_abs_ ## type_name(type a, \
  268. int *err) \
  269. { \
  270. return a; \
  271. }
  272. /*
  273. * Safe fused multiply divide helpers
  274. *
  275. * These are a bit obscure:
  276. * . They begin by checking the denominator for zero and getting rid of this
  277. * corner case.
  278. *
  279. * . Second is an attempt to do the multiplication directly, if it doesn't
  280. * overflow, the quotient is returned (for signed values there is a
  281. * potential problem here which isn't present for unsigned).
  282. *
  283. * . Finally, the multiplication/division is transformed so that the larger
  284. * of the numerators is divided first. This requires a remainder
  285. * correction:
  286. *
  287. * a b / c = (a / c) b + (a mod c) b / c, where a > b
  288. *
  289. * The individual operations need to be overflow checked (again signed
  290. * being more problematic).
  291. *
  292. * The algorithm used is not perfect but it should be "good enough".
  293. */
  294. # define OSSL_SAFE_MATH_MULDIVS(type_name, type, max) \
  295. static ossl_inline ossl_unused type safe_muldiv_ ## type_name(type a, \
  296. type b, \
  297. type c, \
  298. int *err) \
  299. { \
  300. int e2 = 0; \
  301. type q, r, x, y; \
  302. \
  303. if (c == 0) { \
  304. *err |= 1; \
  305. return a == 0 || b == 0 ? 0 : max; \
  306. } \
  307. x = safe_mul_ ## type_name(a, b, &e2); \
  308. if (!e2) \
  309. return safe_div_ ## type_name(x, c, err); \
  310. if (b > a) { \
  311. x = b; \
  312. b = a; \
  313. a = x; \
  314. } \
  315. q = safe_div_ ## type_name(a, c, err); \
  316. r = safe_mod_ ## type_name(a, c, err); \
  317. x = safe_mul_ ## type_name(r, b, err); \
  318. y = safe_mul_ ## type_name(q, b, err); \
  319. q = safe_div_ ## type_name(x, c, err); \
  320. return safe_add_ ## type_name(y, q, err); \
  321. }
  322. # define OSSL_SAFE_MATH_MULDIVU(type_name, type, max) \
  323. static ossl_inline ossl_unused type safe_muldiv_ ## type_name(type a, \
  324. type b, \
  325. type c, \
  326. int *err) \
  327. { \
  328. int e2 = 0; \
  329. type x, y; \
  330. \
  331. if (c == 0) { \
  332. *err |= 1; \
  333. return a == 0 || b == 0 ? 0 : max; \
  334. } \
  335. x = safe_mul_ ## type_name(a, b, &e2); \
  336. if (!e2) \
  337. return x / c; \
  338. if (b > a) { \
  339. x = b; \
  340. b = a; \
  341. a = x; \
  342. } \
  343. x = safe_mul_ ## type_name(a % c, b, err); \
  344. y = safe_mul_ ## type_name(a / c, b, err); \
  345. return safe_add_ ## type_name(y, x / c, err); \
  346. }
  347. /*
  348. * Calculate a / b rounding up:
  349. * i.e. a / b + (a % b != 0)
  350. * Which is usually (less safely) converted to (a + b - 1) / b
  351. * If you *know* that b != 0, then it's safe to ignore err.
  352. */
  353. #define OSSL_SAFE_MATH_DIV_ROUND_UP(type_name, type, max) \
  354. static ossl_inline ossl_unused type safe_div_round_up_ ## type_name \
  355. (type a, type b, int *errp) \
  356. { \
  357. type x; \
  358. int *err, err_local = 0; \
  359. \
  360. /* Allow errors to be ignored by callers */ \
  361. err = errp != NULL ? errp : &err_local; \
  362. /* Fast path, both positive */ \
  363. if (b > 0 && a > 0) { \
  364. /* Faster path: no overflow concerns */ \
  365. if (a < max - b) \
  366. return (a + b - 1) / b; \
  367. return a / b + (a % b != 0); \
  368. } \
  369. if (b == 0) { \
  370. *err |= 1; \
  371. return a == 0 ? 0 : max; \
  372. } \
  373. if (a == 0) \
  374. return 0; \
  375. /* Rather slow path because there are negatives involved */ \
  376. x = safe_mod_ ## type_name(a, b, err); \
  377. return safe_add_ ## type_name(safe_div_ ## type_name(a, b, err), \
  378. x != 0, err); \
  379. }
  380. /* Calculate ranges of types */
  381. # define OSSL_SAFE_MATH_MINS(type) ((type)1 << (sizeof(type) * 8 - 1))
  382. # define OSSL_SAFE_MATH_MAXS(type) (~OSSL_SAFE_MATH_MINS(type))
  383. # define OSSL_SAFE_MATH_MAXU(type) (~(type)0)
  384. /*
  385. * Wrapper macros to create all the functions of a given type
  386. */
  387. # define OSSL_SAFE_MATH_SIGNED(type_name, type) \
  388. OSSL_SAFE_MATH_ADDS(type_name, type, OSSL_SAFE_MATH_MINS(type), \
  389. OSSL_SAFE_MATH_MAXS(type)) \
  390. OSSL_SAFE_MATH_SUBS(type_name, type, OSSL_SAFE_MATH_MINS(type), \
  391. OSSL_SAFE_MATH_MAXS(type)) \
  392. OSSL_SAFE_MATH_MULS(type_name, type, OSSL_SAFE_MATH_MINS(type), \
  393. OSSL_SAFE_MATH_MAXS(type)) \
  394. OSSL_SAFE_MATH_DIVS(type_name, type, OSSL_SAFE_MATH_MINS(type), \
  395. OSSL_SAFE_MATH_MAXS(type)) \
  396. OSSL_SAFE_MATH_MODS(type_name, type, OSSL_SAFE_MATH_MINS(type), \
  397. OSSL_SAFE_MATH_MAXS(type)) \
  398. OSSL_SAFE_MATH_DIV_ROUND_UP(type_name, type, \
  399. OSSL_SAFE_MATH_MAXS(type)) \
  400. OSSL_SAFE_MATH_MULDIVS(type_name, type, OSSL_SAFE_MATH_MAXS(type)) \
  401. OSSL_SAFE_MATH_NEGS(type_name, type, OSSL_SAFE_MATH_MINS(type)) \
  402. OSSL_SAFE_MATH_ABSS(type_name, type, OSSL_SAFE_MATH_MINS(type))
  403. # define OSSL_SAFE_MATH_UNSIGNED(type_name, type) \
  404. OSSL_SAFE_MATH_ADDU(type_name, type, OSSL_SAFE_MATH_MAXU(type)) \
  405. OSSL_SAFE_MATH_SUBU(type_name, type) \
  406. OSSL_SAFE_MATH_MULU(type_name, type, OSSL_SAFE_MATH_MAXU(type)) \
  407. OSSL_SAFE_MATH_DIVU(type_name, type, OSSL_SAFE_MATH_MAXU(type)) \
  408. OSSL_SAFE_MATH_MODU(type_name, type) \
  409. OSSL_SAFE_MATH_DIV_ROUND_UP(type_name, type, \
  410. OSSL_SAFE_MATH_MAXU(type)) \
  411. OSSL_SAFE_MATH_MULDIVU(type_name, type, OSSL_SAFE_MATH_MAXU(type)) \
  412. OSSL_SAFE_MATH_NEGU(type_name, type) \
  413. OSSL_SAFE_MATH_ABSU(type_name, type)
  414. #endif /* OSSL_INTERNAL_SAFE_MATH_H */