diff --git a/include/polarssl/ssl_cache.h b/include/polarssl/ssl_cache.h index 3c5ef8b1c..daa07acb6 100644 --- a/include/polarssl/ssl_cache.h +++ b/include/polarssl/ssl_cache.h @@ -29,6 +29,10 @@ #include "ssl.h" +#if defined(POLARSSL_THREADING_C) +#include "threading.h" +#endif + #if !defined(POLARSSL_CONFIG_OPTIONS) #define SSL_CACHE_DEFAULT_TIMEOUT 86400 /*!< 1 day */ #define SSL_CACHE_DEFAULT_MAX_ENTRIES 50 /*!< Maximum entries in cache */ @@ -64,6 +68,9 @@ struct _ssl_cache_context ssl_cache_entry *chain; /*!< start of the chain */ int timeout; /*!< cache entry timeout */ int max_entries; /*!< maximum entries */ +#if defined(POLARSSL_THREADING_C) + threading_mutex_t mutex; /*!< mutex */ +#endif }; /** @@ -75,6 +82,7 @@ void ssl_cache_init( ssl_cache_context *cache ); /** * \brief Cache get callback implementation + * (Thread-safe if POLARSSL_THREADING_C is enabled) * * \param data SSL cache context * \param session session to retrieve entry for @@ -83,6 +91,7 @@ int ssl_cache_get( void *data, ssl_session *session ); /** * \brief Cache set callback implementation + * (Thread-safe if POLARSSL_THREADING_C is enabled) * * \param data SSL cache context * \param session session to store entry for diff --git a/library/ssl_cache.c b/library/ssl_cache.c index 113f72440..e0847b6dc 100644 --- a/library/ssl_cache.c +++ b/library/ssl_cache.c @@ -48,16 +48,26 @@ void ssl_cache_init( ssl_cache_context *cache ) cache->timeout = SSL_CACHE_DEFAULT_TIMEOUT; cache->max_entries = SSL_CACHE_DEFAULT_MAX_ENTRIES; + +#if defined(POLARSSL_THREADING_C) + polarssl_mutex_init( &cache->mutex ); +#endif } int ssl_cache_get( void *data, ssl_session *session ) { + int ret = 1; #if defined(POLARSSL_HAVE_TIME) time_t t = time( NULL ); #endif ssl_cache_context *cache = (ssl_cache_context *) data; ssl_cache_entry *cur, *entry; +#if defined(POLARSSL_THREADING_C) + if( polarssl_mutex_lock( &cache->mutex ) != 0 ) + return( 1 ); +#endif + cur = cache->chain; entry = NULL; @@ -93,7 +103,10 @@ int ssl_cache_get( void *data, ssl_session *session ) { session->peer_cert = (x509_crt *) polarssl_malloc( sizeof(x509_crt) ); if( session->peer_cert == NULL ) - return( 1 ); + { + ret = 1; + goto exit; + } x509_crt_init( session->peer_cert ); if( x509_crt_parse( session->peer_cert, entry->peer_cert.p, @@ -101,19 +114,28 @@ int ssl_cache_get( void *data, ssl_session *session ) { polarssl_free( session->peer_cert ); session->peer_cert = NULL; - return( 1 ); + ret = 1; + goto exit; } } #endif /* POLARSSL_X509_CRT_PARSE_C */ - return( 0 ); + ret = 0; + goto exit; } - return( 1 ); +exit: +#if defined(POLARSSL_THREADING_C) + if( polarssl_mutex_unlock( &cache->mutex ) != 0 ) + ret = 1; +#endif + + return( ret ); } int ssl_cache_set( void *data, const ssl_session *session ) { + int ret = 1; #if defined(POLARSSL_HAVE_TIME) time_t t = time( NULL ), oldest = 0; ssl_cache_entry *old = NULL; @@ -122,6 +144,11 @@ int ssl_cache_set( void *data, const ssl_session *session ) ssl_cache_entry *cur, *prv; int count = 0; +#if defined(POLARSSL_THREADING_C) + if( ( ret = polarssl_mutex_lock( &cache->mutex ) ) != 0 ) + return( ret ); +#endif + cur = cache->chain; prv = NULL; @@ -179,7 +206,10 @@ int ssl_cache_set( void *data, const ssl_session *session ) if( count >= cache->max_entries ) { if( cache->chain == NULL ) - return( 1 ); + { + ret = 1; + goto exit; + } cur = cache->chain; cache->chain = cur->next; @@ -200,7 +230,10 @@ int ssl_cache_set( void *data, const ssl_session *session ) { cur = (ssl_cache_entry *) polarssl_malloc( sizeof(ssl_cache_entry) ); if( cur == NULL ) - return( 1 ); + { + ret = 1; + goto exit; + } memset( cur, 0, sizeof(ssl_cache_entry) ); @@ -225,7 +258,10 @@ int ssl_cache_set( void *data, const ssl_session *session ) { cur->peer_cert.p = (unsigned char *) polarssl_malloc( session->peer_cert->raw.len ); if( cur->peer_cert.p == NULL ) - return( 1 ); + { + ret = 1; + goto exit; + } memcpy( cur->peer_cert.p, session->peer_cert->raw.p, session->peer_cert->raw.len ); @@ -235,7 +271,15 @@ int ssl_cache_set( void *data, const ssl_session *session ) } #endif /* POLARSSL_X509_CRT_PARSE_C */ - return( 0 ); + ret = 0; + +exit: +#if defined(POLARSSL_THREADING_C) + if( polarssl_mutex_unlock( &cache->mutex ) != 0 ) + ret = 1; +#endif + + return( ret ); } #if defined(POLARSSL_HAVE_TIME) @@ -274,6 +318,10 @@ void ssl_cache_free( ssl_cache_context *cache ) polarssl_free( prv ); } + +#if defined(POLARSSL_THREADING_C) + polarssl_mutex_free( &cache->mutex ); +#endif } #endif /* POLARSSL_SSL_CACHE_C */