diff --git a/include/mbedtls/bignum.h b/include/mbedtls/bignum.h index 3f6cdd1f9..d4aedfc39 100644 --- a/include/mbedtls/bignum.h +++ b/include/mbedtls/bignum.h @@ -595,23 +595,22 @@ int mbedtls_mpi_cmp_abs( const mbedtls_mpi *X, const mbedtls_mpi *Y ); int mbedtls_mpi_cmp_mpi( const mbedtls_mpi *X, const mbedtls_mpi *Y ); /** - * \brief Compare two MPIs in constant time. + * \brief Check if an MPI is less than the other in constant time. * * \param X The left-hand MPI. This must point to an initialized MPI * with the same allocated length as Y. * \param Y The right-hand MPI. This must point to an initialized MPI * with the same allocated length as X. * \param ret The result of the comparison: - * \c 1 if \p X is greater than \p Y. - * \c -1 if \p X is lesser than \p Y. - * \c 0 if \p X is equal to \p Y. + * \c 1 if \p X is less than \p Y. + * \c 0 if \p X is greater than or equal to \p Y. * * \return 0 on success. * \return MBEDTLS_ERR_MPI_BAD_INPUT_DATA if the allocated length of * the two input MPIs is not the same. */ -int mbedtls_mpi_cmp_mpi_ct( const mbedtls_mpi *X, const mbedtls_mpi *Y, - int *ret ); +int mbedtls_mpi_lt_mpi_ct( const mbedtls_mpi *X, const mbedtls_mpi *Y, + unsigned *ret ); /** * \brief Compare an MPI with an integer. diff --git a/library/bignum.c b/library/bignum.c index b90404512..65696470d 100644 --- a/library/bignum.c +++ b/library/bignum.c @@ -1148,7 +1148,8 @@ int mbedtls_mpi_cmp_mpi( const mbedtls_mpi *X, const mbedtls_mpi *Y ) return( 0 ); } -static int ct_lt_mpi_uint( const mbedtls_mpi_uint x, const mbedtls_mpi_uint y ) +static unsigned ct_lt_mpi_uint( const mbedtls_mpi_uint x, + const mbedtls_mpi_uint y ) { mbedtls_mpi_uint ret; mbedtls_mpi_uint cond; @@ -1175,16 +1176,11 @@ static int ct_lt_mpi_uint( const mbedtls_mpi_uint x, const mbedtls_mpi_uint y ) return ret; } -static int ct_bool_get_mask( unsigned int b ) -{ - return ~( b - 1 ); -} - /* * Compare signed values in constant time */ -int mbedtls_mpi_cmp_mpi_ct( const mbedtls_mpi *X, const mbedtls_mpi *Y, - int *ret ) +int mbedtls_mpi_lt_mpi_ct( const mbedtls_mpi *X, const mbedtls_mpi *Y, + unsigned *ret ) { size_t i; unsigned int cond, done, sign_X, sign_Y; @@ -1197,45 +1193,49 @@ int mbedtls_mpi_cmp_mpi_ct( const mbedtls_mpi *X, const mbedtls_mpi *Y, return MBEDTLS_ERR_MPI_BAD_INPUT_DATA; /* - * if( X->s > 0 && Y->s < 0 ) - * { - * *ret = 1; - * done = 1; - * } - * else if( Y->s > 0 && X->s < 0 ) - * { - * *ret = -1; - * done = 1; - * } + * Get sign bits of the signs. */ sign_X = X->s; + sign_X = sign_X >> ( sizeof( unsigned int ) * 8 - 1 ); sign_Y = Y->s; - cond = ( ( sign_X ^ sign_Y ) >> ( sizeof( unsigned int ) * 8 - 1 ) ); - *ret = ct_bool_get_mask( cond ) & X->s; + sign_Y = sign_Y >> ( sizeof( unsigned int ) * 8 - 1 ); + + /* + * If the signs are different, then the positive operand is the bigger. + * That is if X is negative (sign bit 1), then X < Y is true and it is false + * if X is positive (sign bit 0). + */ + cond = ( sign_X ^ sign_Y ); + *ret = cond & sign_X; + + /* + * This is a constant time function, we might have the result, but we still + * need to go through the loop. Record if we have the result already. + */ done = cond; for( i = X->n; i > 0; i-- ) { /* - * if( ( X->p[i - 1] > Y->p[i - 1] ) && !done ) - * { - * done = 1; - * *ret = X->s; - * } + * If Y->p[i - 1] < X->p[i - 1] and both X and Y are negative, then + * X < Y. + * + * Again even if we can make a decision, we just mark the result and + * the fact that we are done and continue looping. */ - cond = ct_lt_mpi_uint( Y->p[i - 1], X->p[i - 1] ); - *ret |= ct_bool_get_mask( cond & ( 1 - done ) ) & X->s; + cond = ct_lt_mpi_uint( Y->p[i - 1], X->p[i - 1] ) & sign_X; + *ret |= cond & ( 1 - done ); done |= cond & ( 1 - done ); /* - * if( ( X->p[i - 1] < Y->p[i - 1] ) && !done ) - * { - * done = 1; - * *ret = -X->s; - * } + * If X->p[i - 1] < Y->p[i - 1] and both X and Y are positive, then + * X < Y. + * + * Again even if we can make a decision, we just mark the result and + * the fact that we are done and continue looping. */ - cond = ct_lt_mpi_uint( X->p[i - 1], Y->p[i - 1] ); - *ret |= ct_bool_get_mask( cond & ( 1 - done ) ) & -X->s; + cond = ct_lt_mpi_uint( X->p[i - 1], Y->p[i - 1] ) & ( 1 - sign_X ); + *ret |= cond & ( 1 - done ); done |= cond & ( 1 - done ); } diff --git a/library/ecp.c b/library/ecp.c index b0ef3ca47..a58e8a6e0 100644 --- a/library/ecp.c +++ b/library/ecp.c @@ -2803,7 +2803,7 @@ int mbedtls_ecp_gen_privkey( const mbedtls_ecp_group *grp, { /* SEC1 3.2.1: Generate d such that 1 <= n < N */ int count = 0; - int cmp = 0; + unsigned cmp = 0; /* * Match the procedure given in RFC 6979 (deterministic ECDSA): @@ -2829,13 +2829,13 @@ int mbedtls_ecp_gen_privkey( const mbedtls_ecp_group *grp, if( ++count > 30 ) return( MBEDTLS_ERR_ECP_RANDOM_FAILED ); - ret = mbedtls_mpi_cmp_mpi_ct( d, &grp->N, &cmp ); + ret = mbedtls_mpi_lt_mpi_ct( d, &grp->N, &cmp ); if( ret != 0 ) { goto cleanup; } } - while( mbedtls_mpi_cmp_int( d, 1 ) < 0 || cmp >= 0 ); + while( mbedtls_mpi_cmp_int( d, 1 ) < 0 || cmp != 1 ); } #endif /* ECP_SHORTWEIERSTRASS */ diff --git a/tests/suites/test_suite_mpi.data b/tests/suites/test_suite_mpi.data index efcb06041..89aa4d51f 100644 --- a/tests/suites/test_suite_mpi.data +++ b/tests/suites/test_suite_mpi.data @@ -175,38 +175,38 @@ mbedtls_mpi_cmp_mpi:10:"2":10:"-3":1 Base test mbedtls_mpi_cmp_mpi (Mixed values) #6 mbedtls_mpi_cmp_mpi:10:"-2":10:"31231231289798":-1 -Base test mbedtls_mpi_cmp_mpi_ct #1 -mbedtls_mpi_cmp_mpi_ct:1:10:"693":1:10:"693":0:0 +Base test mbedtls_mpi_lt_mpi_ct #1 +mbedtls_mpi_lt_mpi_ct:1:10:"693":1:10:"693":0:0 -Base test mbedtls_mpi_cmp_mpi_ct #2 -mbedtls_mpi_cmp_mpi_ct:1:10:"693":1:10:"692":1:0 +Base test mbedtls_mpi_lt_mpi_ct #2 +mbedtls_mpi_lt_mpi_ct:1:10:"693":1:10:"692":0:0 -Base test mbedtls_mpi_cmp_mpi_ct #3 -mbedtls_mpi_cmp_mpi_ct:1:10:"693":1:10:"694":-1:0 +Base test mbedtls_mpi_lt_mpi_ct #3 +mbedtls_mpi_lt_mpi_ct:1:10:"693":1:10:"694":1:0 -Base test mbedtls_mpi_cmp_mpi_ct (Negative values) #1 -mbedtls_mpi_cmp_mpi_ct:1:10:"-2":1:10:"-2":0:0 +Base test mbedtls_mpi_lt_mpi_ct (Negative values) #1 +mbedtls_mpi_lt_mpi_ct:1:10:"-2":1:10:"-2":0:0 -Base test mbedtls_mpi_cmp_mpi_ct (Negative values) #2 -mbedtls_mpi_cmp_mpi_ct:1:10:"-2":1:10:"-3":1:0 +Base test mbedtls_mpi_lt_mpi_ct (Negative values) #2 +mbedtls_mpi_lt_mpi_ct:1:10:"-2":1:10:"-3":0:0 -Base test mbedtls_mpi_cmp_mpi_ct (Negative values) #3 -mbedtls_mpi_cmp_mpi_ct:1:10:"-2":1:10:"-1":-1:0 +Base test mbedtls_mpi_lt_mpi_ct (Negative values) #3 +mbedtls_mpi_lt_mpi_ct:1:10:"-2":1:10:"-1":1:0 -Base test mbedtls_mpi_cmp_mpi_ct (Mixed values) #4 -mbedtls_mpi_cmp_mpi_ct:1:10:"-3":1:10:"2":-1:0 +Base test mbedtls_mpi_lt_mpi_ct (Mixed values) #4 +mbedtls_mpi_lt_mpi_ct:1:10:"-3":1:10:"2":1:0 -Base test mbedtls_mpi_cmp_mpi_ct (Mixed values) #5 -mbedtls_mpi_cmp_mpi_ct:1:10:"2":1:10:"-3":1:0 +Base test mbedtls_mpi_lt_mpi_ct (Mixed values) #5 +mbedtls_mpi_lt_mpi_ct:1:10:"2":1:10:"-3":0:0 -Base test mbedtls_mpi_cmp_mpi_ct (Mixed values) #6 -mbedtls_mpi_cmp_mpi_ct:2:10:"-2":2:10:"31231231289798":-1:0 +Base test mbedtls_mpi_lt_mpi_ct (Mixed values) #6 +mbedtls_mpi_lt_mpi_ct:2:10:"-2":2:10:"31231231289798":1:0 -Base test mbedtls_mpi_cmp_mpi_ct (X is longer in storage) #7 -mbedtls_mpi_cmp_mpi_ct:3:10:"693":2:10:"693":0:MBEDTLS_ERR_MPI_BAD_INPUT_DATA +Base test mbedtls_mpi_lt_mpi_ct (X is longer in storage) #7 +mbedtls_mpi_lt_mpi_ct:3:10:"693":2:10:"693":0:MBEDTLS_ERR_MPI_BAD_INPUT_DATA -Base test mbedtls_mpi_cmp_mpi_ct (Y is longer in storage) #8 -mbedtls_mpi_cmp_mpi_ct:3:10:"693":4:10:"693":0:MBEDTLS_ERR_MPI_BAD_INPUT_DATA +Base test mbedtls_mpi_lt_mpi_ct (Y is longer in storage) #8 +mbedtls_mpi_lt_mpi_ct:3:10:"693":4:10:"693":0:MBEDTLS_ERR_MPI_BAD_INPUT_DATA Base test mbedtls_mpi_cmp_abs #1 mbedtls_mpi_cmp_abs:10:"693":10:"693":0 diff --git a/tests/suites/test_suite_mpi.function b/tests/suites/test_suite_mpi.function index 97fd7b983..617f4615c 100644 --- a/tests/suites/test_suite_mpi.function +++ b/tests/suites/test_suite_mpi.function @@ -588,10 +588,12 @@ exit: /* END_CASE */ /* BEGIN_CASE */ -void mbedtls_mpi_cmp_mpi_ct( int size_X, int radix_X, char * input_X, int size_Y, - int radix_Y, char * input_Y, int input_ret, int input_err ) +void mbedtls_mpi_lt_mpi_ct( int size_X, int radix_X, char * input_X, + int size_Y, int radix_Y, char * input_Y, + int input_ret, int input_err ) { - int ret; + unsigned ret; + unsigned input_uret = input_ret; mbedtls_mpi X, Y; mbedtls_mpi_init( &X ); mbedtls_mpi_init( &Y ); @@ -601,9 +603,9 @@ void mbedtls_mpi_cmp_mpi_ct( int size_X, int radix_X, char * input_X, int size_Y mbedtls_mpi_grow( &X, size_X ); mbedtls_mpi_grow( &Y, size_Y ); - TEST_ASSERT( mbedtls_mpi_cmp_mpi_ct( &X, &Y, &ret ) == input_err ); + TEST_ASSERT( mbedtls_mpi_lt_mpi_ct( &X, &Y, &ret ) == input_err ); if( input_err == 0 ) - TEST_ASSERT( ret == input_ret ); + TEST_ASSERT( ret == input_uret ); exit: mbedtls_mpi_free( &X ); mbedtls_mpi_free( &Y );