Skip to content

Commit d36ec96

Browse files
committed
Additional BrokerConnection locks to synchronize protocol/IFR state
1 parent 227a946 commit d36ec96

File tree

1 file changed

+73
-50
lines changed

1 file changed

+73
-50
lines changed

kafka/conn.py

Lines changed: 73 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -586,11 +586,12 @@ def _try_authenticate_plain(self, future):
586586
self.config['sasl_plain_password']]).encode('utf-8'))
587587
size = Int32.encode(len(msg))
588588
try:
589-
self._send_bytes_blocking(size + msg)
589+
with self._lock:
590+
self._send_bytes_blocking(size + msg)
590591

591-
# The server will send a zero sized message (that is Int32(0)) on success.
592-
# The connection is closed on failure
593-
data = self._recv_bytes_blocking(4)
592+
# The server will send a zero sized message (that is Int32(0)) on success.
593+
# The connection is closed on failure
594+
data = self._recv_bytes_blocking(4)
594595

595596
except ConnectionError as e:
596597
log.exception("%s: Error receiving reply from server", self)
@@ -614,6 +615,7 @@ def _try_authenticate_gssapi(self, future):
614615
).canonicalize(gssapi.MechType.kerberos)
615616
log.debug('%s: GSSAPI name: %s', self, gssapi_name)
616617

618+
self._lock.acquire()
617619
# Establish security context and negotiate protection level
618620
# For reference RFC 2222, section 7.2.1
619621
try:
@@ -656,13 +658,16 @@ def _try_authenticate_gssapi(self, future):
656658
self._send_bytes_blocking(size + msg)
657659

658660
except ConnectionError as e:
661+
self._lock.release()
659662
log.exception("%s: Error receiving reply from server", self)
660663
error = Errors.KafkaConnectionError("%s: %s" % (self, e))
661664
self.close(error=error)
662665
return future.failure(error)
663666
except Exception as e:
667+
self._lock.release()
664668
return future.failure(e)
665669

670+
self._lock.release()
666671
log.info('%s: Authenticated as %s via GSSAPI', self, gssapi_name)
667672
return future.success(True)
668673

@@ -671,6 +676,7 @@ def _try_authenticate_oauth(self, future):
671676

672677
msg = bytes(self._build_oauth_client_request().encode("utf-8"))
673678
size = Int32.encode(len(msg))
679+
self._lock.acquire()
674680
try:
675681
# Send SASL OAuthBearer request with OAuth token
676682
self._send_bytes_blocking(size + msg)
@@ -680,11 +686,14 @@ def _try_authenticate_oauth(self, future):
680686
data = self._recv_bytes_blocking(4)
681687

682688
except ConnectionError as e:
689+
self._lock.release()
683690
log.exception("%s: Error receiving reply from server", self)
684691
error = Errors.KafkaConnectionError("%s: %s" % (self, e))
685692
self.close(error=error)
686693
return future.failure(error)
687694

