diff --git a/library/psa_crypto.c b/library/psa_crypto.c index 3411cc843..82af92086 100644 --- a/library/psa_crypto.c +++ b/library/psa_crypto.c @@ -2821,6 +2821,102 @@ psa_status_t psa_set_key_lifetime( psa_key_slot_t key, /* AEAD */ /****************************************************************/ +typedef struct +{ + key_slot_t *slot; + const mbedtls_cipher_info_t *cipher_info; + union + { +#if defined(MBEDTLS_CCM_C) + mbedtls_ccm_context ccm; +#endif /* MBEDTLS_CCM_C */ +#if defined(MBEDTLS_GCM_C) + mbedtls_gcm_context gcm; +#endif /* MBEDTLS_GCM_C */ + } ctx; + uint8_t tag_length; +} aead_operation_t; + +static void psa_aead_abort( aead_operation_t *operation, + psa_algorithm_t alg ) +{ + switch( alg ) + { +#if defined(MBEDTLS_CCM_C) + case PSA_ALG_CCM: + mbedtls_ccm_free( &operation->ctx.ccm ); + break; +#endif /* MBEDTLS_CCM_C */ +#if defined(MBEDTLS_CCM_C) + case PSA_ALG_GCM: + mbedtls_gcm_free( &operation->ctx.gcm ); + break; +#endif /* MBEDTLS_GCM_C */ + } +} + +static psa_status_t psa_aead_setup( aead_operation_t *operation, + psa_key_slot_t key, + psa_key_usage_t usage, + psa_algorithm_t alg ) +{ + psa_status_t status; + size_t key_bits; + mbedtls_cipher_id_t cipher_id; + + status = psa_get_key_from_slot( key, &operation->slot, usage, alg ); + if( status != PSA_SUCCESS ) + return( status ); + + key_bits = psa_get_key_bits( operation->slot ); + + operation->cipher_info = + mbedtls_cipher_info_from_psa( alg, operation->slot->type, key_bits, + &cipher_id ); + if( operation->cipher_info == NULL ) + return( PSA_ERROR_NOT_SUPPORTED ); + + switch( alg ) + { +#if defined(MBEDTLS_CCM_C) + case PSA_ALG_CCM: + operation->tag_length = 16; + if( PSA_BLOCK_CIPHER_BLOCK_SIZE( operation->slot->type ) != 16 ) + return( PSA_ERROR_INVALID_ARGUMENT ); + mbedtls_ccm_init( &operation->ctx.ccm ); + status = mbedtls_to_psa_error( + mbedtls_ccm_setkey( &operation->ctx.ccm, cipher_id, + operation->slot->data.raw.data, + (unsigned int) key_bits ) ); + if( status != 0 ) + goto cleanup; + break; +#endif /* MBEDTLS_CCM_C */ + +#if defined(MBEDTLS_GCM_C) + case PSA_ALG_GCM: + operation->tag_length = 16; + if( PSA_BLOCK_CIPHER_BLOCK_SIZE( operation->slot->type ) != 16 ) + return( PSA_ERROR_INVALID_ARGUMENT ); + mbedtls_gcm_init( &operation->ctx.gcm ); + status = mbedtls_to_psa_error( + mbedtls_gcm_setkey( &operation->ctx.gcm, cipher_id, + operation->slot->data.raw.data, + (unsigned int) key_bits ) ); + break; +#endif /* MBEDTLS_GCM_C */ + + default: + return( PSA_ERROR_NOT_SUPPORTED ); + } + + return( PSA_SUCCESS ); + +cleanup: + psa_aead_abort( operation, alg ); + return( status ); +} + psa_status_t psa_aead_encrypt( psa_key_slot_t key, psa_algorithm_t alg, const uint8_t *nonce, @@ -2833,113 +2929,60 @@ psa_status_t psa_aead_encrypt( psa_key_slot_t key, size_t ciphertext_size, size_t *ciphertext_length ) { - int ret; psa_status_t status; - key_slot_t *slot; - size_t key_bits; + aead_operation_t operation; uint8_t *tag; - size_t tag_length; - mbedtls_cipher_id_t cipher_id; - const mbedtls_cipher_info_t *cipher_info = NULL; *ciphertext_length = 0; - status = psa_get_key_from_slot( key, &slot, PSA_KEY_USAGE_ENCRYPT, alg ); + status = psa_aead_setup( &operation, key, PSA_KEY_USAGE_ENCRYPT, alg ); if( status != PSA_SUCCESS ) return( status ); - key_bits = psa_get_key_bits( slot ); - cipher_info = mbedtls_cipher_info_from_psa( alg, slot->type, - key_bits, &cipher_id ); - if( cipher_info == NULL ) - return( PSA_ERROR_NOT_SUPPORTED ); - - if( ( slot->type & PSA_KEY_TYPE_CATEGORY_MASK ) != - PSA_KEY_TYPE_CATEGORY_SYMMETRIC ) - return( PSA_ERROR_INVALID_ARGUMENT ); + /* For all currently supported modes, the tag is at the end of the + * ciphertext. */ + if( ciphertext_size < ( plaintext_length + operation.tag_length ) ) + { + status = PSA_ERROR_BUFFER_TOO_SMALL; + goto exit; + } + tag = ciphertext + plaintext_length; if( alg == PSA_ALG_GCM ) { - mbedtls_gcm_context gcm; - tag_length = 16; - - if( PSA_BLOCK_CIPHER_BLOCK_SIZE( slot->type ) != 16 ) - return( PSA_ERROR_INVALID_ARGUMENT ); - - //make sure we have place to hold the tag in the ciphertext buffer - if( ciphertext_size < ( plaintext_length + tag_length ) ) - return( PSA_ERROR_BUFFER_TOO_SMALL ); - - //update the tag pointer to point to the end of the ciphertext_length - tag = ciphertext + plaintext_length; - - mbedtls_gcm_init( &gcm ); - ret = mbedtls_gcm_setkey( &gcm, cipher_id, - slot->data.raw.data, - (unsigned int) key_bits ); - if( ret != 0 ) - { - mbedtls_gcm_free( &gcm ); - return( mbedtls_to_psa_error( ret ) ); - } - ret = mbedtls_gcm_crypt_and_tag( &gcm, MBEDTLS_GCM_ENCRYPT, - plaintext_length, nonce, - nonce_length, additional_data, - additional_data_length, plaintext, - ciphertext, tag_length, tag ); - mbedtls_gcm_free( &gcm ); + status = mbedtls_to_psa_error( + mbedtls_gcm_crypt_and_tag( &operation.ctx.gcm, + MBEDTLS_GCM_ENCRYPT, + plaintext_length, + nonce, nonce_length, + additional_data, additional_data_length, + plaintext, ciphertext, + operation.tag_length, tag ) ); } else if( alg == PSA_ALG_CCM ) { - mbedtls_ccm_context ccm; - tag_length = 16; - - if( PSA_BLOCK_CIPHER_BLOCK_SIZE( slot->type ) != 16 ) - return( PSA_ERROR_INVALID_ARGUMENT ); - - if( nonce_length < 7 || nonce_length > 13 ) - return( PSA_ERROR_INVALID_ARGUMENT ); - - //make sure we have place to hold the tag in the ciphertext buffer - if( ciphertext_size < ( plaintext_length + tag_length ) ) - return( PSA_ERROR_BUFFER_TOO_SMALL ); - - //update the tag pointer to point to the end of the ciphertext_length - tag = ciphertext + plaintext_length; - - mbedtls_ccm_init( &ccm ); - ret = mbedtls_ccm_setkey( &ccm, cipher_id, - slot->data.raw.data, - (unsigned int) key_bits ); - if( ret != 0 ) - { - mbedtls_ccm_free( &ccm ); - return( mbedtls_to_psa_error( ret ) ); - } - ret = mbedtls_ccm_encrypt_and_tag( &ccm, plaintext_length, - nonce, nonce_length, - additional_data, - additional_data_length, - plaintext, ciphertext, - tag, tag_length ); - mbedtls_ccm_free( &ccm ); + status = mbedtls_to_psa_error( + mbedtls_ccm_encrypt_and_tag( &operation.ctx.ccm, + plaintext_length, + nonce, nonce_length, + additional_data, + additional_data_length, + plaintext, ciphertext, + tag, operation.tag_length ) ); } else { return( PSA_ERROR_NOT_SUPPORTED ); } - if( ret != 0 ) - { - /* If ciphertext_size is 0 then ciphertext may be NULL and then the - * call to memset would have undefined behavior. */ - if( ciphertext_size != 0 ) - memset( ciphertext, 0, ciphertext_size ); - return( mbedtls_to_psa_error( ret ) ); - } + if( status != PSA_SUCCESS && ciphertext_size != 0 ) + memset( ciphertext, 0, ciphertext_size ); - *ciphertext_length = plaintext_length + tag_length; - return( PSA_SUCCESS ); +exit: + psa_aead_abort( &operation, alg ); + if( status == PSA_SUCCESS ) + *ciphertext_length = plaintext_length + operation.tag_length; + return( status ); } /* Locate the tag in a ciphertext buffer containing the encrypted data @@ -2975,108 +3018,63 @@ psa_status_t psa_aead_decrypt( psa_key_slot_t key, size_t plaintext_size, size_t *plaintext_length ) { - int ret; psa_status_t status; - key_slot_t *slot; - size_t key_bits; - const uint8_t *tag; - size_t tag_length; - mbedtls_cipher_id_t cipher_id; - const mbedtls_cipher_info_t *cipher_info = NULL; + aead_operation_t operation; + const uint8_t *tag = NULL; *plaintext_length = 0; - status = psa_get_key_from_slot( key, &slot, PSA_KEY_USAGE_DECRYPT, alg ); + status = psa_aead_setup( &operation, key, PSA_KEY_USAGE_DECRYPT, alg ); if( status != PSA_SUCCESS ) return( status ); - key_bits = psa_get_key_bits( slot ); - - cipher_info = mbedtls_cipher_info_from_psa( alg, slot->type, - key_bits, &cipher_id ); - if( cipher_info == NULL ) - return( PSA_ERROR_NOT_SUPPORTED ); - - if( ( slot->type & PSA_KEY_TYPE_CATEGORY_MASK ) != - PSA_KEY_TYPE_CATEGORY_SYMMETRIC ) - return( PSA_ERROR_INVALID_ARGUMENT ); if( alg == PSA_ALG_GCM ) { - mbedtls_gcm_context gcm; - - tag_length = 16; - status = psa_aead_unpadded_locate_tag( tag_length, + status = psa_aead_unpadded_locate_tag( operation.tag_length, ciphertext, ciphertext_length, plaintext_size, &tag ); if( status != PSA_SUCCESS ) - return( status ); + goto exit; - mbedtls_gcm_init( &gcm ); - ret = mbedtls_gcm_setkey( &gcm, cipher_id, - slot->data.raw.data, - (unsigned int) key_bits ); - if( ret != 0 ) - { - mbedtls_gcm_free( &gcm ); - return( mbedtls_to_psa_error( ret ) ); - } - - ret = mbedtls_gcm_auth_decrypt( &gcm, - ciphertext_length - tag_length, - nonce, nonce_length, - additional_data, - additional_data_length, - tag, tag_length, - ciphertext, plaintext ); - mbedtls_gcm_free( &gcm ); + status = mbedtls_to_psa_error( + mbedtls_gcm_auth_decrypt( &operation.ctx.gcm, + ciphertext_length - operation.tag_length, + nonce, nonce_length, + additional_data, + additional_data_length, + tag, operation.tag_length, + ciphertext, plaintext ) ); } else if( alg == PSA_ALG_CCM ) { - mbedtls_ccm_context ccm; - - if( nonce_length < 7 || nonce_length > 13 ) - return( PSA_ERROR_INVALID_ARGUMENT ); - - tag_length = 16; - status = psa_aead_unpadded_locate_tag( tag_length, + status = psa_aead_unpadded_locate_tag( operation.tag_length, ciphertext, ciphertext_length, plaintext_size, &tag ); if( status != PSA_SUCCESS ) - return( status ); + goto exit; - mbedtls_ccm_init( &ccm ); - ret = mbedtls_ccm_setkey( &ccm, cipher_id, - slot->data.raw.data, - (unsigned int) key_bits ); - if( ret != 0 ) - { - mbedtls_ccm_free( &ccm ); - return( mbedtls_to_psa_error( ret ) ); - } - ret = mbedtls_ccm_auth_decrypt( &ccm, ciphertext_length - tag_length, - nonce, nonce_length, - additional_data, - additional_data_length, - ciphertext, plaintext, - tag, tag_length ); - mbedtls_ccm_free( &ccm ); + status = mbedtls_to_psa_error( + mbedtls_ccm_auth_decrypt( &operation.ctx.ccm, + ciphertext_length - operation.tag_length, + nonce, nonce_length, + additional_data, + additional_data_length, + ciphertext, plaintext, + tag, operation.tag_length ) ); } else { return( PSA_ERROR_NOT_SUPPORTED ); } - if( ret != 0 ) - { - /* If plaintext_size is 0 then plaintext may be NULL and then the - * call to memset has undefined behavior. */ - if( plaintext_size != 0 ) - memset( plaintext, 0, plaintext_size ); - } - else - *plaintext_length = ciphertext_length - tag_length; + if( status != PSA_SUCCESS && plaintext_size != 0 ) + memset( plaintext, 0, plaintext_size ); - return( mbedtls_to_psa_error( ret ) ); +exit: + psa_aead_abort( &operation, alg ); + if( status == PSA_SUCCESS ) + *plaintext_length = ciphertext_length - operation.tag_length; + return( status ); }