@@ -586,11 +586,12 @@ def _try_authenticate_plain(self, future):
586
586
self .config ['sasl_plain_password' ]]).encode ('utf-8' ))
587
587
size = Int32 .encode (len (msg ))
588
588
try :
589
- self ._send_bytes_blocking (size + msg )
589
+ with self ._lock :
590
+ self ._send_bytes_blocking (size + msg )
590
591
591
- # The server will send a zero sized message (that is Int32(0)) on success.
592
- # The connection is closed on failure
593
- data = self ._recv_bytes_blocking (4 )
592
+ # The server will send a zero sized message (that is Int32(0)) on success.
593
+ # The connection is closed on failure
594
+ data = self ._recv_bytes_blocking (4 )
594
595
595
596
except ConnectionError as e :
596
597
log .exception ("%s: Error receiving reply from server" , self )
@@ -614,6 +615,7 @@ def _try_authenticate_gssapi(self, future):
614
615
).canonicalize (gssapi .MechType .kerberos )
615
616
log .debug ('%s: GSSAPI name: %s' , self , gssapi_name )
616
617
618
+ self ._lock .acquire ()
617
619
# Establish security context and negotiate protection level
618
620
# For reference RFC 2222, section 7.2.1
619
621
try :
@@ -656,13 +658,16 @@ def _try_authenticate_gssapi(self, future):
656
658
self ._send_bytes_blocking (size + msg )
657
659
658
660
except ConnectionError as e :
661
+ self ._lock .release ()
659
662
log .exception ("%s: Error receiving reply from server" , self )
660
663
error = Errors .KafkaConnectionError ("%s: %s" % (self , e ))
661
664
self .close (error = error )
662
665
return future .failure (error )
663
666
except Exception as e :
667
+ self ._lock .release ()
664
668
return future .failure (e )
665
669
670
+ self ._lock .release ()
666
671
log .info ('%s: Authenticated as %s via GSSAPI' , self , gssapi_name )
667
672
return future .success (True )
668
673
@@ -671,6 +676,7 @@ def _try_authenticate_oauth(self, future):
671
676
672
677
msg = bytes (self ._build_oauth_client_request ().encode ("utf-8" ))
673
678
size = Int32 .encode (len (msg ))
679
+ self ._lock .acquire ()
674
680
try :
675
681
# Send SASL OAuthBearer request with OAuth token
676
682
self ._send_bytes_blocking (size + msg )
@@ -680,11 +686,14 @@ def _try_authenticate_oauth(self, future):
680
686
data = self ._recv_bytes_blocking (4 )
681
687
682
688
except ConnectionError as e :
689
+ self ._lock .release ()
683
690
log .exception ("%s: Error receiving reply from server" , self )
684
691
error = Errors .KafkaConnectionError ("%s: %s" % (self , e ))
685
692
self .close (error = error )
686
693
return future .failure (error )
687
694
695
+ self ._lock .release ()
696
+
688
697
if data != b'\x00 \x00 \x00 \x00 ' :
689
698
error = Errors .AuthenticationFailedError ('Unrecognized response during authentication' )
690
699
return future .failure (error )
@@ -784,26 +793,28 @@ def close(self, error=None):
784
793
will be failed with this exception.
785
794
Default: kafka.errors.KafkaConnectionError.
786
795
"""
787
- if self .state is ConnectionStates .DISCONNECTED :
788
- if error is not None :
789
- log .warning ('%s: Duplicate close() with error: %s' , self , error )
790
- return
791
- log .info ('%s: Closing connection. %s' , self , error or '' )
792
- self .state = ConnectionStates .DISCONNECTING
793
- self .config ['state_change_callback' ](self )
794
- self ._update_reconnect_backoff ()
795
- self ._close_socket ()
796
- self .state = ConnectionStates .DISCONNECTED
797
- self ._sasl_auth_future = None
798
- self ._protocol = KafkaProtocol (
799
- client_id = self .config ['client_id' ],
800
- api_version = self .config ['api_version' ])
801
- if error is None :
802
- error = Errors .Cancelled (str (self ))
803
- while self .in_flight_requests :
804
- (_correlation_id , (future , _timestamp )) = self .in_flight_requests .popitem ()
796
+ with self ._lock :
797
+ if self .state is ConnectionStates .DISCONNECTED :
798
+ return
799
+ log .info ('%s: Closing connection. %s' , self , error or '' )
800
+ self .state = ConnectionStates .DISCONNECTING
801
+ self .config ['state_change_callback' ](self )
802
+ self ._update_reconnect_backoff ()
803
+ self ._close_socket ()
804
+ self .state = ConnectionStates .DISCONNECTED
805
+ self ._sasl_auth_future = None
806
+ self ._protocol = KafkaProtocol (
807
+ client_id = self .config ['client_id' ],
808
+ api_version = self .config ['api_version' ])
809
+ if error is None :
810
+ error = Errors .Cancelled (str (self ))
811
+ ifrs = list (self .in_flight_requests .items ())
812
+ self .in_flight_requests .clear ()
813
+ self .config ['state_change_callback' ](self )
814
+
815
+ # drop lock before processing futures
816
+ for (_correlation_id , (future , _timestamp )) in ifrs :
805
817
future .failure (error )
806
- self .config ['state_change_callback' ](self )
807
818
808
819
def send (self , request , blocking = True ):
809
820
"""Queue request for async network send, return Future()"""
@@ -817,18 +828,21 @@ def send(self, request, blocking=True):
817
828
return self ._send (request , blocking = blocking )
818
829
819
830
def _send (self , request , blocking = True ):
820
- assert self .state in (ConnectionStates .AUTHENTICATING , ConnectionStates .CONNECTED )
821
831
future = Future ()
822
832
with self ._lock :
833
+ if self .state not in (ConnectionStates .AUTHENTICATING ,
834
+ ConnectionStates .CONNECTED ):
835
+ return future .failure (Errors .NodeNotReady (str (self )))
836
+
823
837
correlation_id = self ._protocol .send_request (request )
824
838
825
- log .debug ('%s Request %d: %s' , self , correlation_id , request )
826
- if request .expect_response ():
827
- sent_time = time .time ()
828
- assert correlation_id not in self .in_flight_requests , 'Correlation ID already in-flight!'
829
- self .in_flight_requests [correlation_id ] = (future , sent_time )
830
- else :
831
- future .success (None )
839
+ log .debug ('%s Request %d: %s' , self , correlation_id , request )
840
+ if request .expect_response ():
841
+ sent_time = time .time ()
842
+ assert correlation_id not in self .in_flight_requests , 'Correlation ID already in-flight!'
843
+ self .in_flight_requests [correlation_id ] = (future , sent_time )
844
+ else :
845
+ future .success (None )
832
846
833
847
# Attempt to replicate behavior from prior to introduction of
834
848
# send_pending_requests() / async sends
@@ -839,16 +853,16 @@ def _send(self, request, blocking=True):
839
853
840
854
def send_pending_requests (self ):
841
855
"""Can block on network if request is larger than send_buffer_bytes"""
842
- if self .state not in (ConnectionStates .AUTHENTICATING ,
843
- ConnectionStates .CONNECTED ):
844
- return Errors .NodeNotReadyError (str (self ))
845
- with self ._lock :
846
- data = self ._protocol .send_bytes ()
847
856
try :
848
- # In the future we might manage an internal write buffer
849
- # and send bytes asynchronously. For now, just block
850
- # sending each request payload
851
- total_bytes = self ._send_bytes_blocking (data )
857
+ with self ._lock :
858
+ if self .state not in (ConnectionStates .AUTHENTICATING ,
859
+ ConnectionStates .CONNECTED ):
860
+ return Errors .NodeNotReadyError (str (self ))
861
+ # In the future we might manage an internal write buffer
862
+ # and send bytes asynchronously. For now, just block
863
+ # sending each request payload
864
+ data = self ._protocol .send_bytes ()
865
+ total_bytes = self ._send_bytes_blocking (data )
852
866
if self ._sensors :
853
867
self ._sensors .bytes_sent .record (total_bytes )
854
868
return total_bytes
@@ -868,7 +882,8 @@ def recv(self):
868
882
869
883
Return list of (response, future) tuples
870
884
"""
871
- if not self .connected () and not self .state is ConnectionStates .AUTHENTICATING :
885
+ if self .state not in (ConnectionStates .AUTHENTICATING ,
886
+ ConnectionStates .CONNECTED ):
872
887
log .warning ('%s cannot recv: socket not connected' , self )
873
888
# If requests are pending, we should close the socket and
874
889
# fail all the pending request futures
@@ -892,7 +907,8 @@ def recv(self):
892
907
# augment respones w/ correlation_id, future, and timestamp
893
908
for i , (correlation_id , response ) in enumerate (responses ):
894
909
try :
895
- (future , timestamp ) = self .in_flight_requests .pop (correlation_id )
910
+ with self ._lock :
911
+ (future , timestamp ) = self .in_flight_requests .pop (correlation_id )
896
912
except KeyError :
897
913
self .close (Errors .KafkaConnectionError ('Received unrecognized correlation id' ))
898
914
return ()
@@ -908,6 +924,7 @@ def recv(self):
908
924
def _recv (self ):
909
925
"""Take all available bytes from socket, return list of any responses from parser"""
910
926
recvd = []
927
+ self ._lock .acquire ()
911
928
while len (recvd ) < self .config ['sock_chunk_buffer_count' ]:
912
929
try :
913
930
data = self ._sock .recv (self .config ['sock_chunk_bytes' ])
@@ -917,6 +934,7 @@ def _recv(self):
917
934
# without an exception raised
918
935
if not data :
919
936
log .error ('%s: socket disconnected' , self )
937
+ self ._lock .release ()
920
938
self .close (error = Errors .KafkaConnectionError ('socket disconnected' ))
921
939
return []
922
940
else :
@@ -929,11 +947,13 @@ def _recv(self):
929
947
break
930
948
log .exception ('%s: Error receiving network data'
931
949
' closing socket' , self )
950
+ self ._lock .release ()
932
951
self .close (error = Errors .KafkaConnectionError (e ))
933
952
return []
934
953
except BlockingIOError :
935
954
if six .PY3 :
936
955
break
956
+ self ._lock .release ()
937
957
raise
938
958
939
959
recvd_data = b'' .join (recvd )
@@ -943,20 +963,23 @@ def _recv(self):
943
963
try :
944
964
responses = self ._protocol .receive_bytes (recvd_data )
945
965
except Errors .KafkaProtocolError as e :
966
+ self ._lock .release ()
946
967
self .close (e )
947
968
return []
948
969
else :
970
+ self ._lock .release ()
949
971
return responses
950
972
951
973
def requests_timed_out (self ):
952
- if self .in_flight_requests :
953
- get_timestamp = lambda v : v [1 ]
954
- oldest_at = min (map (get_timestamp ,
955
- self .in_flight_requests .values ()))
956
- timeout = self .config ['request_timeout_ms' ] / 1000.0
957
- if time .time () >= oldest_at + timeout :
958
- return True
959
- return False
974
+ with self ._lock :
975
+ if self .in_flight_requests :
976
+ get_timestamp = lambda v : v [1 ]
977
+ oldest_at = min (map (get_timestamp ,
978
+ self .in_flight_requests .values ()))
979
+ timeout = self .config ['request_timeout_ms' ] / 1000.0
980
+ if time .time () >= oldest_at + timeout :
981
+ return True
982
+ return False
960
983
961
984
def _handle_api_version_response (self , response ):
962
985
error_type = Errors .for_code (response .error_code )
0 commit comments