695+
self._lock.release()
696+
688697
if data != b'\x00\x00\x00\x00':
689698
error = Errors.AuthenticationFailedError('Unrecognized response during authentication')
690699
return future.failure(error)
@@ -784,26 +793,28 @@ def close(self, error=None):
784793
will be failed with this exception.
785794
Default: kafka.errors.KafkaConnectionError.
786795
"""
787-
if self.state is ConnectionStates.DISCONNECTED:
788-
if error is not None:
789-
log.warning('%s: Duplicate close() with error: %s', self, error)
790-
return
791-
log.info('%s: Closing connection. %s', self, error or '')
792-
self.state = ConnectionStates.DISCONNECTING
793-
self.config['state_change_callback'](self)
794-
self._update_reconnect_backoff()
795-
self._close_socket()
796-
self.state = ConnectionStates.DISCONNECTED
797-
self._sasl_auth_future = None
798-
self._protocol = KafkaProtocol(
799-
client_id=self.config['client_id'],
800-
api_version=self.config['api_version'])
801-
if error is None:
802-
error = Errors.Cancelled(str(self))
803-
while self.in_flight_requests:
804-
(_correlation_id, (future, _timestamp)) = self.in_flight_requests.popitem()
796+
with self._lock:
797+
if self.state is ConnectionStates.DISCONNECTED:
798+
return
799+
log.info('%s: Closing connection. %s', self, error or '')
800+
self.state = ConnectionStates.DISCONNECTING
801+
self.config['state_change_callback'](self)
802+
self._update_reconnect_backoff()
803+
self._close_socket()
804+
self.state = ConnectionStates.DISCONNECTED
805+
self._sasl_auth_future = None
806+
self._protocol = KafkaProtocol(
807+
client_id=self.config['client_id'],
808+
api_version=self.config['api_version'])
809+
if error is None:
810+
error = Errors.Cancelled(str(self))
811+
ifrs = list(self.in_flight_requests.items())
812+
self.in_flight_requests.clear()
813+
self.config['state_change_callback'](self)
814+
815+
# drop lock before processing futures
816+
for (_correlation_id, (future, _timestamp)) in ifrs:
805817
future.failure(error)
806-
self.config['state_change_callback'](self)
807818

808819
def send(self, request, blocking=True):
809820
"""Queue request for async network send, return Future()"""
@@ -817,18 +828,21 @@ def send(self, request, blocking=True):
817828
return self._send(request, blocking=blocking)
818829

819830
def _send(self, request, blocking=True):
820-
assert self.state in (ConnectionStates.AUTHENTICATING, ConnectionStates.CONNECTED)
821831
future = Future()
822832
with self._lock:
833+
if self.state not in (ConnectionStates.AUTHENTICATING,
834+
ConnectionStates.CONNECTED):
835+
return future.failure(Errors.NodeNotReady(str(self)))
836+
823837
correlation_id = self._protocol.send_request(request)
824838

825-
log.debug('%s Request %d: %s', self, correlation_id, request)
826-
if request.expect_response():
827-
sent_time = time.time()
828-
assert correlation_id not in self.in_flight_requests, 'Correlation ID already in-flight!'
829-
self.in_flight_requests[correlation_id] = (future, sent_time)
830-
else:
831-
future.success(None)
839+
log.debug('%s Request %d: %s', self, correlation_id, request)
840+
if request.expect_response():
841+
sent_time = time.time()
842+
assert correlation_id not in self.in_flight_requests, 'Correlation ID already in-flight!'
843+
self.in_flight_requests[correlation_id] = (future, sent_time)
844+
else:
845+
future.success(None)
832846

833847
# Attempt to replicate behavior from prior to introduction of
834848
# send_pending_requests() / async sends
@@ -839,16 +853,16 @@ def _send(self, request, blocking=True):
839853

840854
def send_pending_requests(self):
841855
"""Can block on network if request is larger than send_buffer_bytes"""
842-
if self.state not in (ConnectionStates.AUTHENTICATING,
843-
ConnectionStates.CONNECTED):
844-
return Errors.NodeNotReadyError(str(self))
845-
with self._lock:
846-
data = self._protocol.send_bytes()
847856
try:
848-
# In the future we might manage an internal write buffer
849-
# and send bytes asynchronously. For now, just block
850-
# sending each request payload
851-
total_bytes = self._send_bytes_blocking(data)
857+
with self._lock:
858+
if self.state not in (ConnectionStates.AUTHENTICATING,
859+
ConnectionStates.CONNECTED):
860+
return Errors.NodeNotReadyError(str(self))
861+
# In the future we might manage an internal write buffer
862+
# and send bytes asynchronously. For now, just block
863+
# sending each request payload
864+
data = self._protocol.send_bytes()
865+
total_bytes = self._send_bytes_blocking(data)
852866
if self._sensors:
853867
self._sensors.bytes_sent.record(total_bytes)
854868
return total_bytes
@@ -868,7 +882,8 @@ def recv(self):
868882
869883
Return list of (response, future) tuples
870884
"""
871-
if not self.connected() and not self.state is ConnectionStates.AUTHENTICATING:
885+
if self.state not in (ConnectionStates.AUTHENTICATING,
886+
ConnectionStates.CONNECTED):
872887
log.warning('%s cannot recv: socket not connected', self)
873888
# If requests are pending, we should close the socket and
874889
# fail all the pending request futures
@@ -892,7 +907,8 @@ def recv(self):
892907
# augment respones w/ correlation_id, future, and timestamp
893908
for i, (correlation_id, response) in enumerate(responses):
894909
try:
895-
(future, timestamp) = self.in_flight_requests.pop(correlation_id)
910+
with self._lock:
911+
(future, timestamp) = self.in_flight_requests.pop(correlation_id)
896912
except KeyError:
897913
self.close(Errors.KafkaConnectionError('Received unrecognized correlation id'))
898914
return ()
@@ -908,6 +924,7 @@ def recv(self):
908924
def _recv(self):
909925
"""Take all available bytes from socket, return list of any responses from parser"""
910926
recvd = []
927+
self._lock.acquire()
911928
while len(recvd) < self.config['sock_chunk_buffer_count']:
912929
try:
913930
data = self._sock.recv(self.config['sock_chunk_bytes'])
@@ -917,6 +934,7 @@ def _recv(self):
917934
# without an exception raised
918935
if not data:
919936
log.error('%s: socket disconnected', self)
937+
self._lock.release()
920938
self.close(error=Errors.KafkaConnectionError('socket disconnected'))
921939
return []
922940
else:
@@ -929,11 +947,13 @@ def _recv(self):
929947
break
930948
log.exception('%s: Error receiving network data'
931949
' closing socket', self)
950+
self._lock.release()
932951
self.close(error=Errors.KafkaConnectionError(e))
933952
return []
934953
except BlockingIOError:
935954
if six.PY3:
936955
break
956+
self._lock.release()
937957
raise
938958

939959
recvd_data = b''.join(recvd)
@@ -943,20 +963,23 @@ def _recv(self):
943963
try:
944964
responses = self._protocol.receive_bytes(recvd_data)
945965
except Errors.KafkaProtocolError as e:
966+
self._lock.release()
946967
self.close(e)
947968
return []
948969
else:
970+
self._lock.release()
949971
return responses
950972

951973
def requests_timed_out(self):
952-
if self.in_flight_requests:
953-
get_timestamp = lambda v: v[1]
954-
oldest_at = min(map(get_timestamp,
955-
self.in_flight_requests.values()))
956-
timeout = self.config['request_timeout_ms'] / 1000.0
957-
if time.time() >= oldest_at + timeout:
958-
return True
959-
return False
974+
with self._lock:
975+
if self.in_flight_requests:
976+
get_timestamp = lambda v: v[1]
977+
oldest_at = min(map(get_timestamp,
978+
self.in_flight_requests.values()))
979+
timeout = self.config['request_timeout_ms'] / 1000.0
980+
if time.time() >= oldest_at + timeout:
981+
return True
982+
return False
960983

961984
def _handle_api_version_response(self, response):
962985
error_type = Errors.for_code(response.error_code)

0 commit comments

Comments
 (0)