| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445 | /* * Copyright 2021-2022 The OpenSSL Project Authors. All Rights Reserved. * * Licensed under the Apache License 2.0 (the "License").  You may not use * this file except in compliance with the License.  You can obtain a copy * in the file LICENSE in the source distribution or at * https://www.openssl.org/source/license.html */#ifndef OSSL_INTERNAL_SAFE_MATH_H# define OSSL_INTERNAL_SAFE_MATH_H# pragma once# include <openssl/e_os2.h>              /* For 'ossl_inline' */# ifndef OPENSSL_NO_BUILTIN_OVERFLOW_CHECKING#  ifdef __has_builtin#   define has(func) __has_builtin(func)#  elif defined(__GNUC__)#   if __GNUC__ > 5#    define has(func) 1#   endif#  endif# endif /* OPENSSL_NO_BUILTIN_OVERFLOW_CHECKING */# ifndef has#  define has(func) 0# endif/* * Safe addition helpers */# if has(__builtin_add_overflow)#  define OSSL_SAFE_MATH_ADDS(type_name, type, min, max) \    static ossl_inline ossl_unused type safe_add_ ## type_name(type a,       \                                                               type b,       \                                                               int *err)     \    {                                                                        \        type r;                                                              \                                                                             \        if (!__builtin_add_overflow(a, b, &r))                               \            return r;                                                        \        *err |= 1;                                                           \        return a < 0 ? min : max;                                            \    }#  define OSSL_SAFE_MATH_ADDU(type_name, type, max) \    static ossl_inline ossl_unused type safe_add_ ## type_name(type a,       \                                                               type b,       \                                                               int *err)     \    {                                                                        \        type r;                                                              \                                                                             \        if (!__builtin_add_overflow(a, b, &r))                               \            return r;                                                        \        *err |= 1;                                                           \        return a + b;                                                            \    }# else  /* has(__builtin_add_overflow) */#  define OSSL_SAFE_MATH_ADDS(type_name, type, min, max) \    static ossl_inline ossl_unused type safe_add_ ## type_name(type a,       \                                                               type b,       \                                                               int *err)     \    {                                                                        \        if ((a < 0) ^ (b < 0)                                                \                || (a > 0 && b <= max - a)                                   \                || (a < 0 && b >= min - a)                                   \                || a == 0)                                                   \            return a + b;                                                    \        *err |= 1;                                                           \        return a < 0 ? min : max;                                            \    }#  define OSSL_SAFE_MATH_ADDU(type_name, type, max) \    static ossl_inline ossl_unused type safe_add_ ## type_name(type a,       \                                                               type b,       \                                                               int *err)     \    {                                                                        \        if (b > max - a)                                                     \            *err |= 1;                                                       \        return a + b;                                                        \    }# endif /* has(__builtin_add_overflow) *//* * Safe subtraction helpers */# if has(__builtin_sub_overflow)#  define OSSL_SAFE_MATH_SUBS(type_name, type, min, max) \    static ossl_inline ossl_unused type safe_sub_ ## type_name(type a,       \                                                               type b,       \                                                               int *err)     \    {                                                                        \        type r;                                                              \                                                                             \        if (!__builtin_sub_overflow(a, b, &r))                               \            return r;                                                        \        *err |= 1;                                                           \        return a < 0 ? min : max;                                            \    }# else  /* has(__builtin_sub_overflow) */#  define OSSL_SAFE_MATH_SUBS(type_name, type, min, max) \    static ossl_inline ossl_unused type safe_sub_ ## type_name(type a,       \                                                               type b,       \                                                               int *err)     \    {                                                                        \        if (!((a < 0) ^ (b < 0))                                             \                || (b > 0 && a >= min + b)                                   \                || (b < 0 && a <= max + b)                                   \                || b == 0)                                                   \            return a - b;                                                    \        *err |= 1;                                                           \        return a < 0 ? min : max;                                            \    }# endif /* has(__builtin_sub_overflow) */# define OSSL_SAFE_MATH_SUBU(type_name, type) \    static ossl_inline ossl_unused type safe_sub_ ## type_name(type a,       \                                                               type b,       \                                                               int *err)     \    {                                                                        \        if (b > a)                                                           \            *err |= 1;                                                       \        return a - b;                                                        \    }/* * Safe multiplication helpers */# if has(__builtin_mul_overflow)#  define OSSL_SAFE_MATH_MULS(type_name, type, min, max) \    static ossl_inline ossl_unused type safe_mul_ ## type_name(type a,       \                                                               type b,       \                                                               int *err)     \    {                                                                        \        type r;                                                              \                                                                             \        if (!__builtin_mul_overflow(a, b, &r))                               \            return r;                                                        \        *err |= 1;                                                           \        return (a < 0) ^ (b < 0) ? min : max;                                \    }#  define OSSL_SAFE_MATH_MULU(type_name, type, max) \    static ossl_inline ossl_unused type safe_mul_ ## type_name(type a,       \                                                               type b,       \                                                               int *err)     \    {                                                                        \        type r;                                                              \                                                                             \        if (!__builtin_mul_overflow(a, b, &r))                               \            return r;                                                        \        *err |= 1;                                                           \        return a * b;                                                          \    }# else  /* has(__builtin_mul_overflow) */#  define OSSL_SAFE_MATH_MULS(type_name, type, min, max) \    static ossl_inline ossl_unused type safe_mul_ ## type_name(type a,       \                                                               type b,       \                                                               int *err)     \    {                                                                        \        if (a == 0 || b == 0)                                                \            return 0;                                                        \        if (a == 1)                                                          \            return b;                                                        \        if (b == 1)                                                          \            return a;                                                        \        if (a != min && b != min) {                                          \            const type x = a < 0 ? -a : a;                                   \            const type y = b < 0 ? -b : b;                                   \                                                                             \            if (x <= max / y)                                                \                return a * b;                                                \        }                                                                    \        *err |= 1;                                                           \        return (a < 0) ^ (b < 0) ? min : max;                                \    }#  define OSSL_SAFE_MATH_MULU(type_name, type, max) \    static ossl_inline ossl_unused type safe_mul_ ## type_name(type a,       \                                                               type b,       \                                                               int *err)     \    {                                                                        \        if (b != 0 && a > max / b)                                           \            *err |= 1;                                                       \        return a * b;                                                        \    }# endif /* has(__builtin_mul_overflow) *//* * Safe division helpers */# define OSSL_SAFE_MATH_DIVS(type_name, type, min, max) \    static ossl_inline ossl_unused type safe_div_ ## type_name(type a,       \                                                               type b,       \                                                               int *err)     \    {                                                                        \        if (b == 0) {                                                        \            *err |= 1;                                                       \            return a < 0 ? min : max;                                        \        }                                                                    \        if (b == -1 && a == min) {                                           \            *err |= 1;                                                       \            return max;                                                      \        }                                                                    \        return a / b;                                                        \    }# define OSSL_SAFE_MATH_DIVU(type_name, type, max) \    static ossl_inline ossl_unused type safe_div_ ## type_name(type a,       \                                                               type b,       \                                                               int *err)     \    {                                                                        \        if (b != 0)                                                          \            return a / b;                                                    \        *err |= 1;                                                           \        return max;                                                        \    }/* * Safe modulus helpers */# define OSSL_SAFE_MATH_MODS(type_name, type, min, max) \    static ossl_inline ossl_unused type safe_mod_ ## type_name(type a,       \                                                               type b,       \                                                               int *err)     \    {                                                                        \        if (b == 0) {                                                        \            *err |= 1;                                                       \            return 0;                                                        \        }                                                                    \        if (b == -1 && a == min) {                                           \            *err |= 1;                                                       \            return max;                                                      \        }                                                                    \        return a % b;                                                        \    }# define OSSL_SAFE_MATH_MODU(type_name, type) \    static ossl_inline ossl_unused type safe_mod_ ## type_name(type a,       \                                                               type b,       \                                                               int *err)     \    {                                                                        \        if (b != 0)                                                          \            return a % b;                                                    \        *err |= 1;                                                           \        return 0;                                                            \    }/* * Safe negation helpers */# define OSSL_SAFE_MATH_NEGS(type_name, type, min) \    static ossl_inline ossl_unused type safe_neg_ ## type_name(type a,       \                                                               int *err)     \    {                                                                        \        if (a != min)                                                        \            return -a;                                                       \        *err |= 1;                                                           \        return min;                                                          \    }# define OSSL_SAFE_MATH_NEGU(type_name, type) \    static ossl_inline ossl_unused type safe_neg_ ## type_name(type a,       \                                                               int *err)     \    {                                                                        \        if (a == 0)                                                          \            return a;                                                        \        *err |= 1;                                                           \        return 1 + ~a;                                                       \    }/* * Safe absolute value helpers */# define OSSL_SAFE_MATH_ABSS(type_name, type, min) \    static ossl_inline ossl_unused type safe_abs_ ## type_name(type a,       \                                                               int *err)     \    {                                                                        \        if (a != min)                                                        \            return a < 0 ? -a : a;                                           \        *err |= 1;                                                           \        return min;                                                          \    }# define OSSL_SAFE_MATH_ABSU(type_name, type) \    static ossl_inline ossl_unused type safe_abs_ ## type_name(type a,       \                                                               int *err)     \    {                                                                        \        return a;                                                            \    }/* * Safe fused multiply divide helpers * * These are a bit obscure: *    . They begin by checking the denominator for zero and getting rid of this *      corner case. * *    . Second is an attempt to do the multiplication directly, if it doesn't *      overflow, the quotient is returned (for signed values there is a *      potential problem here which isn't present for unsigned). * *    . Finally, the multiplication/division is transformed so that the larger *      of the numerators is divided first.  This requires a remainder *      correction: * *          a b / c = (a / c) b + (a mod c) b / c, where a > b * *      The individual operations need to be overflow checked (again signed *      being more problematic). * * The algorithm used is not perfect but it should be "good enough". */# define OSSL_SAFE_MATH_MULDIVS(type_name, type, max) \    static ossl_inline ossl_unused type safe_muldiv_ ## type_name(type a,    \                                                                  type b,    \                                                                  type c,    \                                                                  int *err)  \    {                                                                        \        int e2 = 0;                                                          \        type q, r, x, y;                                                     \                                                                             \        if (c == 0) {                                                        \            *err |= 1;                                                       \            return a == 0 || b == 0 ? 0 : max;                               \        }                                                                    \        x = safe_mul_ ## type_name(a, b, &e2);                               \        if (!e2)                                                             \            return safe_div_ ## type_name(x, c, err);                        \        if (b > a) {                                                         \            x = b;                                                           \            b = a;                                                           \            a = x;                                                           \        }                                                                    \        q = safe_div_ ## type_name(a, c, err);                               \        r = safe_mod_ ## type_name(a, c, err);                               \        x = safe_mul_ ## type_name(r, b, err);                               \        y = safe_mul_ ## type_name(q, b, err);                               \        q = safe_div_ ## type_name(x, c, err);                               \        return safe_add_ ## type_name(y, q, err);                            \    }# define OSSL_SAFE_MATH_MULDIVU(type_name, type, max) \    static ossl_inline ossl_unused type safe_muldiv_ ## type_name(type a,    \                                                                  type b,    \                                                                  type c,    \                                                                  int *err)  \    {                                                                        \        int e2 = 0;                                                          \        type x, y;                                                           \                                                                             \        if (c == 0) {                                                        \            *err |= 1;                                                       \            return a == 0 || b == 0 ? 0 : max;                               \        }                                                                    \        x = safe_mul_ ## type_name(a, b, &e2);                               \        if (!e2)                                                             \            return x / c;                                                    \        if (b > a) {                                                         \            x = b;                                                           \            b = a;                                                           \            a = x;                                                           \        }                                                                    \        x = safe_mul_ ## type_name(a % c, b, err);                           \        y = safe_mul_ ## type_name(a / c, b, err);                           \        return safe_add_ ## type_name(y, x / c, err);                        \    }/* * Calculate a / b rounding up: *     i.e. a / b + (a % b != 0) * Which is usually (less safely) converted to (a + b - 1) / b * If you *know* that b != 0, then it's safe to ignore err. */#define OSSL_SAFE_MATH_DIV_ROUND_UP(type_name, type, max) \    static ossl_inline ossl_unused type safe_div_round_up_ ## type_name      \        (type a, type b, int *errp)                                          \    {                                                                        \        type x;                                                              \        int *err, err_local = 0;                                             \                                                                             \        /* Allow errors to be ignored by callers */                          \        err = errp != NULL ? errp : &err_local;                              \        /* Fast path, both positive */                                       \        if (b > 0 && a > 0) {                                                \            /* Faster path: no overflow concerns */                          \            if (a < max - b)                                                 \                return (a + b - 1) / b;                                      \            return a / b + (a % b != 0);                                     \        }                                                                    \        if (b == 0) {                                                        \            *err |= 1;                                                       \            return a == 0 ? 0 : max;                                         \        }                                                                    \        if (a == 0)                                                          \            return 0;                                                        \        /* Rather slow path because there are negatives involved */          \        x = safe_mod_ ## type_name(a, b, err);                               \        return safe_add_ ## type_name(safe_div_ ## type_name(a, b, err),     \                                      x != 0, err);                          \    }/* Calculate ranges of types */# define OSSL_SAFE_MATH_MINS(type) ((type)1 << (sizeof(type) * 8 - 1))# define OSSL_SAFE_MATH_MAXS(type) (~OSSL_SAFE_MATH_MINS(type))# define OSSL_SAFE_MATH_MAXU(type) (~(type)0)/* * Wrapper macros to create all the functions of a given type */# define OSSL_SAFE_MATH_SIGNED(type_name, type)                         \    OSSL_SAFE_MATH_ADDS(type_name, type, OSSL_SAFE_MATH_MINS(type),     \                        OSSL_SAFE_MATH_MAXS(type))                      \    OSSL_SAFE_MATH_SUBS(type_name, type, OSSL_SAFE_MATH_MINS(type),     \                        OSSL_SAFE_MATH_MAXS(type))                      \    OSSL_SAFE_MATH_MULS(type_name, type, OSSL_SAFE_MATH_MINS(type),     \                        OSSL_SAFE_MATH_MAXS(type))                      \    OSSL_SAFE_MATH_DIVS(type_name, type, OSSL_SAFE_MATH_MINS(type),     \                        OSSL_SAFE_MATH_MAXS(type))                      \    OSSL_SAFE_MATH_MODS(type_name, type, OSSL_SAFE_MATH_MINS(type),     \                        OSSL_SAFE_MATH_MAXS(type))                      \    OSSL_SAFE_MATH_DIV_ROUND_UP(type_name, type,                        \                                OSSL_SAFE_MATH_MAXS(type))              \    OSSL_SAFE_MATH_MULDIVS(type_name, type, OSSL_SAFE_MATH_MAXS(type))  \    OSSL_SAFE_MATH_NEGS(type_name, type, OSSL_SAFE_MATH_MINS(type))     \    OSSL_SAFE_MATH_ABSS(type_name, type, OSSL_SAFE_MATH_MINS(type))# define OSSL_SAFE_MATH_UNSIGNED(type_name, type) \    OSSL_SAFE_MATH_ADDU(type_name, type, OSSL_SAFE_MATH_MAXU(type))     \    OSSL_SAFE_MATH_SUBU(type_name, type)                                \    OSSL_SAFE_MATH_MULU(type_name, type, OSSL_SAFE_MATH_MAXU(type))     \    OSSL_SAFE_MATH_DIVU(type_name, type, OSSL_SAFE_MATH_MAXU(type))     \    OSSL_SAFE_MATH_MODU(type_name, type)                                \    OSSL_SAFE_MATH_DIV_ROUND_UP(type_name, type,                        \                                OSSL_SAFE_MATH_MAXU(type))              \    OSSL_SAFE_MATH_MULDIVU(type_name, type, OSSL_SAFE_MATH_MAXU(type))  \    OSSL_SAFE_MATH_NEGU(type_name, type)                                \    OSSL_SAFE_MATH_ABSU(type_name, type)#endif                          /* OSSL_INTERNAL_SAFE_MATH_H */
 |