TLSClient: fix multiple issues with error handling

- do not check errno on successful function calls (it might not be
  cleared after previous failed one)
- GNUTLS_E_* are not passed through errno but as function return value
- therefore there's more error spectrum than -1
- do not assume whole header is received, check number of bytes fetched

small additional improvements:
- read as many bytes into buffer as possible before appending to data
- skip writing nul byte at the end of buffer and use append() instead
- additional sanity checks
This commit is contained in:
Jan Palus 2022-01-28 23:03:41 +01:00 committed by Tomas Babej
parent 59a1729a05
commit d541e0da65

View file

@ -30,6 +30,7 @@
#include <TLSClient.h> #include <TLSClient.h>
#include <iostream> #include <iostream>
#include <sstream>
#include <unistd.h> #include <unistd.h>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
@ -44,6 +45,7 @@
#include <shared.h> #include <shared.h>
#include <format.h> #include <format.h>
#define HEADER_SIZE 4
#define MAX_BUF 16384 #define MAX_BUF 16384
#if GNUTLS_VERSION_NUMBER < 0x030406 #if GNUTLS_VERSION_NUMBER < 0x030406
@ -469,24 +471,15 @@ void TLSClient::send (const std::string& data)
packet[3] = l; packet[3] = l;
unsigned int total = 0; unsigned int total = 0;
unsigned int remaining = packet.length ();
while (total < packet.length ())
{
int status; int status;
do do
{ {
status = gnutls_record_send (_session, packet.c_str () + total, remaining); // All status = gnutls_record_send (_session, packet.c_str () + total, packet.length () - total); // All
}
while (errno == GNUTLS_E_INTERRUPTED ||
errno == GNUTLS_E_AGAIN);
if (status == -1)
break;
total += (unsigned int) status;
remaining -= (unsigned int) status;
} }
while ((status > 0 && (total += status) < packet.length ()) ||
status == GNUTLS_E_INTERRUPTED ||
status == GNUTLS_E_AGAIN);
if (_debug) if (_debug)
std::cout << "c: INFO Sending 'XXXX" std::cout << "c: INFO Sending 'XXXX"
@ -500,18 +493,22 @@ void TLSClient::recv (std::string& data)
{ {
data = ""; // No appending of data. data = ""; // No appending of data.
int received = 0; int received = 0;
int total = 0;
// Get the encoded length. // Get the encoded length.
unsigned char header[4] {}; unsigned char header[HEADER_SIZE] {};
do do
{ {
received = gnutls_record_recv (_session, header, 4); // All received = gnutls_record_recv (_session, header + total, HEADER_SIZE - total); // All
} }
while (received > 0 && while ((received > 0 && (total += received) < HEADER_SIZE) ||
(errno == GNUTLS_E_INTERRUPTED || received == GNUTLS_E_INTERRUPTED ||
errno == GNUTLS_E_AGAIN)); received == GNUTLS_E_AGAIN);
int total = received; if (total < HEADER_SIZE) {
throw std::string ("Failed to receive header: ") +
(received < 0 ? gnutls_strerror(received) : "connection lost?");
}
// Decode the length. // Decode the length.
unsigned long expected = (header[0]<<24) | unsigned long expected = (header[0]<<24) |
@ -521,7 +518,11 @@ void TLSClient::recv (std::string& data)
if (_debug) if (_debug)
std::cout << "c: INFO expecting " << expected << " bytes.\n"; std::cout << "c: INFO expecting " << expected << " bytes.\n";
// TODO This would be a good place to assert 'expected < _limit'. if (_limit && expected >= (unsigned long) _limit) {
std::ostringstream err_str;
err_str << "Expected message size " << expected << " is larger than allowed limit " << _limit;
throw err_str.str ();
}
// Arbitrary buffer size. // Arbitrary buffer size.
char buffer[MAX_BUF]; char buffer[MAX_BUF];
@ -531,13 +532,18 @@ void TLSClient::recv (std::string& data)
// fits in the buffer. // fits in the buffer.
do do
{ {
int chunk_size = 0;
do do
{ {
received = gnutls_record_recv (_session, buffer, MAX_BUF - 1); // All received = gnutls_record_recv (_session, buffer + chunk_size, MAX_BUF - chunk_size); // All
if (received > 0) {
total += received;
chunk_size += received;
} }
while (received > 0 && }
(errno == GNUTLS_E_INTERRUPTED || while ((received > 0 && (unsigned long) total < expected && chunk_size < MAX_BUF) ||
errno == GNUTLS_E_AGAIN)); received == GNUTLS_E_INTERRUPTED ||
received == GNUTLS_E_AGAIN);
// Other end closed the connection. // Other end closed the connection.
if (received == 0) if (received == 0)
@ -548,17 +554,10 @@ void TLSClient::recv (std::string& data)
} }
// Something happened. // Something happened.
if (received < 0 && gnutls_error_is_fatal (received) == 0) // All if (received < 0)
{
if (_debug)
std::cout << "c: WARNING " << gnutls_strerror (received) << '\n'; // All
}
else if (received < 0)
throw std::string (gnutls_strerror (received)); // All throw std::string (gnutls_strerror (received)); // All
buffer [received] = '\0'; data.append (buffer, chunk_size);
data += buffer;
total += received;
// Stop at defined limit. // Stop at defined limit.
if (_limit && total > _limit) if (_limit && total > _limit)