diff --git a/programs/test/udp_proxy.c b/programs/test/udp_proxy.c index 20624d227..bb5537ff1 100644 --- a/programs/test/udp_proxy.c +++ b/programs/test/udp_proxy.c @@ -85,6 +85,7 @@ int main( void ) #define DFL_SERVER_PORT "4433" #define DFL_LISTEN_ADDR "localhost" #define DFL_LISTEN_PORT "5556" +#define DFL_PACK 0 #define USAGE \ "\n usage: udp_proxy param=<>...\n" \ @@ -108,6 +109,8 @@ int main( void ) " protect_len=%%d default: (don't protect packets of this size)\n" \ "\n" \ " seed=%%d default: (use current time)\n" \ + " pack=%%d default: 0 (don't merge)\n" \ + " options: t > 0 (merge for t milliseconds)\n" \ "\n" /* @@ -128,6 +131,8 @@ static struct options int bad_ad; /* inject corrupted ApplicationData record */ int protect_hvr; /* never drop or delay HelloVerifyRequest */ int protect_len; /* never drop/delay packet of the given size*/ + int merge; /* merge packets into single datagram for + * at most \c merge milliseconds if > 0 */ unsigned int seed; /* seed for "random" events */ } opt; @@ -152,6 +157,7 @@ static void get_options( int argc, char *argv[] ) opt.server_port = DFL_SERVER_PORT; opt.listen_addr = DFL_LISTEN_ADDR; opt.listen_port = DFL_LISTEN_PORT; + opt.merge = DFL_PACK; /* Other members default to 0 */ for( i = 1; i < argc; i++ ) @@ -193,6 +199,10 @@ static void get_options( int argc, char *argv[] ) if( opt.drop < 0 || opt.drop > 20 || opt.drop == 1 ) exit_usage( p, q ); } + else if( strcmp( p, "pack" ) == 0 ) + { + opt.merge = atoi( q ); + } else if( strcmp( p, "mtu" ) == 0 ) { opt.mtu = atoi( q ); @@ -288,6 +298,94 @@ static unsigned long ellapsed_time( void ) #endif } +typedef struct +{ + mbedtls_net_context *ctx; + + const char *description; + + unsigned long packet_lifetime; + size_t num_datagrams; + + unsigned char data[MAX_MSG_SIZE]; + unsigned len; + +} ctx_buffer; + +static ctx_buffer outbuf[2]; + +static int ctx_buffer_flush( ctx_buffer *buf ) +{ + int ret; + + mbedtls_printf( " %05lu flush %s: %u bytes, %lu datagrams, " + "last %ld ms\n", ellapsed_time(), + buf->description, buf->len, buf->num_datagrams, + ellapsed_time() - buf->packet_lifetime ); + + ret = mbedtls_net_send( buf->ctx, buf->data, buf->len ); + + buf->len = 0; + buf->num_datagrams = 0; + + return( ret ); +} + +static inline int ctx_buffer_check( ctx_buffer *buf ) +{ + if( buf->len > 0 && + ellapsed_time() - buf->packet_lifetime >= (size_t) opt.merge ) + { + return( ctx_buffer_flush( buf ) ); + } + + return( 0 ); +} + +static int ctx_buffer_append( ctx_buffer *buf, + const unsigned char * data, + size_t len ) +{ + int ret; + + if( len > sizeof( buf->data ) ) + { + mbedtls_printf( " ! buffer size %lu too large (max %lu)\n", + len, sizeof( buf->data ) ); + return( -1 ); + } + + if( sizeof( buf->data ) - buf->len < len ) + { + if( ( ret = ctx_buffer_flush( buf ) ) <= 0 ) + return( ret ); + } + + memcpy( buf->data + buf->len, data, len ); + + buf->len += len; + if( ++buf->num_datagrams == 1 ) + buf->packet_lifetime = ellapsed_time(); + + return( len ); +} + +static int dispatch_data( mbedtls_net_context *ctx, + const unsigned char * data, + size_t len ) +{ + ctx_buffer *buf = NULL; + if( outbuf[0].ctx == ctx ) + buf = &outbuf[0]; + else if( outbuf[1].ctx == ctx ) + buf = &outbuf[1]; + + if( buf == NULL ) + return( mbedtls_net_send( ctx, data, len ) ); + + return( ctx_buffer_append( buf, data, len ) ); +} + typedef struct { mbedtls_net_context *dst; @@ -301,10 +399,10 @@ typedef struct void print_packet( const packet *p, const char *why ) { if( why == NULL ) - mbedtls_printf( " %05lu %s %s (%u bytes)\n", + mbedtls_printf( " %05lu dispatch %s %s (%u bytes)\n", ellapsed_time(), p->way, p->type, p->len ); else - mbedtls_printf( " %s %s (%u bytes): %s\n", + mbedtls_printf( " dispatch %s %s (%u bytes): %s\n", p->way, p->type, p->len, why ); fflush( stdout ); } @@ -323,17 +421,17 @@ int send_packet( const packet *p, const char *why ) ++buf[p->len - 1]; print_packet( p, "corrupted" ); - if( ( ret = mbedtls_net_send( dst, buf, p->len ) ) <= 0 ) + if( ( ret = dispatch_data( dst, buf, p->len ) ) <= 0 ) { - mbedtls_printf( " ! mbedtls_net_send returned %d\n", ret ); + mbedtls_printf( " ! dispatch returned %d\n", ret ); return( ret ); } } print_packet( p, why ); - if( ( ret = mbedtls_net_send( dst, p->buf, p->len ) ) <= 0 ) + if( ( ret = dispatch_data( dst, p->buf, p->len ) ) <= 0 ) { - mbedtls_printf( " ! mbedtls_net_send returned %d\n", ret ); + mbedtls_printf( " ! dispatch returned %d\n", ret ); return( ret ); } @@ -344,9 +442,9 @@ int send_packet( const packet *p, const char *why ) { print_packet( p, "duplicated" ); - if( ( ret = mbedtls_net_send( dst, p->buf, p->len ) ) <= 0 ) + if( ( ret = dispatch_data( dst, p->buf, p->len ) ) <= 0 ) { - mbedtls_printf( " ! mbedtls_net_send returned %d\n", ret ); + mbedtls_printf( " ! dispatch returned %d\n", ret ); return( ret ); } } @@ -471,10 +569,14 @@ int main( int argc, char *argv[] ) int ret; mbedtls_net_context listen_fd, client_fd, server_fd; + struct timeval tm; int nb_fds; fd_set read_fds; + tm.tv_sec = 0; + tm.tv_usec = 0; + mbedtls_net_init( &listen_fd ); mbedtls_net_init( &client_fd ); mbedtls_net_init( &server_fd ); @@ -560,6 +662,19 @@ accept: nb_fds = listen_fd.fd; ++nb_fds; + if( opt.merge > 0 ) + { + outbuf[0].ctx = &server_fd; + outbuf[0].description = "S <- C"; + outbuf[0].num_datagrams = 0; + outbuf[0].len = 0; + + outbuf[1].ctx = &client_fd; + outbuf[1].description = "S -> C"; + outbuf[1].num_datagrams = 0; + outbuf[1].len = 0; + } + while( 1 ) { FD_ZERO( &read_fds ); @@ -567,7 +682,10 @@ accept: FD_SET( client_fd.fd, &read_fds ); FD_SET( listen_fd.fd, &read_fds ); - if( ( ret = select( nb_fds, &read_fds, NULL, NULL, NULL ) ) <= 0 ) + ctx_buffer_check( &outbuf[0] ); + ctx_buffer_check( &outbuf[1] ); + + if( ( ret = select( nb_fds, &read_fds, NULL, NULL, &tm ) ) < 0 ) { perror( "select" ); goto exit; @@ -589,6 +707,7 @@ accept: &client_fd, &server_fd ) ) != 0 ) goto accept; } + } exit: