diff --git a/include/polarssl/ssl_ciphersuites.h b/include/polarssl/ssl_ciphersuites.h index 714cdcdfa..85392c177 100644 --- a/include/polarssl/ssl_ciphersuites.h +++ b/include/polarssl/ssl_ciphersuites.h @@ -27,6 +27,7 @@ #ifndef POLARSSL_SSL_CIPHERSUITES_H #define POLARSSL_SSL_CIPHERSUITES_H +#include "pk.h" #include "cipher.h" #include "md.h" @@ -197,6 +198,8 @@ const int *ssl_ciphersuites_list( void ); const ssl_ciphersuite_t *ssl_ciphersuite_from_string( const char *ciphersuite_name ); const ssl_ciphersuite_t *ssl_ciphersuite_from_id( int ciphersuite_id ); +pk_type_t ssl_get_ciphersuite_sig_pk_alg( const ssl_ciphersuite_t *info ); + #ifdef __cplusplus } #endif diff --git a/library/ssl_ciphersuites.c b/library/ssl_ciphersuites.c index 63601f66f..759845ee9 100644 --- a/library/ssl_ciphersuites.c +++ b/library/ssl_ciphersuites.c @@ -916,4 +916,20 @@ int ssl_get_ciphersuite_id( const char *ciphersuite_name ) return( cur->id ); } +pk_type_t ssl_get_ciphersuite_sig_pk_alg( const ssl_ciphersuite_t *info ) +{ + switch( info->key_exchange ) + { + case POLARSSL_KEY_EXCHANGE_DHE_RSA: + case POLARSSL_KEY_EXCHANGE_ECDHE_RSA: + return( POLARSSL_PK_RSA ); + + case POLARSSL_KEY_EXCHANGE_ECDHE_ECDSA: + return( POLARSSL_PK_ECDSA ); + + default: + return( POLARSSL_PK_NONE ); + } +} + #endif diff --git a/library/ssl_cli.c b/library/ssl_cli.c index 267e38595..605d4668d 100644 --- a/library/ssl_cli.c +++ b/library/ssl_cli.c @@ -1394,6 +1394,19 @@ static int ssl_parse_server_key_exchange( ssl_context *ssl ) return( POLARSSL_ERR_SSL_BAD_HS_SERVER_KEY_EXCHANGE ); } + if( pk_alg != POLARSSL_PK_NONE ) + { + if( pk_alg != ssl_get_ciphersuite_sig_pk_alg( ciphersuite_info ) ) + { + SSL_DEBUG_MSG( 1, ( "bad server key exchange message" ) ); + return( POLARSSL_ERR_SSL_BAD_HS_SERVER_KEY_EXCHANGE ); + } + } + else + { + pk_alg = ssl_get_ciphersuite_sig_pk_alg( ciphersuite_info ); + } + sig_len = ( p[0] << 8 ) | p[1]; p += 2;