Extract common code

Make code easier to maintain.
This commit is contained in:
Nir Sonnenschein 2018-06-04 16:40:31 +03:00 committed by itayzafrir
parent 717a040df5
commit 4db79eb36b

View file

@ -1203,6 +1203,30 @@ psa_status_t psa_mac_verify( psa_mac_operation_t *operation,
/* Asymmetric cryptography */ /* Asymmetric cryptography */
/****************************************************************/ /****************************************************************/
static psa_status_t verify_RSA_hash_input_and_get_md_type(psa_algorithm_t alg, size_t hash_length, mbedtls_md_type_t *md_alg)
{
psa_algorithm_t hash_alg = PSA_ALG_RSA_GET_HASH(alg);
const mbedtls_md_info_t *md_info = mbedtls_md_info_from_psa(hash_alg);
*md_alg = hash_alg == 0 ? MBEDTLS_MD_NONE : mbedtls_md_get_type(md_info);
if (*md_alg == MBEDTLS_MD_NONE)
{
#if SIZE_MAX > UINT_MAX
if (hash_length > UINT_MAX)
return(PSA_ERROR_INVALID_ARGUMENT);
#endif
}
else
{
if (mbedtls_md_get_size(md_info) != hash_length)
return(PSA_ERROR_INVALID_ARGUMENT);
if (md_info == NULL)
return(PSA_ERROR_NOT_SUPPORTED);
}
return PSA_SUCCESS;
}
psa_status_t psa_asymmetric_sign(psa_key_slot_t key, psa_status_t psa_asymmetric_sign(psa_key_slot_t key,
psa_algorithm_t alg, psa_algorithm_t alg,
const uint8_t *hash, const uint8_t *hash,
@ -1214,11 +1238,12 @@ psa_status_t psa_asymmetric_sign(psa_key_slot_t key,
size_t *signature_length) size_t *signature_length)
{ {
key_slot_t *slot; key_slot_t *slot;
psa_status_t status;
*signature_length = 0; *signature_length = 0;
(void) salt; (void) salt;
(void) salt_length; (void) salt_length;
if( key == 0 || key > MBEDTLS_PSA_KEY_SLOT_COUNT ) if( key == 0 || key > MBEDTLS_PSA_KEY_SLOT_COUNT )
return( PSA_ERROR_EMPTY_SLOT ); return( PSA_ERROR_EMPTY_SLOT );
slot = &global_data.key_slots[key]; slot = &global_data.key_slots[key];
@ -1234,24 +1259,12 @@ psa_status_t psa_asymmetric_sign(psa_key_slot_t key,
{ {
mbedtls_rsa_context *rsa = slot->data.rsa; mbedtls_rsa_context *rsa = slot->data.rsa;
int ret; int ret;
psa_algorithm_t hash_alg = PSA_ALG_RSA_GET_HASH( alg ); mbedtls_md_type_t md_alg;
const mbedtls_md_info_t *md_info = mbedtls_md_info_from_psa( hash_alg ); status = verify_RSA_hash_input_and_get_md_type( alg, hash_length,
mbedtls_md_type_t md_alg = &md_alg );
hash_alg == 0 ? MBEDTLS_MD_NONE : mbedtls_md_get_type( md_info ); if ( status != PSA_SUCCESS )
if( md_alg == MBEDTLS_MD_NONE ) return status;
{
#if SIZE_MAX > UINT_MAX
if( hash_length > UINT_MAX )
return( PSA_ERROR_INVALID_ARGUMENT );
#endif
}
else
{
if( mbedtls_md_get_size( md_info ) != hash_length )
return( PSA_ERROR_INVALID_ARGUMENT );
if( md_info == NULL )
return( PSA_ERROR_NOT_SUPPORTED );
}
if( signature_size < rsa->len ) if( signature_size < rsa->len )
return( PSA_ERROR_BUFFER_TOO_SMALL ); return( PSA_ERROR_BUFFER_TOO_SMALL );
#if defined(MBEDTLS_PKCS1_V15) #if defined(MBEDTLS_PKCS1_V15)
@ -1323,6 +1336,7 @@ psa_status_t psa_asymmetric_verify(psa_key_slot_t key,
size_t signature_size) size_t signature_size)
{ {
key_slot_t *slot; key_slot_t *slot;
psa_status_t status;
(void) salt; (void) salt;
(void) salt_length; (void) salt_length;
@ -1337,24 +1351,12 @@ psa_status_t psa_asymmetric_verify(psa_key_slot_t key,
{ {
mbedtls_rsa_context *rsa = slot->data.rsa; mbedtls_rsa_context *rsa = slot->data.rsa;
int ret; int ret;
psa_algorithm_t hash_alg = PSA_ALG_RSA_GET_HASH( alg ); mbedtls_md_type_t md_alg;
const mbedtls_md_info_t *md_info = mbedtls_md_info_from_psa( hash_alg ); status = verify_RSA_hash_input_and_get_md_type(alg, hash_length,
mbedtls_md_type_t md_alg = &md_alg);
hash_alg == 0 ? MBEDTLS_MD_NONE : mbedtls_md_get_type( md_info ); if (status != PSA_SUCCESS)
if( md_alg == MBEDTLS_MD_NONE ) return status;
{
#if SIZE_MAX > UINT_MAX
if( hash_length > UINT_MAX )
return( PSA_ERROR_INVALID_ARGUMENT );
#endif
}
else
{
if( mbedtls_md_get_size( md_info ) != hash_length )
return( PSA_ERROR_INVALID_ARGUMENT );
if( md_info == NULL )
return( PSA_ERROR_NOT_SUPPORTED );
}
if( signature_size < rsa->len ) if( signature_size < rsa->len )
return( PSA_ERROR_BUFFER_TOO_SMALL ); return( PSA_ERROR_BUFFER_TOO_SMALL );
#if defined(MBEDTLS_PKCS1_V15) #if defined(MBEDTLS_PKCS1_V15)