Skip to content

Commit d884fa3

Browse files
dpkp88manpreet
authored andcommitted
Check for disconnects during ssl handshake and sasl authentication (dpkp#1249)
1 parent c43c77e commit d884fa3

File tree

1 file changed

+42
-31
lines changed

1 file changed

+42
-31
lines changed

kafka/conn.py

Lines changed: 42 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -307,12 +307,15 @@ def connect(self):
307307
self._sock.setsockopt(*option)
308308

309309
self._sock.setblocking(False)
310+
self.last_attempt = time.time()
311+
self.state = ConnectionStates.CONNECTING
310312
if self.config['security_protocol'] in ('SSL', 'SASL_SSL'):
311313
self._wrap_ssl()
312-
log.info('%s: connecting to %s:%d', self, self.host, self.port)
313-
self.state = ConnectionStates.CONNECTING
314-
self.last_attempt = time.time()
315-
self.config['state_change_callback'](self)
314+
# _wrap_ssl can alter the connection state -- disconnects on failure
315+
# so we need to double check that we are still connecting before
316+
if self.connecting():
317+
self.config['state_change_callback'](self)
318+
log.info('%s: connecting to %s:%d', self, self.host, self.port)
316319

317320
if self.state is ConnectionStates.CONNECTING:
318321
# in non-blocking mode, use repeated calls to socket.connect_ex
@@ -375,10 +378,12 @@ def connect(self):
375378
if self.state is ConnectionStates.AUTHENTICATING:
376379
assert self.config['security_protocol'] in ('SASL_PLAINTEXT', 'SASL_SSL')
377380
if self._try_authenticate():
378-
log.debug('%s: Connection complete.', self)
379-
self.state = ConnectionStates.CONNECTED
380-
self._reset_reconnect_backoff()
381-
self.config['state_change_callback'](self)
381+
# _try_authenticate has side-effects: possibly disconnected on socket errors
382+
if self.state is ConnectionStates.AUTHENTICATING:
383+
log.debug('%s: Connection complete.', self)
384+
self.state = ConnectionStates.CONNECTED
385+
self._reset_reconnect_backoff()
386+
self.config['state_change_callback'](self)
382387

383388
return self.state
384389

@@ -405,10 +410,7 @@ def _wrap_ssl(self):
405410
password=self.config['ssl_password'])
406411
if self.config['ssl_crlfile']:
407412
if not hasattr(ssl, 'VERIFY_CRL_CHECK_LEAF'):
408-
error = 'No CRL support with this version of Python.'
409-
log.error('%s: %s Disconnecting.', self, error)
410-
self.close(Errors.ConnectionError(error))
411-
return
413+
raise RuntimeError('This version of Python does not support ssl_crlfile!')
412414
log.info('%s: Loading SSL CRL from %s', self, self.config['ssl_crlfile'])
413415
self._ssl_context.load_verify_locations(self.config['ssl_crlfile'])
414416
# pylint: disable=no-member
@@ -451,7 +453,9 @@ def _try_authenticate(self):
451453
self._sasl_auth_future = future
452454
self._recv()
453455
if self._sasl_auth_future.failed():
454-
raise self._sasl_auth_future.exception # pylint: disable-msg=raising-bad-type
456+
ex = self._sasl_auth_future.exception
457+
if not isinstance(ex, Errors.ConnectionError):
458+
raise ex # pylint: disable-msg=raising-bad-type
455459
return self._sasl_auth_future.succeeded()
456460

457461
def _handle_sasl_handshake_response(self, future, response):
@@ -471,6 +475,19 @@ def _handle_sasl_handshake_response(self, future, response):
471475
'kafka-python does not support SASL mechanism %s' %
472476
self.config['sasl_mechanism']))
473477

478+
def _recv_bytes_blocking(self, n):
479+
self._sock.setblocking(True)
480+
try:
481+
data = b''
482+
while len(data) < n:
483+
fragment = self._sock.recv(n - len(data))
484+
if not fragment:
485+
raise ConnectionError('Connection reset during recv')
486+
data += fragment
487+
return data
488+
finally:
489+
self._sock.setblocking(False)
490+
474491
def _try_authenticate_plain(self, future):
475492
if self.config['security_protocol'] == 'SASL_PLAINTEXT':
476493
log.warning('%s: Sending username and password in the clear', self)
@@ -484,30 +501,23 @@ def _try_authenticate_plain(self, future):
484501
self.config['sasl_plain_password']]).encode('utf-8'))
485502
size = Int32.encode(len(msg))
486503
self._sock.sendall(size + msg)
504+
self._sock.setblocking(False)
487505

488506
# The server will send a zero sized message (that is Int32(0)) on success.
489507
# The connection is closed on failure
490-
while len(data) < 4:
491-
fragment = self._sock.recv(4 - len(data))
492-
if not fragment:
493-
log.error('%s: Authentication failed for user %s', self, self.config['sasl_plain_username'])
494-
error = Errors.AuthenticationFailedError(
495-
'Authentication failed for user {0}'.format(
496-
self.config['sasl_plain_username']))
497-
future.failure(error)
498-
raise error
499-
data += fragment
500-
self._sock.setblocking(False)
501-
except (AssertionError, ConnectionError) as e:
508+
self._recv_bytes_blocking(4)
509+
510+
except ConnectionError as e:
502511
log.exception("%s: Error receiving reply from server", self)
503512
error = Errors.ConnectionError("%s: %s" % (self, e))
504-
future.failure(error)
505513
self.close(error=error)
514+
return future.failure(error)
506515

507516
if data != b'\x00\x00\x00\x00':
508-
return future.failure(Errors.AuthenticationFailedError())
517+
error = Errors.AuthenticationFailedError('Unrecognized response during authentication')
518+
return future.failure(error)
509519

510-
log.info('%s: Authenticated as %s', self, self.config['sasl_plain_username'])
520+
log.info('%s: Authenticated as %s via PLAIN', self, self.config['sasl_plain_username'])
511521
return future.success(True)
512522

513523
def _try_authenticate_gssapi(self, future):
@@ -532,14 +542,15 @@ def _try_authenticate_gssapi(self, future):
532542
msg = output_token
533543
size = Int32.encode(len(msg))
534544
self._sock.sendall(size + msg)
545+
self._sock.setblocking(False)
546+
535547
# The server will send a token back. Processing of this token either
536548
# establishes a security context, or it needs further token exchange.
537549
# The gssapi will be able to identify the needed next step.
538550
# The connection is closed on failure.
539-
header = self._sock.recv(4)
551+
header = self._recv_bytes_blocking(4)
540552
token_size = struct.unpack('>i', header)
541-
received_token = self._sock.recv(token_size)
542-
self._sock.setblocking(False)
553+
received_token = self._recv_bytes_blocking(token_size)
543554

544555
except ConnectionError as e:
545556
log.exception("%s: Error receiving reply from server", self)

0 commit comments

Comments
 (0)