diff --git a/library/psa_crypto.c b/library/psa_crypto.c index 093d2e568..b64662906 100644 --- a/library/psa_crypto.c +++ b/library/psa_crypto.c @@ -4727,6 +4727,22 @@ static psa_status_t psa_tls12_prf_input( psa_tls12_prf_key_derivation_t *prf, return( PSA_ERROR_INVALID_ARGUMENT ); } + +static psa_status_t psa_tls12_prf_psk_to_ms_input( + psa_tls12_prf_key_derivation_t *prf, + psa_algorithm_t hash_alg, + psa_key_derivation_step_t step, + const uint8_t *data, + size_t data_length ) +{ + (void) prf; + (void) hash_alg; + (void) step; + (void) data; + (void) data_length; + + return( PSA_ERROR_INVALID_ARGUMENT ); +} #else static psa_status_t psa_tls12_prf_set_seed( psa_tls12_prf_key_derivation_t *prf, const uint8_t *data, @@ -4765,6 +4781,38 @@ static psa_status_t psa_tls12_prf_set_key( psa_tls12_prf_key_derivation_t *prf, return( PSA_SUCCESS ); } +static psa_status_t psa_tls12_prf_psk_to_ms_set_key( + psa_tls12_prf_key_derivation_t *prf, + psa_algorithm_t hash_alg, + const uint8_t *data, + size_t data_length ) +{ + psa_status_t status; + unsigned char pms[ 4 + 2 * PSA_ALG_TLS12_PSK_TO_MS_MAX_PSK_LEN ]; + + if( data_length > PSA_ALG_TLS12_PSK_TO_MS_MAX_PSK_LEN ) + return( PSA_ERROR_INVALID_ARGUMENT ); + + /* Quoting RFC 4279, Section 2: + * + * The premaster secret is formed as follows: if the PSK is N octets + * long, concatenate a uint16 with the value N, N zero octets, a second + * uint16 with the value N, and the PSK itself. + */ + + pms[0] = ( data_length >> 8 ) & 0xff; + pms[1] = ( data_length >> 0 ) & 0xff; + memset( pms + 2, 0, data_length ); + pms[2 + data_length + 0] = pms[0]; + pms[2 + data_length + 1] = pms[1]; + memcpy( pms + 4 + data_length, data, data_length ); + + status = psa_tls12_prf_set_key( prf, hash_alg, pms, 4 + 2 * data_length ); + + mbedtls_platform_zeroize( pms, sizeof( pms ) ); + return( status ); +} + static psa_status_t psa_tls12_prf_set_label( psa_tls12_prf_key_derivation_t *prf, const uint8_t *data, size_t data_length ) @@ -4802,6 +4850,20 @@ static psa_status_t psa_tls12_prf_input( psa_tls12_prf_key_derivation_t *prf, return( PSA_ERROR_INVALID_ARGUMENT ); } } + +static psa_status_t psa_tls12_prf_psk_to_ms_input( + psa_tls12_prf_key_derivation_t *prf, + psa_algorithm_t hash_alg, + psa_key_derivation_step_t step, + const uint8_t *data, + size_t data_length ) +{ + if( step == PSA_KEY_DERIVATION_INPUT_SECRET ) + return( psa_tls12_prf_psk_to_ms_set_key( prf, hash_alg, + data, data_length ) ); + + return( psa_tls12_prf_input( prf, hash_alg, step, data, data_length ) ); +} #endif /* PSA_PRE_1_0_KEY_DERIVATION */ #endif /* MBEDTLS_MD_C */ @@ -4824,15 +4886,17 @@ static psa_status_t psa_key_derivation_input_internal( else #endif /* MBEDTLS_MD_C */ #if defined(MBEDTLS_MD_C) - /* TLS-1.2 PRF and TLS-1.2 PSK-to-MS are very similar, so share code. */ - if( PSA_ALG_IS_TLS12_PRF( kdf_alg ) || - PSA_ALG_IS_TLS12_PSK_TO_MS( kdf_alg ) ) + if( PSA_ALG_IS_TLS12_PRF( kdf_alg ) ) { - // To do: implement this status = psa_tls12_prf_input( &operation->ctx.tls12_prf, PSA_ALG_HKDF_GET_HASH( kdf_alg ), step, data, data_length ); - + } + else if( PSA_ALG_IS_TLS12_PSK_TO_MS( kdf_alg ) ) + { + status = psa_tls12_prf_psk_to_ms_input( &operation->ctx.tls12_prf, + PSA_ALG_HKDF_GET_HASH( kdf_alg ), + step, data, data_length ); } else #endif /* MBEDTLS_MD_C */