diff --git a/include/polarssl/memory.h b/include/polarssl/memory.h index 567a64a16..6a3dab94b 100644 --- a/include/polarssl/memory.h +++ b/include/polarssl/memory.h @@ -71,6 +71,8 @@ int memory_set_own( void * (*malloc_func)( size_t ), * presented buffer and does not call malloc() and free(). * It sets the global polarssl_malloc() and polarssl_free() pointers * to its own functions. + * (Provided polarssl_malloc() and polarssl_free() are thread-safe if + * POLARSSL_THREADING_C is defined) * * \note This code is not optimized and provides a straight-forward * implementation of a stack-based memory allocator. @@ -82,6 +84,11 @@ int memory_set_own( void * (*malloc_func)( size_t ), */ int memory_buffer_alloc_init( unsigned char *buf, size_t len ); +/** + * \brief Free the mutex for thread-safety and clear remaining memory + */ +void memory_buffer_alloc_free(); + /** * \brief Determine when the allocator should automatically verify the state * of the entire chain of headers / meta-data. diff --git a/library/memory_buffer_alloc.c b/library/memory_buffer_alloc.c index de2811fb2..7ec6498de 100644 --- a/library/memory_buffer_alloc.c +++ b/library/memory_buffer_alloc.c @@ -38,6 +38,10 @@ #endif #endif +#if defined(POLARSSL_THREADING_C) +#include "polarssl/threading.h" +#endif + #define MAGIC1 0xFF00AA55 #define MAGIC2 0xEE119966 #define MAX_BT 20 @@ -74,6 +78,9 @@ typedef struct size_t maximum_used; size_t header_count; #endif +#if defined(POLARSSL_THREADING_C) + threading_mutex_t mutex; +#endif } buffer_alloc_ctx; @@ -349,7 +356,6 @@ static void buffer_alloc_free( void *ptr ) memory_header *hdr, *old = NULL; unsigned char *p = (unsigned char *) ptr; - if( ptr == NULL || heap.buf == NULL || heap.first == NULL ) return; @@ -492,14 +498,38 @@ void memory_buffer_alloc_status() } #endif /* POLARSSL_MEMORY_BUFFER_ALLOC_DEBUG */ +#if defined(POLARSSL_THREADING_C) +static void *buffer_alloc_malloc_mutexed( size_t len ) +{ + void *buf; + polarssl_mutex_lock( &heap.mutex ); + buf = buffer_alloc_malloc( len ); + polarssl_mutex_unlock( &heap.mutex ); + return( buf ); +} + +static void buffer_alloc_free_mutexed( void *ptr ) +{ + polarssl_mutex_lock( &heap.mutex ); + buffer_alloc_free( ptr ); + polarssl_mutex_unlock( &heap.mutex ); +} +#endif + int memory_buffer_alloc_init( unsigned char *buf, size_t len ) { - polarssl_malloc = buffer_alloc_malloc; - polarssl_free = buffer_alloc_free; - memset( &heap, 0, sizeof(buffer_alloc_ctx) ); memset( buf, 0, len ); +#if defined(POLARSSL_THREADING_C) + polarssl_mutex_init( &heap.mutex ); + polarssl_malloc = buffer_alloc_malloc_mutexed; + polarssl_free = buffer_alloc_free_mutexed; +#else + polarssl_malloc = buffer_alloc_malloc; + polarssl_free = buffer_alloc_free; +#endif + heap.buf = buf; heap.len = len; @@ -511,4 +541,12 @@ int memory_buffer_alloc_init( unsigned char *buf, size_t len ) return( 0 ); } +void memory_buffer_alloc_free() +{ +#if defined(POLARSSL_THREADING_C) + polarssl_mutex_free( &heap.mutex ); +#endif + memset( &heap, 0, sizeof(buffer_alloc_ctx) ); +} + #endif /* POLARSSL_MEMORY_C && POLARSSL_MEMORY_BUFFER_ALLOC_C */ diff --git a/programs/ssl/ssl_server2.c b/programs/ssl/ssl_server2.c index a6ff57fea..43d7d79b7 100644 --- a/programs/ssl/ssl_server2.c +++ b/programs/ssl/ssl_server2.c @@ -969,9 +969,12 @@ exit: ssl_cache_free( &cache ); #endif -#if defined(POLARSSL_MEMORY_BUFFER_ALLOC_C) && defined(POLARSSL_MEMORY_DEBUG) +#if defined(POLARSSL_MEMORY_BUFFER_ALLOC_C) +#if defined(POLARSSL_MEMORY_DEBUG) memory_buffer_alloc_status(); #endif + memory_buffer_alloc_free(); +#endif #if defined(_WIN32) printf( " + Press Enter to exit this program.\n" ); diff --git a/programs/test/selftest.c b/programs/test/selftest.c index 246276500..fb9a7cc72 100644 --- a/programs/test/selftest.c +++ b/programs/test/selftest.c @@ -190,6 +190,9 @@ int main( int argc, char *argv[] ) fflush( stdout ); getchar(); #endif } +#if defined(POLARSSL_MEMORY_BUFFER_ALLOC_C) + memory_buffer_alloc_free(); +#endif return( ret ); } diff --git a/tests/suites/main_test.function b/tests/suites/main_test.function index d5aa3862c..c64d9be8a 100644 --- a/tests/suites/main_test.function +++ b/tests/suites/main_test.function @@ -202,8 +202,8 @@ int main() char *params[50]; #if defined(POLARSSL_MEMORY_BUFFER_ALLOC_C) - unsigned char buf[1000000]; - memory_buffer_alloc_init( buf, sizeof(buf) ); + unsigned char alloc_buf[1000000]; + memory_buffer_alloc_init( alloc_buf, sizeof(alloc_buf) ); #endif file = fopen( filename, "r" ); @@ -288,9 +288,12 @@ int main() fprintf( stdout, " (%d / %d tests (%d skipped))\n", total_tests - total_errors, total_tests, total_skipped ); -#if defined(POLARSSL_MEMORY_BUFFER_ALLOC_C) && defined(POLARSSL_MEMORY_DEBUG) +#if defined(POLARSSL_MEMORY_BUFFER_ALLOC_C) +#if defined(POLARSSL_MEMORY_DEBUG) memory_buffer_alloc_status(); #endif + memory_buffer_alloc_free(); +#endif return( total_errors != 0 ); }