diff --git a/kafka/sasl/msk.py b/kafka/sasl/msk.py index db56b4801..7ec03215d 100644 --- a/kafka/sasl/msk.py +++ b/kafka/sasl/msk.py @@ -4,6 +4,7 @@ import hashlib import hmac import json +import logging import string # needed for AWS_MSK_IAM authentication: @@ -13,10 +14,14 @@ # no botocore available, will disable AWS_MSK_IAM mechanism BotoSession = None +from kafka.errors import KafkaConfigurationError from kafka.sasl.abc import SaslMechanism from kafka.vendor.six.moves import urllib +log = logging.getLogger(__name__) + + class SaslMechanismAwsMskIam(SaslMechanism): def __init__(self, **config): assert BotoSession is not None, 'AWS_MSK_IAM requires the "botocore" package' @@ -27,22 +32,28 @@ def __init__(self, **config): self._is_done = False self._is_authenticated = False - def auth_bytes(self): + def _build_client(self): session = BotoSession() credentials = session.get_credentials().get_frozen_credentials() - client = AwsMskIamClient( + if not session.get_config_variable('region'): + raise KafkaConfigurationError('Unable to determine region for AWS MSK cluster. Is AWS_DEFAULT_REGION set?') + return AwsMskIamClient( host=self.host, access_key=credentials.access_key, secret_key=credentials.secret_key, region=session.get_config_variable('region'), token=credentials.token, ) + + def auth_bytes(self): + client = self._build_client() + log.debug("Generating auth token for MSK scope: %s", client._scope) return client.first_message() def receive(self, auth_bytes): self._is_done = True self._is_authenticated = auth_bytes != b'' - self._auth = auth_bytes.deode('utf-8') + self._auth = auth_bytes.decode('utf-8') def is_done(self): return self._is_done diff --git a/test/sasl/test_msk.py b/test/sasl/test_msk.py index e9f1325f3..f3cc46ce8 100644 --- a/test/sasl/test_msk.py +++ b/test/sasl/test_msk.py @@ -2,7 +2,7 @@ import json import sys -from kafka.sasl.msk import AwsMskIamClient +from kafka.sasl.msk import AwsMskIamClient, SaslMechanismAwsMskIam try: from unittest import mock @@ -69,3 +69,17 @@ def test_aws_msk_iam_client_temporary_credentials(): 'x-amz-security-token': 'XXXXX', } assert actual == expected + + +def test_aws_msk_iam_sasl_mechanism(): + with mock.patch('kafka.sasl.msk.BotoSession'): + sasl = SaslMechanismAwsMskIam(security_protocol='SASL_SSL', host='localhost') + with mock.patch.object(sasl, '_build_client', return_value=client_factory(token=None)): + assert sasl.auth_bytes() != b'' + assert not sasl.is_done() + assert not sasl.is_authenticated() + sasl.receive(b'foo') + assert sasl._auth == 'foo' + assert sasl.is_done() + assert sasl.is_authenticated() + assert sasl.auth_details()