From 9d6e0e4c3083b66e64dabc092fb1f1bc383e9c5e Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Fri, 20 Jun 2025 16:59:04 +0100 Subject: [PATCH] Fixing Glue + Confluent + Plain protobuf deser --- .../utilities/kafka/consumer_records.py | 27 +++- .../utilities/kafka/deserializer/avro.py | 15 ++- .../utilities/kafka/deserializer/default.py | 4 + .../kafka/deserializer/deserializer.py | 24 ++-- .../utilities/kafka/deserializer/json.py | 20 ++- .../utilities/kafka/deserializer/protobuf.py | 94 +++++++------- .../utilities/kafka/exceptions.py | 6 + .../kafka/serialization/custom_dict.py | 4 + .../kafka/serialization/dataclass.py | 5 + .../utilities/kafka/serialization/pydantic.py | 4 + .../_avro/test_kafka_consumer_with_avro.py | 34 +++++ .../_protobuf/schemas/__init__.py | 0 .../schemas/complex_schema_with_confuent.py | 53 ++++++++ .../schemas/complex_schema_with_glue.py | 54 ++++++++ .../test_kafka_consumer_with_protobuf.py | 121 +++++++++--------- .../kafka_consumer/_protobuf/user_prof.proto | 21 +++ .../kafka_consumer/_protobuf/user_prof_pb2.py | 35 +++++ .../test_kafka_consumer.py | 29 +++++ 18 files changed, 431 insertions(+), 119 deletions(-) create mode 100644 tests/functional/kafka_consumer/_protobuf/schemas/__init__.py create mode 100644 tests/functional/kafka_consumer/_protobuf/schemas/complex_schema_with_confuent.py create mode 100644 tests/functional/kafka_consumer/_protobuf/schemas/complex_schema_with_glue.py create mode 100644 tests/functional/kafka_consumer/_protobuf/user_prof.proto create mode 100644 tests/functional/kafka_consumer/_protobuf/user_prof_pb2.py diff --git a/aws_lambda_powertools/utilities/kafka/consumer_records.py b/aws_lambda_powertools/utilities/kafka/consumer_records.py index 47c732136d0..6da8f9fa1fa 100644 --- a/aws_lambda_powertools/utilities/kafka/consumer_records.py +++ b/aws_lambda_powertools/utilities/kafka/consumer_records.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging from functools import cached_property from typing import TYPE_CHECKING, Any @@ -13,6 +14,8 @@ from aws_lambda_powertools.utilities.kafka.schema_config import SchemaConfig +logger = logging.getLogger(__name__) + class ConsumerRecordRecords(KafkaEventRecordBase): """ @@ -31,18 +34,24 @@ def key(self) -> Any: if not key: return None + logger.debug("Deserializing key field") + # Determine schema type and schema string schema_type = None - schema_str = None + schema_value = None output_serializer = None if self.schema_config and self.schema_config.key_schema_type: schema_type = self.schema_config.key_schema_type - schema_str = self.schema_config.key_schema + schema_value = self.schema_config.key_schema output_serializer = self.schema_config.key_output_serializer # Always use get_deserializer if None it will default to DEFAULT - deserializer = get_deserializer(schema_type, schema_str) + deserializer = get_deserializer( + schema_type=schema_type, + schema_value=schema_value, + field_metadata=self.key_schema_metadata, + ) deserialized_value = deserializer.deserialize(key) # Apply output serializer if specified @@ -57,16 +66,22 @@ def value(self) -> Any: # Determine schema type and schema string schema_type = None - schema_str = None + schema_value = None output_serializer = None + logger.debug("Deserializing value field") + if self.schema_config and self.schema_config.value_schema_type: schema_type = self.schema_config.value_schema_type - schema_str = self.schema_config.value_schema + schema_value = self.schema_config.value_schema output_serializer = self.schema_config.value_output_serializer # Always use get_deserializer if None it will default to DEFAULT - deserializer = get_deserializer(schema_type, schema_str) + deserializer = get_deserializer( + schema_type=schema_type, + schema_value=schema_value, + field_metadata=self.value_schema_metadata, + ) deserialized_value = deserializer.deserialize(value) # Apply output serializer if specified diff --git a/aws_lambda_powertools/utilities/kafka/deserializer/avro.py b/aws_lambda_powertools/utilities/kafka/deserializer/avro.py index 89073f9e784..d3b96da9d34 100644 --- a/aws_lambda_powertools/utilities/kafka/deserializer/avro.py +++ b/aws_lambda_powertools/utilities/kafka/deserializer/avro.py @@ -1,6 +1,8 @@ from __future__ import annotations import io +import logging +from typing import Any from avro.io import BinaryDecoder, DatumReader from avro.schema import parse as parse_schema @@ -9,8 +11,11 @@ from aws_lambda_powertools.utilities.kafka.exceptions import ( KafkaConsumerAvroSchemaParserError, KafkaConsumerDeserializationError, + KafkaConsumerDeserializationFormatMismatch, ) +logger = logging.getLogger(__name__) + class AvroDeserializer(DeserializerBase): """ @@ -20,10 +25,11 @@ class AvroDeserializer(DeserializerBase): a provided Avro schema definition. """ - def __init__(self, schema_str: str): + def __init__(self, schema_str: str, field_metadata: dict[str, Any] | None = None): try: self.parsed_schema = parse_schema(schema_str) self.reader = DatumReader(self.parsed_schema) + self.field_metatada = field_metadata except Exception as e: raise KafkaConsumerAvroSchemaParserError( f"Invalid Avro schema. Please ensure the provided avro schema is valid: {type(e).__name__}: {str(e)}", @@ -60,6 +66,13 @@ def deserialize(self, data: bytes | str) -> object: ... except KafkaConsumerDeserializationError as e: ... print(f"Failed to deserialize: {e}") """ + data_format = self.field_metatada.get("dataFormat") if self.field_metatada else None + + if data_format and data_format != "AVRO": + raise KafkaConsumerDeserializationFormatMismatch(f"Expected data is AVRO but you sent {data_format}") + + logger.debug("Deserializing data with AVRO format") + try: value = self._decode_input(data) bytes_reader = io.BytesIO(value) diff --git a/aws_lambda_powertools/utilities/kafka/deserializer/default.py b/aws_lambda_powertools/utilities/kafka/deserializer/default.py index b889e958c08..f5c73296d90 100644 --- a/aws_lambda_powertools/utilities/kafka/deserializer/default.py +++ b/aws_lambda_powertools/utilities/kafka/deserializer/default.py @@ -1,9 +1,12 @@ from __future__ import annotations import base64 +import logging from aws_lambda_powertools.utilities.kafka.deserializer.base import DeserializerBase +logger = logging.getLogger(__name__) + class DefaultDeserializer(DeserializerBase): """ @@ -43,4 +46,5 @@ def deserialize(self, data: bytes | str) -> str: >>> result = deserializer.deserialize(bytes_data) >>> print(result == bytes_data) # Output: True """ + logger.debug("Deserializing data with primitives types") return base64.b64decode(data).decode("utf-8") diff --git a/aws_lambda_powertools/utilities/kafka/deserializer/deserializer.py b/aws_lambda_powertools/utilities/kafka/deserializer/deserializer.py index 81c34be3aa5..c1443c83b00 100644 --- a/aws_lambda_powertools/utilities/kafka/deserializer/deserializer.py +++ b/aws_lambda_powertools/utilities/kafka/deserializer/deserializer.py @@ -13,21 +13,27 @@ _deserializer_cache: dict[str, DeserializerBase] = {} -def _get_cache_key(schema_type: str | object, schema_value: Any) -> str: +def _get_cache_key(schema_type: str | object, schema_value: Any, field_metadata: dict[str, Any]) -> str: + schema_metadata = None + + if field_metadata: + schema_metadata = field_metadata.get("schemaId") + if schema_value is None: - return str(schema_type) + schema_hash = f"{str(schema_type)}_{schema_metadata}" if isinstance(schema_value, str): + hashable_value = f"{schema_value}_{schema_metadata}" # For string schemas like Avro, hash the content - schema_hash = hashlib.md5(schema_value.encode("utf-8"), usedforsecurity=False).hexdigest() + schema_hash = hashlib.md5(hashable_value.encode("utf-8"), usedforsecurity=False).hexdigest() else: # For objects like Protobuf, use the object id - schema_hash = str(id(schema_value)) + schema_hash = f"{str(id(schema_value))}_{schema_metadata}" return f"{schema_type}_{schema_hash}" -def get_deserializer(schema_type: str | object, schema_value: Any) -> DeserializerBase: +def get_deserializer(schema_type: str | object, schema_value: Any, field_metadata: Any) -> DeserializerBase: """ Factory function to get the appropriate deserializer based on schema type. @@ -75,7 +81,7 @@ def get_deserializer(schema_type: str | object, schema_value: Any) -> Deserializ """ # Generate a cache key based on schema type and value - cache_key = _get_cache_key(schema_type, schema_value) + cache_key = _get_cache_key(schema_type, schema_value, field_metadata) # Check if we already have this deserializer in cache if cache_key in _deserializer_cache: @@ -87,14 +93,14 @@ def get_deserializer(schema_type: str | object, schema_value: Any) -> Deserializ # Import here to avoid dependency if not used from aws_lambda_powertools.utilities.kafka.deserializer.avro import AvroDeserializer - deserializer = AvroDeserializer(schema_value) + deserializer = AvroDeserializer(schema_str=schema_value, field_metadata=field_metadata) elif schema_type == "PROTOBUF": # Import here to avoid dependency if not used from aws_lambda_powertools.utilities.kafka.deserializer.protobuf import ProtobufDeserializer - deserializer = ProtobufDeserializer(schema_value) + deserializer = ProtobufDeserializer(message_class=schema_value, field_metadata=field_metadata) elif schema_type == "JSON": - deserializer = JsonDeserializer() + deserializer = JsonDeserializer(field_metadata=field_metadata) else: # Default to no-op deserializer diff --git a/aws_lambda_powertools/utilities/kafka/deserializer/json.py b/aws_lambda_powertools/utilities/kafka/deserializer/json.py index afd8effd489..baafd3bb288 100644 --- a/aws_lambda_powertools/utilities/kafka/deserializer/json.py +++ b/aws_lambda_powertools/utilities/kafka/deserializer/json.py @@ -2,9 +2,16 @@ import base64 import json +import logging +from typing import Any from aws_lambda_powertools.utilities.kafka.deserializer.base import DeserializerBase -from aws_lambda_powertools.utilities.kafka.exceptions import KafkaConsumerDeserializationError +from aws_lambda_powertools.utilities.kafka.exceptions import ( + KafkaConsumerDeserializationError, + KafkaConsumerDeserializationFormatMismatch, +) + +logger = logging.getLogger(__name__) class JsonDeserializer(DeserializerBase): @@ -15,6 +22,9 @@ class JsonDeserializer(DeserializerBase): into Python dictionaries. """ + def __init__(self, field_metadata: dict[str, Any] | None = None): + self.field_metatada = field_metadata + def deserialize(self, data: bytes | str) -> dict: """ Deserialize JSON data to a Python dictionary. @@ -45,6 +55,14 @@ def deserialize(self, data: bytes | str) -> dict: ... except KafkaConsumerDeserializationError as e: ... print(f"Failed to deserialize: {e}") """ + + data_format = self.field_metatada.get("dataFormat") if self.field_metatada else None + + if data_format and data_format != "JSON": + raise KafkaConsumerDeserializationFormatMismatch(f"Expected data is JSON but you sent {data_format}") + + logger.debug("Deserializing data with JSON format") + try: return json.loads(base64.b64decode(data).decode("utf-8")) except Exception as e: diff --git a/aws_lambda_powertools/utilities/kafka/deserializer/protobuf.py b/aws_lambda_powertools/utilities/kafka/deserializer/protobuf.py index a9209330505..16bb3bbc6ec 100644 --- a/aws_lambda_powertools/utilities/kafka/deserializer/protobuf.py +++ b/aws_lambda_powertools/utilities/kafka/deserializer/protobuf.py @@ -1,15 +1,19 @@ from __future__ import annotations +import logging from typing import Any -from google.protobuf.internal.decoder import _DecodeVarint # type: ignore[attr-defined] +from google.protobuf.internal.decoder import _DecodeSignedVarint # type: ignore[attr-defined] from google.protobuf.json_format import MessageToDict from aws_lambda_powertools.utilities.kafka.deserializer.base import DeserializerBase from aws_lambda_powertools.utilities.kafka.exceptions import ( KafkaConsumerDeserializationError, + KafkaConsumerDeserializationFormatMismatch, ) +logger = logging.getLogger(__name__) + class ProtobufDeserializer(DeserializerBase): """ @@ -19,8 +23,9 @@ class ProtobufDeserializer(DeserializerBase): into Python dictionaries using the provided Protocol Buffer message class. """ - def __init__(self, message_class: Any): + def __init__(self, message_class: Any, field_metadata: dict[str, Any] | None = None): self.message_class = message_class + self.field_metatada = field_metadata def deserialize(self, data: bytes | str) -> dict: """ @@ -61,57 +66,56 @@ def deserialize(self, data: bytes | str) -> dict: ... except KafkaConsumerDeserializationError as e: ... print(f"Failed to deserialize: {e}") """ - value = self._decode_input(data) - try: - message = self.message_class() - message.ParseFromString(value) - return MessageToDict(message, preserving_proto_field_name=True) - except Exception: - return self._deserialize_with_message_index(value, self.message_class()) - def _deserialize_with_message_index(self, data: bytes, parser: Any) -> dict: - """ - Deserialize protobuf message with Confluent message index handling. + data_format = self.field_metatada.get("dataFormat") if self.field_metatada else None + schema_id = self.field_metatada.get("schemaId") if self.field_metatada else None - Parameters - ---------- - data : bytes - data - parser : google.protobuf.message.Message - Protobuf message instance to parse the data into + if data_format and data_format != "PROTOBUF": + raise KafkaConsumerDeserializationFormatMismatch(f"Expected data is PROTOBUF but you sent {data_format}") - Returns - ------- - dict - Dictionary representation of the parsed protobuf message with original field names + logger.debug("Deserializing data with PROTOBUF format") - Raises - ------ - KafkaConsumerDeserializationError - If deserialization fails + try: + value = self._decode_input(data) + message = self.message_class() + if schema_id is None: + logger.debug("Plain PROTOBUF data: using default deserializer") + # Plain protobuf - direct parser + message.ParseFromString(value) + elif len(schema_id) > 20: + logger.debug("PROTOBUF data integrated with Glue SchemaRegistry: using Glue deserializer") + # Glue schema registry integration - remove the first byte + message.ParseFromString(value[1:]) + else: + logger.debug("PROTOBUF data integrated with Confluent SchemaRegistry: using Confluent deserializer") + # Confluent schema registry integration - remove message index list + message.ParseFromString(self._remove_message_index(value)) - Notes - ----- - This method handles the special case of Confluent Schema Registry's message index - format, where the message is prefixed with either a single 0 (for the first schema) - or a list of schema indexes. The actual protobuf message follows these indexes. - """ + return MessageToDict(message, preserving_proto_field_name=True) + except Exception as e: + raise KafkaConsumerDeserializationError( + f"Error trying to deserialize protobuf data - {type(e).__name__}: {str(e)}", + ) from e + def _remove_message_index(self, data): + """ + Identifies and removes Confluent Schema Registry MessageIndex from bytes. + Returns pure protobuf bytes. + """ buffer = memoryview(data) pos = 0 - try: - first_value, new_pos = _DecodeVarint(buffer, pos) - pos = new_pos + logger.debug("Removing message list bytes") - if first_value != 0: - for _ in range(first_value): - _, new_pos = _DecodeVarint(buffer, pos) - pos = new_pos + # Read first varint (index count or 0) + first_value, new_pos = _DecodeSignedVarint(buffer, pos) + pos = new_pos - parser.ParseFromString(data[pos:]) - return MessageToDict(parser, preserving_proto_field_name=True) - except Exception as e: - raise KafkaConsumerDeserializationError( - f"Error trying to deserialize protobuf data - {type(e).__name__}: {str(e)}", - ) from e + # Skip index values if present + if first_value != 0: + for _ in range(first_value): + _, new_pos = _DecodeSignedVarint(buffer, pos) + pos = new_pos + + # Return remaining bytes (pure protobuf) + return data[pos:] diff --git a/aws_lambda_powertools/utilities/kafka/exceptions.py b/aws_lambda_powertools/utilities/kafka/exceptions.py index c8b5ee810a2..aa48efcaa64 100644 --- a/aws_lambda_powertools/utilities/kafka/exceptions.py +++ b/aws_lambda_powertools/utilities/kafka/exceptions.py @@ -4,6 +4,12 @@ class KafkaConsumerAvroSchemaParserError(Exception): """ +class KafkaConsumerDeserializationFormatMismatch(Exception): + """ + Error raised when deserialization format is incompatible + """ + + class KafkaConsumerDeserializationError(Exception): """ Error raised when message deserialization fails. diff --git a/aws_lambda_powertools/utilities/kafka/serialization/custom_dict.py b/aws_lambda_powertools/utilities/kafka/serialization/custom_dict.py index b644e5f9b68..efa5b2efd28 100644 --- a/aws_lambda_powertools/utilities/kafka/serialization/custom_dict.py +++ b/aws_lambda_powertools/utilities/kafka/serialization/custom_dict.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging from typing import TYPE_CHECKING, Any from aws_lambda_powertools.utilities.kafka.serialization.base import OutputSerializerBase @@ -9,6 +10,8 @@ from aws_lambda_powertools.utilities.kafka.serialization.types import T +logger = logging.getLogger(__name__) + class CustomDictOutputSerializer(OutputSerializerBase): """ @@ -19,4 +22,5 @@ class CustomDictOutputSerializer(OutputSerializerBase): """ def serialize(self, data: dict[str, Any], output: type[T] | Callable | None = None) -> T | dict[str, Any]: + logger.debug("Serializing output data with CustomDictOutputSerializer") return data if output is None else output(data) # type: ignore[call-arg] diff --git a/aws_lambda_powertools/utilities/kafka/serialization/dataclass.py b/aws_lambda_powertools/utilities/kafka/serialization/dataclass.py index 2cdbfe11be2..3f601fa4674 100644 --- a/aws_lambda_powertools/utilities/kafka/serialization/dataclass.py +++ b/aws_lambda_powertools/utilities/kafka/serialization/dataclass.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging from dataclasses import is_dataclass from typing import TYPE_CHECKING, Any, cast @@ -9,6 +10,8 @@ if TYPE_CHECKING: from collections.abc import Callable +logger = logging.getLogger(__name__) + class DataclassOutputSerializer(OutputSerializerBase): """ @@ -22,4 +25,6 @@ def serialize(self, data: dict[str, Any], output: type[T] | Callable | None = No if not is_dataclass(output): # pragma: no cover raise ValueError("Output class must be a dataclass") + logger.debug("Serializing output data with DataclassOutputSerializer") + return cast(T, output(**data)) diff --git a/aws_lambda_powertools/utilities/kafka/serialization/pydantic.py b/aws_lambda_powertools/utilities/kafka/serialization/pydantic.py index 63484644ba3..3fa62393d4b 100644 --- a/aws_lambda_powertools/utilities/kafka/serialization/pydantic.py +++ b/aws_lambda_powertools/utilities/kafka/serialization/pydantic.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging from typing import TYPE_CHECKING, Any from pydantic import TypeAdapter @@ -11,6 +12,8 @@ from aws_lambda_powertools.utilities.kafka.serialization.types import T +logger = logging.getLogger(__name__) + class PydanticOutputSerializer(OutputSerializerBase): """ @@ -21,6 +24,7 @@ class PydanticOutputSerializer(OutputSerializerBase): """ def serialize(self, data: dict[str, Any], output: type[T] | Callable | None = None) -> T | dict[str, Any]: + logger.debug("Serializing output data with PydanticOutputSerializer") # Use TypeAdapter for better support of Union types and other complex types adapter: TypeAdapter = TypeAdapter(output) return adapter.validate_python(data) diff --git a/tests/functional/kafka_consumer/_avro/test_kafka_consumer_with_avro.py b/tests/functional/kafka_consumer/_avro/test_kafka_consumer_with_avro.py index 9359558605c..f22171c37af 100644 --- a/tests/functional/kafka_consumer/_avro/test_kafka_consumer_with_avro.py +++ b/tests/functional/kafka_consumer/_avro/test_kafka_consumer_with_avro.py @@ -11,6 +11,7 @@ from aws_lambda_powertools.utilities.kafka.exceptions import ( KafkaConsumerAvroSchemaParserError, KafkaConsumerDeserializationError, + KafkaConsumerDeserializationFormatMismatch, KafkaConsumerMissingSchemaError, ) from aws_lambda_powertools.utilities.kafka.kafka_consumer import kafka_consumer @@ -309,3 +310,36 @@ def test_kafka_consumer_without_avro_key_schema(): # Verify the error message mentions 'key_schema' assert "key_schema" in str(excinfo.value) + + +def test_kafka_consumer_avro_with_wrong_json_schema( + kafka_event_with_avro_data, + lambda_context, + avro_value_schema, + avro_key_schema, +): + # GIVEN + # A Kafka event with a null key in the record + kafka_event_wrong_metadata = deepcopy(kafka_event_with_avro_data) + kafka_event_wrong_metadata["records"]["my-topic-1"][0]["valueSchemaMetadata"] = { + "dataFormat": "JSON", + "schemaId": "123", + } + + schema_config = SchemaConfig(value_schema_type="AVRO", value_schema=avro_value_schema) + + # A Kafka consumer with no schema configuration specified + @kafka_consumer(schema_config=schema_config) + def handler(event: ConsumerRecords, context): + # Get the first record's key which should be None + record = next(event.records) + return record.value + + # WHEN + # The handler processes the Kafka event with a null key + with pytest.raises(KafkaConsumerDeserializationFormatMismatch) as excinfo: + handler(kafka_event_wrong_metadata, lambda_context) + + # THEN + # Ensure the error contains useful diagnostic information + assert "Expected data is AVRO but you sent " in str(excinfo.value) diff --git a/tests/functional/kafka_consumer/_protobuf/schemas/__init__.py b/tests/functional/kafka_consumer/_protobuf/schemas/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/functional/kafka_consumer/_protobuf/schemas/complex_schema_with_confuent.py b/tests/functional/kafka_consumer/_protobuf/schemas/complex_schema_with_confuent.py new file mode 100644 index 00000000000..b2e14b715eb --- /dev/null +++ b/tests/functional/kafka_consumer/_protobuf/schemas/complex_schema_with_confuent.py @@ -0,0 +1,53 @@ +# ruff: noqa: E501 +complex_event = { + "eventSource": "aws:kafka", + "eventSourceArn": "arn:aws:kafka:us-east-1:0123456789019:cluster/SalesCluster/abcd1234-abcd-cafe-abab-9876543210ab-4", + "bootstrapServers": ",b-1.demo-cluster-1.a1bcde.c1.kafka.us-east-1.amazonaws.com:9092", + "records": { + "mytopic-0": [ + { + "topic": "mytopic", + "partition": 0, + "offset": 15, + "timestamp": 1545084650987, + "timestampType": "CREATE_TIME", + "key": "NDI=", + "value": "CgMxMjMSBFRlc3QaDHRlc3RAZ214LmNvbSAKMgoyMDI1LTA2LTIwOgR0YWcxOgR0YWcyQQAAAAAAAChASg4KBXRoZW1lEgVsaWdodFIaCgpNeXRoZW5xdWFpEgZadXJpY2gaBDgwMDI=", + "headers": [{"headerKey": [104, 101, 97, 100, 101, 114, 86, 97, 108, 117, 101]}], + }, + { + "topic": "mytopic", + "partition": 0, + "offset": 16, + "timestamp": 1545084650988, + "timestampType": "CREATE_TIME", + "key": "NDI=", + "value": "AAoDMTIzEgRUZXN0Ggx0ZXN0QGdteC5jb20gCjIKMjAyNS0wNi0yMDoEdGFnMToEdGFnMkEAAAAAAAAoQEoOCgV0aGVtZRIFbGlnaHRSGgoKTXl0aGVucXVhaRIGWnVyaWNoGgQ4MDAy", + "headers": [{"headerKey": [104, 101, 97, 100, 101, 114, 86, 97, 108, 117, 101]}], + "valueSchemaMetadata": {"schemaId": "123", "dataFormat": "PROTOBUF"}, + }, + { + "topic": "mytopic", + "partition": 0, + "offset": 17, + "timestamp": 1545084650989, + "timestampType": "CREATE_TIME", + "key": None, + "value": "BAIACgMxMjMSBFRlc3QaDHRlc3RAZ214LmNvbSAKMgoyMDI1LTA2LTIwOgR0YWcxOgR0YWcyQQAAAAAAAChASg4KBXRoZW1lEgVsaWdodFIaCgpNeXRoZW5xdWFpEgZadXJpY2gaBDgwMDI=", + "headers": [{"headerKey": [104, 101, 97, 100, 101, 114, 86, 97, 108, 117, 101]}], + "valueSchemaMetadata": {"schemaId": "456", "dataFormat": "PROTOBUF"}, + }, + { + "topic": "mytopic", + "partition": 0, + "offset": 18, + "timestamp": 1545084650990, + "timestampType": "CREATE_TIME", + "key": "NDI=", + "value": "AQoDMTIzEgRUZXN0Ggx0ZXN0QGdteC5jb20gCjIKMjAyNS0wNi0yMDoEdGFnMToEdGFnMkEAAAAAAAAoQEoOCgV0aGVtZRIFbGlnaHRSGgoKTXl0aGVucXVhaRIGWnVyaWNoGgQ4MDAy", + "headers": [{"headerKey": [104, 101, 97, 100, 101, 114, 86, 97, 108, 117, 101]}], + "valueSchemaMetadata": {"schemaId": "12345678-1234-1234-1234-123456789012", "dataFormat": "PROTOBUF"}, + }, + ], + }, +} diff --git a/tests/functional/kafka_consumer/_protobuf/schemas/complex_schema_with_glue.py b/tests/functional/kafka_consumer/_protobuf/schemas/complex_schema_with_glue.py new file mode 100644 index 00000000000..59cf1400b08 --- /dev/null +++ b/tests/functional/kafka_consumer/_protobuf/schemas/complex_schema_with_glue.py @@ -0,0 +1,54 @@ +# ruff: noqa: E501 +complex_event = { + "eventSource": "aws:kafka", + "eventSourceArn": "arn:aws:kafka:us-east-1:0123456789019:cluster/SalesCluster/abcd1234-abcd-cafe-abab-9876543210ab-4", + "bootstrapServers": ",b-1.demo-cluster-1.a1bcde.c1.kafka.us-east-1.amazonaws.com:9092", + "records": { + "gsr_proto-0": [ + { + "headers": [], + "key": "dTg1OQ==", + "offset": 4130352, + "partition": 0, + "timestamp": 1750284651283, + "timestampType": "CREATE_TIME", + "topic": "gsr_proto", + "value": "AQoEdTg1ORIFQWxpY2UaEWFsaWNlQGV4YW1wbGUuY29tIDYyCjIwMjQtMDEtMDE6GgoIMTIzIE1haW4SB1NlYXR0bGUaBTk4MTAxQgR0YWcxQgR0YWcySZZFopoJWkdAUg0KBXRoZW1lEgRkYXJr", + "valueSchemaMetadata": {"dataFormat": "PROTOBUF", "schemaId": "7d55d475-2244-4485-8341-f74468c1e058"}, + }, + { + "headers": [], + "key": "dTgwOQ==", + "offset": 4130353, + "partition": 0, + "timestamp": 1750284652283, + "timestampType": "CREATE_TIME", + "topic": "gsr_proto", + "value": "AQoEdTgwORIFQWxpY2UaEWFsaWNlQGV4YW1wbGUuY29tICgyCjIwMjQtMDEtMDE6GgoIMTIzIE1haW4SB1NlYXR0bGUaBTk4MTAxQgR0YWcxQgR0YWcySTnSqQSHn0FAUg0KBXRoZW1lEgRkYXJr", + "valueSchemaMetadata": {"dataFormat": "PROTOBUF", "schemaId": "7d55d475-2244-4485-8341-f74468c1e058"}, + }, + { + "headers": [], + "key": "dTQ1Mw==", + "offset": 4130354, + "partition": 0, + "timestamp": 1750284653283, + "timestampType": "CREATE_TIME", + "topic": "gsr_proto", + "value": "AQoEdTQ1MxIFQWxpY2UaEWFsaWNlQGV4YW1wbGUuY29tIEooATIKMjAyNC0wMS0wMToaCggxMjMgTWFpbhIHU2VhdHRsZRoFOTgxMDFCBHRhZzFCBHRhZzJJRJi47bmvV0BSDQoFdGhlbWUSBGRhcms=", + "valueSchemaMetadata": {"dataFormat": "PROTOBUF", "schemaId": "7d55d475-2244-4485-8341-f74468c1e058"}, + }, + { + "headers": [], + "key": "dTcwNQ==", + "offset": 4130355, + "partition": 0, + "timestamp": 1750284654283, + "timestampType": "CREATE_TIME", + "topic": "gsr_proto", + "value": "AQoEdTcwNRIFQWxpY2UaEWFsaWNlQGV4YW1wbGUuY29tIBMyCjIwMjQtMDEtMDE6GgoIMTIzIE1haW4SB1NlYXR0bGUaBTk4MTAxQgR0YWcxQgR0YWcySUSydyF28ldAUg0KBXRoZW1lEgRkYXJr", + "valueSchemaMetadata": {"dataFormat": "PROTOBUF", "schemaId": "7d55d475-2244-4485-8341-f74468c1e058"}, + }, + ], + }, +} diff --git a/tests/functional/kafka_consumer/_protobuf/test_kafka_consumer_with_protobuf.py b/tests/functional/kafka_consumer/_protobuf/test_kafka_consumer_with_protobuf.py index 2ab38dcc4f6..a3ce3c69a51 100644 --- a/tests/functional/kafka_consumer/_protobuf/test_kafka_consumer_with_protobuf.py +++ b/tests/functional/kafka_consumer/_protobuf/test_kafka_consumer_with_protobuf.py @@ -7,16 +7,15 @@ from aws_lambda_powertools.utilities.kafka.consumer_records import ConsumerRecords from aws_lambda_powertools.utilities.kafka.exceptions import ( KafkaConsumerDeserializationError, + KafkaConsumerDeserializationFormatMismatch, KafkaConsumerMissingSchemaError, ) from aws_lambda_powertools.utilities.kafka.kafka_consumer import kafka_consumer from aws_lambda_powertools.utilities.kafka.schema_config import SchemaConfig -# Import confluent complex schema -from .confluent_protobuf_pb2 import ProtobufProduct - # Import the generated protobuf classes from .user_pb2 import Key, User +from .user_prof_pb2 import UserProfile @pytest.fixture @@ -340,54 +339,16 @@ def test_kafka_consumer_without_protobuf_key_schema(): assert "PROTOBUF" in str(excinfo.value) -def test_confluent_complex_schema(lambda_context): +def test_confluent_schema_registry_complex_schema(lambda_context): # GIVEN # A scenario where a complex schema is used with the PROTOBUF schema type - complex_event = { - "eventSource": "aws:kafka", - "eventSourceArn": "arn:aws:kafka:us-east-1:0123456789019:cluster/SalesCluster/abcd1234", - "bootstrapServers": "b-2.demo-cluster-1.a1bcde.c1.kafka.us-east-1.amazonaws.com:9092", - "records": { - "mytopic-0": [ - { - "topic": "mytopic", - "partition": 0, - "offset": 15, - "timestamp": 1545084650987, - "timestampType": "CREATE_TIME", - "key": "NDI=", - "value": "COkHEgZMYXB0b3AZUrgehes/j0A=", - "headers": [{"headerKey": [104, 101, 97, 100, 101, 114, 86, 97, 108, 117, 101]}], - }, - { - "topic": "mytopic", - "partition": 0, - "offset": 16, - "timestamp": 1545084650988, - "timestampType": "CREATE_TIME", - "key": "NDI=", - "value": "AAjpBxIGTGFwdG9wGVK4HoXrP49A", - "headers": [{"headerKey": [104, 101, 97, 100, 101, 114, 86, 97, 108, 117, 101]}], - }, - { - "topic": "mytopic", - "partition": 0, - "offset": 17, - "timestamp": 1545084650989, - "timestampType": "CREATE_TIME", - "key": "NDI=", - "value": "AgEACOkHEgZMYXB0b3AZUrgehes/j0A=", - "headers": [{"headerKey": [104, 101, 97, 100, 101, 114, 86, 97, 108, 117, 101]}], - }, - ], - }, - } + from tests.functional.kafka_consumer._protobuf.schemas.complex_schema_with_confuent import complex_event # GIVEN A Kafka consumer configured to deserialize Protobuf data # using the User protobuf message type as the schema schema_config = SchemaConfig( value_schema_type="PROTOBUF", - value_schema=ProtobufProduct, + value_schema=UserProfile, ) processed_records = [] @@ -396,7 +357,7 @@ def test_confluent_complex_schema(lambda_context): def handler(event: ConsumerRecords, context): for record in event.records: processed_records.append( - {"id": record.value["id"], "name": record.value["name"], "price": record.value["price"]}, + {"email": record.value["email"], "age": record.value["age"]}, ) return {"processed": len(processed_records)} @@ -406,19 +367,65 @@ def handler(event: ConsumerRecords, context): # THEN # The handler should successfully process both records # and return the correct count - assert result == {"processed": 3} + assert result == {"processed": 4} + assert len(processed_records) == 4 - # All records should be correctly deserialized with proper values - assert len(processed_records) == 3 - # First record should contain decoded values - assert processed_records[0]["id"] == 1001 - assert processed_records[0]["name"] == "Laptop" +def test_glue_schema_registry_complex_schema(lambda_context): + # GIVEN + # A scenario where a complex schema is used with the PROTOBUF schema type + from tests.functional.kafka_consumer._protobuf.schemas.complex_schema_with_glue import complex_event - # Second record should contain decoded values - assert processed_records[1]["id"] == 1001 - assert processed_records[1]["name"] == "Laptop" + # GIVEN A Kafka consumer configured to deserialize Protobuf data + # using the User protobuf message type as the schema + schema_config = SchemaConfig( + value_schema_type="PROTOBUF", + value_schema=UserProfile, + ) + + processed_records = [] + + @kafka_consumer(schema_config=schema_config) + def handler(event: ConsumerRecords, context): + for record in event.records: + processed_records.append( + {"email": record.value["email"], "age": record.value["age"]}, + ) + return {"processed": len(processed_records)} + + # WHEN The handler processes a Kafka event containing Protobuf-encoded data + result = handler(complex_event, lambda_context) + + # THEN + # The handler should successfully process both records + # and return the correct count + assert result == {"processed": 4} + assert len(processed_records) == 4 + + +def test_kafka_consumer_protobuf_with_wrong_avro_schema(kafka_event_with_proto_data, lambda_context): + # GIVEN + # A Kafka event with a null key in the record + kafka_event_wrong_metadata = deepcopy(kafka_event_with_proto_data) + kafka_event_wrong_metadata["records"]["my-topic-1"][0]["valueSchemaMetadata"] = { + "dataFormat": "AVRO", + "schemaId": "1234", + } - # Third record should contain decoded values - assert processed_records[2]["id"] == 1001 - assert processed_records[2]["name"] == "Laptop" + schema_config = SchemaConfig(value_schema_type="PROTOBUF", value_schema=UserProfile) + + # A Kafka consumer with no schema configuration specified + @kafka_consumer(schema_config=schema_config) + def handler(event: ConsumerRecords, context): + # Get the first record's key which should be None + record = next(event.records) + return record.value + + # WHEN + # The handler processes the Kafka event with a null key + with pytest.raises(KafkaConsumerDeserializationFormatMismatch) as excinfo: + handler(kafka_event_wrong_metadata, lambda_context) + + # THEN + # Ensure the error contains useful diagnostic information + assert "Expected data is PROTOBUF but you sent " in str(excinfo.value) diff --git a/tests/functional/kafka_consumer/_protobuf/user_prof.proto b/tests/functional/kafka_consumer/_protobuf/user_prof.proto new file mode 100644 index 00000000000..a8162b1e293 --- /dev/null +++ b/tests/functional/kafka_consumer/_protobuf/user_prof.proto @@ -0,0 +1,21 @@ +syntax = "proto3"; +package com.example.protobuf; + +message Address { + string street = 1; + string city = 2; + string zip = 3; +} + +message UserProfile { + string userId = 1; + string name = 2; + string email = 3; + int32 age = 4; + bool isActive = 5; + string signupDate = 6; + repeated string tags = 7; + double score = 8; + map preferences = 9; + Address address = 10; +} diff --git a/tests/functional/kafka_consumer/_protobuf/user_prof_pb2.py b/tests/functional/kafka_consumer/_protobuf/user_prof_pb2.py new file mode 100644 index 00000000000..af4062ad630 --- /dev/null +++ b/tests/functional/kafka_consumer/_protobuf/user_prof_pb2.py @@ -0,0 +1,35 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# Protobuf Python Version: 6.30.2 +"""Generated protocol buffer code.""" + +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder + +_runtime_version.ValidateProtobufRuntimeVersion(_runtime_version.Domain.PUBLIC, 6, 30, 2, "", "user_prof.proto") +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x0fuser_prof.proto\x12\x14\x63om.example.protobuf"4\n\x07\x41\x64\x64ress\x12\x0e\n\x06street\x18\x01 \x01(\t\x12\x0c\n\x04\x63ity\x18\x02 \x01(\t\x12\x0b\n\x03zip\x18\x03 \x01(\t"\xb7\x02\n\x0bUserProfile\x12\x0e\n\x06userId\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\r\n\x05\x65mail\x18\x03 \x01(\t\x12\x0b\n\x03\x61ge\x18\x04 \x01(\x05\x12\x10\n\x08isActive\x18\x05 \x01(\x08\x12\x12\n\nsignupDate\x18\x06 \x01(\t\x12\x0c\n\x04tags\x18\x07 \x03(\t\x12\r\n\x05score\x18\x08 \x01(\x01\x12G\n\x0bpreferences\x18\t \x03(\x0b\x32\x32.com.example.protobuf.UserProfile.PreferencesEntry\x12.\n\x07\x61\x64\x64ress\x18\n \x01(\x0b\x32\x1d.com.example.protobuf.Address\x1a\x32\n\x10PreferencesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x62\x06proto3', # noqa: E501 +) + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "user_prof_pb2", _globals) +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None + _globals["_USERPROFILE_PREFERENCESENTRY"]._loaded_options = None + _globals["_USERPROFILE_PREFERENCESENTRY"]._serialized_options = b"8\001" + _globals["_ADDRESS"]._serialized_start = 41 + _globals["_ADDRESS"]._serialized_end = 93 + _globals["_USERPROFILE"]._serialized_start = 96 + _globals["_USERPROFILE"]._serialized_end = 407 + _globals["_USERPROFILE_PREFERENCESENTRY"]._serialized_start = 357 + _globals["_USERPROFILE_PREFERENCESENTRY"]._serialized_end = 407 +# @@protoc_insertion_point(module_scope) diff --git a/tests/functional/kafka_consumer/required_dependencies/test_kafka_consumer.py b/tests/functional/kafka_consumer/required_dependencies/test_kafka_consumer.py index a5240eb4d12..657ac2cc46c 100644 --- a/tests/functional/kafka_consumer/required_dependencies/test_kafka_consumer.py +++ b/tests/functional/kafka_consumer/required_dependencies/test_kafka_consumer.py @@ -8,6 +8,7 @@ from aws_lambda_powertools.utilities.kafka.consumer_records import ConsumerRecords from aws_lambda_powertools.utilities.kafka.exceptions import ( KafkaConsumerDeserializationError, + KafkaConsumerDeserializationFormatMismatch, ) from aws_lambda_powertools.utilities.kafka.kafka_consumer import kafka_consumer from aws_lambda_powertools.utilities.kafka.schema_config import SchemaConfig @@ -295,6 +296,34 @@ def handler(event: ConsumerRecords, context): assert result is None +def test_kafka_consumer_json_with_wrong_avro_schema(kafka_event_with_json_data, lambda_context): + # GIVEN + # A Kafka event with a null key in the record + kafka_event_wrong_metadata = deepcopy(kafka_event_with_json_data) + kafka_event_wrong_metadata["records"]["my-topic-1"][0]["valueSchemaMetadata"] = { + "dataFormat": "AVRO", + "schemaId": "1234532323", + } + + schema_config = SchemaConfig(value_schema_type="JSON") + + # A Kafka consumer with no schema configuration specified + @kafka_consumer(schema_config=schema_config) + def handler(event: ConsumerRecords, context): + # Get the first record's key which should be None + record = next(event.records) + return record.value + + # WHEN + # The handler processes the Kafka event with a null key + with pytest.raises(KafkaConsumerDeserializationFormatMismatch) as excinfo: + handler(kafka_event_wrong_metadata, lambda_context) + + # THEN + # Ensure the error contains useful diagnostic information + assert "Expected data is JSON but you sent " in str(excinfo.value) + + def test_kafka_consumer_metadata_fields(kafka_event_with_json_data, lambda_context): # GIVEN # A Kafka event with specific metadata we want to verify is preserved