diff --git a/src/TLSClient.cpp b/src/TLSClient.cpp index ab80c5647..7c9598c97 100644 --- a/src/TLSClient.cpp +++ b/src/TLSClient.cpp @@ -30,6 +30,7 @@ #include #include +#include #include #include #include @@ -44,6 +45,7 @@ #include #include +#define HEADER_SIZE 4 #define MAX_BUF 16384 #if GNUTLS_VERSION_NUMBER < 0x030406 @@ -469,24 +471,15 @@ void TLSClient::send (const std::string& data) packet[3] = l; unsigned int total = 0; - unsigned int remaining = packet.length (); - while (total < packet.length ()) + int status; + do { - int status; - do - { - status = gnutls_record_send (_session, packet.c_str () + total, remaining); // All - } - while (errno == GNUTLS_E_INTERRUPTED || - errno == GNUTLS_E_AGAIN); - - if (status == -1) - break; - - total += (unsigned int) status; - remaining -= (unsigned int) status; + status = gnutls_record_send (_session, packet.c_str () + total, packet.length () - total); // All } + while ((status > 0 && (total += status) < packet.length ()) || + status == GNUTLS_E_INTERRUPTED || + status == GNUTLS_E_AGAIN); if (_debug) std::cout << "c: INFO Sending 'XXXX" @@ -500,18 +493,22 @@ void TLSClient::recv (std::string& data) { data = ""; // No appending of data. int received = 0; + int total = 0; // Get the encoded length. - unsigned char header[4] {}; + unsigned char header[HEADER_SIZE] {}; do { - received = gnutls_record_recv (_session, header, 4); // All + received = gnutls_record_recv (_session, header + total, HEADER_SIZE - total); // All } - while (received > 0 && - (errno == GNUTLS_E_INTERRUPTED || - errno == GNUTLS_E_AGAIN)); + while ((received > 0 && (total += received) < HEADER_SIZE) || + received == GNUTLS_E_INTERRUPTED || + received == GNUTLS_E_AGAIN); - int total = received; + if (total < HEADER_SIZE) { + throw std::string ("Failed to receive header: ") + + (received < 0 ? gnutls_strerror(received) : "connection lost?"); + } // Decode the length. unsigned long expected = (header[0]<<24) | @@ -521,7 +518,11 @@ void TLSClient::recv (std::string& data) if (_debug) std::cout << "c: INFO expecting " << expected << " bytes.\n"; - // TODO This would be a good place to assert 'expected < _limit'. + if (_limit && expected >= (unsigned long) _limit) { + std::ostringstream err_str; + err_str << "Expected message size " << expected << " is larger than allowed limit " << _limit; + throw err_str.str (); + } // Arbitrary buffer size. char buffer[MAX_BUF]; @@ -531,13 +532,18 @@ void TLSClient::recv (std::string& data) // fits in the buffer. do { + int chunk_size = 0; do { - received = gnutls_record_recv (_session, buffer, MAX_BUF - 1); // All + received = gnutls_record_recv (_session, buffer + chunk_size, MAX_BUF - chunk_size); // All + if (received > 0) { + total += received; + chunk_size += received; + } } - while (received > 0 && - (errno == GNUTLS_E_INTERRUPTED || - errno == GNUTLS_E_AGAIN)); + while ((received > 0 && (unsigned long) total < expected && chunk_size < MAX_BUF) || + received == GNUTLS_E_INTERRUPTED || + received == GNUTLS_E_AGAIN); // Other end closed the connection. if (received == 0) @@ -548,17 +554,10 @@ void TLSClient::recv (std::string& data) } // Something happened. - if (received < 0 && gnutls_error_is_fatal (received) == 0) // All - { - if (_debug) - std::cout << "c: WARNING " << gnutls_strerror (received) << '\n'; // All - } - else if (received < 0) + if (received < 0) throw std::string (gnutls_strerror (received)); // All - buffer [received] = '\0'; - data += buffer; - total += received; + data.append (buffer, chunk_size); // Stop at defined limit. if (_limit && total > _limit)