diff --git a/include/mbedtls/aes.h b/include/mbedtls/aes.h index 0f8934f72..1bfa434c0 100644 --- a/include/mbedtls/aes.h +++ b/include/mbedtls/aes.h @@ -325,6 +325,7 @@ int mbedtls_aes_crypt_cbc( mbedtls_aes_context *ctx, * returns #MBEDTLS_ERR_AES_INVALID_INPUT_LENGTH. * * \param ctx The AES XTS context to use for AES XTS operations. + * It must be initialized and bound to a key. * \param mode The AES operation: #MBEDTLS_AES_ENCRYPT or * #MBEDTLS_AES_DECRYPT. * \param length The length of a data unit in bytes. This can be any diff --git a/library/aes.c b/library/aes.c index 2da86c713..c15022b91 100644 --- a/library/aes.c +++ b/library/aes.c @@ -1182,6 +1182,12 @@ int mbedtls_aes_crypt_xts( mbedtls_aes_xts_context *ctx, unsigned char prev_tweak[16]; unsigned char tmp[16]; + AES_VALIDATE_RET( ctx != NULL ); + AES_VALIDATE_RET( mode == MBEDTLS_AES_ENCRYPT || + mode == MBEDTLS_AES_DECRYPT ); + AES_VALIDATE_RET( input != NULL ); + AES_VALIDATE_RET( output != NULL ); + /* Data units must be at least 16 bytes long. */ if( length < 16 ) return MBEDTLS_ERR_AES_INVALID_INPUT_LENGTH; diff --git a/tests/suites/test_suite_aes.function b/tests/suites/test_suite_aes.function index d21a41dd5..bcffe37b6 100644 --- a/tests/suites/test_suite_aes.function +++ b/tests/suites/test_suite_aes.function @@ -194,8 +194,8 @@ exit: void aes_crypt_xts_size( int size, int retval ) { mbedtls_aes_xts_context ctx; - const unsigned char *src = NULL; - unsigned char *output = NULL; + const unsigned char src[16] = { 0 }; + unsigned char output[16]; unsigned char data_unit[16]; size_t length = size; @@ -203,10 +203,8 @@ void aes_crypt_xts_size( int size, int retval ) memset( data_unit, 0x00, sizeof( data_unit ) ); - /* Note that this function will most likely crash on failure, as NULL - * parameters will be used. In the passing case, the length check in - * mbedtls_aes_crypt_xts() will prevent any accesses to parameters by - * exiting the function early. */ + /* Valid pointers are passed for builds with MBEDTLS_CHECK_PARAMS, as + * otherwise we wouldn't get to the size check we're interested in. */ TEST_ASSERT( mbedtls_aes_crypt_xts( &ctx, MBEDTLS_AES_ENCRYPT, length, data_unit, src, output ) == retval ); } /* END_CASE */ @@ -445,6 +443,29 @@ void aes_check_params( ) MBEDTLS_AES_ENCRYPT, 16, out, in, NULL ) ); #endif /* MBEDTLS_CIPHER_MODE_CBC */ + +#if defined(MBEDTLS_CIPHER_MODE_XTS) + TEST_INVALID_PARAM_RET( MBEDTLS_ERR_AES_BAD_INPUT_DATA, + mbedtls_aes_crypt_xts( NULL, + MBEDTLS_AES_ENCRYPT, 16, + in, in, out ) ); + TEST_INVALID_PARAM_RET( MBEDTLS_ERR_AES_BAD_INPUT_DATA, + mbedtls_aes_crypt_xts( &xts_ctx, + 42, 16, + in, in, out ) ); + TEST_INVALID_PARAM_RET( MBEDTLS_ERR_AES_BAD_INPUT_DATA, + mbedtls_aes_crypt_xts( &xts_ctx, + MBEDTLS_AES_ENCRYPT, 16, + NULL, in, out ) ); + TEST_INVALID_PARAM_RET( MBEDTLS_ERR_AES_BAD_INPUT_DATA, + mbedtls_aes_crypt_xts( &xts_ctx, + MBEDTLS_AES_ENCRYPT, 16, + in, NULL, out ) ); + TEST_INVALID_PARAM_RET( MBEDTLS_ERR_AES_BAD_INPUT_DATA, + mbedtls_aes_crypt_xts( &xts_ctx, + MBEDTLS_AES_ENCRYPT, 16, + in, in, NULL ) ); +#endif /* MBEDTLS_CIPHER_MODE_XTS */ } /* END_CASE */ @@ -452,6 +473,9 @@ void aes_check_params( ) void aes_misc_params( ) { mbedtls_aes_context aes_ctx; +#if defined(MBEDTLS_CIPHER_MODE_XTS) + mbedtls_aes_xts_context xts_ctx; +#endif const unsigned char in[16] = { 0 }; unsigned char out[16]; @@ -463,13 +487,25 @@ void aes_misc_params( ) #if defined(MBEDTLS_CIPHER_MODE_CBC) TEST_ASSERT( mbedtls_aes_crypt_cbc( &aes_ctx, MBEDTLS_AES_ENCRYPT, - 15, out, in, out ) + 15, + out, in, out ) == MBEDTLS_ERR_AES_INVALID_INPUT_LENGTH ); TEST_ASSERT( mbedtls_aes_crypt_cbc( &aes_ctx, MBEDTLS_AES_ENCRYPT, - 17, out, in, out ) + 17, + out, in, out ) == MBEDTLS_ERR_AES_INVALID_INPUT_LENGTH ); #endif +#if defined(MBEDTLS_CIPHER_MODE_XTS) + TEST_ASSERT( mbedtls_aes_crypt_xts( &xts_ctx, MBEDTLS_AES_ENCRYPT, + 15, + in, in, out ) + == MBEDTLS_ERR_AES_INVALID_INPUT_LENGTH ); + TEST_ASSERT( mbedtls_aes_crypt_xts( &xts_ctx, MBEDTLS_AES_ENCRYPT, + (1 << 24) + 1, + in, in, out ) + == MBEDTLS_ERR_AES_INVALID_INPUT_LENGTH ); +#endif } /* END_CASE */