diff --git a/library/psa_crypto.c b/library/psa_crypto.c index 6e9e191af..552c45f27 100644 --- a/library/psa_crypto.c +++ b/library/psa_crypto.c @@ -3116,24 +3116,33 @@ static psa_status_t psa_rsa_decode_md_type( psa_algorithm_t alg, return( PSA_SUCCESS ); } -static psa_status_t psa_rsa_sign( mbedtls_rsa_context *rsa, - psa_algorithm_t alg, - const uint8_t *hash, - size_t hash_length, - uint8_t *signature, - size_t signature_size, - size_t *signature_length ) +static psa_status_t psa_rsa_sign( + const psa_key_attributes_t *attributes, + const uint8_t *key_buffer, size_t key_buffer_size, + psa_algorithm_t alg, const uint8_t *hash, size_t hash_length, + uint8_t *signature, size_t signature_size, size_t *signature_length ) { - psa_status_t status; + psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED; + mbedtls_rsa_context *rsa = NULL; int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED; mbedtls_md_type_t md_alg; - status = psa_rsa_decode_md_type( alg, hash_length, &md_alg ); + status = mbedtls_psa_rsa_load_representation( attributes->core.type, + key_buffer, + key_buffer_size, + &rsa ); if( status != PSA_SUCCESS ) return( status ); + status = psa_rsa_decode_md_type( alg, hash_length, &md_alg ); + if( status != PSA_SUCCESS ) + goto exit; + if( signature_size < mbedtls_rsa_get_len( rsa ) ) - return( PSA_ERROR_BUFFER_TOO_SMALL ); + { + status = PSA_ERROR_BUFFER_TOO_SMALL; + goto exit; + } #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN) if( PSA_ALG_IS_RSA_PKCS1V15_SIGN( alg ) ) @@ -3167,31 +3176,48 @@ static psa_status_t psa_rsa_sign( mbedtls_rsa_context *rsa, else #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS */ { - return( PSA_ERROR_INVALID_ARGUMENT ); + status = PSA_ERROR_INVALID_ARGUMENT; + goto exit; } if( ret == 0 ) *signature_length = mbedtls_rsa_get_len( rsa ); - return( mbedtls_to_psa_error( ret ) ); + status = mbedtls_to_psa_error( ret ); + +exit: + mbedtls_rsa_free( rsa ); + mbedtls_free( rsa ); + + return( status ); } -static psa_status_t psa_rsa_verify( mbedtls_rsa_context *rsa, - psa_algorithm_t alg, - const uint8_t *hash, - size_t hash_length, - const uint8_t *signature, - size_t signature_length ) +static psa_status_t psa_rsa_verify( + const psa_key_attributes_t *attributes, + const uint8_t *key_buffer, size_t key_buffer_size, + psa_algorithm_t alg, const uint8_t *hash, size_t hash_length, + const uint8_t *signature, size_t signature_length ) { - psa_status_t status; + psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED; + mbedtls_rsa_context *rsa = NULL; int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED; mbedtls_md_type_t md_alg; + status = mbedtls_psa_rsa_load_representation( attributes->core.type, + key_buffer, + key_buffer_size, + &rsa ); + if( status != PSA_SUCCESS ) + goto exit; + status = psa_rsa_decode_md_type( alg, hash_length, &md_alg ); if( status != PSA_SUCCESS ) - return( status ); + goto exit; if( signature_length != mbedtls_rsa_get_len( rsa ) ) - return( PSA_ERROR_INVALID_SIGNATURE ); + { + status = PSA_ERROR_INVALID_SIGNATURE; + goto exit; + } #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN) if( PSA_ALG_IS_RSA_PKCS1V15_SIGN( alg ) ) @@ -3225,16 +3251,24 @@ static psa_status_t psa_rsa_verify( mbedtls_rsa_context *rsa, else #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS */ { - return( PSA_ERROR_INVALID_ARGUMENT ); + status = PSA_ERROR_INVALID_ARGUMENT; + goto exit; } /* Mbed TLS distinguishes "invalid padding" from "valid padding but * the rest of the signature is invalid". This has little use in * practice and PSA doesn't report this distinction. */ - if( ret == MBEDTLS_ERR_RSA_INVALID_PADDING ) - return( PSA_ERROR_INVALID_SIGNATURE ); - return( mbedtls_to_psa_error( ret ) ); + status = ( ret == MBEDTLS_ERR_RSA_INVALID_PADDING ) ? + PSA_ERROR_INVALID_SIGNATURE : + mbedtls_to_psa_error( ret ); + +exit: + mbedtls_rsa_free( rsa ); + mbedtls_free( rsa ); + + return( status ); } + #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN) || * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS) */ @@ -3353,23 +3387,10 @@ psa_status_t psa_sign_hash_internal( defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS) if( attributes->core.type == PSA_KEY_TYPE_RSA_KEY_PAIR ) { - mbedtls_rsa_context *rsa = NULL; - - status = mbedtls_psa_rsa_load_representation( attributes->core.type, - key_buffer, - key_buffer_size, - &rsa ); - if( status != PSA_SUCCESS ) - goto exit; - - status = psa_rsa_sign( rsa, - alg, - hash, hash_length, - signature, signature_size, - signature_length ); - - mbedtls_rsa_free( rsa ); - mbedtls_free( rsa ); + return( psa_rsa_sign( attributes, + key_buffer, key_buffer_size, + alg, hash, hash_length, + signature, signature_size, signature_length ) ); } else #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN) || @@ -3489,22 +3510,10 @@ psa_status_t psa_verify_hash_internal( defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS) if( PSA_KEY_TYPE_IS_RSA( attributes->core.type ) ) { - mbedtls_rsa_context *rsa = NULL; - - status = mbedtls_psa_rsa_load_representation( attributes->core.type, - key_buffer, - key_buffer_size, - &rsa ); - if( status != PSA_SUCCESS ) - goto exit; - - status = psa_rsa_verify( rsa, - alg, - hash, hash_length, - signature, signature_length ); - mbedtls_rsa_free( rsa ); - mbedtls_free( rsa ); - goto exit; + return( psa_rsa_verify( attributes, + key_buffer, key_buffer_size, + alg, hash, hash_length, + signature, signature_length ) ); } else #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN) ||