diff --git a/include/mbedtls/aes.h b/include/mbedtls/aes.h
index 11edc0fab..197d4db10 100644
--- a/include/mbedtls/aes.h
+++ b/include/mbedtls/aes.h
@@ -197,8 +197,10 @@ int mbedtls_aes_setkey_dec( mbedtls_aes_context *ctx, const unsigned char *key,
* sets the encryption key.
*
* \param ctx The AES XTS context to which the key should be bound.
+ * It must be initialized.
* \param key The encryption key. This is comprised of the XTS key1
* concatenated with the XTS key2.
+ * This must be a readable buffer of size \p keybits bits.
* \param keybits The size of \p key passed in bits. Valid options are:
*
- 256 bits (each of key1 and key2 is a 128-bit key)
* - 512 bits (each of key1 and key2 is a 256-bit key)
@@ -215,8 +217,10 @@ int mbedtls_aes_xts_setkey_enc( mbedtls_aes_xts_context *ctx,
* sets the decryption key.
*
* \param ctx The AES XTS context to which the key should be bound.
+ * It must be initialized.
* \param key The decryption key. This is comprised of the XTS key1
* concatenated with the XTS key2.
+ * This must be a readable buffer of size \p keybits bits.
* \param keybits The size of \p key passed in bits. Valid options are:
* - 256 bits (each of key1 and key2 is a 128-bit key)
* - 512 bits (each of key1 and key2 is a 256-bit key)
diff --git a/library/aes.c b/library/aes.c
index cc1e5ceb4..4d9a56a5c 100644
--- a/library/aes.c
+++ b/library/aes.c
@@ -771,6 +771,9 @@ int mbedtls_aes_xts_setkey_enc( mbedtls_aes_xts_context *ctx,
const unsigned char *key1, *key2;
unsigned int key1bits, key2bits;
+ AES_VALIDATE_RET( ctx != NULL );
+ AES_VALIDATE_RET( key != NULL );
+
ret = mbedtls_aes_xts_decode_keys( key, keybits, &key1, &key1bits,
&key2, &key2bits );
if( ret != 0 )
@@ -793,6 +796,9 @@ int mbedtls_aes_xts_setkey_dec( mbedtls_aes_xts_context *ctx,
const unsigned char *key1, *key2;
unsigned int key1bits, key2bits;
+ AES_VALIDATE_RET( ctx != NULL );
+ AES_VALIDATE_RET( key != NULL );
+
ret = mbedtls_aes_xts_decode_keys( key, keybits, &key1, &key1bits,
&key2, &key2bits );
if( ret != 0 )
diff --git a/tests/suites/test_suite_aes.function b/tests/suites/test_suite_aes.function
index 131565060..576e5be08 100644
--- a/tests/suites/test_suite_aes.function
+++ b/tests/suites/test_suite_aes.function
@@ -215,7 +215,7 @@ void aes_crypt_xts_size( int size, int retval )
void aes_crypt_xts_keysize( int size, int retval )
{
mbedtls_aes_xts_context ctx;
- const unsigned char *key = NULL;
+ const unsigned char key[] = { 0x01, 0x02, 0x03, 0x04, 0x05, 0x06 };
size_t key_len = size;
mbedtls_aes_xts_init( &ctx );
@@ -374,26 +374,38 @@ exit:
/* BEGIN_CASE depends_on:MBEDTLS_CHECK_PARAMS:!MBEDTLS_PARAM_FAILED_ALT */
void aes_invalid_param( )
{
- mbedtls_aes_context dummy_ctx;
+ mbedtls_aes_context aes_ctx;
+#if defined(MBEDTLS_CIPHER_MODE_XTS)
+ mbedtls_aes_xts_context xts_ctx;
+#endif
const unsigned char key[] = { 0x01, 0x02, 0x03, 0x04, 0x05, 0x06 };
TEST_INVALID_PARAM( mbedtls_aes_init( NULL ) );
-
+#if defined(MBEDTLS_CIPHER_MODE_XTS)
TEST_INVALID_PARAM( mbedtls_aes_xts_init( NULL ) );
+#endif
- /* mbedtls_aes_setkey_enc() */
TEST_INVALID_PARAM_RET( MBEDTLS_ERR_AES_BAD_INPUT_DATA,
mbedtls_aes_setkey_enc( NULL, key, 128 ) );
-
TEST_INVALID_PARAM_RET( MBEDTLS_ERR_AES_BAD_INPUT_DATA,
- mbedtls_aes_setkey_enc( &dummy_ctx, NULL, 128 ) );
+ mbedtls_aes_setkey_enc( &aes_ctx, NULL, 128 ) );
- /* mbedtls_aes_setkey_dec() */
TEST_INVALID_PARAM_RET( MBEDTLS_ERR_AES_BAD_INPUT_DATA,
mbedtls_aes_setkey_dec( NULL, key, 128 ) );
+ TEST_INVALID_PARAM_RET( MBEDTLS_ERR_AES_BAD_INPUT_DATA,
+ mbedtls_aes_setkey_dec( &aes_ctx, NULL, 128 ) );
+
+#if defined(MBEDTLS_CIPHER_MODE_XTS)
+ TEST_INVALID_PARAM_RET( MBEDTLS_ERR_AES_BAD_INPUT_DATA,
+ mbedtls_aes_xts_setkey_enc( NULL, key, 128 ) );
+ TEST_INVALID_PARAM_RET( MBEDTLS_ERR_AES_BAD_INPUT_DATA,
+ mbedtls_aes_xts_setkey_enc( &xts_ctx, NULL, 128 ) );
TEST_INVALID_PARAM_RET( MBEDTLS_ERR_AES_BAD_INPUT_DATA,
- mbedtls_aes_setkey_dec( &dummy_ctx, NULL, 128 ) );
+ mbedtls_aes_xts_setkey_dec( NULL, key, 128 ) );
+ TEST_INVALID_PARAM_RET( MBEDTLS_ERR_AES_BAD_INPUT_DATA,
+ mbedtls_aes_xts_setkey_dec( &xts_ctx, NULL, 128 ) );
+#endif
}
/* END_CASE */