diff --git a/library/ssl_tls.c b/library/ssl_tls.c index 82e65251f..2150c03d1 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -7326,6 +7326,37 @@ int mbedtls_ssl_set_hs_ecjpake_password( mbedtls_ssl_context *ssl, #endif /* MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED */ #if defined(MBEDTLS_KEY_EXCHANGE__SOME__PSK_ENABLED) + +static void ssl_conf_remove_psk( mbedtls_ssl_config *conf ) +{ + /* Remove reference to existing PSK, if any. */ +#if defined(MBEDTLS_USE_PSA_CRYPTO) + if( conf->psk_opaque != 0 ) + { + /* The maintenance of the PSK key slot is the + * user's responsibility. */ + conf->psk_opaque = 0; + } + else +#endif /* MBEDTLS_USE_PSA_CRYPTO */ + if( conf->psk != NULL ) + { + mbedtls_platform_zeroize( conf->psk, conf->psk_len ); + + mbedtls_free( conf->psk ); + conf->psk = NULL; + conf->psk_len = 0; + } + + /* Remove reference to PSK identity, if any. */ + if( conf->psk_identity != NULL ) + { + mbedtls_free( conf->psk_identity ); + conf->psk_identity = NULL; + conf->psk_identity_len = 0; + } +} + int mbedtls_ssl_conf_psk( mbedtls_ssl_config *conf, const unsigned char *psk, size_t psk_len, const unsigned char *psk_identity, size_t psk_identity_len ) @@ -7343,20 +7374,7 @@ int mbedtls_ssl_conf_psk( mbedtls_ssl_config *conf, return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); } - if( conf->psk != NULL ) - { - mbedtls_platform_zeroize( conf->psk, conf->psk_len ); - - mbedtls_free( conf->psk ); - conf->psk = NULL; - conf->psk_len = 0; - } - if( conf->psk_identity != NULL ) - { - mbedtls_free( conf->psk_identity ); - conf->psk_identity = NULL; - conf->psk_identity_len = 0; - } + ssl_conf_remove_psk( conf ); if( ( conf->psk = mbedtls_calloc( 1, psk_len ) ) == NULL || ( conf->psk_identity = mbedtls_calloc( 1, psk_identity_len ) ) == NULL ) @@ -7377,6 +7395,24 @@ int mbedtls_ssl_conf_psk( mbedtls_ssl_config *conf, return( 0 ); } +static void ssl_remove_psk( mbedtls_ssl_context *ssl ) +{ +#if defined(MBEDTLS_USE_PSA_CRYPTO) + if( ssl->handshake->psk_opaque != 0 ) + { + ssl->handshake->psk_opaque = 0; + } + else +#endif /* MBEDTLS_USE_PSA_CRYPTO */ + if( ssl->handshake->psk != NULL ) + { + mbedtls_platform_zeroize( ssl->handshake->psk, + ssl->handshake->psk_len ); + mbedtls_free( ssl->handshake->psk ); + ssl->handshake->psk_len = 0; + } +} + int mbedtls_ssl_set_hs_psk( mbedtls_ssl_context *ssl, const unsigned char *psk, size_t psk_len ) { @@ -7386,13 +7422,7 @@ int mbedtls_ssl_set_hs_psk( mbedtls_ssl_context *ssl, if( psk_len > MBEDTLS_PSK_MAX_LEN ) return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); - if( ssl->handshake->psk != NULL ) - { - mbedtls_platform_zeroize( ssl->handshake->psk, - ssl->handshake->psk_len ); - mbedtls_free( ssl->handshake->psk ); - ssl->handshake->psk_len = 0; - } + ssl_remove_psk( ssl ); if( ( ssl->handshake->psk = mbedtls_calloc( 1, psk_len ) ) == NULL ) return( MBEDTLS_ERR_SSL_ALLOC_FAILED ); @@ -7403,6 +7433,50 @@ int mbedtls_ssl_set_hs_psk( mbedtls_ssl_context *ssl, return( 0 ); } +#if defined(MBEDTLS_USE_PSA_CRYPTO) +int mbedtls_ssl_conf_psk_opaque( mbedtls_ssl_config *conf, + psa_key_slot_t psk_slot, + const unsigned char *psk_identity, + size_t psk_identity_len ) +{ + if( psk_slot == 0 || psk_identity == NULL ) + return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); + + /* Identity len will be encoded on two bytes */ + if( ( psk_identity_len >> 16 ) != 0 || + psk_identity_len > MBEDTLS_SSL_OUT_CONTENT_LEN ) + { + return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); + } + + ssl_conf_remove_psk( conf ); + + if( ( conf->psk_identity = mbedtls_calloc( 1, psk_identity_len ) ) == NULL ) + { + mbedtls_free( conf->psk_identity ); + conf->psk_identity = NULL; + return( MBEDTLS_ERR_SSL_ALLOC_FAILED ); + } + + conf->psk_identity_len = psk_identity_len; + memcpy( conf->psk_identity, psk_identity, conf->psk_identity_len ); + + conf->psk_opaque = psk_slot; + return( 0 ); +} + +int mbedtls_ssl_set_hs_psk_opaque( mbedtls_ssl_context *ssl, + psa_key_slot_t psk_slot ) +{ + if( psk_slot == 0 || ssl->handshake == NULL ) + return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); + + ssl_remove_psk( ssl ); + ssl->handshake->psk_opaque = psk_slot; + return( 0 ); +} +#endif /* MBEDTLS_USE_PSA_CRYPTO */ + void mbedtls_ssl_conf_psk_cb( mbedtls_ssl_config *conf, int (*f_psk)(void *, mbedtls_ssl_context *, const unsigned char *, size_t),