diff --git a/include/polarssl/ssl.h b/include/polarssl/ssl.h index 1cfb606c9..b5a033d8e 100644 --- a/include/polarssl/ssl.h +++ b/include/polarssl/ssl.h @@ -1974,6 +1974,23 @@ static inline int safer_memcmp( const void *a, const void *b, size_t n ) return( diff ); } +/* + * Temporary function while transitionning away from memmove() + * on received DTLS handshake messages + */ +static inline void ssl_hs_rm_dtls_hdr( ssl_context *ssl ) +{ +#if defined(POLARSSL_SSL_PROTO_DTLS) + if( ssl->transport == SSL_TRANSPORT_DATAGRAM ) + { + memmove( ssl->in_msg + 4, ssl->in_msg + 12, ssl->in_hslen - 12 ); + ssl->in_hslen -= 8; + } +#else + (void) ssl; +#endif +} + #ifdef __cplusplus } #endif diff --git a/library/ssl_cli.c b/library/ssl_cli.c index 12a8ff5dc..22787fd33 100644 --- a/library/ssl_cli.c +++ b/library/ssl_cli.c @@ -1069,6 +1069,8 @@ static int ssl_parse_server_hello( ssl_context *ssl ) return( POLARSSL_ERR_SSL_UNEXPECTED_MESSAGE ); } + ssl_hs_rm_dtls_hdr( ssl ); + #if defined(POLARSSL_SSL_PROTO_DTLS) if( ssl->transport == SSL_TRANSPORT_DATAGRAM ) { @@ -1774,6 +1776,8 @@ static int ssl_parse_server_key_exchange( ssl_context *ssl ) return( ret ); } + ssl_hs_rm_dtls_hdr( ssl ); + if( ssl->in_msgtype != SSL_MSG_HANDSHAKE ) { SSL_DEBUG_MSG( 1, ( "bad server key exchange message" ) ); @@ -2098,6 +2102,8 @@ static int ssl_parse_certificate_request( ssl_context *ssl ) return( ret ); } + ssl_hs_rm_dtls_hdr( ssl ); + if( ssl->in_msgtype != SSL_MSG_HANDSHAKE ) { SSL_DEBUG_MSG( 1, ( "bad certificate request message" ) ); @@ -2222,6 +2228,8 @@ static int ssl_parse_server_hello_done( ssl_context *ssl ) return( ret ); } + ssl_hs_rm_dtls_hdr( ssl ); + if( ssl->in_msgtype != SSL_MSG_HANDSHAKE ) { SSL_DEBUG_MSG( 1, ( "bad server hello done message" ) ); @@ -2648,6 +2656,8 @@ static int ssl_parse_new_session_ticket( ssl_context *ssl ) return( ret ); } + ssl_hs_rm_dtls_hdr( ssl ); + if( ssl->in_msgtype != SSL_MSG_HANDSHAKE ) { SSL_DEBUG_MSG( 1, ( "bad new session ticket message" ) ); diff --git a/library/ssl_srv.c b/library/ssl_srv.c index a1c2c4df5..81e3d0d0d 100644 --- a/library/ssl_srv.c +++ b/library/ssl_srv.c @@ -3018,6 +3018,8 @@ static int ssl_parse_client_key_exchange( ssl_context *ssl ) return( ret ); } + ssl_hs_rm_dtls_hdr( ssl ); + if( ssl->in_msgtype != SSL_MSG_HANDSHAKE ) { SSL_DEBUG_MSG( 1, ( "bad client key exchange message" ) ); @@ -3310,6 +3312,8 @@ static int ssl_parse_certificate_verify( ssl_context *ssl ) return( ret ); } + ssl_hs_rm_dtls_hdr( ssl ); + ssl->state++; if( ssl->in_msgtype != SSL_MSG_HANDSHAKE ) diff --git a/library/ssl_tls.c b/library/ssl_tls.c index b8946cda0..08b26b313 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -2408,22 +2408,6 @@ static int ssl_prepare_handshake_record( ssl_context *ssl ) } #endif - /* - * For DTLS, we move data so that is looks like TLS handshake format to - * other functions. - * Except on server after the initial handshake (wait until after - * update_checksum() in ssl_parse_client_hello()). - */ -#if defined(POLARSSL_SSL_PROTO_DTLS) - if( ssl->transport == SSL_TRANSPORT_DATAGRAM && - ! ( ssl->endpoint == SSL_IS_SERVER && - ssl->state == SSL_HANDSHAKE_OVER ) ) - { - memmove( ssl->in_msg + 4, ssl->in_msg + 12, ssl->in_hslen - 12 ); - ssl->in_hslen -= 8; - } -#endif /* POLARSSL_SSL_PROTO_DTLS */ - return( 0 ); } @@ -2625,16 +2609,8 @@ int ssl_read_record( ssl_context *ssl ) SSL_DEBUG_MSG( 2, ( "=> read record" ) ); - /* - * With DTLS, we cheated on in_hslen to make the handshake message look - * like TLS format, restore the truth now - */ -#if defined(POLARSSL_SSL_PROTO_DTLS) - if( ssl->in_hslen != 0 && ssl->transport == SSL_TRANSPORT_DATAGRAM ) - ssl->in_hslen += 8; -#endif - - if( ssl->in_hslen != 0 && ssl->in_hslen < ssl->in_msglen ) + /* Temporarily disabled */ + if( ( 0 ) && ssl->in_hslen != 0 && ssl->in_hslen < ssl->in_msglen ) { /* * Get next Handshake message in the current record @@ -3001,6 +2977,8 @@ int ssl_parse_certificate( ssl_context *ssl ) return( ret ); } + ssl_hs_rm_dtls_hdr( ssl ); + ssl->state++; #if defined(POLARSSL_SSL_PROTO_SSL3) @@ -3813,6 +3791,8 @@ int ssl_parse_finished( ssl_context *ssl ) return( ret ); } + ssl_hs_rm_dtls_hdr( ssl ); + if( ssl->in_msgtype != SSL_MSG_HANDSHAKE ) { SSL_DEBUG_MSG( 1, ( "bad finished message" ) ); @@ -5041,7 +5021,7 @@ int ssl_read( ssl_context *ssl, unsigned char *buf, size_t len ) if( ssl->endpoint == SSL_IS_CLIENT && ( ssl->in_msg[0] != SSL_HS_HELLO_REQUEST || - ssl->in_hslen != 4 ) ) + ssl->in_hslen != ssl_hs_hdr_len( ssl ) ) ) { SSL_DEBUG_MSG( 1, ( "handshake received (not HelloRequest)" ) );