From 8df2769178781aa4442beb43a27d627297a5b217 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Manuel=20P=C3=A9gouri=C3=A9-Gonnard?= Date: Wed, 21 Aug 2013 10:34:38 +0200 Subject: [PATCH] Introduce pk_sign() and use it in ssl --- include/polarssl/pk.h | 26 ++++++++++++++++++++ include/polarssl/ssl.h | 1 + library/pk.c | 21 ++++++++++++++-- library/pk_wrap.c | 54 +++++++++++++++++++++++++++++++++++++++++- library/ssl_cli.c | 46 ++++++++++++++++++----------------- library/ssl_srv.c | 46 ++++++++++++++++++----------------- library/ssl_tls.c | 1 + 7 files changed, 148 insertions(+), 47 deletions(-) diff --git a/include/polarssl/pk.h b/include/polarssl/pk.h index fb0e92ec5..cc8a2fcfb 100644 --- a/include/polarssl/pk.h +++ b/include/polarssl/pk.h @@ -129,6 +129,13 @@ typedef struct const unsigned char *hash, size_t hash_len, const unsigned char *sig, size_t sig_len ); + /** Make signature */ + int (*sign_func)( void *ctx, md_type_t md_alg, + const unsigned char *hash, size_t hash_len, + unsigned char *sig, size_t *sig_len, + int (*f_rng)(void *, unsigned char *, size_t), + void *p_rng ); + /** Allocate a new context */ void * (*ctx_alloc_func)( void ); @@ -218,6 +225,25 @@ int pk_verify( pk_context *ctx, md_type_t md_alg, const unsigned char *hash, size_t hash_len, const unsigned char *sig, size_t sig_len ); +/** + * \brief Make signature + * + * \param ctx PK context to use + * \param md_alg Hash algorithm used + * \param hash Hash of the message to sign + * \param hash_len Hash length + * \param sig Place to write the signature + * \param sig_len Number of bytes written + * \param f_rng RNG function + * \param p_rng RNG parameter + * + * \return 0 on success, or a specific error code. + */ +int pk_sign( pk_context *ctx, md_type_t md_alg, + const unsigned char *hash, size_t hash_len, + unsigned char *sig, size_t *sig_len, + int (*f_rng)(void *, unsigned char *, size_t), void *p_rng ); + /** * \brief Export debug information * diff --git a/include/polarssl/ssl.h b/include/polarssl/ssl.h index b98551b4a..9a1d22044 100644 --- a/include/polarssl/ssl.h +++ b/include/polarssl/ssl.h @@ -580,6 +580,7 @@ struct _ssl_context */ pk_context *pk_key; /*!< own private key */ #if defined(POLARSSL_RSA_C) + int rsa_use_alt; /*pk_info->verify_func == NULL ) return( POLARSSL_ERR_PK_TYPE_MISMATCH ); - return( ctx->pk_info->verify_func( ctx->pk_ctx, md_alg, - hash, hash_len, + return( ctx->pk_info->verify_func( ctx->pk_ctx, md_alg, hash, hash_len, sig, sig_len ) ); } +/* + * Make a signature + */ +int pk_sign( pk_context *ctx, md_type_t md_alg, + const unsigned char *hash, size_t hash_len, + unsigned char *sig, size_t *sig_len, + int (*f_rng)(void *, unsigned char *, size_t), void *p_rng ) +{ + if( ctx == NULL || ctx->pk_info == NULL ) + return( POLARSSL_ERR_PK_BAD_INPUT_DATA ); + + if( ctx->pk_info->sign_func == NULL ) + return( POLARSSL_ERR_PK_TYPE_MISMATCH ); + + return( ctx->pk_info->sign_func( ctx->pk_ctx, md_alg, hash, hash_len, + sig, sig_len, f_rng, p_rng ) ); +} + /* * Get key size in bits */ diff --git a/library/pk_wrap.c b/library/pk_wrap.c index 249f7bd0c..eb91d895f 100644 --- a/library/pk_wrap.c +++ b/library/pk_wrap.c @@ -69,6 +69,17 @@ static int rsa_verify_wrap( void *ctx, md_type_t md_alg, RSA_PUBLIC, md_alg, hash_len, hash, sig ) ); } +static int rsa_sign_wrap( void *ctx, md_type_t md_alg, + const unsigned char *hash, size_t hash_len, + unsigned char *sig, size_t *sig_len, + int (*f_rng)(void *, unsigned char *, size_t), void *p_rng ) +{ + *sig_len = ((rsa_context *) ctx)->len; + + return( rsa_pkcs1_sign( (rsa_context *) ctx, f_rng, p_rng, RSA_PRIVATE, + md_alg, hash_len, hash, sig ) ); +} + static void *rsa_alloc_wrap( void ) { void *ctx = polarssl_malloc( sizeof( rsa_context ) ); @@ -104,6 +115,7 @@ const pk_info_t rsa_info = { rsa_get_size, rsa_can_do, rsa_verify_wrap, + rsa_sign_wrap, rsa_alloc_wrap, rsa_free_wrap, rsa_debug, @@ -127,11 +139,16 @@ static size_t eckey_get_size( const void *ctx ) } #if defined(POLARSSL_ECDSA_C) -/* Forward declaration */ +/* Forward declarations */ static int ecdsa_verify_wrap( void *ctx, md_type_t md_alg, const unsigned char *hash, size_t hash_len, const unsigned char *sig, size_t sig_len ); +static int ecdsa_sign_wrap( void *ctx, md_type_t md_alg, + const unsigned char *hash, size_t hash_len, + unsigned char *sig, size_t *sig_len, + int (*f_rng)(void *, unsigned char *, size_t), void *p_rng ); + static int eckey_verify_wrap( void *ctx, md_type_t md_alg, const unsigned char *hash, size_t hash_len, const unsigned char *sig, size_t sig_len ) @@ -148,6 +165,26 @@ static int eckey_verify_wrap( void *ctx, md_type_t md_alg, return( ret ); } + +static int eckey_sign_wrap( void *ctx, md_type_t md_alg, + const unsigned char *hash, size_t hash_len, + unsigned char *sig, size_t *sig_len, + int (*f_rng)(void *, unsigned char *, size_t), void *p_rng ) +{ + int ret; + ecdsa_context ecdsa; + + ecdsa_init( &ecdsa ); + + if( ( ret = ecdsa_from_keypair( &ecdsa, ctx ) ) == 0 ) + ret = ecdsa_sign_wrap( &ecdsa, md_alg, hash, hash_len, sig, sig_len, + f_rng, p_rng ); + + ecdsa_free( &ecdsa ); + + return( ret ); +} + #endif /* POLARSSL_ECDSA_C */ static void *eckey_alloc_wrap( void ) @@ -180,8 +217,10 @@ const pk_info_t eckey_info = { eckey_can_do, #if defined(POLARSSL_ECDSA_C) eckey_verify_wrap, + eckey_sign_wrap, #else NULL, + NULL, #endif eckey_alloc_wrap, eckey_free_wrap, @@ -203,6 +242,7 @@ const pk_info_t eckeydh_info = { eckey_get_size, /* Same underlying key structure */ eckeydh_can_do, NULL, + NULL, eckey_alloc_wrap, /* Same underlying key structure */ eckey_free_wrap, /* Same underlying key structure */ eckey_debug, /* Same underlying key structure */ @@ -225,6 +265,17 @@ static int ecdsa_verify_wrap( void *ctx, md_type_t md_alg, hash, hash_len, sig, sig_len ) ); } +static int ecdsa_sign_wrap( void *ctx, md_type_t md_alg, + const unsigned char *hash, size_t hash_len, + unsigned char *sig, size_t *sig_len, + int (*f_rng)(void *, unsigned char *, size_t), void *p_rng ) +{ + ((void) md_alg); + + return( ecdsa_write_signature( (ecdsa_context *) ctx, + hash, hash_len, sig, sig_len, f_rng, p_rng ) ); +} + static void *ecdsa_alloc_wrap( void ) { void *ctx = polarssl_malloc( sizeof( ecdsa_context ) ); @@ -247,6 +298,7 @@ const pk_info_t ecdsa_info = { eckey_get_size, /* Compatible key structures */ ecdsa_can_do, ecdsa_verify_wrap, + ecdsa_sign_wrap, ecdsa_alloc_wrap, ecdsa_free_wrap, eckey_debug, /* Compatible key structures */ diff --git a/library/ssl_cli.c b/library/ssl_cli.c index 274cb3ae3..829e46b75 100644 --- a/library/ssl_cli.c +++ b/library/ssl_cli.c @@ -2044,40 +2044,42 @@ static int ssl_write_certificate_verify( ssl_context *ssl ) if( ssl->minor_ver == SSL_MINOR_VERSION_3 ) ssl->out_msg[5] = SSL_SIG_RSA; - if( ( ret = ssl->rsa_sign( ssl->rsa_key, ssl->f_rng, ssl->p_rng, - RSA_PRIVATE, md_alg, - hashlen, hash, ssl->out_msg + 6 + offset ) ) != 0 ) + if( ssl->rsa_use_alt ) { - SSL_DEBUG_RET( 1, "pkcs1_sign", ret ); - return( ret ); - } + if( ( ret = ssl->rsa_sign( ssl->rsa_key, ssl->f_rng, ssl->p_rng, + RSA_PRIVATE, md_alg, + hashlen, hash, ssl->out_msg + 6 + offset ) ) != 0 ) + { + SSL_DEBUG_RET( 1, "rsa_sign", ret ); + return( ret ); + } - n = ssl->rsa_key_len ( ssl->rsa_key ); + n = ssl->rsa_key_len ( ssl->rsa_key ); + } + else + { + if( ( ret = pk_sign( ssl->pk_key, md_alg, hash, hashlen, + ssl->out_msg + 6 + offset, &n, + ssl->f_rng, ssl->p_rng ) ) != 0 ) + { + SSL_DEBUG_RET( 1, "pk_sign", ret ); + return( ret ); + } + } } else #endif /* POLARSSL_RSA_C */ #if defined(POLARSSL_ECDSA_C) if( pk_can_do( ssl->pk_key, POLARSSL_PK_ECDSA ) ) { - ecdsa_context ecdsa; - if( ssl->minor_ver == SSL_MINOR_VERSION_3 ) ssl->out_msg[5] = SSL_SIG_ECDSA; - ecdsa_init( &ecdsa ); - - if( ( ret = ecdsa_from_keypair( &ecdsa, ssl->pk_key->pk_ctx ) ) == 0 ) + if( ( ret = pk_sign( ssl->pk_key, md_alg, hash, hashlen, + ssl->out_msg + 6 + offset, &n, + ssl->f_rng, ssl->p_rng ) ) != 0 ) { - ret = ecdsa_write_signature( &ecdsa, hash, hashlen, - ssl->out_msg + 6 + offset, &n, - ssl->f_rng, ssl->p_rng ); - } - - ecdsa_free( &ecdsa ); - - if( ret != 0 ) - { - SSL_DEBUG_RET( 1, "ecdsa_sign", ret ); + SSL_DEBUG_RET( 1, "pk_sign", ret ); return( ret ); } } diff --git a/library/ssl_srv.c b/library/ssl_srv.c index e3f604fe9..ffd754e36 100644 --- a/library/ssl_srv.c +++ b/library/ssl_srv.c @@ -2080,22 +2080,34 @@ static int ssl_write_server_key_exchange( ssl_context *ssl ) n += 2; } - if( ( ret = ssl->rsa_sign( ssl->rsa_key, ssl->f_rng, ssl->p_rng, - RSA_PRIVATE, md_alg, hashlen, hash, p + 2 ) ) != 0 ) + if( ssl->rsa_use_alt ) { - SSL_DEBUG_RET( 1, "rsa_sign", ret ); - return( ret ); - } + if( ( ret = ssl->rsa_sign( ssl->rsa_key, ssl->f_rng, + ssl->p_rng, RSA_PRIVATE, md_alg, hashlen, + hash, p + 2 ) ) != 0 ) + { + SSL_DEBUG_RET( 1, "rsa_sign", ret ); + return( ret ); + } - signature_len = ssl->rsa_key_len( ssl->rsa_key ); + signature_len = ssl->rsa_key_len( ssl->rsa_key ); + } + else + { + if( ( ret = pk_sign( ssl->pk_key, md_alg, hash, hashlen, + p + 2 , &signature_len, + ssl->f_rng, ssl->p_rng ) ) != 0 ) + { + SSL_DEBUG_RET( 1, "pk_sign", ret ); + return( ret ); + } + } } else #endif /* POLARSSL_RSA_C */ #if defined(POLARSSL_ECDSA_C) if( pk_can_do( ssl->pk_key, POLARSSL_PK_ECDSA ) ) { - ecdsa_context ecdsa; - if( ssl->minor_ver == SSL_MINOR_VERSION_3 ) { *(p++) = ssl->handshake->sig_alg; @@ -2104,21 +2116,11 @@ static int ssl_write_server_key_exchange( ssl_context *ssl ) n += 2; } - ecdsa_init( &ecdsa ); - - ret = ecdsa_from_keypair( &ecdsa, ssl->pk_key->pk_ctx ); - if( ret == 0 ) + if( ( ret = pk_sign( ssl->pk_key, md_alg, hash, hashlen, + p + 2 , &signature_len, + ssl->f_rng, ssl->p_rng ) ) != 0 ) { - ret = ecdsa_write_signature( &ecdsa, hash, hashlen, - p + 2, &signature_len, - ssl->f_rng, ssl->p_rng ); - } - - ecdsa_free( &ecdsa ); - - if( ret != 0 ) - { - SSL_DEBUG_RET( 1, "ecdsa_sign", ret ); + SSL_DEBUG_RET( 1, "pk_sign", ret ); return( ret ); } } diff --git a/library/ssl_tls.c b/library/ssl_tls.c index 4e5b3e6ae..d4723d759 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -3169,6 +3169,7 @@ void ssl_set_own_cert_alt_rsa( ssl_context *ssl, x509_cert *own_cert, rsa_key_len_func rsa_key_len ) { ssl->own_cert = own_cert; + ssl->rsa_use_alt = 1; ssl->rsa_key = rsa_key; ssl->rsa_decrypt = rsa_decrypt; ssl->rsa_sign = rsa_sign;