diff --git a/library/psa_crypto.c b/library/psa_crypto.c index d730bd821..cc996a01c 100644 --- a/library/psa_crypto.c +++ b/library/psa_crypto.c @@ -2451,18 +2451,19 @@ psa_status_t psa_cipher_finish( psa_cipher_operation_t *operation, size_t output_size, size_t *output_length ) { - int ret = MBEDTLS_ERR_CIPHER_FEATURE_UNAVAILABLE; + psa_status_t status = PSA_ERROR_UNKNOWN_ERROR; + int cipher_ret = MBEDTLS_ERR_CIPHER_FEATURE_UNAVAILABLE; uint8_t temp_output_buffer[MBEDTLS_MAX_BLOCK_LENGTH]; if( ! operation->key_set ) { - psa_cipher_abort( operation ); - return( PSA_ERROR_BAD_STATE ); + status = PSA_ERROR_BAD_STATE; + goto error; } if( operation->iv_required && ! operation->iv_set ) { - psa_cipher_abort( operation ); - return( PSA_ERROR_BAD_STATE ); + status = PSA_ERROR_BAD_STATE; + goto error; } if( operation->ctx.cipher.operation == MBEDTLS_ENCRYPT && PSA_ALG_IS_BLOCK_CIPHER( operation->alg ) ) @@ -2471,37 +2472,51 @@ psa_status_t psa_cipher_finish( psa_cipher_operation_t *operation, operation->alg & PSA_ALG_BLOCK_CIPHER_PADDING_MASK; if( operation->ctx.cipher.unprocessed_len >= operation->block_size ) { - psa_cipher_abort( operation ); - return( PSA_ERROR_TAMPERING_DETECTED ); + status = PSA_ERROR_TAMPERING_DETECTED; + goto error; } if( padding_mode == PSA_ALG_BLOCK_CIPHER_PAD_NONE ) { if( operation->ctx.cipher.unprocessed_len != 0 ) { - psa_cipher_abort( operation ); - return( PSA_ERROR_INVALID_ARGUMENT ); + status = PSA_ERROR_INVALID_ARGUMENT; + goto error; } } } - ret = mbedtls_cipher_finish( &operation->ctx.cipher, temp_output_buffer, - output_length ); - if( ret != 0 ) + cipher_ret = mbedtls_cipher_finish( &operation->ctx.cipher, + temp_output_buffer, + output_length ); + if( cipher_ret != 0 ) { - psa_cipher_abort( operation ); - return( mbedtls_to_psa_error( ret ) ); + status = mbedtls_to_psa_error( cipher_ret ); + goto error; } + if( *output_length == 0 ) - /* Nothing to copy. Note that output may be NULL in this case. */ ; + ; /* Nothing to copy. Note that output may be NULL in this case. */ else if( output_size >= *output_length ) memcpy( output, temp_output_buffer, *output_length ); else { - psa_cipher_abort( operation ); - return( PSA_ERROR_BUFFER_TOO_SMALL ); + status = PSA_ERROR_BUFFER_TOO_SMALL; + goto error; } - return( PSA_SUCCESS ); + mbedtls_zeroize( temp_output_buffer, sizeof( temp_output_buffer ) ); + status = psa_cipher_abort( operation ); + + return( status ); + +error: + + *output_length = 0; + + mbedtls_zeroize( temp_output_buffer, sizeof( temp_output_buffer ) ); + (void) psa_cipher_abort( operation ); + + return( status ); } psa_status_t psa_cipher_abort( psa_cipher_operation_t *operation ) diff --git a/tests/suites/test_suite_psa_crypto.function b/tests/suites/test_suite_psa_crypto.function index 977222bbf..3681a2ee1 100644 --- a/tests/suites/test_suite_psa_crypto.function +++ b/tests/suites/test_suite_psa_crypto.function @@ -1628,7 +1628,7 @@ void cipher_verify_output( int alg_arg, int key_type_arg, output2_length += function_output_length; - TEST_ASSERT( psa_cipher_abort( &operation1 ) == PSA_SUCCESS ); + TEST_ASSERT( psa_cipher_abort( &operation2 ) == PSA_SUCCESS ); TEST_ASSERT( input->len == output2_length ); TEST_ASSERT( memcmp( input->x, output2, input->len ) == 0 ); @@ -1739,7 +1739,7 @@ void cipher_verify_output_multipart( int alg_arg, &function_output_length ) == PSA_SUCCESS ); output2_length += function_output_length; - TEST_ASSERT( psa_cipher_abort( &operation1 ) == PSA_SUCCESS ); + TEST_ASSERT( psa_cipher_abort( &operation2 ) == PSA_SUCCESS ); TEST_ASSERT( input->len == output2_length ); TEST_ASSERT( memcmp( input->x, output2, input->len ) == 0 );