diff --git a/include/mbedtls/ssl_internal.h b/include/mbedtls/ssl_internal.h index 6913dc0f6..f50cf9ff5 100644 --- a/include/mbedtls/ssl_internal.h +++ b/include/mbedtls/ssl_internal.h @@ -1112,6 +1112,7 @@ int mbedtls_ssl_set_calc_verify_md( mbedtls_ssl_context *ssl, int md ); #if defined(MBEDTLS_ECP_C) int mbedtls_ssl_check_curve( const mbedtls_ssl_context *ssl, mbedtls_ecp_group_id grp_id ); +int mbedtls_ssl_check_curve_tls_id( const mbedtls_ssl_context *ssl, uint16_t tls_id ); #endif #if defined(MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED) diff --git a/library/ssl_tls.c b/library/ssl_tls.c index c7265f108..bd0eb10ec 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -7326,6 +7326,18 @@ int mbedtls_ssl_check_curve( const mbedtls_ssl_context *ssl, mbedtls_ecp_group_i return( -1 ); } + +/* + * Same as mbedtls_ssl_check_curve() but takes a TLS ID for the curve. + */ +int mbedtls_ssl_check_curve_tls_id( const mbedtls_ssl_context *ssl, uint16_t tls_id ) +{ + const mbedtls_ecp_curve_info *curve_info = + mbedtls_ecp_curve_info_from_tls_id( tls_id ); + if( curve_info == NULL ) + return( -1 ); + return( mbedtls_ssl_check_curve( ssl, curve_info->grp_id ) ); +} #endif /* MBEDTLS_ECP_C */ #if defined(MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED)