diff --git a/src/TLSClient.cpp b/src/TLSClient.cpp index d2823c3da..f3ef90942 100644 --- a/src/TLSClient.cpp +++ b/src/TLSClient.cpp @@ -45,6 +45,7 @@ #include #include #include +#include #define MAX_BUF 16384 @@ -68,11 +69,13 @@ TLSClient::TLSClient () : _ca ("") , _cert ("") , _key ("") +, _host ("") +, _port ("") , _session(0) , _socket (0) , _limit (0) , _debug (false) -, _trust(false) +, _trust(strict) { } @@ -109,13 +112,15 @@ void TLSClient::debug (int level) } //////////////////////////////////////////////////////////////////////////////// -void TLSClient::trust (bool value) +void TLSClient::trust (const enum trust_level value) { _trust = value; if (_debug) { - if (_trust) + if (_trust == allow_all) std::cout << "c: INFO Server certificate trusted automatically.\n"; + else if (_trust == ignore_hostname) + std::cout << "c: INFO Server certificate trust verified but hostname ignored.\n"; else std::cout << "c: INFO Server certificate trust verified.\n"; } @@ -179,6 +184,9 @@ void TLSClient::init ( //////////////////////////////////////////////////////////////////////////////// void TLSClient::connect (const std::string& host, const std::string& port) { + _host = host; + _port = port; + // Store the TLSClient instance, so that the verification callback can access // it during the handshake below and call the verifcation method. gnutls_session_set_ptr (_session, (void*) this); @@ -273,19 +281,55 @@ void TLSClient::bye () //////////////////////////////////////////////////////////////////////////////// int TLSClient::verify_certificate () const { - if (_trust) + if (_trust == TLSClient::allow_all) return 0; // This verification function uses the trusted CAs in the credentials // structure. So you must have installed one or more CA certificates. unsigned int status = 0; + + const char* hostname = _host.c_str(); #if GNUTLS_VERSION_NUMBER >= 0x030104 - int ret = gnutls_certificate_verify_peers3 (_session, NULL, &status); -#else - int ret = gnutls_certificate_verify_peers2 (_session, &status); -#endif + if (_trust == TLSClient::ignore_hostname) + hostname = NULL; + + int ret = gnutls_certificate_verify_peers3 (_session, hostname, &status); if (ret < 0) return GNUTLS_E_CERTIFICATE_ERROR; +#else + int ret = gnutls_certificate_verify_peers2 (_session, &status); + if (ret < 0) + return GNUTLS_E_CERTIFICATE_ERROR; + + if ((status == 0) && (_trust != TLSClient::ignore_hostname)) + { + if (gnutls_certificate_type_get (_session) == GNUTLS_CRT_X509) + { + const gnutls_datum* cert_list; + unsigned int cert_list_size; + gnutls_x509_crt cert; + + cert_list = gnutls_certificate_get_peers (_session, &cert_list_size); + if (cert_list_size == 0) + return GNUTLS_E_CERTIFICATE_ERROR; + + ret = gnutls_x509_crt_init (&cert); + if (ret < 0) + return GNUTLS_E_CERTIFICATE_ERROR; + + ret = gnutls_x509_crt_import (cert, &cert_list[0], GNUTLS_X509_FMT_DER); + if (ret < 0) + gnutls_x509_crt_deinit(cert); + status = GNUTLS_E_CERTIFICATE_ERROR; + + if (gnutls_x509_crt_check_hostname (cert, hostname) == 0) + gnutls_x509_crt_deinit(cert); + return GNUTLS_E_CERTIFICATE_ERROR; + } + else + return GNUTLS_E_CERTIFICATE_ERROR; + } +#endif #if GNUTLS_VERSION_NUMBER >= 0x030105 gnutls_certificate_type_t type = gnutls_certificate_type_get (_session); diff --git a/src/TLSClient.h b/src/TLSClient.h index 98bc7d54a..ca99a82a3 100644 --- a/src/TLSClient.h +++ b/src/TLSClient.h @@ -34,11 +34,13 @@ class TLSClient { public: + enum trust_level { strict, ignore_hostname, allow_all }; + TLSClient (); ~TLSClient (); void limit (int); void debug (int); - void trust (bool); + void trust (const enum trust_level); void ciphers (const std::string&); void init (const std::string&, const std::string&, const std::string&); void connect (const std::string&, const std::string&); @@ -53,12 +55,14 @@ private: std::string _cert; std::string _key; std::string _ciphers; + std::string _host; + std::string _port; gnutls_certificate_credentials_t _credentials; gnutls_session_t _session; int _socket; int _limit; bool _debug; - bool _trust; + enum trust_level _trust; }; #endif diff --git a/src/commands/CmdDiagnostics.cpp b/src/commands/CmdDiagnostics.cpp index 5dd764fba..f00a4b8b8 100644 --- a/src/commands/CmdDiagnostics.cpp +++ b/src/commands/CmdDiagnostics.cpp @@ -232,8 +232,12 @@ int CmdDiagnostics::execute (std::string& output) ? " (readable)" : " (not readable)") << "\n"; - if (context.config.get ("taskd.trust") != "") - out << " Trust: override\n"; + if (context.config.get ("taskd.trust") == "allow all") + out << " Trust: allow all\n"; + else if (context.config.get ("taskd.trust") == "ignore hostname") + out << " Trust: ignore hostanme\n"; + else + out << " Trust: strict\n"; out << " Cert: " << context.config.get ("taskd.certificate") diff --git a/src/commands/CmdSync.cpp b/src/commands/CmdSync.cpp index ecaa98ff8..618d2279c 100644 --- a/src/commands/CmdSync.cpp +++ b/src/commands/CmdSync.cpp @@ -29,7 +29,6 @@ #include #include #include -#include #include #include #include @@ -87,14 +86,18 @@ int CmdSync::execute (std::string& output) if (credentials.size () != 3) throw std::string (STRING_CMD_SYNC_BAD_CRED); - bool trust = context.config.getBoolean ("taskd.trust"); + enum TLSClient::trust_level trust = TLSClient::strict; + if (context.config.get ("taskd.trust") == "allow all") + trust = TLSClient::allow_all; + else if (context.config.get ("taskd.trust") == "ignore hostname") + trust = TLSClient::ignore_hostname; // CA must exist, if provided. File ca (context.config.get ("taskd.ca")); if (ca._data != "" && ! ca.exists ()) throw std::string (STRING_CMD_SYNC_BAD_CA); - if (trust && ca._data != "") + if (trust == TLSClient::allow_all && ca._data != "") throw std::string (STRING_CMD_SYNC_TRUST_CA); File certificate (context.config.get ("taskd.certificate")); @@ -319,7 +322,7 @@ bool CmdSync::send ( const std::string& ca, const std::string& certificate, const std::string& key, - bool trust, + const enum TLSClient::trust_level trust, const Msg& request, Msg& response) { diff --git a/src/commands/CmdSync.h b/src/commands/CmdSync.h index 9bfc90561..9dd50e373 100644 --- a/src/commands/CmdSync.h +++ b/src/commands/CmdSync.h @@ -30,6 +30,7 @@ #include #include #include +#include class CmdSync : public Command { @@ -38,7 +39,7 @@ public: int execute (std::string&); private: - bool send (const std::string&, const std::string&, const std::string&, const std::string&, bool, const Msg&, Msg&); + bool send (const std::string&, const std::string&, const std::string&, const std::string&, const enum TLSClient::trust_level, const Msg&, Msg&); }; #endif