diff --git a/library/ssl_tls.c b/library/ssl_tls.c index 6fa2e3a76..93439697e 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -7450,44 +7450,56 @@ static void ssl_conf_remove_psk( mbedtls_ssl_config *conf ) } } -int mbedtls_ssl_conf_psk( mbedtls_ssl_config *conf, - const unsigned char *psk, size_t psk_len, - const unsigned char *psk_identity, size_t psk_identity_len ) +/* This function assumes that PSK identity in the SSL config is unset. + * It checks that the provided identity is well-formed and attempts + * to make a copy of it in the SSL config. + * On failure, the PSK identity in the config remains unset. */ +static int ssl_conf_set_psk_identity( mbedtls_ssl_config *conf, + unsigned char const *psk_identity, + size_t psk_identity_len ) { - if( psk == NULL || psk_identity == NULL ) - return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); - - if( psk_len > MBEDTLS_PSK_MAX_LEN ) - return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); - /* Identity len will be encoded on two bytes */ - if( ( psk_identity_len >> 16 ) != 0 || + if( psk_identity == NULL || + ( psk_identity_len >> 16 ) != 0 || psk_identity_len > MBEDTLS_SSL_OUT_CONTENT_LEN ) { return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); } - ssl_conf_remove_psk( conf ); - - if( ( conf->psk = mbedtls_calloc( 1, psk_len ) ) == NULL || - ( conf->psk_identity = mbedtls_calloc( 1, psk_identity_len ) ) == NULL ) - { - mbedtls_free( conf->psk ); - mbedtls_free( conf->psk_identity ); - conf->psk = NULL; - conf->psk_identity = NULL; + conf->psk_identity = mbedtls_calloc( 1, psk_identity_len ); + if( conf->psk_identity == NULL ) return( MBEDTLS_ERR_SSL_ALLOC_FAILED ); - } - conf->psk_len = psk_len; conf->psk_identity_len = psk_identity_len; - - memcpy( conf->psk, psk, conf->psk_len ); memcpy( conf->psk_identity, psk_identity, conf->psk_identity_len ); return( 0 ); } +int mbedtls_ssl_conf_psk( mbedtls_ssl_config *conf, + const unsigned char *psk, size_t psk_len, + const unsigned char *psk_identity, size_t psk_identity_len ) +{ + int ret; + /* Remove opaque/raw PSK + PSK Identity */ + ssl_conf_remove_psk( conf ); + + /* Check and set raw PSK */ + if( psk == NULL || psk_len > MBEDTLS_PSK_MAX_LEN ) + return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); + if( ( conf->psk = mbedtls_calloc( 1, psk_len ) ) == NULL ) + return( MBEDTLS_ERR_SSL_ALLOC_FAILED ); + conf->psk_len = psk_len; + memcpy( conf->psk, psk, conf->psk_len ); + + /* Check and set PSK Identity */ + ret = ssl_conf_set_psk_identity( conf, psk_identity, psk_identity_len ); + if( ret != 0 ) + ssl_conf_remove_psk( conf ); + + return( ret ); +} + static void ssl_remove_psk( mbedtls_ssl_context *ssl ) { #if defined(MBEDTLS_USE_PSA_CRYPTO) @@ -7532,30 +7544,22 @@ int mbedtls_ssl_conf_psk_opaque( mbedtls_ssl_config *conf, const unsigned char *psk_identity, size_t psk_identity_len ) { - if( psk_slot == 0 || psk_identity == NULL ) - return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); - - /* Identity len will be encoded on two bytes */ - if( ( psk_identity_len >> 16 ) != 0 || - psk_identity_len > MBEDTLS_SSL_OUT_CONTENT_LEN ) - { - return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); - } - + int ret; + /* Clear opaque/raw PSK + PSK Identity, if present. */ ssl_conf_remove_psk( conf ); - if( ( conf->psk_identity = mbedtls_calloc( 1, psk_identity_len ) ) == NULL ) - { - mbedtls_free( conf->psk_identity ); - conf->psk_identity = NULL; - return( MBEDTLS_ERR_SSL_ALLOC_FAILED ); - } - - conf->psk_identity_len = psk_identity_len; - memcpy( conf->psk_identity, psk_identity, conf->psk_identity_len ); - + /* Check and set opaque PSK */ + if( psk_slot == 0 ) + return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); conf->psk_opaque = psk_slot; - return( 0 ); + + /* Check and set PSK Identity */ + ret = ssl_conf_set_psk_identity( conf, psk_identity, + psk_identity_len ); + if( ret != 0 ) + ssl_conf_remove_psk( conf ); + + return( ret ); } int mbedtls_ssl_set_hs_psk_opaque( mbedtls_ssl_context *ssl,