diff --git a/include/polarssl/pk.h b/include/polarssl/pk.h index 6135bb013..7c674ab1f 100644 --- a/include/polarssl/pk.h +++ b/include/polarssl/pk.h @@ -244,6 +244,17 @@ int pk_init_ctx_rsa_alt( pk_context *ctx, void * key, */ size_t pk_get_size( const pk_context *ctx ); +/** + * \brief Get the length in bytes of the underlying key + * \param ctx Context to use + * + * \return Key lenght in bytes, or 0 on error + */ +static size_t pk_get_len( const pk_context *ctx ) +{ + return( ( pk_get_size( ctx ) + 7 ) / 8 ); +} + /** * \brief Tell if a context can do the operation given by type * diff --git a/include/polarssl/ssl.h b/include/polarssl/ssl.h index 442dba28f..dfd649081 100644 --- a/include/polarssl/ssl.h +++ b/include/polarssl/ssl.h @@ -202,6 +202,7 @@ #define SSL_HASH_SHA384 5 #define SSL_HASH_SHA512 6 +#define SSL_SIG_ANON 0 #define SSL_SIG_RSA 1 #define SSL_SIG_ECDSA 3 @@ -580,13 +581,6 @@ struct _ssl_context */ pk_context *pk_key; /*!< own private key */ int pk_key_own_alloc; /*!< did we allocate pk_key? */ -#if defined(POLARSSL_RSA_C) - int rsa_use_alt; /*out_msg[4] = SSL_HASH_SHA256; } - /* SIG added later */ + ssl->out_msg[5] = ssl_sig_from_pk( ssl->pk_key ); if( ( md_info = md_info_from_type( md_alg ) ) == NULL ) { @@ -2036,40 +2036,13 @@ static int ssl_write_certificate_verify( ssl_context *ssl ) offset = 2; } -#if defined(POLARSSL_RSA_C) - if( ssl->rsa_key != NULL ) + if( ( ret = pk_sign( ssl->pk_key, md_alg, hash, hashlen, + ssl->out_msg + 6 + offset, &n, + ssl->f_rng, ssl->p_rng ) ) != 0 ) { - if( ssl->minor_ver == SSL_MINOR_VERSION_3 ) - ssl->out_msg[5] = SSL_SIG_RSA; - - if( ( ret = pk_sign( ssl->pk_key, md_alg, hash, hashlen, - ssl->out_msg + 6 + offset, &n, - ssl->f_rng, ssl->p_rng ) ) != 0 ) - { - SSL_DEBUG_RET( 1, "pk_sign", ret ); - return( ret ); - } + SSL_DEBUG_RET( 1, "pk_sign", ret ); + return( ret ); } - else -#endif /* POLARSSL_RSA_C */ -#if defined(POLARSSL_ECDSA_C) - if( pk_can_do( ssl->pk_key, POLARSSL_PK_ECDSA ) ) - { - if( ssl->minor_ver == SSL_MINOR_VERSION_3 ) - ssl->out_msg[5] = SSL_SIG_ECDSA; - - if( ( ret = pk_sign( ssl->pk_key, md_alg, hash, hashlen, - ssl->out_msg + 6 + offset, &n, - ssl->f_rng, ssl->p_rng ) ) != 0 ) - { - SSL_DEBUG_RET( 1, "pk_sign", ret ); - return( ret ); - } - } - else -#endif /* POLARSSL_ECDSA_C */ - /* should never happen */ - return( POLARSSL_ERR_SSL_FEATURE_UNAVAILABLE ); ssl->out_msg[4 + offset] = (unsigned char)( n >> 8 ); ssl->out_msg[5 + offset] = (unsigned char)( n ); diff --git a/library/ssl_srv.c b/library/ssl_srv.c index 0fa4f66f0..6c4bdf086 100644 --- a/library/ssl_srv.c +++ b/library/ssl_srv.c @@ -2069,50 +2069,21 @@ static int ssl_write_server_key_exchange( ssl_context *ssl ) return( POLARSSL_ERR_SSL_PRIVATE_KEY_REQUIRED ); } -#if defined(POLARSSL_RSA_C) - if( ssl->rsa_key != NULL ) + if( ssl->minor_ver == SSL_MINOR_VERSION_3 ) { - if( ssl->minor_ver == SSL_MINOR_VERSION_3 ) - { - *(p++) = ssl->handshake->sig_alg; - *(p++) = SSL_SIG_RSA; + *(p++) = ssl->handshake->sig_alg; + *(p++) = ssl_sig_from_pk( ssl->pk_key ); - n += 2; - } - - if( ( ret = pk_sign( ssl->pk_key, md_alg, hash, hashlen, - p + 2 , &signature_len, - ssl->f_rng, ssl->p_rng ) ) != 0 ) - { - SSL_DEBUG_RET( 1, "pk_sign", ret ); - return( ret ); - } + n += 2; } - else -#endif /* POLARSSL_RSA_C */ -#if defined(POLARSSL_ECDSA_C) - if( pk_can_do( ssl->pk_key, POLARSSL_PK_ECDSA ) ) + + if( ( ret = pk_sign( ssl->pk_key, md_alg, hash, hashlen, + p + 2 , &signature_len, + ssl->f_rng, ssl->p_rng ) ) != 0 ) { - if( ssl->minor_ver == SSL_MINOR_VERSION_3 ) - { - *(p++) = ssl->handshake->sig_alg; - *(p++) = SSL_SIG_ECDSA; - - n += 2; - } - - if( ( ret = pk_sign( ssl->pk_key, md_alg, hash, hashlen, - p + 2 , &signature_len, - ssl->f_rng, ssl->p_rng ) ) != 0 ) - { - SSL_DEBUG_RET( 1, "pk_sign", ret ); - return( ret ); - } + SSL_DEBUG_RET( 1, "pk_sign", ret ); + return( ret ); } - else -#endif /* POLARSSL_ECDSA_C */ - /* should never happen */ - return( POLARSSL_ERR_SSL_FEATURE_UNAVAILABLE ); *(p++) = (unsigned char)( signature_len >> 8 ); *(p++) = (unsigned char)( signature_len ); @@ -2254,7 +2225,7 @@ static int ssl_parse_encrypted_pms_secret( ssl_context *ssl ) * Decrypt the premaster using own private RSA key */ i = 4; - n = ssl->rsa_key_len( ssl->rsa_key ); + n = pk_get_len( ssl->pk_key ); ssl->handshake->pmslen = 48; if( ssl->minor_ver != SSL_MINOR_VERSION_0 ) diff --git a/library/ssl_tls.c b/library/ssl_tls.c index 9e446f613..527b333e6 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -131,30 +131,6 @@ int (*ssl_hw_record_read)(ssl_context *ssl) = NULL; int (*ssl_hw_record_finish)(ssl_context *ssl) = NULL; #endif -#if defined(POLARSSL_RSA_C) -static int ssl_rsa_decrypt( void *ctx, int mode, size_t *olen, - const unsigned char *input, unsigned char *output, - size_t output_max_len ) -{ - return rsa_pkcs1_decrypt( (rsa_context *) ctx, mode, olen, input, output, - output_max_len ); -} - -static int ssl_rsa_sign( void *ctx, - int (*f_rng)(void *, unsigned char *, size_t), void *p_rng, - int mode, int hash_id, unsigned int hashlen, - const unsigned char *hash, unsigned char *sig ) -{ - return rsa_pkcs1_sign( (rsa_context *) ctx, f_rng, p_rng, mode, hash_id, - hashlen, hash, sig ); -} - -static size_t ssl_rsa_key_len( void *ctx ) -{ - return ( (rsa_context *) ctx )->len; -} -#endif /* POLARSSL_RSA_C */ - /* * Key material generation */ @@ -2858,12 +2834,6 @@ int ssl_init( ssl_context *ssl ) /* * Sane defaults */ -#if defined(POLARSSL_RSA_C) - ssl->rsa_decrypt = ssl_rsa_decrypt; - ssl->rsa_sign = ssl_rsa_sign; - ssl->rsa_key_len = ssl_rsa_key_len; -#endif - ssl->min_major_ver = SSL_MAJOR_VERSION_3; ssl->min_minor_ver = SSL_MINOR_VERSION_0; ssl->max_major_ver = SSL_MAJOR_VERSION_3; @@ -3147,18 +3117,31 @@ void ssl_set_own_cert( ssl_context *ssl, x509_cert *own_cert, { ssl->own_cert = own_cert; ssl->pk_key = pk_key; - - /* Temporary, until everything is moved to PK */ - if( pk_key->pk_info->type == POLARSSL_PK_RSA ) - ssl->rsa_key = pk_key->pk_ctx; } #if defined(POLARSSL_RSA_C) -void ssl_set_own_cert_rsa( ssl_context *ssl, x509_cert *own_cert, +int ssl_set_own_cert_rsa( ssl_context *ssl, x509_cert *own_cert, rsa_context *rsa_key ) { + int ret; + ssl->own_cert = own_cert; - ssl->rsa_key = rsa_key; + + if( ( ssl->pk_key = polarssl_malloc( sizeof( pk_context ) ) ) == NULL ) + return( POLARSSL_ERR_SSL_MALLOC_FAILED ); + + ssl->pk_key_own_alloc = 1; + + pk_init( ssl->pk_key ); + + ret = pk_init_ctx( ssl->pk_key, pk_info_from_type( POLARSSL_PK_RSA ) ); + if( ret != 0 ) + return( ret ); + + if( ( ret = rsa_copy( ssl->pk_key->pk_ctx, rsa_key ) ) != 0 ) + return( ret ); + + return( 0 ); } #endif /* POLARSSL_RSA_C */ @@ -3168,14 +3151,7 @@ int ssl_set_own_cert_alt_rsa( ssl_context *ssl, x509_cert *own_cert, rsa_sign_func rsa_sign, rsa_key_len_func rsa_key_len ) { - int ret; - ssl->own_cert = own_cert; - ssl->rsa_use_alt = 1; - ssl->rsa_key = rsa_key; - ssl->rsa_decrypt = rsa_decrypt; - ssl->rsa_sign = rsa_sign; - ssl->rsa_key_len = rsa_key_len; if( ( ssl->pk_key = polarssl_malloc( sizeof( pk_context ) ) ) == NULL ) return( POLARSSL_ERR_SSL_MALLOC_FAILED ); @@ -3812,4 +3788,20 @@ void ssl_free( ssl_context *ssl ) memset( ssl, 0, sizeof( ssl_context ) ); } +/* + * Get the SSL_SIG_* constant corresponding to a public key + */ +unsigned char ssl_sig_from_pk( pk_context *pk ) +{ +#if defined(POLARSSL_RSA_C) + if( pk_can_do( pk, POLARSSL_PK_RSA ) ) + return( SSL_SIG_RSA ); +#endif +#if defined(POLARSSL_ECDSA_C) + if( pk_can_do( pk, POLARSSL_PK_ECDSA ) ) + return( SSL_SIG_ECDSA ); +#endif + return( SSL_SIG_ANON ); +} + #endif