diff --git a/library/ssl_msg.c b/library/ssl_msg.c index b4e4aea3d..075345d36 100644 --- a/library/ssl_msg.c +++ b/library/ssl_msg.c @@ -1044,6 +1044,25 @@ int mbedtls_ssl_encrypt_buf( mbedtls_ssl_context *ssl, } #if defined(MBEDTLS_SSL_SOME_SUITES_USE_TLS_CBC) +/* + * Turn a bit into a mask: + * - if bit == 1, return the all-bits 1 mask, aka (size_t) -1 + * - if bit == 0, return the all-bits 0 mask, aka 0 + */ +static size_t mbedtls_ssl_cf_mask_from_bit( size_t bit ) +{ + /* MSVC has a warning about unary minus on unsigned integer types, + * but this is well-defined and precisely what we want to do here. */ +#if defined(_MSC_VER) +#pragma warning( push ) +#pragma warning( disable : 4146 ) +#endif + return -bit; +#if defined(_MSC_VER) +#pragma warning( pop ) +#endif +} + /* * Constant-flow mask generation for "less than" comparison: * - if x < y, return all bits 1, that is (size_t) -1 @@ -1052,7 +1071,7 @@ int mbedtls_ssl_encrypt_buf( mbedtls_ssl_context *ssl, * Use only bit operations to avoid branches that could be used by some * compilers on some platforms to translate comparison operators. */ -static size_t mbedtls_ssl_cf_mask_lt(size_t x, size_t y) +static size_t mbedtls_ssl_cf_mask_lt( size_t x, size_t y ) { /* This has the msb set if and only if x < y */ const size_t sub = x - y; @@ -1060,17 +1079,8 @@ static size_t mbedtls_ssl_cf_mask_lt(size_t x, size_t y) /* sub1 = (x < y) in {0, 1} */ const size_t sub1 = sub >> ( sizeof( sub ) * 8 - 1 ); - /* MSVC has a warning about unary minus on unsigned integer types, - * but this is well-defined and precisely what we want to do here. */ -#if defined(_MSC_VER) -#pragma warning( push ) -#pragma warning( disable : 4146 ) -#endif /* mask = (x < y) ? 0xff... : 0x00... */ - const size_t mask = -sub1; -#if defined(_MSC_VER) -#pragma warning( pop ) -#endif + const size_t mask = mbedtls_ssl_cf_mask_from_bit( sub1 ); return( mask ); } @@ -1083,9 +1093,9 @@ static size_t mbedtls_ssl_cf_mask_lt(size_t x, size_t y) * Use only bit operations to avoid branches that could be used by some * compilers on some platforms to translate comparison operators. */ -static size_t mbedtls_ssl_cf_mask_ge(size_t x, size_t y) +static size_t mbedtls_ssl_cf_mask_ge( size_t x, size_t y ) { - return( ~mbedtls_ssl_cf_mask_lt(x, y) ); + return( ~mbedtls_ssl_cf_mask_lt( x, y ) ); } /* @@ -1095,7 +1105,7 @@ static size_t mbedtls_ssl_cf_mask_ge(size_t x, size_t y) * Use only bit operations to avoid branches that could be used by some * compilers on some platforms to translate comparison operators. */ -static size_t mbedtls_ssl_cf_bool_eq(size_t x, size_t y) +static size_t mbedtls_ssl_cf_bool_eq( size_t x, size_t y ) { /* diff = 0 if x == y, non-zero otherwise */ const size_t diff = x ^ y; @@ -1134,32 +1144,13 @@ static void mbedtls_ssl_cf_memcpy_if_eq( unsigned char *dst, size_t len, size_t c1, size_t c2 ) { - /* diff = 0 if c1 == c2, non-zero otherwise */ - const size_t diff = c1 ^ c2; - - /* MSVC has a warning about unary minus on unsigned integer types, - * but this is well-defined and precisely what we want to do here. */ -#if defined(_MSC_VER) -#pragma warning( push ) -#pragma warning( disable : 4146 ) -#endif - - /* diff_msb's most significant bit is equal to c1 != c2 */ - const size_t diff_msb = ( diff | -diff ); - - /* diff1 = (c1 != c2) in {0, 1} */ - const size_t diff1 = diff_msb >> ( sizeof( diff_msb ) * 8 - 1 ); - - /* mask = c1 != c2 ? 0xff : 0x00 */ - const unsigned char mask = (unsigned char) -diff1; - -#if defined(_MSC_VER) -#pragma warning( pop ) -#endif + /* mask = c1 == c2 ? 0xff : 0x00 */ + const size_t equal = mbedtls_ssl_cf_bool_eq( c1, c2 ); + const unsigned char mask = mbedtls_ssl_cf_mask_from_bit( equal ); /* dst[i] = c1 != c2 ? dst[i] : src[i] */ for( size_t i = 0; i < len; i++ ) - dst[i] = ( dst[i] & mask ) | ( src[i] & ~mask ); + dst[i] = ( dst[i] & ~mask ) | ( src[i] & mask ); } /*