Simplify retaining of messages for future processing

There are situations in which it is not clear what message to expect
next. For example, the message following the ServerHello might be
either a Certificate, a ServerKeyExchange or a CertificateRequest. We
deal with this situation in the following way: Initially, the message
processing function for one of the allowed message types is called,
which fetches and decodes a new message. If that message is not the
expected one, the function returns successfully (instead of throwing
an error as usual for unexpected messages), and the handshake
continues to the processing function for the next possible message. To
not have this function fetch a new message, a flag in the SSL context
structure is used to indicate that the last message was retained for
further processing, and if that's set, the following processing
function will not fetch a new record.

This commit simplifies the usage of this message-retaining parameter
by doing the check within the record-fetching routine instead of the
specific message-processing routines. The code gets cleaner this way
and allows retaining messages to be used in other situations as well
without much effort. This will be used in the next commits.
This commit is contained in:
Hanno Becker 2017-06-08 13:08:45 +01:00
parent 431c2afe3e
commit 704f493730
3 changed files with 41 additions and 37 deletions

View file

@ -714,7 +714,9 @@ struct mbedtls_ssl_context
size_t in_hslen; /*!< current handshake message length, size_t in_hslen; /*!< current handshake message length,
including the handshake header */ including the handshake header */
int nb_zero; /*!< # of 0-length encrypted messages */ int nb_zero; /*!< # of 0-length encrypted messages */
int record_read; /*!< record is already present */
int keep_current_message; /*!< drop or reuse current message
on next call to record layer? */
/* /*
* Record layer (outgoing data) * Record layer (outgoing data)

View file

@ -2102,11 +2102,14 @@ static int ssl_parse_server_key_exchange( mbedtls_ssl_context *ssl )
if( ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_PSK || if( ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_PSK ||
ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_RSA_PSK ) ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_RSA_PSK )
{ {
ssl->record_read = 1; /* Current message is probably either
* CertificateRequest or ServerHelloDone */
ssl->keep_current_message = 1;
goto exit; goto exit;
} }
MBEDTLS_SSL_DEBUG_MSG( 1, ( "bad server key exchange message" ) ); MBEDTLS_SSL_DEBUG_MSG( 1, ( "server key exchange message must "
"not be skipped" ) );
return( MBEDTLS_ERR_SSL_UNEXPECTED_MESSAGE ); return( MBEDTLS_ERR_SSL_UNEXPECTED_MESSAGE );
} }
@ -2389,36 +2392,30 @@ static int ssl_parse_certificate_request( mbedtls_ssl_context *ssl )
return( 0 ); return( 0 );
} }
if( ssl->record_read == 0 ) if( ( ret = mbedtls_ssl_read_record( ssl ) ) != 0 )
{ {
if( ( ret = mbedtls_ssl_read_record( ssl ) ) != 0 ) MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_read_record", ret );
{ return( ret );
MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_read_record", ret );
return( ret );
}
if( ssl->in_msgtype != MBEDTLS_SSL_MSG_HANDSHAKE )
{
MBEDTLS_SSL_DEBUG_MSG( 1, ( "bad certificate request message" ) );
return( MBEDTLS_ERR_SSL_UNEXPECTED_MESSAGE );
}
ssl->record_read = 1;
} }
ssl->client_auth = 0; if( ssl->in_msgtype != MBEDTLS_SSL_MSG_HANDSHAKE )
ssl->state++; {
MBEDTLS_SSL_DEBUG_MSG( 1, ( "bad certificate request message" ) );
return( MBEDTLS_ERR_SSL_UNEXPECTED_MESSAGE );
}
if( ssl->in_msg[0] == MBEDTLS_SSL_HS_CERTIFICATE_REQUEST ) ssl->state++;
ssl->client_auth++; ssl->client_auth = ( ssl->in_msg[0] == MBEDTLS_SSL_HS_CERTIFICATE_REQUEST );
MBEDTLS_SSL_DEBUG_MSG( 3, ( "got %s certificate request", MBEDTLS_SSL_DEBUG_MSG( 3, ( "got %s certificate request",
ssl->client_auth ? "a" : "no" ) ); ssl->client_auth ? "a" : "no" ) );
if( ssl->client_auth == 0 ) if( ssl->client_auth == 0 )
{
/* Current message is probably the ServerHelloDone */
ssl->keep_current_message = 1;
goto exit; goto exit;
}
ssl->record_read = 0;
// TODO: handshake_failure alert for an anonymous server to request // TODO: handshake_failure alert for an anonymous server to request
// client authentication // client authentication
@ -2517,21 +2514,17 @@ static int ssl_parse_server_hello_done( mbedtls_ssl_context *ssl )
MBEDTLS_SSL_DEBUG_MSG( 2, ( "=> parse server hello done" ) ); MBEDTLS_SSL_DEBUG_MSG( 2, ( "=> parse server hello done" ) );
if( ssl->record_read == 0 ) if( ( ret = mbedtls_ssl_read_record( ssl ) ) != 0 )
{ {
if( ( ret = mbedtls_ssl_read_record( ssl ) ) != 0 ) MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_read_record", ret );
{ return( ret );
MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_read_record", ret ); }
return( ret );
} if( ssl->in_msgtype != MBEDTLS_SSL_MSG_HANDSHAKE )
{
if( ssl->in_msgtype != MBEDTLS_SSL_MSG_HANDSHAKE ) MBEDTLS_SSL_DEBUG_MSG( 1, ( "bad server hello done message" ) );
{ return( MBEDTLS_ERR_SSL_UNEXPECTED_MESSAGE );
MBEDTLS_SSL_DEBUG_MSG( 1, ( "bad server hello done message" ) );
return( MBEDTLS_ERR_SSL_UNEXPECTED_MESSAGE );
}
} }
ssl->record_read = 0;
if( ssl->in_hslen != mbedtls_ssl_hs_hdr_len( ssl ) || if( ssl->in_hslen != mbedtls_ssl_hs_hdr_len( ssl ) ||
ssl->in_msg[0] != MBEDTLS_SSL_HS_SERVER_HELLO_DONE ) ssl->in_msg[0] != MBEDTLS_SSL_HS_SERVER_HELLO_DONE )

View file

@ -3716,6 +3716,15 @@ int mbedtls_ssl_read_record( mbedtls_ssl_context *ssl )
MBEDTLS_SSL_DEBUG_MSG( 2, ( "=> read record" ) ); MBEDTLS_SSL_DEBUG_MSG( 2, ( "=> read record" ) );
if( ssl->keep_current_message == 1 )
{
MBEDTLS_SSL_DEBUG_MSG( 2, ( "reuse previously read message" ) );
MBEDTLS_SSL_DEBUG_MSG( 2, ( "<= read record" ) );
ssl->keep_current_message = 0;
return( 0 );
}
if( ssl->in_hslen != 0 && ssl->in_hslen < ssl->in_msglen ) if( ssl->in_hslen != 0 && ssl->in_hslen < ssl->in_msglen )
{ {
/* /*
@ -5452,7 +5461,7 @@ static int ssl_session_reset_int( mbedtls_ssl_context *ssl, int partial )
ssl->in_hslen = 0; ssl->in_hslen = 0;
ssl->nb_zero = 0; ssl->nb_zero = 0;
ssl->record_read = 0; ssl->keep_current_message = 0;
ssl->out_msg = ssl->out_buf + 13; ssl->out_msg = ssl->out_buf + 13;
ssl->out_msgtype = 0; ssl->out_msgtype = 0;