From 4fed4553473e5f2ed154bcd90f5a937f6fa31922 Mon Sep 17 00:00:00 2001
From: Steven Cooreman <steven.cooreman@silabs.com>
Date: Mon, 3 Aug 2020 14:46:03 +0200
Subject: [PATCH] Apply review feedback

* No need to check for NULL before free'ing
* No need to reset variables that weren't touched
* Set output buffer to zero if key output fails
* Document internal functions and rearrange order of input arguments to
  better match other functions.
* Clean up Montgomery fix to be less verbose code

Signed-off-by: Steven Cooreman <steven.cooreman@silabs.com>
---
 library/psa_crypto.c | 286 ++++++++++++++++++++++---------------------
 1 file changed, 145 insertions(+), 141 deletions(-)

diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 056e67796..dfd37ae4e 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -520,17 +520,17 @@ static psa_status_t psa_check_rsa_key_byte_aligned(
 
 /** Load the contents of a key buffer into an internal RSA representation
  *
- * \param[in] buffer    The buffer from which to load the representation.
- * \param[in] size      The size in bytes of \p buffer.
- * \param[in] type      The type of key contained in the \p buffer.
- * \param[out] p_rsa    Returns a pointer to an RSA context on success.
- *                      The caller is responsible for freeing both the
- *                      contents of the context and the context itself
- *                      when done.
+ * \param[in] type          The type of key contained in \p data.
+ * \param[in] data          The buffer from which to load the representation.
+ * \param[in] data_length   The size in bytes of \p data.
+ * \param[out] p_rsa        Returns a pointer to an RSA context on success.
+ *                          The caller is responsible for freeing both the
+ *                          contents of the context and the context itself
+ *                          when done.
  */
-static psa_status_t psa_load_rsa_representation( const uint8_t *buffer,
-                                                 size_t size,
-                                                 psa_key_type_t type,
+static psa_status_t psa_load_rsa_representation( psa_key_type_t type,
+                                                 const uint8_t *data,
+                                                 size_t data_length,
                                                  mbedtls_rsa_context **p_rsa )
 {
 #if defined(MBEDTLS_PK_PARSE_C)
@@ -542,10 +542,10 @@ static psa_status_t psa_load_rsa_representation( const uint8_t *buffer,
     /* Parse the data. */
     if( PSA_KEY_TYPE_IS_KEY_PAIR( type ) )
         status = mbedtls_to_psa_error(
-            mbedtls_pk_parse_key( &ctx, buffer, size, NULL, 0 ) );
+            mbedtls_pk_parse_key( &ctx, data, data_length, NULL, 0 ) );
     else
         status = mbedtls_to_psa_error(
-            mbedtls_pk_parse_public_key( &ctx, buffer, size ) );
+            mbedtls_pk_parse_public_key( &ctx, data, data_length ) );
     if( status != PSA_SUCCESS )
         goto exit;
 
@@ -580,8 +580,8 @@ exit:
     mbedtls_pk_free( &ctx );
     return( status );
 #else
-    (void) buffer;
-    (void) size;
+    (void) data;
+    (void) data_length;
     (void) type;
     (void) rsa;
     return( PSA_ERROR_NOT_SUPPORTED );
@@ -620,7 +620,11 @@ static psa_status_t psa_export_rsa_key( psa_key_type_t type,
         ret = mbedtls_pk_write_pubkey( &pos, data, &pk );
 
     if( ret < 0 )
+    {
+        /* Clean up in case pk_write failed halfway through. */
+        memset( data, 0, data_size );
         return mbedtls_to_psa_error( ret );
+    }
 
     /* The mbedtls_pk_xxx functions write to the end of the buffer.
      * Move the data to the beginning and erase remaining data
@@ -663,9 +667,9 @@ static psa_status_t psa_import_rsa_key( psa_key_slot_t *slot,
     mbedtls_rsa_context *rsa = NULL;
 
     /* Parse input */
-    status = psa_load_rsa_representation( data,
+    status = psa_load_rsa_representation( slot->attr.type,
+                                          data,
                                           data_length,
-                                          slot->attr.type,
                                           &rsa );
     if( status != PSA_SUCCESS )
         goto exit;
@@ -700,8 +704,6 @@ exit:
     if( status != PSA_SUCCESS )
     {
         mbedtls_free( output );
-        slot->data.key.data = NULL;
-        slot->data.key.bytes = 0;
         return( status );
     }
 
@@ -716,50 +718,42 @@ exit:
 #if defined(MBEDTLS_ECP_C)
 /** Load the contents of a key buffer into an internal ECP representation
  *
- * \param[in] buffer    The buffer from which to load the representation.
- * \param[in] size      The size in bytes of \p buffer.
- * \param[in] type      The type of key contained in the \p buffer.
- * \param[out] p_ecp    Returns a pointer to an ECP context on success.
- *                      The caller is responsible for freeing both the
- *                      contents of the context and the context itself
- *                      when done.
+ * \param[in] type          The type of key contained in \p data.
+ * \param[in] data          The buffer from which to load the representation.
+ * \param[in] data_length   The size in bytes of \p data.
+ * \param[out] p_ecp        Returns a pointer to an ECP context on success.
+ *                          The caller is responsible for freeing both the
+ *                          contents of the context and the context itself
+ *                          when done.
  */
-static psa_status_t psa_load_ecp_representation( const uint8_t *buffer,
-                                                 size_t size,
-                                                 psa_key_type_t type,
+static psa_status_t psa_load_ecp_representation( psa_key_type_t type,
+                                                 const uint8_t *data,
+                                                 size_t data_length,
                                                  mbedtls_ecp_keypair **p_ecp )
 {
     mbedtls_ecp_group_id grp_id = MBEDTLS_ECP_DP_NONE;
     psa_status_t status;
     mbedtls_ecp_keypair *ecp = NULL;
-    size_t curve_size;
+    size_t curve_size = data_length;
 
-    if( PSA_KEY_TYPE_IS_PUBLIC_KEY( type ) )
+    if( PSA_KEY_TYPE_IS_PUBLIC_KEY( type ) &&
+        PSA_KEY_TYPE_ECC_GET_FAMILY( type ) != PSA_ECC_FAMILY_MONTGOMERY )
     {
-        if( PSA_KEY_TYPE_ECC_GET_FAMILY( type ) == PSA_ECC_FAMILY_MONTGOMERY )
-        {
-            /* A Montgomery public key is represented as its raw
-             * compressed public point.
-             */
-            curve_size = size;
-        }
-        else
-        {
-            /* A Weierstrass public key is represented as:
-             * - The byte 0x04;
-             * - `x_P` as a `ceiling(m/8)`-byte string, big-endian;
-             * - `y_P` as a `ceiling(m/8)`-byte string, big-endian.
-             * So its data length is 2m+1 where n is the key size in bits.
-             */
-            if( ( size & 1 ) == 0 )
-                return( PSA_ERROR_INVALID_ARGUMENT );
-            curve_size = size / 2;
-        }
-    }
-    else
-    {
-        /* Private keys are represented as the raw private value */
-        curve_size = size;
+        /* A Weierstrass public key is represented as:
+         * - The byte 0x04;
+         * - `x_P` as a `ceiling(m/8)`-byte string, big-endian;
+         * - `y_P` as a `ceiling(m/8)`-byte string, big-endian.
+         * So its data length is 2m+1 where n is the key size in bits.
+         */
+        if( ( data_length & 1 ) == 0 )
+            return( PSA_ERROR_INVALID_ARGUMENT );
+        curve_size = data_length / 2;
+
+        /* Montgomery public keys are represented in compressed format, meaning
+         * their curve_size is equal to the amount of input. */
+
+        /* Private keys are represented in uncompressed private random integer
+         * format, meaning their curve_size is equal to the amount of input. */
     }
 
     /* Allocate and initialize a key representation. */
@@ -788,8 +782,8 @@ static psa_status_t psa_load_ecp_representation( const uint8_t *buffer,
         /* Load the public value. */
         status = mbedtls_to_psa_error(
             mbedtls_ecp_point_read_binary( &ecp->grp, &ecp->Q,
-                                           buffer,
-                                           size ) );
+                                           data,
+                                           data_length ) );
         if( status != PSA_SUCCESS )
             goto exit;
 
@@ -805,8 +799,8 @@ static psa_status_t psa_load_ecp_representation( const uint8_t *buffer,
         status = mbedtls_to_psa_error(
             mbedtls_ecp_read_key( ecp->grp.id,
                                   ecp,
-                                  buffer,
-                                  size ) );
+                                  data,
+                                  data_length ) );
 
         if( status != PSA_SUCCESS )
             goto exit;
@@ -823,6 +817,14 @@ exit:
     return status;
 }
 
+/** Export an ECP key to export representation
+ *
+ * \param[in] type          The type of key (public/private) to export
+ * \param[in] ecp           The internal ECP representation from which to export
+ * \param[out] data         The buffer to export to
+ * \param[in] data_size     The length of the buffer to export to
+ * \param[out] data_length  The amount of bytes written to \p data
+ */
 static psa_status_t psa_export_ecp_key( psa_key_type_t type,
                                         mbedtls_ecp_keypair *ecp,
                                         uint8_t *data,
@@ -844,12 +846,17 @@ static psa_status_t psa_export_ecp_key( psa_key_type_t type,
                 return status;
         }
 
-        return( mbedtls_to_psa_error(
+        status = mbedtls_to_psa_error(
                     mbedtls_ecp_point_write_binary( &ecp->grp, &ecp->Q,
                                                     MBEDTLS_ECP_PF_UNCOMPRESSED,
                                                     data_length,
                                                     data,
-                                                    data_size ) ) );
+                                                    data_size ) );
+
+        if( status != PSA_SUCCESS )
+            memset( data, 0, data_size );
+
+        return status;
     }
     else
     {
@@ -869,6 +876,12 @@ static psa_status_t psa_export_ecp_key( psa_key_type_t type,
     }
 }
 
+/** Import an ECP key from import representation to a slot
+ *
+ * \param[in,out] slot      The slot where to store the export representation to
+ * \param[in] data          The buffer containing the import representation
+ * \param[in] data_length   The amount of bytes in \p data
+ */
 static psa_status_t psa_import_ecp_key( psa_key_slot_t *slot,
                                         const uint8_t *data,
                                         size_t data_length )
@@ -878,9 +891,9 @@ static psa_status_t psa_import_ecp_key( psa_key_slot_t *slot,
     mbedtls_ecp_keypair *ecp = NULL;
 
     /* Parse input */
-    status = psa_load_ecp_representation( data,
+    status = psa_load_ecp_representation( slot->attr.type,
+                                          data,
                                           data_length,
-                                          slot->attr.type,
                                           &ecp );
     if( status != PSA_SUCCESS )
         goto exit;
@@ -916,8 +929,6 @@ exit:
     if( status != PSA_SUCCESS )
     {
         mbedtls_free( output );
-        slot->data.key.data = NULL;
-        slot->data.key.bytes = 0;
         return( status );
     }
 
@@ -1193,10 +1204,6 @@ static psa_status_t psa_get_transparent_key( psa_key_handle_t handle,
 /** Wipe key data from a slot. Preserve metadata such as the policy. */
 static psa_status_t psa_remove_key_data_from_memory( psa_key_slot_t *slot )
 {
-    /* Check whether key is already clean */
-    if( slot->data.key.data == NULL )
-        return PSA_SUCCESS;
-
 #if defined(MBEDTLS_PSA_CRYPTO_SE_C)
     if( psa_key_slot_is_external( slot ) )
     {
@@ -1459,9 +1466,9 @@ psa_status_t psa_get_key_attributes( psa_key_handle_t handle,
             {
                 mbedtls_rsa_context *rsa = NULL;
 
-                status = psa_load_rsa_representation( slot->data.key.data,
+                status = psa_load_rsa_representation( slot->attr.type,
+                                                      slot->data.key.data,
                                                       slot->data.key.bytes,
-                                                      slot->attr.type,
                                                       &rsa );
                 if( status != PSA_SUCCESS )
                     break;
@@ -1575,9 +1582,9 @@ static psa_status_t psa_internal_export_key( const psa_key_slot_t *slot,
 #if defined(MBEDTLS_RSA_C)
             mbedtls_rsa_context *rsa = NULL;
             psa_status_t status = psa_load_rsa_representation(
+                                    slot->attr.type,
                                     slot->data.key.data,
                                     slot->data.key.bytes,
-                                    slot->attr.type,
                                     &rsa );
             if( status != PSA_SUCCESS )
                 return status;
@@ -1602,9 +1609,9 @@ static psa_status_t psa_internal_export_key( const psa_key_slot_t *slot,
 #if defined(MBEDTLS_ECP_C)
             mbedtls_ecp_keypair *ecp = NULL;
             psa_status_t status = psa_load_ecp_representation(
+                                    slot->attr.type,
                                     slot->data.key.data,
                                     slot->data.key.bytes,
-                                    slot->attr.type,
                                     &ecp );
             if( status != PSA_SUCCESS )
                 return status;
@@ -2034,9 +2041,9 @@ static psa_status_t psa_validate_optional_attributes(
             int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
 
             psa_status_t status = psa_load_rsa_representation(
+                                    slot->attr.type,
                                     slot->data.key.data,
                                     slot->data.key.bytes,
-                                    slot->attr.type,
                                     &rsa );
             if( status != PSA_SUCCESS )
                 return status;
@@ -3693,9 +3700,9 @@ psa_status_t psa_sign_hash( psa_key_handle_t handle,
     {
         mbedtls_rsa_context *rsa = NULL;
 
-        status = psa_load_rsa_representation( slot->data.key.data,
+        status = psa_load_rsa_representation( slot->attr.type,
+                                              slot->data.key.data,
                                               slot->data.key.bytes,
-                                              slot->attr.type,
                                               &rsa );
         if( status != PSA_SUCCESS )
             goto exit;
@@ -3724,9 +3731,9 @@ psa_status_t psa_sign_hash( psa_key_handle_t handle,
             )
         {
             mbedtls_ecp_keypair *ecp = NULL;
-            status = psa_load_ecp_representation( slot->data.key.data,
+            status = psa_load_ecp_representation( slot->attr.type,
+                                                  slot->data.key.data,
                                                   slot->data.key.bytes,
-                                                  slot->attr.type,
                                                   &ecp );
             if( status != PSA_SUCCESS )
                 goto exit;
@@ -3802,9 +3809,9 @@ psa_status_t psa_verify_hash( psa_key_handle_t handle,
     {
         mbedtls_rsa_context *rsa = NULL;
 
-        status = psa_load_rsa_representation( slot->data.key.data,
+        status = psa_load_rsa_representation( slot->attr.type,
+                                              slot->data.key.data,
                                               slot->data.key.bytes,
-                                              slot->attr.type,
                                               &rsa );
         if( status != PSA_SUCCESS )
             return status;
@@ -3826,9 +3833,9 @@ psa_status_t psa_verify_hash( psa_key_handle_t handle,
         if( PSA_ALG_IS_ECDSA( alg ) )
         {
             mbedtls_ecp_keypair *ecp = NULL;
-            status = psa_load_ecp_representation( slot->data.key.data,
+            status = psa_load_ecp_representation( slot->attr.type,
+                                                  slot->data.key.data,
                                                   slot->data.key.bytes,
-                                                  slot->attr.type,
                                                   &ecp );
             if( status != PSA_SUCCESS )
                 return status;
@@ -3898,31 +3905,29 @@ psa_status_t psa_asymmetric_encrypt( psa_key_handle_t handle,
     if( PSA_KEY_TYPE_IS_RSA( slot->attr.type ) )
     {
         mbedtls_rsa_context *rsa = NULL;
-        int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-
-        status = psa_load_rsa_representation( slot->data.key.data,
+        status = psa_load_rsa_representation( slot->attr.type,
+                                              slot->data.key.data,
                                               slot->data.key.bytes,
-                                              slot->attr.type,
                                               &rsa );
         if( status != PSA_SUCCESS )
-            return status;
+            goto rsa_exit;
 
         if( output_size < mbedtls_rsa_get_len( rsa ) )
         {
-            mbedtls_rsa_free( rsa );
-            mbedtls_free( rsa );
-            return( PSA_ERROR_BUFFER_TOO_SMALL );
+            status = PSA_ERROR_BUFFER_TOO_SMALL;
+            goto rsa_exit;
         }
 #if defined(MBEDTLS_PKCS1_V15)
         if( alg == PSA_ALG_RSA_PKCS1V15_CRYPT )
         {
-            ret = mbedtls_rsa_pkcs1_encrypt( rsa,
-                                             mbedtls_ctr_drbg_random,
-                                             &global_data.ctr_drbg,
-                                             MBEDTLS_RSA_PUBLIC,
-                                             input_length,
-                                             input,
-                                             output );
+            status = mbedtls_to_psa_error(
+                    mbedtls_rsa_pkcs1_encrypt( rsa,
+                                               mbedtls_ctr_drbg_random,
+                                               &global_data.ctr_drbg,
+                                               MBEDTLS_RSA_PUBLIC,
+                                               input_length,
+                                               input,
+                                               output ) );
         }
         else
 #endif /* MBEDTLS_PKCS1_V15 */
@@ -3930,28 +3935,29 @@ psa_status_t psa_asymmetric_encrypt( psa_key_handle_t handle,
         if( PSA_ALG_IS_RSA_OAEP( alg ) )
         {
             psa_rsa_oaep_set_padding_mode( alg, rsa );
-            ret = mbedtls_rsa_rsaes_oaep_encrypt( rsa,
-                                                  mbedtls_ctr_drbg_random,
-                                                  &global_data.ctr_drbg,
-                                                  MBEDTLS_RSA_PUBLIC,
-                                                  salt, salt_length,
-                                                  input_length,
-                                                  input,
-                                                  output );
+            status = mbedtls_to_psa_error(
+                mbedtls_rsa_rsaes_oaep_encrypt( rsa,
+                                                mbedtls_ctr_drbg_random,
+                                                &global_data.ctr_drbg,
+                                                MBEDTLS_RSA_PUBLIC,
+                                                salt, salt_length,
+                                                input_length,
+                                                input,
+                                                output ) );
         }
         else
 #endif /* MBEDTLS_PKCS1_V21 */
         {
-            mbedtls_rsa_free( rsa );
-            mbedtls_free( rsa );
-            return( PSA_ERROR_INVALID_ARGUMENT );
+            status = PSA_ERROR_INVALID_ARGUMENT;
+            goto rsa_exit;
         }
-        if( ret == 0 )
+rsa_exit:
+        if( status == PSA_SUCCESS )
             *output_length = mbedtls_rsa_get_len( rsa );
 
         mbedtls_rsa_free( rsa );
         mbedtls_free( rsa );
-        return( mbedtls_to_psa_error( ret ) );
+        return( status );
     }
     else
 #endif /* defined(MBEDTLS_RSA_C) */
@@ -3993,34 +3999,32 @@ psa_status_t psa_asymmetric_decrypt( psa_key_handle_t handle,
 #if defined(MBEDTLS_RSA_C)
     if( slot->attr.type == PSA_KEY_TYPE_RSA_KEY_PAIR )
     {
-        mbedtls_rsa_context *rsa;
-        int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-
-        status = psa_load_rsa_representation( slot->data.key.data,
+        mbedtls_rsa_context *rsa = NULL;
+        status = psa_load_rsa_representation( slot->attr.type,
+                                              slot->data.key.data,
                                               slot->data.key.bytes,
-                                              slot->attr.type,
                                               &rsa );
         if( status != PSA_SUCCESS )
             return status;
 
         if( input_length != mbedtls_rsa_get_len( rsa ) )
         {
-            mbedtls_rsa_free( rsa );
-            mbedtls_free( rsa );
-            return( PSA_ERROR_INVALID_ARGUMENT );
+            status = PSA_ERROR_INVALID_ARGUMENT;
+            goto rsa_exit;
         }
 
 #if defined(MBEDTLS_PKCS1_V15)
         if( alg == PSA_ALG_RSA_PKCS1V15_CRYPT )
         {
-            ret = mbedtls_rsa_pkcs1_decrypt( rsa,
-                                             mbedtls_ctr_drbg_random,
-                                             &global_data.ctr_drbg,
-                                             MBEDTLS_RSA_PRIVATE,
-                                             output_length,
-                                             input,
-                                             output,
-                                             output_size );
+            status = mbedtls_to_psa_error(
+                mbedtls_rsa_pkcs1_decrypt( rsa,
+                                           mbedtls_ctr_drbg_random,
+                                           &global_data.ctr_drbg,
+                                           MBEDTLS_RSA_PRIVATE,
+                                           output_length,
+                                           input,
+                                           output,
+                                           output_size ) );
         }
         else
 #endif /* MBEDTLS_PKCS1_V15 */
@@ -4028,27 +4032,27 @@ psa_status_t psa_asymmetric_decrypt( psa_key_handle_t handle,
         if( PSA_ALG_IS_RSA_OAEP( alg ) )
         {
             psa_rsa_oaep_set_padding_mode( alg, rsa );
-            ret = mbedtls_rsa_rsaes_oaep_decrypt( rsa,
-                                                  mbedtls_ctr_drbg_random,
-                                                  &global_data.ctr_drbg,
-                                                  MBEDTLS_RSA_PRIVATE,
-                                                  salt, salt_length,
-                                                  output_length,
-                                                  input,
-                                                  output,
-                                                  output_size );
+            status = mbedtls_to_psa_error(
+                mbedtls_rsa_rsaes_oaep_decrypt( rsa,
+                                                mbedtls_ctr_drbg_random,
+                                                &global_data.ctr_drbg,
+                                                MBEDTLS_RSA_PRIVATE,
+                                                salt, salt_length,
+                                                output_length,
+                                                input,
+                                                output,
+                                                output_size ) );
         }
         else
 #endif /* MBEDTLS_PKCS1_V21 */
         {
-            mbedtls_rsa_free( rsa );
-            mbedtls_free( rsa );
-            return( PSA_ERROR_INVALID_ARGUMENT );
+            status = PSA_ERROR_INVALID_ARGUMENT;
         }
 
+rsa_exit:
         mbedtls_rsa_free( rsa );
         mbedtls_free( rsa );
-        return( mbedtls_to_psa_error( ret ) );
+        return( status );
     }
     else
 #endif /* defined(MBEDTLS_RSA_C) */
@@ -5605,9 +5609,9 @@ static psa_status_t psa_key_agreement_ecdh( const uint8_t *peer_key,
     psa_ecc_family_t curve = mbedtls_ecc_group_to_psa( our_key->grp.id, &bits );
     mbedtls_ecdh_init( &ecdh );
 
-    status = psa_load_ecp_representation( peer_key,
+    status = psa_load_ecp_representation( PSA_KEY_TYPE_ECC_PUBLIC_KEY(curve),
+                                          peer_key,
                                           peer_key_length,
-                                          PSA_KEY_TYPE_ECC_PUBLIC_KEY(curve),
                                           &their_key );
     if( status != PSA_SUCCESS )
         goto exit;
@@ -5661,9 +5665,9 @@ static psa_status_t psa_key_agreement_raw_internal( psa_algorithm_t alg,
                 return( PSA_ERROR_INVALID_ARGUMENT );
             mbedtls_ecp_keypair *ecp = NULL;
             psa_status_t status = psa_load_ecp_representation(
+                                    private_key->attr.type,
                                     private_key->data.key.data,
                                     private_key->data.key.bytes,
-                                    private_key->attr.type,
                                     &ecp );
             if( status != PSA_SUCCESS )
                 return status;