diff --git a/ChangeLog b/ChangeLog index d662bcf54..bfe342126 100644 --- a/ChangeLog +++ b/ChangeLog @@ -40,6 +40,8 @@ Changes POLARSSL_MODE_CFB, to also handle different block size CFB modes. * Removed handling for SSLv2 Client Hello (as per RFC 5246 recommendation) * Revamped session resumption handling + * Generalized external private key implementation handling (like PKCS#11) + in SSL/TLS Bugfix * Fixed handling error in mpi_cmp_mpi() on longer B values (found by diff --git a/include/polarssl/config.h b/include/polarssl/config.h index 538ef817d..543b96c8c 100644 --- a/include/polarssl/config.h +++ b/include/polarssl/config.h @@ -612,7 +612,7 @@ /** * \def POLARSSL_PKCS11_C * - * Enable support for PKCS#11 smartcard support. + * Enable wrapper for PKCS#11 smartcard support. * * Module: library/ssl_srv.c * Caller: library/ssl_cli.c @@ -620,7 +620,7 @@ * * Requires: POLARSSL_SSL_TLS_C * - * This module is required for SSL/TLS PKCS #11 smartcard support. + * This module enables SSL/TLS PKCS #11 smartcard support. * Requires the presence of the PKCS#11 helper library (libpkcs11-helper) #define POLARSSL_PKCS11_C */ diff --git a/include/polarssl/pkcs11.h b/include/polarssl/pkcs11.h index a65a72e81..ddfae3017 100644 --- a/include/polarssl/pkcs11.h +++ b/include/polarssl/pkcs11.h @@ -37,6 +37,14 @@ #include +#if defined(_MSC_VER) && !defined(inline) +#define inline _inline +#else +#if defined(__ARMCC_VERSION) && !defined(inline) +#define inline __inline +#endif /* __ARMCC_VERSION */ +#endif /*_MSC_VER */ + /** * Context for PKCS #11 private keys. */ @@ -121,6 +129,33 @@ int pkcs11_sign( pkcs11_context *ctx, const unsigned char *hash, unsigned char *sig ); +/** + * SSL/TLS wrappers for PKCS#11 functions + */ +static inline int ssl_pkcs11_decrypt( void *ctx, int mode, size_t *olen, + const unsigned char *input, unsigned char *output, + unsigned int output_max_len ) +{ + return pkcs11_decrypt( (pkcs11_context *) ctx, mode, olen, input, output, + output_max_len ); +} + +static inline int ssl_pkcs11_sign( void *ctx, + int (*f_rng)(void *, unsigned char *, size_t), void *p_rng, + int mode, int hash_id, unsigned int hashlen, + const unsigned char *hash, unsigned char *sig ) +{ + ((void) f_rng); + ((void) p_rng); + return pkcs11_sign( (pkcs11_context *) ctx, mode, hash_id, + hashlen, hash, sig ); +} + +static inline size_t ssl_pkcs11_key_len( void *ctx ) +{ + return ( (pkcs11_context *) ctx )->len; +} + #endif /* POLARSSL_PKCS11_C */ #endif /* POLARSSL_PKCS11_H */ diff --git a/include/polarssl/ssl.h b/include/polarssl/ssl.h index fcf8a8ffe..62ffba2d3 100644 --- a/include/polarssl/ssl.h +++ b/include/polarssl/ssl.h @@ -42,10 +42,6 @@ #include "dhm.h" #endif -#if defined(POLARSSL_PKCS11_C) -#include "pkcs11.h" -#endif - #if defined(POLARSSL_ZLIB_SUPPORT) #include "zlib.h" #endif @@ -253,6 +249,20 @@ #define TLS_EXT_RENEGOTIATION_INFO 0xFF01 + +/* + * Generic function pointers for allowing external RSA private key + * implementations. + */ +typedef int (*rsa_decrypt_func)( void *ctx, int mode, size_t *olen, + const unsigned char *input, unsigned char *output, + size_t output_max_len ); +typedef int (*rsa_sign_func)( void *ctx, + int (*f_rng)(void *, unsigned char *, size_t), void *p_rng, + int mode, int hash_id, unsigned int hashlen, + const unsigned char *hash, unsigned char *sig ); +typedef size_t (*rsa_key_len_func)( void *ctx ); + /* * SSL state machine */ @@ -446,10 +456,11 @@ struct _ssl_context /* * PKI layer */ - rsa_context *rsa_key; /*!< own RSA private key */ -#if defined(POLARSSL_PKCS11_C) - pkcs11_context *pkcs11_key; /*!< own PKCS#11 RSA private key */ -#endif + void *rsa_key; /*!< own RSA private key */ + rsa_decrypt_func rsa_decrypt; /*!< function for RSA decrypt*/ + rsa_sign_func rsa_sign; /*!< function for RSA sign */ + rsa_key_len_func rsa_key_len; /*!< function for RSA key len*/ + x509_cert *own_cert; /*!< own X.509 certificate */ x509_cert *ca_chain; /*!< own trusted CA chain */ x509_crl *ca_crl; /*!< trusted CA CRLs */ @@ -722,17 +733,26 @@ void ssl_set_ca_chain( ssl_context *ssl, x509_cert *ca_chain, void ssl_set_own_cert( ssl_context *ssl, x509_cert *own_cert, rsa_context *rsa_key ); -#if defined(POLARSSL_PKCS11_C) /** - * \brief Set own certificate and PKCS#11 private key + * \brief Set own certificate and alternate non-PolarSSL private + * key and handling callbacks, such as the PKCS#11 wrappers + * or any other external private key handler. + * (see the respective RSA functions in rsa.h for documentation + * of the callback parameters, with the only change being + * that the rsa_context * is a void * in the callbacks) * * \param ssl SSL context * \param own_cert own public certificate - * \param pkcs11_key own PKCS#11 RSA key + * \param rsa_key alternate implementation private RSA key + * \param rsa_decrypt_func alternate implementation of \c rsa_pkcs1_decrypt() + * \param rsa_sign_func alternate implementation of \c rsa_pkcs1_sign() + * \param rsa_key_len_func function returning length of RSA key in bytes */ -void ssl_set_own_cert_pkcs11( ssl_context *ssl, x509_cert *own_cert, - pkcs11_context *pkcs11_key ); -#endif +void ssl_set_own_cert_alt( ssl_context *ssl, x509_cert *own_cert, + void *rsa_key, + rsa_decrypt_func rsa_decrypt, + rsa_sign_func rsa_sign, + rsa_key_len_func rsa_key_len ); #if defined(POLARSSL_DHM_C) /** diff --git a/library/ssl_cli.c b/library/ssl_cli.c index b44af2ba3..3e1b0569f 100644 --- a/library/ssl_cli.c +++ b/library/ssl_cli.c @@ -30,10 +30,6 @@ #include "polarssl/debug.h" #include "polarssl/ssl.h" -#if defined(POLARSSL_PKCS11_C) -#include "polarssl/pkcs11.h" -#endif /* defined(POLARSSL_PKCS11_C) */ - #include #include #include @@ -1115,15 +1111,8 @@ static int ssl_write_certificate_verify( ssl_context *ssl ) if( ssl->rsa_key == NULL ) { -#if defined(POLARSSL_PKCS11_C) - if( ssl->pkcs11_key == NULL ) - { -#endif /* defined(POLARSSL_PKCS11_C) */ - SSL_DEBUG_MSG( 1, ( "got no private key" ) ); - return( POLARSSL_ERR_SSL_PRIVATE_KEY_REQUIRED ); -#if defined(POLARSSL_PKCS11_C) - } -#endif /* defined(POLARSSL_PKCS11_C) */ + SSL_DEBUG_MSG( 1, ( "got no private key" ) ); + return( POLARSSL_ERR_SSL_PRIVATE_KEY_REQUIRED ); } /* @@ -1132,11 +1121,7 @@ static int ssl_write_certificate_verify( ssl_context *ssl ) ssl->handshake->calc_verify( ssl, hash ); if ( ssl->rsa_key ) - n = ssl->rsa_key->len; -#if defined(POLARSSL_PKCS11_C) - else - n = ssl->pkcs11_key->len; -#endif /* defined(POLARSSL_PKCS11_C) */ + n = ssl->rsa_key_len ( ssl->rsa_key ); if( ssl->minor_ver == SSL_MINOR_VERSION_3 ) { @@ -1164,14 +1149,9 @@ static int ssl_write_certificate_verify( ssl_context *ssl ) if( ssl->rsa_key ) { - ret = rsa_pkcs1_sign( ssl->rsa_key, ssl->f_rng, ssl->p_rng, - RSA_PRIVATE, hash_id, - hashlen, hash, ssl->out_msg + 6 + offset ); - } else { -#if defined(POLARSSL_PKCS11_C) - ret = pkcs11_sign( ssl->pkcs11_key, RSA_PRIVATE, hash_id, - hashlen, hash, ssl->out_msg + 6 + offset ); -#endif /* defined(POLARSSL_PKCS11_C) */ + ret = ssl->rsa_sign( ssl->rsa_key, ssl->f_rng, ssl->p_rng, + RSA_PRIVATE, hash_id, + hashlen, hash, ssl->out_msg + 6 + offset ); } if (ret != 0) diff --git a/library/ssl_srv.c b/library/ssl_srv.c index 64b0d2df4..e31145864 100644 --- a/library/ssl_srv.c +++ b/library/ssl_srv.c @@ -30,10 +30,6 @@ #include "polarssl/debug.h" #include "polarssl/ssl.h" -#if defined(POLARSSL_PKCS11_C) -#include "polarssl/pkcs11.h" -#endif /* defined(POLARSSL_PKCS11_C) */ - #include #include #include @@ -644,15 +640,8 @@ static int ssl_write_server_key_exchange( ssl_context *ssl ) if( ssl->rsa_key == NULL ) { -#if defined(POLARSSL_PKCS11_C) - if( ssl->pkcs11_key == NULL ) - { -#endif /* defined(POLARSSL_PKCS11_C) */ - SSL_DEBUG_MSG( 1, ( "got no private key" ) ); - return( POLARSSL_ERR_SSL_PRIVATE_KEY_REQUIRED ); -#if defined(POLARSSL_PKCS11_C) - } -#endif /* defined(POLARSSL_PKCS11_C) */ + SSL_DEBUG_MSG( 1, ( "got no private key" ) ); + return( POLARSSL_ERR_SSL_PRIVATE_KEY_REQUIRED ); } /* @@ -738,11 +727,7 @@ static int ssl_write_server_key_exchange( ssl_context *ssl ) SSL_DEBUG_BUF( 3, "parameters hash", hash, hashlen ); if ( ssl->rsa_key ) - rsa_key_len = ssl->rsa_key->len; -#if defined(POLARSSL_PKCS11_C) - else - rsa_key_len = ssl->pkcs11_key->len; -#endif /* defined(POLARSSL_PKCS11_C) */ + rsa_key_len = ssl->rsa_key_len( ssl->rsa_key ); if( ssl->minor_ver == SSL_MINOR_VERSION_3 ) { @@ -758,16 +743,11 @@ static int ssl_write_server_key_exchange( ssl_context *ssl ) if ( ssl->rsa_key ) { - ret = rsa_pkcs1_sign( ssl->rsa_key, ssl->f_rng, ssl->p_rng, - RSA_PRIVATE, - hash_id, hashlen, hash, ssl->out_msg + 6 + n ); + ret = ssl->rsa_sign( ssl->rsa_key, ssl->f_rng, ssl->p_rng, + RSA_PRIVATE, + hash_id, hashlen, hash, + ssl->out_msg + 6 + n ); } -#if defined(POLARSSL_PKCS11_C) - else { - ret = pkcs11_sign( ssl->pkcs11_key, RSA_PRIVATE, - hash_id, hashlen, hash, ssl->out_msg + 6 + n ); - } -#endif /* defined(POLARSSL_PKCS11_C) */ if( ret != 0 ) { @@ -898,15 +878,8 @@ static int ssl_parse_client_key_exchange( ssl_context *ssl ) { if( ssl->rsa_key == NULL ) { -#if defined(POLARSSL_PKCS11_C) - if( ssl->pkcs11_key == NULL ) - { -#endif - SSL_DEBUG_MSG( 1, ( "got no private key" ) ); - return( POLARSSL_ERR_SSL_PRIVATE_KEY_REQUIRED ); -#if defined(POLARSSL_PKCS11_C) - } -#endif + SSL_DEBUG_MSG( 1, ( "got no private key" ) ); + return( POLARSSL_ERR_SSL_PRIVATE_KEY_REQUIRED ); } /* @@ -914,11 +887,7 @@ static int ssl_parse_client_key_exchange( ssl_context *ssl ) */ i = 4; if( ssl->rsa_key ) - n = ssl->rsa_key->len; -#if defined(POLARSSL_PKCS11_C) - else - n = ssl->pkcs11_key->len; -#endif + n = ssl->rsa_key_len( ssl->rsa_key ); ssl->handshake->pmslen = 48; if( ssl->minor_ver != SSL_MINOR_VERSION_0 ) @@ -939,21 +908,12 @@ static int ssl_parse_client_key_exchange( ssl_context *ssl ) } if( ssl->rsa_key ) { - ret = rsa_pkcs1_decrypt( ssl->rsa_key, RSA_PRIVATE, - &ssl->handshake->pmslen, - ssl->in_msg + i, - ssl->handshake->premaster, - sizeof(ssl->handshake->premaster) ); + ret = ssl->rsa_decrypt( ssl->rsa_key, RSA_PRIVATE, + &ssl->handshake->pmslen, + ssl->in_msg + i, + ssl->handshake->premaster, + sizeof(ssl->handshake->premaster) ); } -#if defined(POLARSSL_PKCS11_C) - else { - ret = pkcs11_decrypt( ssl->pkcs11_key, RSA_PRIVATE, - &ssl->handshake->pmslen, - ssl->in_msg + i, - ssl->handshake->premaster, - sizeof(ssl->handshake->premaster) ); - } -#endif /* defined(POLARSSL_PKCS11_C) */ if( ret != 0 || ssl->handshake->pmslen != 48 || ssl->handshake->premaster[0] != ssl->max_major_ver || diff --git a/library/ssl_tls.c b/library/ssl_tls.c index 61920042b..cc0f65c57 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -65,6 +65,28 @@ int (*ssl_hw_record_read)(ssl_context *ssl) = NULL; int (*ssl_hw_record_finish)(ssl_context *ssl) = NULL; #endif +static int ssl_rsa_decrypt( void *ctx, int mode, size_t *olen, + const unsigned char *input, unsigned char *output, + size_t output_max_len ) +{ + return rsa_pkcs1_decrypt( (rsa_context *) ctx, mode, olen, input, output, + output_max_len ); +} + +static int ssl_rsa_sign( void *ctx, + int (*f_rng)(void *, unsigned char *, size_t), void *p_rng, + int mode, int hash_id, unsigned int hashlen, + const unsigned char *hash, unsigned char *sig ) +{ + return rsa_pkcs1_sign( (rsa_context *) ctx, f_rng, p_rng, mode, hash_id, + hashlen, hash, sig ); +} + +static size_t ssl_rsa_key_len( void *ctx ) +{ + return ( (rsa_context *) ctx )->len; +} + /* * Key material generation */ @@ -2826,6 +2848,10 @@ int ssl_init( ssl_context *ssl ) memset( ssl, 0, sizeof( ssl_context ) ); + ssl->rsa_decrypt = ssl_rsa_decrypt; + ssl->rsa_sign = ssl_rsa_sign; + ssl->rsa_key_len = ssl_rsa_key_len; + ssl->in_ctr = (unsigned char *) malloc( len ); ssl->in_hdr = ssl->in_ctr + 8; ssl->in_msg = ssl->in_ctr + 13; @@ -3002,14 +3028,19 @@ void ssl_set_own_cert( ssl_context *ssl, x509_cert *own_cert, ssl->rsa_key = rsa_key; } -#if defined(POLARSSL_PKCS11_C) -void ssl_set_own_cert_pkcs11( ssl_context *ssl, x509_cert *own_cert, - pkcs11_context *pkcs11_key ) +void ssl_set_own_cert_alt( ssl_context *ssl, x509_cert *own_cert, + void *rsa_key, + rsa_decrypt_func rsa_decrypt, + rsa_sign_func rsa_sign, + rsa_key_len_func rsa_key_len ) { ssl->own_cert = own_cert; - ssl->pkcs11_key = pkcs11_key; + ssl->rsa_key = rsa_key; + ssl->rsa_decrypt = rsa_decrypt; + ssl->rsa_sign = rsa_sign; + ssl->rsa_key_len = rsa_key_len; } -#endif + #if defined(POLARSSL_DHM_C) int ssl_set_dh_param( ssl_context *ssl, const char *dhm_P, const char *dhm_G )