From 6b65141718df81ff48a2370299ab30d14f239d00 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Manuel=20P=C3=A9gouri=C3=A9-Gonnard?= <mpg@elzevir.fr>
Date: Wed, 1 Oct 2014 18:29:03 +0200
Subject: [PATCH] Implement ssl_read() timeout (DTLS only for now)

---
 library/ssl_tls.c          | 53 ++++++++++++++++++++++++++------------
 programs/ssl/ssl_client2.c |  8 +++++-
 programs/ssl/ssl_server2.c |  8 +++++-
 3 files changed, 51 insertions(+), 18 deletions(-)

diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index 6b0d11e85..089d17ea7 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -1930,6 +1930,8 @@ int ssl_fetch_input( ssl_context *ssl, size_t nb_want )
 #if defined(POLARSSL_SSL_PROTO_DTLS)
     if( ssl->transport == SSL_TRANSPORT_DATAGRAM )
     {
+        uint32_t timeout;
+
         /*
          * The point is, we need to always read a full datagram at once, so we
          * sometimes read more then requested, and handle the additional data.
@@ -1986,12 +1988,16 @@ int ssl_fetch_input( ssl_context *ssl, size_t nb_want )
         }
 
         len = SSL_BUFFER_LEN - ( ssl->in_hdr - ssl->in_buf );
-        if( ssl->f_recv_timeout != NULL &&
-            ssl->handshake != NULL ) /* No timeout outside handshake */
-        {
-            ret = ssl->f_recv_timeout( ssl->p_bio, ssl->in_hdr, len,
-                                       ssl->handshake->retransmit_timeout );
-        }
+
+        if( ssl->state != SSL_HANDSHAKE_OVER )
+            timeout = ssl->handshake->retransmit_timeout;
+        else
+            timeout = ssl->read_timeout;
+
+        SSL_DEBUG_MSG( 3, ( "f_recv_timeout: %u ms", timeout ) );
+
+        if( ssl->f_recv_timeout != NULL && timeout != 0 )
+            ret = ssl->f_recv_timeout( ssl->p_bio, ssl->in_hdr, len, timeout );
         else
             ret = ssl->f_recv( ssl->p_bio, ssl->in_hdr, len );
 
@@ -2006,19 +2012,24 @@ int ssl_fetch_input( ssl_context *ssl, size_t nb_want )
         {
             SSL_DEBUG_MSG( 2, ( "recv timeout" ) );
 
-            if( ssl_double_retransmit_timeout( ssl ) != 0 )
+            if( ssl->state != SSL_HANDSHAKE_OVER )
             {
-                SSL_DEBUG_MSG( 1, ( "handshake timeout" ) );
-                return( POLARSSL_ERR_NET_TIMEOUT );
+                if( ssl_double_retransmit_timeout( ssl ) != 0 )
+                {
+                    SSL_DEBUG_MSG( 1, ( "handshake timeout" ) );
+                    return( POLARSSL_ERR_NET_TIMEOUT );
+                }
+
+                if( ( ret = ssl_resend( ssl ) ) != 0 )
+                {
+                    SSL_DEBUG_RET( 1, "ssl_resend", ret );
+                    return( ret );
+                }
+
+                return( POLARSSL_ERR_NET_WANT_READ );
             }
 
-            if( ( ret = ssl_resend( ssl ) ) != 0 )
-            {
-                SSL_DEBUG_RET( 1, "ssl_resend", ret );
-                return( ret );
-            }
-
-            return( POLARSSL_ERR_NET_WANT_READ );
+            return( POLARSSL_ERR_NET_TIMEOUT );
         }
 
         if( ret < 0 )
@@ -4226,6 +4237,9 @@ void ssl_handshake_wrapup( ssl_context *ssl )
     if( ssl->transport == SSL_TRANSPORT_DATAGRAM &&
         ssl->handshake->flight != NULL )
     {
+        /* Cancel handshake timer */
+        ssl_set_timer( ssl, 0 );
+
         /* Keep last flight around in case we need to resend it:
          * we need the handshake and transform structures for that */
         SSL_DEBUG_MSG( 3, ( "skip freeing handshake and transform" ) );
@@ -5649,6 +5663,10 @@ int ssl_read( ssl_context *ssl, unsigned char *buf, size_t len )
 
     if( ssl->in_offt == NULL )
     {
+        /* Start timer if not already running */
+        if( ssl->time_limit == 0 )
+            ssl_set_timer( ssl, ssl->read_timeout );
+
         if( ! record_read )
         {
             if( ( ret = ssl_read_record( ssl ) ) != 0 )
@@ -5799,6 +5817,9 @@ int ssl_read( ssl_context *ssl, unsigned char *buf, size_t len )
         }
 
         ssl->in_offt = ssl->in_msg;
+
+        /* We're going to return something now, cancel timer */
+        ssl_set_timer( ssl, 0 );
     }
 
     n = ( len < ssl->in_msglen )
diff --git a/programs/ssl/ssl_client2.c b/programs/ssl/ssl_client2.c
index 238a1f11b..f0a2b4854 100644
--- a/programs/ssl/ssl_client2.c
+++ b/programs/ssl/ssl_client2.c
@@ -75,6 +75,7 @@ int main( int argc, char *argv[] )
 #define DFL_REQUEST_SIZE        -1
 #define DFL_DEBUG_LEVEL         0
 #define DFL_NBIO                0
+#define DFL_READ_TIMEOUT        0
 #define DFL_CA_FILE             ""
 #define DFL_CA_PATH             ""
 #define DFL_CRT_FILE            ""
@@ -112,6 +113,7 @@ struct options
     int server_port;            /* port on which the ssl service runs       */
     int debug_level;            /* level of debugging                       */
     int nbio;                   /* should I/O be blocking?                  */
+    uint32_t read_timeout;      /* timeout on ssl_read() in milliseconds    */
     const char *request_page;   /* page on server to request                */
     int request_size;           /* pad request with header to requested size */
     const char *ca_file;        /* the file with the CA certificate(s)      */
@@ -311,6 +313,7 @@ static int my_verify( void *data, x509_crt *crt, int depth, int *flags )
     "    debug_level=%%d      default: 0 (disabled)\n"      \
     "    nbio=%%d             default: 0 (blocking I/O)\n"  \
     "                        options: 1 (non-blocking), 2 (added delays)\n" \
+    "    read_timeout=%%d     default: 0 (no timeout)\n"    \
     "\n"                                                    \
     USAGE_DTLS                                              \
     "\n"                                                    \
@@ -408,6 +411,7 @@ int main( int argc, char *argv[] )
     opt.server_port         = DFL_SERVER_PORT;
     opt.debug_level         = DFL_DEBUG_LEVEL;
     opt.nbio                = DFL_NBIO;
+    opt.read_timeout        = DFL_READ_TIMEOUT;
     opt.request_page        = DFL_REQUEST_PAGE;
     opt.request_size        = DFL_REQUEST_SIZE;
     opt.ca_file             = DFL_CA_FILE;
@@ -473,6 +477,8 @@ int main( int argc, char *argv[] )
             if( opt.nbio < 0 || opt.nbio > 2 )
                 goto usage;
         }
+        else if( strcmp( p, "read_timeout" ) == 0 )
+            opt.read_timeout = atoi( q );
         else if( strcmp( p, "request_page" ) == 0 )
             opt.request_page = q;
         else if( strcmp( p, "request_size" ) == 0 )
@@ -982,7 +988,7 @@ int main( int argc, char *argv[] )
 #else
                              NULL,
 #endif
-                             0 );
+                             opt.read_timeout );
 
 #if defined(POLARSSL_SSL_SESSION_TICKETS)
     if( ( ret = ssl_set_session_tickets( &ssl, opt.tickets ) ) != 0 )
