diff --git a/include/mbedtls/ssl_internal.h b/include/mbedtls/ssl_internal.h index fbf3e70e8..660173401 100644 --- a/include/mbedtls/ssl_internal.h +++ b/include/mbedtls/ssl_internal.h @@ -322,6 +322,13 @@ struct mbedtls_ssl_handshake_params unsigned char *data; } hs[MBEDTLS_SSL_MAX_BUFFERED_HS]; + struct + { + unsigned char *data; + size_t len; + unsigned epoch; + } future_record; + } buffering; #endif /* MBEDTLS_SSL_PROTO_DTLS */ diff --git a/library/ssl_tls.c b/library/ssl_tls.c index b6e2c0edb..85ed1e51c 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -4097,7 +4097,16 @@ static int ssl_parse_record_header( mbedtls_ssl_context *ssl ) } else #endif /* MBEDTLS_SSL_DTLS_CLIENT_PORT_REUSE && MBEDTLS_SSL_SRV_C */ + { + /* Consider buffering the record. */ + if( rec_epoch == (unsigned int) ssl->in_epoch + 1 ) + { + MBEDTLS_SSL_DEBUG_MSG( 2, ( "Consider record for buffering" ) ); + return( MBEDTLS_ERR_SSL_EARLY_MESSAGE ); + } + return( MBEDTLS_ERR_SSL_UNEXPECTED_RECORD ); + } } #if defined(MBEDTLS_SSL_DTLS_ANTI_REPLAY) @@ -4254,7 +4263,9 @@ static int ssl_record_is_in_progress( mbedtls_ssl_context *ssl ); #if defined(MBEDTLS_SSL_PROTO_DTLS) static int ssl_load_buffered_message( mbedtls_ssl_context *ssl ); +static int ssl_load_buffered_record( mbedtls_ssl_context *ssl ); static int ssl_buffer_message( mbedtls_ssl_context *ssl ); +static int ssl_buffer_future_record( mbedtls_ssl_context *ssl ); static int ssl_another_record_in_datagram( mbedtls_ssl_context *ssl ); #endif /* MBEDTLS_SSL_PROTO_DTLS */ @@ -4689,13 +4700,133 @@ static int ssl_record_is_in_progress( mbedtls_ssl_context *ssl ) return( 0 ); } +#if defined(MBEDTLS_SSL_PROTO_DTLS) + +static void ssl_free_buffered_record( mbedtls_ssl_context *ssl ) +{ + mbedtls_ssl_handshake_params * const hs = ssl->handshake; + if( hs == NULL ) + return; + + mbedtls_free( hs->buffering.future_record.data ); + hs->buffering.future_record.data = NULL; +} + +static int ssl_load_buffered_record( mbedtls_ssl_context *ssl ) +{ + mbedtls_ssl_handshake_params * const hs = ssl->handshake; + unsigned char * rec; + size_t rec_len; + unsigned rec_epoch; + + if( ssl->conf->transport != MBEDTLS_SSL_TRANSPORT_DATAGRAM ) + return( 0 ); + + if( hs == NULL ) + return( 0 ); + + /* Only consider loading future records if the + * input buffer is empty. */ + if( ssl_another_record_in_datagram( ssl ) == 1 ) + return( 0 ); + + rec = hs->buffering.future_record.data; + rec_len = hs->buffering.future_record.len; + rec_epoch = hs->buffering.future_record.epoch; + + if( rec == NULL ) + return( 0 ); + + MBEDTLS_SSL_DEBUG_MSG( 2, ( "=> ssl_load_buffered_record" ) ); + + if( rec_epoch != ssl->in_epoch ) + { + MBEDTLS_SSL_DEBUG_MSG( 2, ( "Buffered record not from current epoch." ) ); + goto exit; + } + + MBEDTLS_SSL_DEBUG_MSG( 2, ( "Found buffered record from current epoch - load" ) ); + + /* Double-check that the record is not too large */ + if( rec_len > MBEDTLS_SSL_IN_BUFFER_LEN - + (size_t)( ssl->in_hdr - ssl->in_buf ) ) + { + MBEDTLS_SSL_DEBUG_MSG( 1, ( "should never happen" ) ); + return( MBEDTLS_ERR_SSL_INTERNAL_ERROR ); + } + + memcpy( ssl->in_hdr, rec, rec_len ); + ssl->in_left = rec_len; + ssl->next_record_offset = 0; + + ssl_free_buffered_record( ssl ); + +exit: + MBEDTLS_SSL_DEBUG_MSG( 2, ( "<= ssl_load_buffered_record" ) ); + return( 0 ); +} + +static int ssl_buffer_future_record( mbedtls_ssl_context *ssl ) +{ + mbedtls_ssl_handshake_params * const hs = ssl->handshake; + size_t const rec_hdr_len = 13; + + /* Don't buffer future records outside handshakes. */ + if( hs == NULL ) + return( 0 ); + + /* Only buffer handshake records (we are only interested + * in Finished messages). */ + if( ssl->in_msgtype != MBEDTLS_SSL_MSG_HANDSHAKE ) + return( 0 ); + + /* Don't buffer more than one future epoch record. */ + if( hs->buffering.future_record.data != NULL ) + return( 0 ); + + /* Buffer record */ + MBEDTLS_SSL_DEBUG_MSG( 2, ( "Buffer record from epoch %u", + ssl->in_epoch + 1 ) ); + MBEDTLS_SSL_DEBUG_BUF( 3, "Buffered record", ssl->in_hdr, + rec_hdr_len + ssl->in_msglen ); + + /* ssl_parse_record_header() only considers records + * of the next epoch as candidates for buffering. */ + hs->buffering.future_record.epoch = ssl->in_epoch + 1; + hs->buffering.future_record.len = rec_hdr_len + ssl->in_msglen; + + hs->buffering.future_record.data = + mbedtls_calloc( 1, hs->buffering.future_record.len ); + if( hs->buffering.future_record.data == NULL ) + { + /* If we run out of RAM trying to buffer a + * record from the next epoch, just ignore. */ + return( 0 ); + } + + memcpy( hs->buffering.future_record.data, + ssl->in_hdr, rec_hdr_len + ssl->in_msglen ); + + return( 0 ); +} + +#endif /* MBEDTLS_SSL_PROTO_DTLS */ + static int ssl_get_next_record( mbedtls_ssl_context *ssl ) { int ret; - /* - * Fetch and decode new record - */ +#if defined(MBEDTLS_SSL_PROTO_DTLS) + /* We might have buffered a future record; if so, + * and if the epoch matches now, load it. + * On success, this call will set ssl->in_left to + * the length of the buffered record, so that + * the calls to ssl_fetch_input() below will + * essentially be no-ops. */ + ret = ssl_load_buffered_record( ssl ); + if( ret != 0 ) + return( ret ); +#endif /* MBEDTLS_SSL_PROTO_DTLS */ if( ( ret = mbedtls_ssl_fetch_input( ssl, mbedtls_ssl_hdr_len( ssl ) ) ) != 0 ) { @@ -4709,6 +4840,16 @@ static int ssl_get_next_record( mbedtls_ssl_context *ssl ) if( ssl->conf->transport == MBEDTLS_SSL_TRANSPORT_DATAGRAM && ret != MBEDTLS_ERR_SSL_CLIENT_RECONNECT ) { + if( ret == MBEDTLS_ERR_SSL_EARLY_MESSAGE ) + { + ret = ssl_buffer_future_record( ssl ); + if( ret != 0 ) + return( ret ); + + /* Fall through to handling of unexpected records */ + ret = MBEDTLS_ERR_SSL_UNEXPECTED_RECORD; + } + if( ret == MBEDTLS_ERR_SSL_UNEXPECTED_RECORD ) { /* Skip unexpected record (but not whole datagram) */ @@ -8489,6 +8630,7 @@ void mbedtls_ssl_handshake_free( mbedtls_ssl_context *ssl ) mbedtls_free( handshake->verify_cookie ); ssl_flight_free( handshake->flight ); ssl_buffering_free( ssl ); + ssl_free_buffered_record( ssl ); #endif mbedtls_platform_zeroize( handshake,