Rework AES countermeasures implementation

Use control bytes to instruct AES calculation rounds. Each
calculation round has a control byte that indicates what data
(real/fake) is used and if any offset is required for AES data
positions.

First and last AES calculation round are calculated with SCA CM data
included. The calculation order is randomized by the control bytes.

Calculations between the first and last rounds contains 3 SCA CMs
in randomized positions.
This commit is contained in:
Arto Kinnunen 2019-11-28 13:34:13 +02:00
parent b2be92e2c7
commit 172836a281

View file

@ -85,6 +85,21 @@
}
#endif
/*
* Data structure for AES round data
*/
typedef struct _aes_r_data_s {
uint32_t *rk_ptr; /* Round Key */
uint32_t xy_values[8]; /* X0, X1, X2, X3, Y0, U1, Y2, Y3 */
} aes_r_data_t;
#if defined(MBEDTLS_AES_SCA_COUNTERMEASURES)
/* Number of additional AES calculation rounds added for SCA CM */
#define AES_SCA_CM_ROUNDS 3
#else /* MBEDTLS_AES_SCA_COUNTERMEASURES */
#define AES_SCA_CM_ROUNDS 0
#endif /* MBEDTLS_AES_SCA_COUNTERMEASURES */
#if defined(MBEDTLS_PADLOCK_C) && \
( defined(MBEDTLS_HAVE_X86) || defined(MBEDTLS_PADLOCK_ALIGN16) )
static int aes_padlock_ace = -1;
@ -497,6 +512,81 @@ static void aes_gen_tables( void )
#endif /* MBEDTLS_AES_ROM_TABLES */
/**
* Randomize positions when to use AES SCA countermeasures.
* Each byte indicates one AES round as follows:
* first ( tbl_len - 2 ) bytes are reserved for AES rounds
* -4 high bit = table to use 0x10 for SCA CM data, 0 otherwise
* -4 low bits = offset based on order, 4 for even position, 0 otherwise
* Last 2 bytes for first/final round calculation
* -4 high bit = table to use, 0x10 for SCA CM data, otherwise real data
* -4 low bits = not used
*
*/
static void aes_sca_cm_data_randomize( uint8_t *tbl, uint8_t tbl_len )
{
int i, is_even_pos;
#if AES_SCA_CM_ROUNDS != 0
int is_unique_number;
int num;
#endif
memset( tbl, 0, tbl_len );
#if AES_SCA_CM_ROUNDS != 0
// Randomize SCA CM positions to tbl
for ( i = 0; i < AES_SCA_CM_ROUNDS; i++ )
{
do {
is_unique_number = 1;
/* TODO - Use proper random. This is now ONLY FOR TESTING as mbedtls_platform_random_in_range is alwyays returning 0 */
num = /* mbedtls_platform_random_in_range( tbl_len - 1 ) */rand() % (tbl_len - 2);
if ( tbl[num] == 0 )
{
is_unique_number = 0;
tbl[num] = 0x10;
}
} while ( is_unique_number == 1 );
}
// Fill start/final round control data
if ( AES_SCA_CM_ROUNDS != 0 )
{
num = /* mbedtls_platform_random_in_range( tbl_len - 1 ) */rand() % 0xff;
if ( ( num % 2 ) == 0 )
{
tbl[tbl_len - 2] = 0x10;
tbl[tbl_len - 1] = 0x0;
}
else
{
tbl[tbl_len - 2] = 0x00;
tbl[tbl_len - 1] = 0x10;
}
}
#endif /* AES_SCA_CM_ROUNDS != 0 */
// Fill real AES round data to the remaining places
is_even_pos = 1;
for ( i = 0; i < tbl_len - 2; i++ )
{
if ( tbl[i] == 0 )
{
if ( is_even_pos == 1 )
{
tbl[i] = 0x04; // real data, offset 0
is_even_pos = 0;
}
else
{
tbl[i] = 0x00; // real data, offset 0
is_even_pos = 1;
}
}
}
}
#if defined(MBEDTLS_AES_FEWER_TABLES)
#define ROTL8(x) ( (uint32_t)( ( x ) << 8 ) + (uint32_t)( ( x ) >> 24 ) )
@ -527,57 +617,6 @@ static void aes_gen_tables( void )
#endif /* MBEDTLS_AES_FEWER_TABLES */
#if defined(MBEDTLS_AES_SCA_COUNTERMEASURES)
/*
* SCA CM table position check
*/
#define SCA_CM_TBL_MATCH(tbl, n) ( tbl[0] == ( n ) || \
tbl[1] == ( n ) || \
tbl[2] == ( n ) )
/*
* SCA CM always true check
*/
#define SCA_CM_ALWAYS_TRUE(tbl, n) ( tbl[0] != ( n ) || \
tbl[1] != ( n ) || \
tbl[2] != tbl[0] )
/*
* Number of SCA CM dummy rounds.
*/
#define SCA_CM_DUMMY_ROUND_COUNT 3
#endif /* MBEDTLS_AES_SCA_COUNTERMEASURES */
#if defined(MBEDTLS_AES_SCA_COUNTERMEASURES)
static void aes_sca_rand_tbl_fill( uint8_t *tbl, uint8_t tbl_len, uint8_t max_num )
{
int i, j, is_unique_number;
uint8_t *cur_num;
uint8_t num;
cur_num = tbl;
for ( i = 0; i < tbl_len; i++ )
{
do {
is_unique_number = 1;
num = mbedtls_platform_random_in_range( max_num + 1 );
for ( j = 0; j < i; j++ )
{
if ( num == tbl[j] )
{
is_unique_number = 0;
break;
}
}
} while ( is_unique_number == 0 );
*cur_num++ = num;
}
}
#endif /* MBEDTLS_AES_SCA_COUNTERMEASURES */
void mbedtls_aes_init( mbedtls_aes_context *ctx )
{
AES_VALIDATE( ctx != NULL );
@ -936,6 +975,7 @@ int mbedtls_aes_xts_setkey_dec( mbedtls_aes_xts_context *ctx,
( (uint32_t) FSb[ ( (Y2) >> 24 ) & 0xFF ] << 24 ); \
} while ( 0 )
#define AES_RROUND(R,X0,X1,X2,X3,Y0,Y1,Y2,Y3) \
do \
{ \
@ -992,88 +1032,76 @@ int mbedtls_internal_aes_encrypt( mbedtls_aes_context *ctx,
const unsigned char input[16],
unsigned char output[16] )
{
int i;
uint32_t *RK, X0, X1, X2, X3, Y0 = 0, Y1 = 0, Y2 = 0, Y3 = 0;
int i, j, offset, start_fin_loops = 1;
aes_r_data_t aes_data_real; // real data
#if AES_SCA_CM_ROUNDS != 0
aes_r_data_t aes_data_fake; // fake data
#endif /* AES_SCA_CM_ROUNDS != 0 */
aes_r_data_t *aes_data_ptr; // pointer to aes_data_real or aes_data_fake
aes_r_data_t *aes_data_table[2]; // pointers to real and fake data
int round_ctrl_table_len = ctx->nr - 1 + AES_SCA_CM_ROUNDS + 2;
// control bytes for AES rounds, reserve based on max ctx->nr
uint8_t round_ctrl_table[ 14 - 1 + AES_SCA_CM_ROUNDS + 2 ];
#ifdef MBEDTLS_AES_SCA_COUNTERMEASURES
uint32_t *RK_SCA, X0_SCA, X1_SCA, X2_SCA, X3_SCA, Y0_SCA, Y1_SCA, Y2_SCA, Y3_SCA;
uint8_t sca_cm_pos_tbl[SCA_CM_DUMMY_ROUND_COUNT]; // position for SCA countermeasure dummy rounds, not in any order
aes_data_real.rk_ptr = ctx->rk;
aes_data_table[0] = &aes_data_real;
aes_sca_rand_tbl_fill( sca_cm_pos_tbl, SCA_CM_DUMMY_ROUND_COUNT, ctx->nr );
X0_SCA = mbedtls_platform_random_in_range( 0xffffffff );
X1_SCA = mbedtls_platform_random_in_range( 0xffffffff );
X2_SCA = mbedtls_platform_random_in_range( 0xffffffff );
X3_SCA = mbedtls_platform_random_in_range( 0xffffffff );
#endif /* MBEDTLS_AES_SCA_COUNTERMEASURES */
RK = ctx->rk;
#ifdef MBEDTLS_AES_SCA_COUNTERMEASURES
RK_SCA = RK;
if ( SCA_CM_TBL_MATCH( sca_cm_pos_tbl, ctx->nr ) )
{
/* LE conversions to Xn, Xn_SCA randomized */
GET_UINT32_LE( X0, input, 0 ); X0_SCA ^= *RK_SCA++;
GET_UINT32_LE( X1, input, 4 ); X1_SCA ^= *RK_SCA++;
GET_UINT32_LE( X2, input, 8 ); X2_SCA ^= *RK_SCA++;
GET_UINT32_LE( X3, input, 12 ); X3_SCA ^= *RK_SCA++;
}
#endif /* MBEDTLS_AES_SCA_COUNTERMEASURES */
GET_UINT32_LE( X0, input, 0 ); X0 ^= *RK++;
GET_UINT32_LE( X1, input, 4 ); X1 ^= *RK++;
GET_UINT32_LE( X2, input, 8 ); X2 ^= *RK++;
GET_UINT32_LE( X3, input, 12 ); X3 ^= *RK++;
for( i = ( ctx->nr >> 1 ) - 1; i > 0; i-- )
{
#ifdef MBEDTLS_AES_SCA_COUNTERMEASURES
// Would random delay before each round be necessary?
//
if ( SCA_CM_TBL_MATCH( sca_cm_pos_tbl, i * 2 ) )
AES_FROUND( RK_SCA, Y0_SCA, Y1_SCA, Y2_SCA, Y3_SCA,
X0_SCA, X1_SCA, X2_SCA, X3_SCA );
if ( SCA_CM_ALWAYS_TRUE( sca_cm_pos_tbl, i* 2 ) )
AES_FROUND( RK, Y0, Y1, Y2, Y3, X0, X1, X2, X3 );
if ( SCA_CM_TBL_MATCH( sca_cm_pos_tbl, i * 2 + 1 ) )
AES_FROUND( RK_SCA, X0_SCA, X1_SCA, X2_SCA, X3_SCA,
Y0_SCA, Y1_SCA, Y2_SCA, Y3_SCA);
if ( SCA_CM_ALWAYS_TRUE( sca_cm_pos_tbl, i * 2 + 1 ) )
AES_FROUND( RK, X0, X1, X2, X3, Y0, Y1, Y2, Y3 );
#else /* MBEDTLS_AES_SCA_COUNTERMEASURES */
AES_FROUND( RK, Y0, Y1, Y2, Y3, X0, X1, X2, X3 );
AES_FROUND( RK, X0, X1, X2, X3, Y0, Y1, Y2, Y3 );
#endif /* MBEDTLS_AES_SCA_COUNTERMEASURES */
}
#ifdef MBEDTLS_AES_SCA_COUNTERMEASURES
if ( SCA_CM_TBL_MATCH( sca_cm_pos_tbl, 1 ) )
AES_FROUND( RK_SCA, Y0_SCA, Y1_SCA, Y2_SCA, Y3_SCA,
X0_SCA, X1_SCA, X2_SCA, X3_SCA );
if ( SCA_CM_ALWAYS_TRUE ( sca_cm_pos_tbl, 1 ) )
AES_FROUND( RK, Y0, Y1, Y2, Y3, X0, X1, X2, X3 );
if ( SCA_CM_TBL_MATCH( sca_cm_pos_tbl, 0 ) )
AES_FROUND_F( RK_SCA, X0_SCA, X1_SCA, X2_SCA, X3_SCA,
Y0_SCA, Y1_SCA, Y2_SCA, Y3_SCA );
if ( SCA_CM_ALWAYS_TRUE ( sca_cm_pos_tbl, 0 ) )
AES_FROUND_F( RK, X0, X1, X2, X3, Y0, Y1, Y2, Y3 );
#else
AES_FROUND( RK, Y0, Y1, Y2, Y3, X0, X1, X2, X3 );
AES_FROUND_F( RK, X0, X1, X2, X3, Y0, Y1, Y2, Y3 );
#if AES_SCA_CM_ROUNDS != 0
aes_data_table[1] = &aes_data_fake;
aes_data_fake.rk_ptr = ctx->rk;
start_fin_loops = 2;
for (i = 0; i < 4; i++ )
aes_data_fake.xy_values[i] = mbedtls_platform_random_in_range( 0xffffffff );
#endif
PUT_UINT32_LE( X0, output, 0 );
PUT_UINT32_LE( X1, output, 4 );
PUT_UINT32_LE( X2, output, 8 );
PUT_UINT32_LE( X3, output, 12 );
// Get randomized AES calculation control bytes
aes_sca_cm_data_randomize( round_ctrl_table, round_ctrl_table_len );
for (i = 0; i < 4; i++ )
{
GET_UINT32_LE( aes_data_real.xy_values[i], input, ( i * 4 ) );
for (j = 0; j < start_fin_loops; j++ )
{
aes_data_ptr = aes_data_table[round_ctrl_table[ round_ctrl_table_len - 2 + j ] >> 4];
aes_data_ptr->xy_values[i] ^= *aes_data_ptr->rk_ptr++;
}
}
for( i = 0; i < ( ctx->nr - 1 + AES_SCA_CM_ROUNDS ); i++ )
{
// Read AES control data
aes_data_ptr = aes_data_table[round_ctrl_table[i] >> 4];
offset = round_ctrl_table[i] & 0x0f;
AES_FROUND( aes_data_ptr->rk_ptr,
aes_data_ptr->xy_values[0 + offset],
aes_data_ptr->xy_values[1 + offset],
aes_data_ptr->xy_values[2 + offset],
aes_data_ptr->xy_values[3 + offset],
aes_data_ptr->xy_values[4 - offset],
aes_data_ptr->xy_values[5 - offset],
aes_data_ptr->xy_values[6 - offset],
aes_data_ptr->xy_values[7 - offset] );
}
for ( j = 0; j < start_fin_loops; j++ )
{
aes_data_ptr = aes_data_table[round_ctrl_table[ i + j ] >> 4];
AES_FROUND_F( aes_data_ptr->rk_ptr,
aes_data_ptr->xy_values[0],
aes_data_ptr->xy_values[1],
aes_data_ptr->xy_values[2],
aes_data_ptr->xy_values[3],
aes_data_ptr->xy_values[4],
aes_data_ptr->xy_values[5],
aes_data_ptr->xy_values[6],
aes_data_ptr->xy_values[7] );
}
for ( i = 0; i < 4; i++ )
{
PUT_UINT32_LE( aes_data_real.xy_values[i], output, ( i * 4 ) );
}
return( 0 );
}
@ -1098,95 +1126,76 @@ int mbedtls_internal_aes_decrypt( mbedtls_aes_context *ctx,
const unsigned char input[16],
unsigned char output[16] )
{
int i;
uint32_t *RK, X0, X1, X2, X3, Y0 = 0, Y1 = 0, Y2 = 0, Y3 = 0;
int i, j, offset, start_fin_loops = 1;
aes_r_data_t aes_data_real; // real data
#if AES_SCA_CM_ROUNDS != 0
aes_r_data_t aes_data_fake; // fake data
#endif /* AES_SCA_CM_ROUNDS != 0 */
aes_r_data_t *aes_data_ptr; // pointer to aes_data_real or aes_data_fake
aes_r_data_t *aes_data_table[2]; // pointers to real and fake data
int round_ctrl_table_len = ctx->nr - 1 + AES_SCA_CM_ROUNDS + 2;
// control bytes for AES rounds, reserve based on max ctx->nr
uint8_t round_ctrl_table[ 14 - 1 + AES_SCA_CM_ROUNDS + 2 ];
#ifdef MBEDTLS_AES_SCA_COUNTERMEASURES
uint32_t *RK_SCA, X0_SCA, X1_SCA, X2_SCA, X3_SCA, Y0_SCA, Y1_SCA, Y2_SCA, Y3_SCA;
uint8_t sca_cm_pos_tbl[SCA_CM_DUMMY_ROUND_COUNT]; // position for SCA countermeasure dummy rounds, not in any order
aes_data_real.rk_ptr = ctx->rk;
aes_data_table[0] = &aes_data_real;
aes_sca_rand_tbl_fill( sca_cm_pos_tbl, SCA_CM_DUMMY_ROUND_COUNT, ctx->nr );
#if AES_SCA_CM_ROUNDS != 0
aes_data_table[1] = &aes_data_fake;
aes_data_fake.rk_ptr = ctx->rk;
start_fin_loops = 2;
for (i = 0; i < 4; i++ )
aes_data_fake.xy_values[i] = mbedtls_platform_random_in_range( 0xffffffff );
#endif
X0_SCA = mbedtls_platform_random_in_range( 0xffffffff );
X1_SCA = mbedtls_platform_random_in_range( 0xffffffff );
X2_SCA = mbedtls_platform_random_in_range( 0xffffffff );
X3_SCA = mbedtls_platform_random_in_range( 0xffffffff );
#endif /* MBEDTLS_AES_SCA_COUNTERMEASURES */
// Get randomized AES calculation control bytes
aes_sca_cm_data_randomize( round_ctrl_table, round_ctrl_table_len );
RK = ctx->rk;
#ifdef MBEDTLS_AES_SCA_COUNTERMEASURES
RK_SCA = RK;
if ( SCA_CM_TBL_MATCH( sca_cm_pos_tbl, ctx->nr ) )
for (i = 0; i < 4; i++ )
{
GET_UINT32_LE( X0, input, 0 ); X0_SCA ^= *RK_SCA++;
GET_UINT32_LE( X1, input, 4 ); X1_SCA ^= *RK_SCA++;
GET_UINT32_LE( X2, input, 8 ); X2_SCA ^= *RK_SCA++;
GET_UINT32_LE( X3, input, 12 ); X3_SCA ^= *RK_SCA++;
GET_UINT32_LE( aes_data_real.xy_values[i], input, ( i * 4 ) );
for (j = 0; j < start_fin_loops; j++ )
{
aes_data_ptr = aes_data_table[round_ctrl_table[ round_ctrl_table_len - 2 + j ] >> 4];
aes_data_ptr->xy_values[i] ^= *aes_data_ptr->rk_ptr++;
}
}
if ( SCA_CM_ALWAYS_TRUE( sca_cm_pos_tbl, ctx->nr ) )
for( i = 0; i < ( ctx->nr - 1 + AES_SCA_CM_ROUNDS ); i++ )
{
GET_UINT32_LE( X0, input, 0 ); X0 ^= *RK++;
GET_UINT32_LE( X1, input, 4 ); X1 ^= *RK++;
GET_UINT32_LE( X2, input, 8 ); X2 ^= *RK++;
GET_UINT32_LE( X3, input, 12 ); X3 ^= *RK++;
}
#else /* MBEDTLS_AES_SCA_COUNTERMEASURES */
GET_UINT32_LE( X0, input, 0 ); X0 ^= *RK++;
GET_UINT32_LE( X1, input, 4 ); X1 ^= *RK++;
GET_UINT32_LE( X2, input, 8 ); X2 ^= *RK++;
GET_UINT32_LE( X3, input, 12 ); X3 ^= *RK++;
#endif /* MBEDTLS_AES_SCA_COUNTERMEASURES */
// Read AES control data
aes_data_ptr = aes_data_table[round_ctrl_table[i] >> 4];
offset = round_ctrl_table[i] & 0x0f;
for( i = ( ctx->nr >> 1 ) - 1; i > 0; i-- )
{
#ifdef MBEDTLS_AES_SCA_COUNTERMEASURES
// Would random delay before each round be necessary?
//
if ( SCA_CM_TBL_MATCH( sca_cm_pos_tbl, i * 2 ) )
AES_RROUND( RK_SCA, Y0_SCA, Y1_SCA, Y2_SCA, Y3_SCA,
X0_SCA, X1_SCA, X2_SCA, X3_SCA );
if ( SCA_CM_ALWAYS_TRUE( sca_cm_pos_tbl, i* 2 ) )
AES_RROUND( RK, Y0, Y1, Y2, Y3, X0, X1, X2, X3 );
if ( SCA_CM_TBL_MATCH( sca_cm_pos_tbl, i * 2 + 1 ) )
AES_RROUND( RK_SCA, X0_SCA, X1_SCA, X2_SCA, X3_SCA,
Y0_SCA, Y1_SCA, Y2_SCA, Y3_SCA);
if ( SCA_CM_ALWAYS_TRUE( sca_cm_pos_tbl, i * 2 + 1 ) )
AES_RROUND( RK, X0, X1, X2, X3, Y0, Y1, Y2, Y3 );
#else /* MBEDTLS_AES_SCA_COUNTERMEASURES */
AES_RROUND( RK, Y0, Y1, Y2, Y3, X0, X1, X2, X3 );
AES_RROUND( RK, X0, X1, X2, X3, Y0, Y1, Y2, Y3 );
#endif /* MBEDTLS_AES_SCA_COUNTERMEASURES */
AES_RROUND( aes_data_ptr->rk_ptr,
aes_data_ptr->xy_values[0 + offset],
aes_data_ptr->xy_values[1 + offset],
aes_data_ptr->xy_values[2 + offset],
aes_data_ptr->xy_values[3 + offset],
aes_data_ptr->xy_values[4 - offset],
aes_data_ptr->xy_values[5 - offset],
aes_data_ptr->xy_values[6 - offset],
aes_data_ptr->xy_values[7 - offset] );
}
#ifdef MBEDTLS_AES_SCA_COUNTERMEASURES
if ( SCA_CM_TBL_MATCH( sca_cm_pos_tbl, 1 ) )
AES_RROUND( RK_SCA, Y0_SCA, Y1_SCA, Y2_SCA, Y3_SCA,
X0_SCA, X1_SCA, X2_SCA, X3_SCA );
for ( j = 0; j < start_fin_loops; j++ )
{
aes_data_ptr = aes_data_table[round_ctrl_table[ i + j ] >> 4];
AES_RROUND_F( aes_data_ptr->rk_ptr,
aes_data_ptr->xy_values[0],
aes_data_ptr->xy_values[1],
aes_data_ptr->xy_values[2],
aes_data_ptr->xy_values[3],
aes_data_ptr->xy_values[4],
aes_data_ptr->xy_values[5],
aes_data_ptr->xy_values[6],
aes_data_ptr->xy_values[7] );
}
if ( SCA_CM_ALWAYS_TRUE ( sca_cm_pos_tbl, 1 ) )
AES_RROUND( RK, Y0, Y1, Y2, Y3, X0, X1, X2, X3 );
if ( SCA_CM_TBL_MATCH( sca_cm_pos_tbl, 0 ) )
AES_RROUND_F( RK_SCA, X0_SCA, X1_SCA, X2_SCA, X3_SCA,
Y0_SCA, Y1_SCA, Y2_SCA, Y3_SCA );
if ( SCA_CM_ALWAYS_TRUE ( sca_cm_pos_tbl, 0 ) )
AES_RROUND_F( RK, X0, X1, X2, X3, Y0, Y1, Y2, Y3 );
#else /* MBEDTLS_AES_SCA_COUNTERMEASURES */
AES_RROUND( RK, Y0, Y1, Y2, Y3, X0, X1, X2, X3 );
AES_RROUND_F( RK, X0, X1, X2, X3, Y0, Y1, Y2, Y3 );
#endif /* MBEDTLS_AES_SCA_COUNTERMEASURES */
PUT_UINT32_LE( X0, output, 0 );
PUT_UINT32_LE( X1, output, 4 );
PUT_UINT32_LE( X2, output, 8 );
PUT_UINT32_LE( X3, output, 12 );
for ( i = 0; i < 4; i++ )
{
PUT_UINT32_LE( aes_data_real.xy_values[i], output, ( i * 4 ) );
}
return( 0 );
}