From 8e3f299895a8bcd234eab350ddf2c1ae6e6b2b08 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Mon, 2 Sep 2019 19:00:44 -0700 Subject: [PATCH 1/2] Improve connection lock handling; always use context manager --- kafka/conn.py | 277 +++++++++++++++++++++++++++----------------------- 1 file changed, 151 insertions(+), 126 deletions(-) diff --git a/kafka/conn.py b/kafka/conn.py index 5ef141c65..6d0496be5 100644 --- a/kafka/conn.py +++ b/kafka/conn.py @@ -593,21 +593,30 @@ def _try_authenticate_plain(self, future): self.config['sasl_plain_username'], self.config['sasl_plain_password']]).encode('utf-8')) size = Int32.encode(len(msg)) - try: - with self._lock: - if not self._can_send_recv(): - return future.failure(Errors.NodeNotReadyError(str(self))) - self._send_bytes_blocking(size + msg) - # The server will send a zero sized message (that is Int32(0)) on success. - # The connection is closed on failure - data = self._recv_bytes_blocking(4) + err = None + close = False + with self._lock: + if not self._can_send_recv(): + err = Errors.NodeNotReadyError(str(self)) + close = False + else: + try: + self._send_bytes_blocking(size + msg) + + # The server will send a zero sized message (that is Int32(0)) on success. + # The connection is closed on failure + data = self._recv_bytes_blocking(4) - except (ConnectionError, TimeoutError) as e: - log.exception("%s: Error receiving reply from server", self) - error = Errors.KafkaConnectionError("%s: %s" % (self, e)) - self.close(error=error) - return future.failure(error) + except (ConnectionError, TimeoutError) as e: + log.exception("%s: Error receiving reply from server", self) + err = Errors.KafkaConnectionError("%s: %s" % (self, e)) + close = True + + if err is not None: + if close: + self.close(error=err) + return future.failure(err) if data != b'\x00\x00\x00\x00': error = Errors.AuthenticationFailedError('Unrecognized response during authentication') @@ -625,61 +634,67 @@ def _try_authenticate_gssapi(self, future): ).canonicalize(gssapi.MechType.kerberos) log.debug('%s: GSSAPI name: %s', self, gssapi_name) - self._lock.acquire() - if not self._can_send_recv(): - return future.failure(Errors.NodeNotReadyError(str(self))) - # Establish security context and negotiate protection level - # For reference RFC 2222, section 7.2.1 - try: - # Exchange tokens until authentication either succeeds or fails - client_ctx = gssapi.SecurityContext(name=gssapi_name, usage='initiate') - received_token = None - while not client_ctx.complete: - # calculate an output token from kafka token (or None if first iteration) - output_token = client_ctx.step(received_token) - - # pass output token to kafka, or send empty response if the security - # context is complete (output token is None in that case) - if output_token is None: - self._send_bytes_blocking(Int32.encode(0)) - else: - msg = output_token + err = None + close = False + with self._lock: + if not self._can_send_recv(): + err = Errors.NodeNotReadyError(str(self)) + close = False + else: + # Establish security context and negotiate protection level + # For reference RFC 2222, section 7.2.1 + try: + # Exchange tokens until authentication either succeeds or fails + client_ctx = gssapi.SecurityContext(name=gssapi_name, usage='initiate') + received_token = None + while not client_ctx.complete: + # calculate an output token from kafka token (or None if first iteration) + output_token = client_ctx.step(received_token) + + # pass output token to kafka, or send empty response if the security + # context is complete (output token is None in that case) + if output_token is None: + self._send_bytes_blocking(Int32.encode(0)) + else: + msg = output_token + size = Int32.encode(len(msg)) + self._send_bytes_blocking(size + msg) + + # The server will send a token back. Processing of this token either + # establishes a security context, or it needs further token exchange. + # The gssapi will be able to identify the needed next step. + # The connection is closed on failure. + header = self._recv_bytes_blocking(4) + (token_size,) = struct.unpack('>i', header) + received_token = self._recv_bytes_blocking(token_size) + + # Process the security layer negotiation token, sent by the server + # once the security context is established. + + # unwraps message containing supported protection levels and msg size + msg = client_ctx.unwrap(received_token).message + # Kafka currently doesn't support integrity or confidentiality security layers, so we + # simply set QoP to 'auth' only (first octet). We reuse the max message size proposed + # by the server + msg = Int8.encode(SASL_QOP_AUTH & Int8.decode(io.BytesIO(msg[0:1]))) + msg[1:] + # add authorization identity to the response, GSS-wrap and send it + msg = client_ctx.wrap(msg + auth_id.encode(), False).message size = Int32.encode(len(msg)) self._send_bytes_blocking(size + msg) - # The server will send a token back. Processing of this token either - # establishes a security context, or it needs further token exchange. - # The gssapi will be able to identify the needed next step. - # The connection is closed on failure. - header = self._recv_bytes_blocking(4) - (token_size,) = struct.unpack('>i', header) - received_token = self._recv_bytes_blocking(token_size) - - # Process the security layer negotiation token, sent by the server - # once the security context is established. - - # unwraps message containing supported protection levels and msg size - msg = client_ctx.unwrap(received_token).message - # Kafka currently doesn't support integrity or confidentiality security layers, so we - # simply set QoP to 'auth' only (first octet). We reuse the max message size proposed - # by the server - msg = Int8.encode(SASL_QOP_AUTH & Int8.decode(io.BytesIO(msg[0:1]))) + msg[1:] - # add authorization identity to the response, GSS-wrap and send it - msg = client_ctx.wrap(msg + auth_id.encode(), False).message - size = Int32.encode(len(msg)) - self._send_bytes_blocking(size + msg) + except (ConnectionError, TimeoutError) as e: + log.exception("%s: Error receiving reply from server", self) + err = Errors.KafkaConnectionError("%s: %s" % (self, e)) + close = True + except Exception as e: + err = e + close = True - except (ConnectionError, TimeoutError) as e: - self._lock.release() - log.exception("%s: Error receiving reply from server", self) - error = Errors.KafkaConnectionError("%s: %s" % (self, e)) - self.close(error=error) - return future.failure(error) - except Exception as e: - self._lock.release() - return future.failure(e) + if err is not None: + if close: + self.close(error=err) + return future.failure(err) - self._lock.release() log.info('%s: Authenticated as %s via GSSAPI', self, gssapi_name) return future.success(True) @@ -688,25 +703,31 @@ def _try_authenticate_oauth(self, future): msg = bytes(self._build_oauth_client_request().encode("utf-8")) size = Int32.encode(len(msg)) - self._lock.acquire() - if not self._can_send_recv(): - return future.failure(Errors.NodeNotReadyError(str(self))) - try: - # Send SASL OAuthBearer request with OAuth token - self._send_bytes_blocking(size + msg) - # The server will send a zero sized message (that is Int32(0)) on success. - # The connection is closed on failure - data = self._recv_bytes_blocking(4) + err = None + close = False + with self._lock: + if not self._can_send_recv(): + err = Errors.NodeNotReadyError(str(self)) + close = False + else: + try: + # Send SASL OAuthBearer request with OAuth token + self._send_bytes_blocking(size + msg) - except (ConnectionError, TimeoutError) as e: - self._lock.release() - log.exception("%s: Error receiving reply from server", self) - error = Errors.KafkaConnectionError("%s: %s" % (self, e)) - self.close(error=error) - return future.failure(error) + # The server will send a zero sized message (that is Int32(0)) on success. + # The connection is closed on failure + data = self._recv_bytes_blocking(4) - self._lock.release() + except (ConnectionError, TimeoutError) as e: + log.exception("%s: Error receiving reply from server", self) + err = Errors.KafkaConnectionError("%s: %s" % (self, e)) + close = True + + if err is not None: + if close: + self.close(error=err) + return future.failure(err) if data != b'\x00\x00\x00\x00': error = Errors.AuthenticationFailedError('Unrecognized response during authentication') @@ -857,6 +878,9 @@ def _send(self, request, blocking=True): future = Future() with self._lock: if not self._can_send_recv(): + # In this case, since we created the future above, + # we know there are no callbacks/errbacks that could fire w/ + # lock. So failing + returning inline should be safe return future.failure(Errors.NodeNotReadyError(str(self))) correlation_id = self._protocol.send_request(request) @@ -935,56 +959,57 @@ def recv(self): def _recv(self): """Take all available bytes from socket, return list of any responses from parser""" recvd = [] - self._lock.acquire() - if not self._can_send_recv(): - log.warning('%s cannot recv: socket not connected', self) - self._lock.release() - return () - - while len(recvd) < self.config['sock_chunk_buffer_count']: - try: - data = self._sock.recv(self.config['sock_chunk_bytes']) - # We expect socket.recv to raise an exception if there are no - # bytes available to read from the socket in non-blocking mode. - # but if the socket is disconnected, we will get empty data - # without an exception raised - if not data: - log.error('%s: socket disconnected', self) - self._lock.release() - self.close(error=Errors.KafkaConnectionError('socket disconnected')) - return [] - else: - recvd.append(data) + exc = None + with self._lock: + if not self._can_send_recv(): + log.warning('%s cannot recv: socket not connected', self) + return () - except SSLWantReadError: - break - except (ConnectionError, TimeoutError) as e: - if six.PY2 and e.errno == errno.EWOULDBLOCK: + while len(recvd) < self.config['sock_chunk_buffer_count']: + try: + data = self._sock.recv(self.config['sock_chunk_bytes']) + # We expect socket.recv to raise an exception if there are no + # bytes available to read from the socket in non-blocking mode. + # but if the socket is disconnected, we will get empty data + # without an exception raised + if not data: + log.error('%s: socket disconnected', self) + exc = Errors.KafkaConnectionError('socket disconnected') + break + else: + recvd.append(data) + + except SSLWantReadError: break - log.exception('%s: Error receiving network data' - ' closing socket', self) - self._lock.release() - self.close(error=Errors.KafkaConnectionError(e)) - return [] - except BlockingIOError: - if six.PY3: + except (ConnectionError, TimeoutError) as e: + if six.PY2 and e.errno == errno.EWOULDBLOCK: + break + log.exception('%s: Error receiving network data' + ' closing socket', self) + exc = Errors.KafkaConnectionError(e) break - self._lock.release() - raise - - recvd_data = b''.join(recvd) - if self._sensors: - self._sensors.bytes_received.record(len(recvd_data)) - - try: - responses = self._protocol.receive_bytes(recvd_data) - except Errors.KafkaProtocolError as e: - self._lock.release() - self.close(e) - return [] - else: - self._lock.release() - return responses + except BlockingIOError: + if six.PY3: + break + # For PY2 this is a catchall and should be re-raised + raise + + # Only process bytes if there was no connection exception + if exc is None: + recvd_data = b''.join(recvd) + if self._sensors: + self._sensors.bytes_received.record(len(recvd_data)) + + # We need to keep the lock through protocol receipt + # so that we ensure that the processed byte order is the + # same as the received byte order + try: + return self._protocol.receive_bytes(recvd_data) + except Errors.KafkaProtocolError as e: + exc = e + + self.close(error=exc) + return () def requests_timed_out(self): with self._lock: From 3d6c7d6d6dcbdd9ca9ae0ba553b1db0e8a505ff2 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Mon, 2 Sep 2019 22:06:21 -0700 Subject: [PATCH 2/2] fixup var name for consistency --- kafka/conn.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/kafka/conn.py b/kafka/conn.py index 6d0496be5..99466d90f 100644 --- a/kafka/conn.py +++ b/kafka/conn.py @@ -959,7 +959,7 @@ def recv(self): def _recv(self): """Take all available bytes from socket, return list of any responses from parser""" recvd = [] - exc = None + err = None with self._lock: if not self._can_send_recv(): log.warning('%s cannot recv: socket not connected', self) @@ -974,7 +974,7 @@ def _recv(self): # without an exception raised if not data: log.error('%s: socket disconnected', self) - exc = Errors.KafkaConnectionError('socket disconnected') + err = Errors.KafkaConnectionError('socket disconnected') break else: recvd.append(data) @@ -986,7 +986,7 @@ def _recv(self): break log.exception('%s: Error receiving network data' ' closing socket', self) - exc = Errors.KafkaConnectionError(e) + err = Errors.KafkaConnectionError(e) break except BlockingIOError: if six.PY3: @@ -995,7 +995,7 @@ def _recv(self): raise # Only process bytes if there was no connection exception - if exc is None: + if err is None: recvd_data = b''.join(recvd) if self._sensors: self._sensors.bytes_received.record(len(recvd_data)) @@ -1006,9 +1006,9 @@ def _recv(self): try: return self._protocol.receive_bytes(recvd_data) except Errors.KafkaProtocolError as e: - exc = e + err = e - self.close(error=exc) + self.close(error=err) return () def requests_timed_out(self):