diff --git a/library/rsa.c b/library/rsa.c index 09f3112d6..b1ba19dbe 100644 --- a/library/rsa.c +++ b/library/rsa.c @@ -693,10 +693,9 @@ int rsa_rsaes_pkcs1_v15_decrypt( rsa_context *ctx, unsigned char *output, size_t output_max_len) { - int ret, correct = 1; - size_t ilen, pad_count = 0; - unsigned char *p, *q; - unsigned char bt; + int ret; + size_t ilen, pad_count = 0, i; + unsigned char *p, bad, pad_done = 0; unsigned char buf[POLARSSL_MPI_MAX_SIZE]; if( ctx->padding != RSA_PKCS_V15 ) @@ -715,57 +714,46 @@ int rsa_rsaes_pkcs1_v15_decrypt( rsa_context *ctx, return( ret ); p = buf; + bad = 0; - if( *p++ != 0 ) - correct = 0; + /* + * Check and get padding len in "constant-time" + */ + bad |= *p++; /* First byte must be 0 */ - bt = *p++; - if( ( bt != RSA_CRYPT && mode == RSA_PRIVATE ) || - ( bt != RSA_SIGN && mode == RSA_PUBLIC ) ) + /* This test does not depend on secret data */ + if( mode == RSA_PRIVATE ) { - correct = 0; - } + bad |= *p++ ^ RSA_CRYPT; - if( bt == RSA_CRYPT ) - { - while( *p != 0 && p < buf + ilen - 1 ) - pad_count += ( *p++ != 0 ); + /* Get padding len, but always read till end of buffer + * (minus one, for the 00 byte) */ + for( i = 0; i < ilen - 3; i++ ) + { + pad_done |= ( p[i] == 0 ); + pad_count += ( pad_done == 0 ); + } - correct &= ( *p == 0 && p < buf + ilen - 1 ); - - q = p; - - // Also pass over all other bytes to reduce timing differences - // - while ( q < buf + ilen - 1 ) - pad_count += ( *q++ != 0 ); - - // Prevent compiler optimization of pad_count - // - correct |= pad_count & 0x100000; /* Always 0 unless 1M bit keys */ - p++; + p += pad_count; + bad |= *p++; /* Must be zero */ } else { - while( *p == 0xFF && p < buf + ilen - 1 ) - pad_count += ( *p++ == 0xFF ); + bad |= *p++ ^ RSA_SIGN; - correct &= ( *p == 0 && p < buf + ilen - 1 ); + /* Get padding len, but always read till end of buffer + * (minus one, for the 00 byte) */ + for( i = 0; i < ilen - 3; i++ ) + { + pad_done |= ( p[i] == 0xFF ); + pad_count += ( pad_done == 0 ); + } - q = p; - - // Also pass over all other bytes to reduce timing differences - // - while ( q < buf + ilen - 1 ) - pad_count += ( *q++ != 0 ); - - // Prevent compiler optimization of pad_count - // - correct |= pad_count & 0x100000; /* Always 0 unless 1M bit keys */ - p++; + p += pad_count; + bad |= *p++; /* Must be zero */ } - if( correct == 0 ) + if( bad ) return( POLARSSL_ERR_RSA_INVALID_PADDING ); if (ilen - (p - buf) > output_max_len)