diff --git a/library/rsa.c b/library/rsa.c index 5d7129a00..ee6ca0195 100644 --- a/library/rsa.c +++ b/library/rsa.c @@ -441,8 +441,6 @@ int rsa_pkcs1_encrypt( rsa_context *ctx, memset( output, 0, olen ); - md_init_ctx( &md_ctx, md_info ); - *p++ = 0; // Generate a random octet string seed @@ -460,6 +458,8 @@ int rsa_pkcs1_encrypt( rsa_context *ctx, *p++ = 1; memcpy( p, input, ilen ); + md_init_ctx( &md_ctx, md_info ); + // maskedDB: Apply dbMask to DB // mgf_mask( output + hlen + 1, olen - hlen - 1, output + 1, hlen, @@ -800,8 +800,6 @@ int rsa_pkcs1_sign( rsa_context *ctx, memset( sig, 0, olen ); - md_init_ctx( &md_ctx, md_info ); - msb = mpi_msb( &ctx->N ) - 1; // Generate salt of length slen @@ -817,6 +815,8 @@ int rsa_pkcs1_sign( rsa_context *ctx, memcpy( p, salt, slen ); p += slen; + md_init_ctx( &md_ctx, md_info ); + // Generate H = Hash( M' ) // md_starts( &md_ctx ); @@ -1016,8 +1016,6 @@ int rsa_pkcs1_verify( rsa_context *ctx, memset( zeros, 0, 8 ); - md_init_ctx( &md_ctx, md_info ); - // Note: EMSA-PSS verification is over the length of N - 1 bits // msb = mpi_msb( &ctx->N ) - 1; @@ -1032,6 +1030,8 @@ int rsa_pkcs1_verify( rsa_context *ctx, if( buf[0] >> ( 8 - siglen * 8 + msb ) ) return( POLARSSL_ERR_RSA_BAD_INPUT_DATA ); + md_init_ctx( &md_ctx, md_info ); + mgf_mask( p, siglen - hlen - 1, p + siglen - hlen - 1, hlen, &md_ctx ); buf[0] &= 0xFF >> ( siglen * 8 - msb ); @@ -1039,11 +1039,12 @@ int rsa_pkcs1_verify( rsa_context *ctx, while( *p == 0 && p < buf + siglen ) p++; - if( p == buf + siglen ) - return( POLARSSL_ERR_RSA_INVALID_PADDING ); - - if( *p++ != 0x01 ) + if( p == buf + siglen || + *p++ != 0x01 ) + { + md_free_ctx( &md_ctx ); return( POLARSSL_ERR_RSA_INVALID_PADDING ); + } slen -= p - buf;