Implement parameter validation for MPI module

This commit is contained in:
Hanno Becker 2018-12-11 10:35:51 +00:00
parent c23483ed8c
commit 73d7d79bc1

View file

@ -59,6 +59,11 @@
#define mbedtls_free free #define mbedtls_free free
#endif #endif
#define MPI_VALIDATE_RET( cond ) \
MBEDTLS_INTERNAL_VALIDATE_RET( cond, MBEDTLS_ERR_MPI_BAD_INPUT_DATA )
#define MPI_VALIDATE( cond ) \
MBEDTLS_INTERNAL_VALIDATE( cond )
#define ciL (sizeof(mbedtls_mpi_uint)) /* chars in limb */ #define ciL (sizeof(mbedtls_mpi_uint)) /* chars in limb */
#define biL (ciL << 3) /* bits in limb */ #define biL (ciL << 3) /* bits in limb */
#define biH (ciL << 2) /* half limb size */ #define biH (ciL << 2) /* half limb size */
@ -83,8 +88,7 @@ static void mbedtls_mpi_zeroize( mbedtls_mpi_uint *v, size_t n )
*/ */
void mbedtls_mpi_init( mbedtls_mpi *X ) void mbedtls_mpi_init( mbedtls_mpi *X )
{ {
if( X == NULL ) MPI_VALIDATE( X != NULL );
return;
X->s = 1; X->s = 1;
X->n = 0; X->n = 0;
@ -116,6 +120,7 @@ void mbedtls_mpi_free( mbedtls_mpi *X )
int mbedtls_mpi_grow( mbedtls_mpi *X, size_t nblimbs ) int mbedtls_mpi_grow( mbedtls_mpi *X, size_t nblimbs )
{ {
mbedtls_mpi_uint *p; mbedtls_mpi_uint *p;
MPI_VALIDATE_RET( X != NULL );
if( nblimbs > MBEDTLS_MPI_MAX_LIMBS ) if( nblimbs > MBEDTLS_MPI_MAX_LIMBS )
return( MBEDTLS_ERR_MPI_ALLOC_FAILED ); return( MBEDTLS_ERR_MPI_ALLOC_FAILED );
@ -147,6 +152,10 @@ int mbedtls_mpi_shrink( mbedtls_mpi *X, size_t nblimbs )
{ {
mbedtls_mpi_uint *p; mbedtls_mpi_uint *p;
size_t i; size_t i;
MPI_VALIDATE_RET( X != NULL );
if( nblimbs > MBEDTLS_MPI_MAX_LIMBS )
return( MBEDTLS_ERR_MPI_ALLOC_FAILED );
/* Actually resize up in this case */ /* Actually resize up in this case */
if( X->n <= nblimbs ) if( X->n <= nblimbs )
@ -183,6 +192,8 @@ int mbedtls_mpi_copy( mbedtls_mpi *X, const mbedtls_mpi *Y )
{ {
int ret = 0; int ret = 0;
size_t i; size_t i;
MPI_VALIDATE_RET( X != NULL );
MPI_VALIDATE_RET( Y != NULL );
if( X == Y ) if( X == Y )
return( 0 ); return( 0 );
@ -222,6 +233,8 @@ cleanup:
void mbedtls_mpi_swap( mbedtls_mpi *X, mbedtls_mpi *Y ) void mbedtls_mpi_swap( mbedtls_mpi *X, mbedtls_mpi *Y )
{ {
mbedtls_mpi T; mbedtls_mpi T;
MPI_VALIDATE( X != NULL );
MPI_VALIDATE( Y != NULL );
memcpy( &T, X, sizeof( mbedtls_mpi ) ); memcpy( &T, X, sizeof( mbedtls_mpi ) );
memcpy( X, Y, sizeof( mbedtls_mpi ) ); memcpy( X, Y, sizeof( mbedtls_mpi ) );
@ -237,6 +250,8 @@ int mbedtls_mpi_safe_cond_assign( mbedtls_mpi *X, const mbedtls_mpi *Y, unsigned
{ {
int ret = 0; int ret = 0;
size_t i; size_t i;
MPI_VALIDATE_RET( X != NULL );
MPI_VALIDATE_RET( Y != NULL );
/* make sure assign is 0 or 1 in a time-constant manner */ /* make sure assign is 0 or 1 in a time-constant manner */
assign = (assign | (unsigned char)-assign) >> 7; assign = (assign | (unsigned char)-assign) >> 7;
@ -266,6 +281,8 @@ int mbedtls_mpi_safe_cond_swap( mbedtls_mpi *X, mbedtls_mpi *Y, unsigned char sw
int ret, s; int ret, s;
size_t i; size_t i;
mbedtls_mpi_uint tmp; mbedtls_mpi_uint tmp;
MPI_VALIDATE_RET( X != NULL );
MPI_VALIDATE_RET( Y != NULL );
if( X == Y ) if( X == Y )
return( 0 ); return( 0 );
@ -298,6 +315,7 @@ cleanup:
int mbedtls_mpi_lset( mbedtls_mpi *X, mbedtls_mpi_sint z ) int mbedtls_mpi_lset( mbedtls_mpi *X, mbedtls_mpi_sint z )
{ {
int ret; int ret;
MPI_VALIDATE_RET( X != NULL );
MBEDTLS_MPI_CHK( mbedtls_mpi_grow( X, 1 ) ); MBEDTLS_MPI_CHK( mbedtls_mpi_grow( X, 1 ) );
memset( X->p, 0, X->n * ciL ); memset( X->p, 0, X->n * ciL );
@ -315,6 +333,8 @@ cleanup:
*/ */
int mbedtls_mpi_get_bit( const mbedtls_mpi *X, size_t pos ) int mbedtls_mpi_get_bit( const mbedtls_mpi *X, size_t pos )
{ {
MPI_VALIDATE_RET( X != NULL );
if( X->n * biL <= pos ) if( X->n * biL <= pos )
return( 0 ); return( 0 );
@ -333,6 +353,7 @@ int mbedtls_mpi_set_bit( mbedtls_mpi *X, size_t pos, unsigned char val )
int ret = 0; int ret = 0;
size_t off = pos / biL; size_t off = pos / biL;
size_t idx = pos % biL; size_t idx = pos % biL;
MPI_VALIDATE_RET( X != NULL );
if( val != 0 && val != 1 ) if( val != 0 && val != 1 )
return( MBEDTLS_ERR_MPI_BAD_INPUT_DATA ); return( MBEDTLS_ERR_MPI_BAD_INPUT_DATA );
@ -359,6 +380,7 @@ cleanup:
size_t mbedtls_mpi_lsb( const mbedtls_mpi *X ) size_t mbedtls_mpi_lsb( const mbedtls_mpi *X )
{ {
size_t i, j, count = 0; size_t i, j, count = 0;
MPI_VALIDATE_RET( X != NULL );
for( i = 0; i < X->n; i++ ) for( i = 0; i < X->n; i++ )
for( j = 0; j < biL; j++, count++ ) for( j = 0; j < biL; j++, count++ )
@ -439,9 +461,11 @@ int mbedtls_mpi_read_string( mbedtls_mpi *X, int radix, const char *s )
size_t i, j, slen, n; size_t i, j, slen, n;
mbedtls_mpi_uint d; mbedtls_mpi_uint d;
mbedtls_mpi T; mbedtls_mpi T;
MPI_VALIDATE_RET( X != NULL );
MPI_VALIDATE_RET( s != NULL );
if( radix < 2 || radix > 16 ) if( radix < 2 || radix > 16 )
return( MBEDTLS_ERR_MPI_BAD_INPUT_DATA ); return( MBEDTLS_ERR_MPI_BAD_INPUT_DATA );;
mbedtls_mpi_init( &T ); mbedtls_mpi_init( &T );
@ -539,9 +563,12 @@ int mbedtls_mpi_write_string( const mbedtls_mpi *X, int radix,
size_t n; size_t n;
char *p; char *p;
mbedtls_mpi T; mbedtls_mpi T;
MPI_VALIDATE_RET( X != NULL );
MPI_VALIDATE_RET( olen != NULL );
MPI_VALIDATE_RET( buflen == 0 || buf != NULL );
if( radix < 2 || radix > 16 ) if( radix < 2 || radix > 16 )
return( MBEDTLS_ERR_MPI_BAD_INPUT_DATA ); return( MBEDTLS_ERR_MPI_BAD_INPUT_DATA );;
n = mbedtls_mpi_bitlen( X ); n = mbedtls_mpi_bitlen( X );
if( radix >= 4 ) n >>= 1; if( radix >= 4 ) n >>= 1;
@ -620,6 +647,12 @@ int mbedtls_mpi_read_file( mbedtls_mpi *X, int radix, FILE *fin )
*/ */
char s[ MBEDTLS_MPI_RW_BUFFER_SIZE ]; char s[ MBEDTLS_MPI_RW_BUFFER_SIZE ];
MPI_VALIDATE_RET( X != NULL );
MPI_VALIDATE_RET( fin != NULL );
if( radix < 2 || radix > 16 )
return( MBEDTLS_ERR_MPI_BAD_INPUT_DATA );
memset( s, 0, sizeof( s ) ); memset( s, 0, sizeof( s ) );
if( fgets( s, sizeof( s ) - 1, fin ) == NULL ) if( fgets( s, sizeof( s ) - 1, fin ) == NULL )
return( MBEDTLS_ERR_MPI_FILE_IO_ERROR ); return( MBEDTLS_ERR_MPI_FILE_IO_ERROR );
@ -651,6 +684,10 @@ int mbedtls_mpi_write_file( const char *p, const mbedtls_mpi *X, int radix, FILE
* newline characters and '\0' * newline characters and '\0'
*/ */
char s[ MBEDTLS_MPI_RW_BUFFER_SIZE ]; char s[ MBEDTLS_MPI_RW_BUFFER_SIZE ];
MPI_VALIDATE_RET( X != NULL );
if( radix < 2 || radix > 16 )
return( MBEDTLS_ERR_MPI_BAD_INPUT_DATA );
memset( s, 0, sizeof( s ) ); memset( s, 0, sizeof( s ) );
@ -687,6 +724,9 @@ int mbedtls_mpi_read_binary( mbedtls_mpi *X, const unsigned char *buf, size_t bu
size_t i, j; size_t i, j;
size_t const limbs = CHARS_TO_LIMBS( buflen ); size_t const limbs = CHARS_TO_LIMBS( buflen );
MPI_VALIDATE_RET( X != NULL );
MPI_VALIDATE_RET( buflen == 0 || buf != NULL );
/* Ensure that target MPI has exactly the necessary number of limbs */ /* Ensure that target MPI has exactly the necessary number of limbs */
if( X->n != limbs ) if( X->n != limbs )
{ {
@ -711,11 +751,16 @@ cleanup:
int mbedtls_mpi_write_binary( const mbedtls_mpi *X, int mbedtls_mpi_write_binary( const mbedtls_mpi *X,
unsigned char *buf, size_t buflen ) unsigned char *buf, size_t buflen )
{ {
size_t stored_bytes = X->n * ciL; size_t stored_bytes;
size_t bytes_to_copy; size_t bytes_to_copy;
unsigned char *p; unsigned char *p;
size_t i; size_t i;
MPI_VALIDATE_RET( X != NULL );
MPI_VALIDATE_RET( buflen == 0 || buf != NULL );
stored_bytes = X->n * ciL;
if( stored_bytes < buflen ) if( stored_bytes < buflen )
{ {
/* There is enough space in the output buffer. Write initial /* There is enough space in the output buffer. Write initial
@ -754,6 +799,7 @@ int mbedtls_mpi_shift_l( mbedtls_mpi *X, size_t count )
int ret; int ret;
size_t i, v0, t1; size_t i, v0, t1;
mbedtls_mpi_uint r0 = 0, r1; mbedtls_mpi_uint r0 = 0, r1;
MPI_VALIDATE_RET( X != NULL );
v0 = count / (biL ); v0 = count / (biL );
t1 = count & (biL - 1); t1 = count & (biL - 1);
@ -803,6 +849,7 @@ int mbedtls_mpi_shift_r( mbedtls_mpi *X, size_t count )
{ {
size_t i, v0, v1; size_t i, v0, v1;
mbedtls_mpi_uint r0 = 0, r1; mbedtls_mpi_uint r0 = 0, r1;
MPI_VALIDATE_RET( X != NULL );
v0 = count / biL; v0 = count / biL;
v1 = count & (biL - 1); v1 = count & (biL - 1);
@ -845,6 +892,8 @@ int mbedtls_mpi_shift_r( mbedtls_mpi *X, size_t count )
int mbedtls_mpi_cmp_abs( const mbedtls_mpi *X, const mbedtls_mpi *Y ) int mbedtls_mpi_cmp_abs( const mbedtls_mpi *X, const mbedtls_mpi *Y )
{ {
size_t i, j; size_t i, j;
MPI_VALIDATE_RET( X != NULL );
MPI_VALIDATE_RET( Y != NULL );
for( i = X->n; i > 0; i-- ) for( i = X->n; i > 0; i-- )
if( X->p[i - 1] != 0 ) if( X->p[i - 1] != 0 )
@ -875,6 +924,8 @@ int mbedtls_mpi_cmp_abs( const mbedtls_mpi *X, const mbedtls_mpi *Y )
int mbedtls_mpi_cmp_mpi( const mbedtls_mpi *X, const mbedtls_mpi *Y ) int mbedtls_mpi_cmp_mpi( const mbedtls_mpi *X, const mbedtls_mpi *Y )
{ {
size_t i, j; size_t i, j;
MPI_VALIDATE_RET( X != NULL );
MPI_VALIDATE_RET( Y != NULL );
for( i = X->n; i > 0; i-- ) for( i = X->n; i > 0; i-- )
if( X->p[i - 1] != 0 ) if( X->p[i - 1] != 0 )
@ -909,6 +960,7 @@ int mbedtls_mpi_cmp_int( const mbedtls_mpi *X, mbedtls_mpi_sint z )
{ {
mbedtls_mpi Y; mbedtls_mpi Y;
mbedtls_mpi_uint p[1]; mbedtls_mpi_uint p[1];
MPI_VALIDATE_RET( X != NULL );
*p = ( z < 0 ) ? -z : z; *p = ( z < 0 ) ? -z : z;
Y.s = ( z < 0 ) ? -1 : 1; Y.s = ( z < 0 ) ? -1 : 1;
@ -926,6 +978,9 @@ int mbedtls_mpi_add_abs( mbedtls_mpi *X, const mbedtls_mpi *A, const mbedtls_mpi
int ret; int ret;
size_t i, j; size_t i, j;
mbedtls_mpi_uint *o, *p, c, tmp; mbedtls_mpi_uint *o, *p, c, tmp;
MPI_VALIDATE_RET( X != NULL );
MPI_VALIDATE_RET( A != NULL );
MPI_VALIDATE_RET( B != NULL );
if( X == B ) if( X == B )
{ {
@ -1003,6 +1058,9 @@ int mbedtls_mpi_sub_abs( mbedtls_mpi *X, const mbedtls_mpi *A, const mbedtls_mpi
mbedtls_mpi TB; mbedtls_mpi TB;
int ret; int ret;
size_t n; size_t n;
MPI_VALIDATE_RET( X != NULL );
MPI_VALIDATE_RET( A != NULL );
MPI_VALIDATE_RET( B != NULL );
if( mbedtls_mpi_cmp_abs( A, B ) < 0 ) if( mbedtls_mpi_cmp_abs( A, B ) < 0 )
return( MBEDTLS_ERR_MPI_NEGATIVE_VALUE ); return( MBEDTLS_ERR_MPI_NEGATIVE_VALUE );
@ -1043,8 +1101,12 @@ cleanup:
*/ */
int mbedtls_mpi_add_mpi( mbedtls_mpi *X, const mbedtls_mpi *A, const mbedtls_mpi *B ) int mbedtls_mpi_add_mpi( mbedtls_mpi *X, const mbedtls_mpi *A, const mbedtls_mpi *B )
{ {
int ret, s = A->s; int ret, s;
MPI_VALIDATE_RET( X != NULL );
MPI_VALIDATE_RET( A != NULL );
MPI_VALIDATE_RET( B != NULL );
s = A->s;
if( A->s * B->s < 0 ) if( A->s * B->s < 0 )
{ {
if( mbedtls_mpi_cmp_abs( A, B ) >= 0 ) if( mbedtls_mpi_cmp_abs( A, B ) >= 0 )
@ -1074,8 +1136,12 @@ cleanup:
*/ */
int mbedtls_mpi_sub_mpi( mbedtls_mpi *X, const mbedtls_mpi *A, const mbedtls_mpi *B ) int mbedtls_mpi_sub_mpi( mbedtls_mpi *X, const mbedtls_mpi *A, const mbedtls_mpi *B )
{ {
int ret, s = A->s; int ret, s;
MPI_VALIDATE_RET( X != NULL );
MPI_VALIDATE_RET( A != NULL );
MPI_VALIDATE_RET( B != NULL );
s = A->s;
if( A->s * B->s > 0 ) if( A->s * B->s > 0 )
{ {
if( mbedtls_mpi_cmp_abs( A, B ) >= 0 ) if( mbedtls_mpi_cmp_abs( A, B ) >= 0 )
@ -1107,6 +1173,8 @@ int mbedtls_mpi_add_int( mbedtls_mpi *X, const mbedtls_mpi *A, mbedtls_mpi_sint
{ {
mbedtls_mpi _B; mbedtls_mpi _B;
mbedtls_mpi_uint p[1]; mbedtls_mpi_uint p[1];
MPI_VALIDATE_RET( X != NULL );
MPI_VALIDATE_RET( A != NULL );
p[0] = ( b < 0 ) ? -b : b; p[0] = ( b < 0 ) ? -b : b;
_B.s = ( b < 0 ) ? -1 : 1; _B.s = ( b < 0 ) ? -1 : 1;
@ -1123,6 +1191,8 @@ int mbedtls_mpi_sub_int( mbedtls_mpi *X, const mbedtls_mpi *A, mbedtls_mpi_sint
{ {
mbedtls_mpi _B; mbedtls_mpi _B;
mbedtls_mpi_uint p[1]; mbedtls_mpi_uint p[1];
MPI_VALIDATE_RET( X != NULL );
MPI_VALIDATE_RET( A != NULL );
p[0] = ( b < 0 ) ? -b : b; p[0] = ( b < 0 ) ? -b : b;
_B.s = ( b < 0 ) ? -1 : 1; _B.s = ( b < 0 ) ? -1 : 1;
@ -1212,6 +1282,9 @@ int mbedtls_mpi_mul_mpi( mbedtls_mpi *X, const mbedtls_mpi *A, const mbedtls_mpi
int ret; int ret;
size_t i, j; size_t i, j;
mbedtls_mpi TA, TB; mbedtls_mpi TA, TB;
MPI_VALIDATE_RET( X != NULL );
MPI_VALIDATE_RET( A != NULL );
MPI_VALIDATE_RET( B != NULL );
mbedtls_mpi_init( &TA ); mbedtls_mpi_init( &TB ); mbedtls_mpi_init( &TA ); mbedtls_mpi_init( &TB );
@ -1248,6 +1321,8 @@ int mbedtls_mpi_mul_int( mbedtls_mpi *X, const mbedtls_mpi *A, mbedtls_mpi_uint
{ {
mbedtls_mpi _B; mbedtls_mpi _B;
mbedtls_mpi_uint p[1]; mbedtls_mpi_uint p[1];
MPI_VALIDATE_RET( X != NULL );
MPI_VALIDATE_RET( A != NULL );
_B.s = 1; _B.s = 1;
_B.n = 1; _B.n = 1;
@ -1356,11 +1431,14 @@ static mbedtls_mpi_uint mbedtls_int_div_int( mbedtls_mpi_uint u1,
/* /*
* Division by mbedtls_mpi: A = Q * B + R (HAC 14.20) * Division by mbedtls_mpi: A = Q * B + R (HAC 14.20)
*/ */
int mbedtls_mpi_div_mpi( mbedtls_mpi *Q, mbedtls_mpi *R, const mbedtls_mpi *A, const mbedtls_mpi *B ) int mbedtls_mpi_div_mpi( mbedtls_mpi *Q, mbedtls_mpi *R, const mbedtls_mpi *A,
const mbedtls_mpi *B )
{ {
int ret; int ret;
size_t i, n, t, k; size_t i, n, t, k;
mbedtls_mpi X, Y, Z, T1, T2; mbedtls_mpi X, Y, Z, T1, T2;
MPI_VALIDATE_RET( A != NULL );
MPI_VALIDATE_RET( B != NULL );
if( mbedtls_mpi_cmp_int( B, 0 ) == 0 ) if( mbedtls_mpi_cmp_int( B, 0 ) == 0 )
return( MBEDTLS_ERR_MPI_DIVISION_BY_ZERO ); return( MBEDTLS_ERR_MPI_DIVISION_BY_ZERO );
@ -1471,10 +1549,13 @@ cleanup:
/* /*
* Division by int: A = Q * b + R * Division by int: A = Q * b + R
*/ */
int mbedtls_mpi_div_int( mbedtls_mpi *Q, mbedtls_mpi *R, const mbedtls_mpi *A, mbedtls_mpi_sint b ) int mbedtls_mpi_div_int( mbedtls_mpi *Q, mbedtls_mpi *R,
const mbedtls_mpi *A,
mbedtls_mpi_sint b )
{ {
mbedtls_mpi _B; mbedtls_mpi _B;
mbedtls_mpi_uint p[1]; mbedtls_mpi_uint p[1];
MPI_VALIDATE_RET( A != NULL );
p[0] = ( b < 0 ) ? -b : b; p[0] = ( b < 0 ) ? -b : b;
_B.s = ( b < 0 ) ? -1 : 1; _B.s = ( b < 0 ) ? -1 : 1;
@ -1490,6 +1571,9 @@ int mbedtls_mpi_div_int( mbedtls_mpi *Q, mbedtls_mpi *R, const mbedtls_mpi *A, m
int mbedtls_mpi_mod_mpi( mbedtls_mpi *R, const mbedtls_mpi *A, const mbedtls_mpi *B ) int mbedtls_mpi_mod_mpi( mbedtls_mpi *R, const mbedtls_mpi *A, const mbedtls_mpi *B )
{ {
int ret; int ret;
MPI_VALIDATE_RET( R != NULL );
MPI_VALIDATE_RET( A != NULL );
MPI_VALIDATE_RET( B != NULL );
if( mbedtls_mpi_cmp_int( B, 0 ) < 0 ) if( mbedtls_mpi_cmp_int( B, 0 ) < 0 )
return( MBEDTLS_ERR_MPI_NEGATIVE_VALUE ); return( MBEDTLS_ERR_MPI_NEGATIVE_VALUE );
@ -1514,6 +1598,8 @@ int mbedtls_mpi_mod_int( mbedtls_mpi_uint *r, const mbedtls_mpi *A, mbedtls_mpi_
{ {
size_t i; size_t i;
mbedtls_mpi_uint x, y, z; mbedtls_mpi_uint x, y, z;
MPI_VALIDATE_RET( r != NULL );
MPI_VALIDATE_RET( A != NULL );
if( b == 0 ) if( b == 0 )
return( MBEDTLS_ERR_MPI_DIVISION_BY_ZERO ); return( MBEDTLS_ERR_MPI_DIVISION_BY_ZERO );
@ -1627,7 +1713,8 @@ static int mpi_montmul( mbedtls_mpi *A, const mbedtls_mpi *B, const mbedtls_mpi
/* /*
* Montgomery reduction: A = A * R^-1 mod N * Montgomery reduction: A = A * R^-1 mod N
*/ */
static int mpi_montred( mbedtls_mpi *A, const mbedtls_mpi *N, mbedtls_mpi_uint mm, const mbedtls_mpi *T ) static int mpi_montred( mbedtls_mpi *A, const mbedtls_mpi *N,
mbedtls_mpi_uint mm, const mbedtls_mpi *T )
{ {
mbedtls_mpi_uint z = 1; mbedtls_mpi_uint z = 1;
mbedtls_mpi U; mbedtls_mpi U;
@ -1641,7 +1728,9 @@ static int mpi_montred( mbedtls_mpi *A, const mbedtls_mpi *N, mbedtls_mpi_uint m
/* /*
* Sliding-window exponentiation: X = A^E mod N (HAC 14.85) * Sliding-window exponentiation: X = A^E mod N (HAC 14.85)
*/ */
int mbedtls_mpi_exp_mod( mbedtls_mpi *X, const mbedtls_mpi *A, const mbedtls_mpi *E, const mbedtls_mpi *N, mbedtls_mpi *_RR ) int mbedtls_mpi_exp_mod( mbedtls_mpi *X, const mbedtls_mpi *A,
const mbedtls_mpi *E, const mbedtls_mpi *N,
mbedtls_mpi *_RR )
{ {
int ret; int ret;
size_t wbits, wsize, one = 1; size_t wbits, wsize, one = 1;
@ -1651,6 +1740,11 @@ int mbedtls_mpi_exp_mod( mbedtls_mpi *X, const mbedtls_mpi *A, const mbedtls_mpi
mbedtls_mpi RR, T, W[ 2 << MBEDTLS_MPI_WINDOW_SIZE ], Apos; mbedtls_mpi RR, T, W[ 2 << MBEDTLS_MPI_WINDOW_SIZE ], Apos;
int neg; int neg;
MPI_VALIDATE_RET( X != NULL );
MPI_VALIDATE_RET( A != NULL );
MPI_VALIDATE_RET( E != NULL );
MPI_VALIDATE_RET( N != NULL );
if( mbedtls_mpi_cmp_int( N, 0 ) <= 0 || ( N->p[0] & 1 ) == 0 ) if( mbedtls_mpi_cmp_int( N, 0 ) <= 0 || ( N->p[0] & 1 ) == 0 )
return( MBEDTLS_ERR_MPI_BAD_INPUT_DATA ); return( MBEDTLS_ERR_MPI_BAD_INPUT_DATA );
@ -1855,6 +1949,10 @@ int mbedtls_mpi_gcd( mbedtls_mpi *G, const mbedtls_mpi *A, const mbedtls_mpi *B
size_t lz, lzt; size_t lz, lzt;
mbedtls_mpi TG, TA, TB; mbedtls_mpi TG, TA, TB;
MPI_VALIDATE_RET( G != NULL );
MPI_VALIDATE_RET( A != NULL );
MPI_VALIDATE_RET( B != NULL );
mbedtls_mpi_init( &TG ); mbedtls_mpi_init( &TA ); mbedtls_mpi_init( &TB ); mbedtls_mpi_init( &TG ); mbedtls_mpi_init( &TA ); mbedtls_mpi_init( &TB );
MBEDTLS_MPI_CHK( mbedtls_mpi_copy( &TA, A ) ); MBEDTLS_MPI_CHK( mbedtls_mpi_copy( &TA, A ) );
@ -1911,6 +2009,8 @@ int mbedtls_mpi_fill_random( mbedtls_mpi *X, size_t size,
{ {
int ret; int ret;
unsigned char buf[MBEDTLS_MPI_MAX_SIZE]; unsigned char buf[MBEDTLS_MPI_MAX_SIZE];
MPI_VALIDATE_RET( X != NULL );
MPI_VALIDATE_RET( f_rng != NULL );
if( size > MBEDTLS_MPI_MAX_SIZE ) if( size > MBEDTLS_MPI_MAX_SIZE )
return( MBEDTLS_ERR_MPI_BAD_INPUT_DATA ); return( MBEDTLS_ERR_MPI_BAD_INPUT_DATA );
@ -1930,6 +2030,9 @@ int mbedtls_mpi_inv_mod( mbedtls_mpi *X, const mbedtls_mpi *A, const mbedtls_mpi
{ {
int ret; int ret;
mbedtls_mpi G, TA, TU, U1, U2, TB, TV, V1, V2; mbedtls_mpi G, TA, TU, U1, U2, TB, TV, V1, V2;
MPI_VALIDATE_RET( X != NULL );
MPI_VALIDATE_RET( A != NULL );
MPI_VALIDATE_RET( N != NULL );
if( mbedtls_mpi_cmp_int( N, 1 ) <= 0 ) if( mbedtls_mpi_cmp_int( N, 1 ) <= 0 )
return( MBEDTLS_ERR_MPI_BAD_INPUT_DATA ); return( MBEDTLS_ERR_MPI_BAD_INPUT_DATA );
@ -2089,7 +2192,11 @@ static int mpi_miller_rabin( const mbedtls_mpi *X, size_t rounds,
size_t i, j, k, s; size_t i, j, k, s;
mbedtls_mpi W, R, T, A, RR; mbedtls_mpi W, R, T, A, RR;
mbedtls_mpi_init( &W ); mbedtls_mpi_init( &R ); mbedtls_mpi_init( &T ); mbedtls_mpi_init( &A ); MPI_VALIDATE_RET( X != NULL );
MPI_VALIDATE_RET( f_rng != NULL );
mbedtls_mpi_init( &W ); mbedtls_mpi_init( &R );
mbedtls_mpi_init( &T ); mbedtls_mpi_init( &A );
mbedtls_mpi_init( &RR ); mbedtls_mpi_init( &RR );
/* /*
@ -2161,7 +2268,8 @@ static int mpi_miller_rabin( const mbedtls_mpi *X, size_t rounds,
} }
cleanup: cleanup:
mbedtls_mpi_free( &W ); mbedtls_mpi_free( &R ); mbedtls_mpi_free( &T ); mbedtls_mpi_free( &A ); mbedtls_mpi_free( &W ); mbedtls_mpi_free( &R );
mbedtls_mpi_free( &T ); mbedtls_mpi_free( &A );
mbedtls_mpi_free( &RR ); mbedtls_mpi_free( &RR );
return( ret ); return( ret );
@ -2176,6 +2284,8 @@ int mbedtls_mpi_is_prime_ext( const mbedtls_mpi *X, int rounds,
{ {
int ret; int ret;
mbedtls_mpi XX; mbedtls_mpi XX;
MPI_VALIDATE_RET( X != NULL );
MPI_VALIDATE_RET( f_rng != NULL );
XX.s = 1; XX.s = 1;
XX.n = X->n; XX.n = X->n;
@ -2207,12 +2317,15 @@ int mbedtls_mpi_is_prime( const mbedtls_mpi *X,
int (*f_rng)(void *, unsigned char *, size_t), int (*f_rng)(void *, unsigned char *, size_t),
void *p_rng ) void *p_rng )
{ {
MPI_VALIDATE_RET( X != NULL );
MPI_VALIDATE_RET( f_rng != NULL );
/* /*
* In the past our key generation aimed for an error rate of at most * In the past our key generation aimed for an error rate of at most
* 2^-80. Since this function is deprecated, aim for the same certainty * 2^-80. Since this function is deprecated, aim for the same certainty
* here as well. * here as well.
*/ */
return mbedtls_mpi_is_prime_ext( X, 40, f_rng, p_rng ); return( mbedtls_mpi_is_prime_ext( X, 40, f_rng, p_rng ) );
} }
#endif #endif
@ -2240,6 +2353,9 @@ int mbedtls_mpi_gen_prime( mbedtls_mpi *X, size_t nbits, int flags,
mbedtls_mpi_uint r; mbedtls_mpi_uint r;
mbedtls_mpi Y; mbedtls_mpi Y;
MPI_VALIDATE_RET( X != NULL );
MPI_VALIDATE_RET( f_rng != NULL );
if( nbits < 3 || nbits > MBEDTLS_MPI_MAX_BITS ) if( nbits < 3 || nbits > MBEDTLS_MPI_MAX_BITS )
return( MBEDTLS_ERR_MPI_BAD_INPUT_DATA ); return( MBEDTLS_ERR_MPI_BAD_INPUT_DATA );