diff --git a/programs/ssl/ssl_server2.c b/programs/ssl/ssl_server2.c
index 2eec1ce84..8ab7baa26 100644
--- a/programs/ssl/ssl_server2.c
+++ b/programs/ssl/ssl_server2.c
@@ -93,6 +93,7 @@ int main( int argc, char *argv[] )
 #define DFL_SERVER_PORT         4433
 #define DFL_DEBUG_LEVEL         0
 #define DFL_NBIO                0
+#define DFL_READ_TIMEOUT        0
 #define DFL_CA_FILE             ""
 #define DFL_CA_PATH             ""
 #define DFL_CRT_FILE            ""
@@ -158,6 +159,7 @@ struct options
     int server_port;            /* port on which the ssl service runs       */
     int debug_level;            /* level of debugging                       */
     int nbio;                   /* should I/O be blocking?                  */
+    uint32_t read_timeout;      /* timeout on ssl_read() in milliseconds    */
     const char *ca_file;        /* the file with the CA certificate(s)      */
     const char *ca_path;        /* the path with the CA certificate(s) reside */
     const char *crt_file;       /* the file with the server certificate     */
@@ -345,6 +347,7 @@ static int my_send( void *ctx, const unsigned char *buf, size_t len )
     "    debug_level=%%d      default: 0 (disabled)\n"      \
     "    nbio=%%d             default: 0 (blocking I/O)\n"  \
     "                        options: 1 (non-blocking), 2 (added delays)\n" \
+    "    read_timeout=%%d     default: 0 (no timeout)\n"    \
     "\n"                                                    \
     USAGE_DTLS                                              \
     USAGE_COOKIES                                           \
@@ -736,6 +739,7 @@ int main( int argc, char *argv[] )
     opt.server_port         = DFL_SERVER_PORT;
     opt.debug_level         = DFL_DEBUG_LEVEL;
     opt.nbio                = DFL_NBIO;
+    opt.read_timeout        = DFL_READ_TIMEOUT;
     opt.ca_file             = DFL_CA_FILE;
     opt.ca_path             = DFL_CA_PATH;
     opt.crt_file            = DFL_CRT_FILE;
@@ -806,6 +810,8 @@ int main( int argc, char *argv[] )
             if( opt.nbio < 0 || opt.nbio > 2 )
                 goto usage;
         }
+        else if( strcmp( p, "read_timeout" ) == 0 )
+            opt.read_timeout = atoi( q );
         else if( strcmp( p, "ca_file" ) == 0 )
             opt.ca_file = q;
         else if( strcmp( p, "ca_path" ) == 0 )
@@ -1632,7 +1638,7 @@ reset:
 #else
                              NULL,
 #endif
-                             0 );
+                             opt.read_timeout );
 
 #if defined(POLARSSL_SSL_DTLS_HELLO_VERIFY)
     if( opt.transport == SSL_TRANSPORT_DATAGRAM )