diff --git a/datadog_lambda/dsm.py b/datadog_lambda/dsm.py deleted file mode 100644 index 427f5e47..00000000 --- a/datadog_lambda/dsm.py +++ /dev/null @@ -1,38 +0,0 @@ -from datadog_lambda import logger -from datadog_lambda.trigger import EventTypes - - -def set_dsm_context(event, event_source): - - if event_source.equals(EventTypes.SQS): - _dsm_set_sqs_context(event) - - -def _dsm_set_sqs_context(event): - from datadog_lambda.wrapper import format_err_with_traceback - from ddtrace.internal.datastreams import data_streams_processor - from ddtrace.internal.datastreams.processor import DsmPathwayCodec - from ddtrace.internal.datastreams.botocore import ( - get_datastreams_context, - calculate_sqs_payload_size, - ) - - records = event.get("Records") - if records is None: - return - processor = data_streams_processor() - - for record in records: - try: - queue_arn = record.get("eventSourceARN", "") - - contextjson = get_datastreams_context(record) - payload_size = calculate_sqs_payload_size(record) - - ctx = DsmPathwayCodec.decode(contextjson, processor) - ctx.set_checkpoint( - ["direction:in", f"topic:{queue_arn}", "type:sqs"], - payload_size=payload_size, - ) - except Exception as e: - logger.error(format_err_with_traceback(e)) diff --git a/datadog_lambda/tracing.py b/datadog_lambda/tracing.py index 89a4126b..f4057480 100644 --- a/datadog_lambda/tracing.py +++ b/datadog_lambda/tracing.py @@ -67,6 +67,24 @@ LOWER_64_BITS = "LOWER_64_BITS" +def _dsm_set_checkpoint(context_json, event_type, arn): + if not config.data_streams_enabled: + return + + if not arn: + return + + try: + from ddtrace.data_streams import set_consume_checkpoint + + carrier_get = lambda k: context_json and context_json.get(k) # noqa: E731 + set_consume_checkpoint(event_type, arn, carrier_get, manual_checkpoint=False) + except Exception as e: + logger.debug( + f"DSM:Failed to set consume checkpoint for {event_type} {arn}: {e}" + ) + + def _convert_xray_trace_id(xray_trace_id): """ Convert X-Ray trace id (hex)'s last 63 bits to a Datadog trace id (int). @@ -202,7 +220,9 @@ def create_sns_event(message): } -def extract_context_from_sqs_or_sns_event_or_context(event, lambda_context): +def extract_context_from_sqs_or_sns_event_or_context( + event, lambda_context, event_source +): """ Extract Datadog trace context from an SQS event. @@ -214,7 +234,10 @@ def extract_context_from_sqs_or_sns_event_or_context(event, lambda_context): Lambda Context. Falls back to lambda context if no trace data is found in the SQS message attributes. + Set a DSM checkpoint if DSM is enabled and the method for context propagation is supported. """ + source_arn = "" + event_type = "sqs" if event_source.equals(EventTypes.SQS) else "sns" # EventBridge => SQS try: @@ -226,6 +249,7 @@ def extract_context_from_sqs_or_sns_event_or_context(event, lambda_context): try: first_record = event.get("Records")[0] + source_arn = first_record.get("eventSourceARN", "") # logic to deal with SNS => SQS event if "body" in first_record: @@ -241,6 +265,9 @@ def extract_context_from_sqs_or_sns_event_or_context(event, lambda_context): msg_attributes = first_record.get("messageAttributes") if msg_attributes is None: sns_record = first_record.get("Sns") or {} + # SNS->SQS event would extract SNS arn without this check + if event_source.equals(EventTypes.SNS): + source_arn = sns_record.get("TopicArn", "") msg_attributes = sns_record.get("MessageAttributes") or {} dd_payload = msg_attributes.get("_datadog") if dd_payload: @@ -272,8 +299,9 @@ def extract_context_from_sqs_or_sns_event_or_context(event, lambda_context): logger.debug( "Failed to extract Step Functions context from SQS/SNS event." ) - - return propagator.extract(dd_data) + context = propagator.extract(dd_data) + _dsm_set_checkpoint(dd_data, event_type, source_arn) + return context else: # Handle case where trace context is injected into attributes.AWSTraceHeader # example: Root=1-654321ab-000000001234567890abcdef;Parent=0123456789abcdef;Sampled=1 @@ -296,9 +324,13 @@ def extract_context_from_sqs_or_sns_event_or_context(event, lambda_context): span_id=int(x_ray_context["parent_id"], 16), sampling_priority=float(x_ray_context["sampled"]), ) + # Still want to set a DSM checkpoint even if DSM context not propagated + _dsm_set_checkpoint(None, event_type, source_arn) return extract_context_from_lambda_context(lambda_context) except Exception as e: logger.debug("The trace extractor returned with error %s", e) + # Still want to set a DSM checkpoint even if DSM context not propagated + _dsm_set_checkpoint(None, event_type, source_arn) return extract_context_from_lambda_context(lambda_context) @@ -357,9 +389,12 @@ def extract_context_from_eventbridge_event(event, lambda_context): def extract_context_from_kinesis_event(event, lambda_context): """ Extract datadog trace context from a Kinesis Stream's base64 encoded data string + Set a DSM checkpoint if DSM is enabled and the method for context propagation is supported. """ + source_arn = "" try: record = get_first_record(event) + source_arn = record.get("eventSourceARN", "") kinesis = record.get("kinesis") if not kinesis: return extract_context_from_lambda_context(lambda_context) @@ -373,10 +408,13 @@ def extract_context_from_kinesis_event(event, lambda_context): data_obj = json.loads(data_str) dd_ctx = data_obj.get("_datadog") if dd_ctx: - return propagator.extract(dd_ctx) + context = propagator.extract(dd_ctx) + _dsm_set_checkpoint(dd_ctx, "kinesis", source_arn) + return context except Exception as e: logger.debug("The trace extractor returned with error %s", e) - + # Still want to set a DSM checkpoint even if DSM context not propagated + _dsm_set_checkpoint(None, "kinesis", source_arn) return extract_context_from_lambda_context(lambda_context) @@ -594,7 +632,7 @@ def extract_dd_trace_context( ) elif event_source.equals(EventTypes.SNS) or event_source.equals(EventTypes.SQS): context = extract_context_from_sqs_or_sns_event_or_context( - event, lambda_context + event, lambda_context, event_source ) elif event_source.equals(EventTypes.EVENTBRIDGE): context = extract_context_from_eventbridge_event(event, lambda_context) diff --git a/datadog_lambda/wrapper.py b/datadog_lambda/wrapper.py index c7474f65..305ed4b3 100644 --- a/datadog_lambda/wrapper.py +++ b/datadog_lambda/wrapper.py @@ -10,7 +10,6 @@ from time import time_ns from datadog_lambda.asm import asm_start_response, asm_start_request -from datadog_lambda.dsm import set_dsm_context from datadog_lambda.extension import should_use_extension, flush_extension from datadog_lambda.cold_start import ( set_cold_start, @@ -237,8 +236,6 @@ def _before(self, event, context): self.inferred_span = create_inferred_span( event, context, event_source, config.decode_authorizer_context ) - if config.data_streams_enabled: - set_dsm_context(event, event_source) self.span = create_function_execution_span( context=context, function_name=config.function_name, diff --git a/tests/test_dsm.py b/tests/test_dsm.py deleted file mode 100644 index 544212d8..00000000 --- a/tests/test_dsm.py +++ /dev/null @@ -1,112 +0,0 @@ -import unittest -from unittest.mock import patch, MagicMock - -from datadog_lambda.dsm import set_dsm_context, _dsm_set_sqs_context -from datadog_lambda.trigger import EventTypes, _EventSource - - -class TestDsmSQSContext(unittest.TestCase): - def setUp(self): - patcher = patch("datadog_lambda.dsm._dsm_set_sqs_context") - self.mock_dsm_set_sqs_context = patcher.start() - self.addCleanup(patcher.stop) - - patcher = patch("ddtrace.internal.datastreams.data_streams_processor") - self.mock_data_streams_processor = patcher.start() - self.addCleanup(patcher.stop) - - patcher = patch("ddtrace.internal.datastreams.botocore.get_datastreams_context") - self.mock_get_datastreams_context = patcher.start() - self.mock_get_datastreams_context.return_value = {} - self.addCleanup(patcher.stop) - - patcher = patch( - "ddtrace.internal.datastreams.botocore.calculate_sqs_payload_size" - ) - self.mock_calculate_sqs_payload_size = patcher.start() - self.mock_calculate_sqs_payload_size.return_value = 100 - self.addCleanup(patcher.stop) - - patcher = patch("ddtrace.internal.datastreams.processor.DsmPathwayCodec.decode") - self.mock_dsm_pathway_codec_decode = patcher.start() - self.addCleanup(patcher.stop) - - def test_non_sqs_event_source_does_nothing(self): - """Test that non-SQS event sources don't trigger DSM context setting""" - event = {} - # Use Unknown Event Source - event_source = _EventSource(EventTypes.UNKNOWN) - set_dsm_context(event, event_source) - - # DSM context should not be set for non-SQS events - self.mock_dsm_set_sqs_context.assert_not_called() - - def test_sqs_event_with_no_records_does_nothing(self): - """Test that events where Records is None don't trigger DSM processing""" - events_with_no_records = [ - {}, - {"Records": None}, - {"someOtherField": "value"}, - ] - - for event in events_with_no_records: - _dsm_set_sqs_context(event) - self.mock_data_streams_processor.assert_not_called() - - def test_sqs_event_triggers_dsm_sqs_context(self): - """Test that SQS event sources trigger the SQS-specific DSM context function""" - sqs_event = { - "Records": [ - { - "eventSource": "aws:sqs", - "eventSourceARN": "arn:aws:sqs:us-east-1:123456789012:my-queue", - "body": "Hello from SQS!", - } - ] - } - - event_source = _EventSource(EventTypes.SQS) - set_dsm_context(sqs_event, event_source) - - self.mock_dsm_set_sqs_context.assert_called_once_with(sqs_event) - - def test_sqs_multiple_records_process_each_record(self): - """Test that each record in an SQS event gets processed individually""" - multi_record_event = { - "Records": [ - { - "eventSourceARN": "arn:aws:sqs:us-east-1:123456789012:queue1", - "body": "Message 1", - }, - { - "eventSourceARN": "arn:aws:sqs:us-east-1:123456789012:queue2", - "body": "Message 2", - }, - { - "eventSourceARN": "arn:aws:sqs:us-east-1:123456789012:queue3", - "body": "Message 3", - }, - ] - } - - mock_context = MagicMock() - self.mock_dsm_pathway_codec_decode.return_value = mock_context - - _dsm_set_sqs_context(multi_record_event) - - self.assertEqual(mock_context.set_checkpoint.call_count, 3) - - calls = mock_context.set_checkpoint.call_args_list - expected_arns = [ - "arn:aws:sqs:us-east-1:123456789012:queue1", - "arn:aws:sqs:us-east-1:123456789012:queue2", - "arn:aws:sqs:us-east-1:123456789012:queue3", - ] - - for i, call in enumerate(calls): - args, kwargs = call - tags = args[0] - self.assertIn("direction:in", tags) - self.assertIn(f"topic:{expected_arns[i]}", tags) - self.assertIn("type:sqs", tags) - self.assertEqual(kwargs["payload_size"], 100) diff --git a/tests/test_tracing.py b/tests/test_tracing.py index a629343e..9f4547f0 100644 --- a/tests/test_tracing.py +++ b/tests/test_tracing.py @@ -1,3 +1,4 @@ +import base64 import copy import functools import json @@ -41,8 +42,12 @@ service_mapping as global_service_mapping, propagator, emit_telemetry_on_exception_outside_of_handler, + _dsm_set_checkpoint, + extract_context_from_kinesis_event, + extract_context_from_sqs_or_sns_event_or_context, ) +from datadog_lambda.trigger import parse_event_source from tests.utils import get_mock_context @@ -55,6 +60,7 @@ fake_xray_header_value_root_decimal = "3995693151288333088" event_samples = "tests/event_samples/" +DSM_PROPAGATION_KEY_BASE_64 = "dd-pathway-ctx-base64" def with_trace_propagation_style(style): @@ -2438,3 +2444,650 @@ def test_exception_outside_handler_tracing_disabled( mock_submit_errors_metric.assert_called_once_with(None) mock_trace.assert_not_called() + + +class TestDsmSetCheckpoint(unittest.TestCase): + def setUp(self): + checkpoint_patcher = patch("ddtrace.data_streams.set_consume_checkpoint") + self.mock_checkpoint = checkpoint_patcher.start() + self.addCleanup(checkpoint_patcher.stop) + + logger_patcher = patch("datadog_lambda.tracing.logger") + self.mock_logger = logger_patcher.start() + self.addCleanup(logger_patcher.stop) + + @patch("datadog_lambda.config.Config.data_streams_enabled", False) + def test_dsm_set_checkpoint_data_streams_disabled(self): + context_json = {DSM_PROPAGATION_KEY_BASE_64: "12345"} + event_type = "sqs" + arn = "arn:aws:sqs:us-east-1:123456789012:test-queue" + + _dsm_set_checkpoint(context_json, event_type, arn) + + self.mock_checkpoint.assert_not_called() + + @patch("datadog_lambda.config.Config.data_streams_enabled", True) + def test_dsm_set_checkpoint_data_streams_enabled_complete_context(self): + context_json = {DSM_PROPAGATION_KEY_BASE_64: "12345"} + event_type = "sqs" + arn = "arn:aws:sqs:us-east-1:123456789012:test-queue" + + _dsm_set_checkpoint(context_json, event_type, arn) + + self.mock_checkpoint.assert_called_once() + args, kwargs = self.mock_checkpoint.call_args + self.assertEqual(args[0], event_type) + self.assertEqual(args[1], arn) + self.assertTrue(callable(args[2])) + self.assertEqual(kwargs["manual_checkpoint"], False) + + @patch("datadog_lambda.config.Config.data_streams_enabled", True) + def test_dsm_set_checkpoint_exception_path(self): + context_json = {DSM_PROPAGATION_KEY_BASE_64: "12345"} + event_type = "sqs" + arn = "arn:aws:sqs:us-east-1:123456789012:test-queue" + + test_exception = Exception("Test exception") + self.mock_checkpoint.side_effect = test_exception + + _dsm_set_checkpoint(context_json, event_type, arn) + + self.mock_checkpoint.assert_called_once() + self.mock_logger.debug.assert_called_once() + + @patch("ddtrace.data_streams.set_consume_checkpoint") + def test_dsm_set_checkpoint_with_non_dict_context_does_not_set_checkpoint( + self, mock_checkpoint + ): + _dsm_set_checkpoint( + [], + "sqs", + "arn:aws:sqs:us-east-1:123456789012:test-queue", + ) + mock_checkpoint.assert_not_called() + + +class TestExtractContextFromSqsOrSnsEventWithDSMLogic(unittest.TestCase): + def setUp(self): + self.lambda_context = get_mock_context() + + @patch("datadog_lambda.tracing._dsm_set_checkpoint") + @patch("datadog_lambda.tracing.propagator.extract") + def test_sqs_event_with_datadog_message_attributes( + self, mock_extract, mock_dsm_set_checkpoint + ): + dd_data = {DSM_PROPAGATION_KEY_BASE_64: "12345"} + dd_json_data = json.dumps(dd_data) + + event = { + "Records": [ + { + "eventSourceARN": "arn:aws:sqs:us-east-1:123456789012:test-queue", + "messageAttributes": { + "_datadog": {"dataType": "String", "stringValue": dd_json_data} + }, + "eventSource": "aws:sqs", + } + ] + } + + mock_context = Context(trace_id=12345, span_id=67890, sampling_priority=1) + mock_extract.return_value = mock_context + + result = extract_context_from_sqs_or_sns_event_or_context( + event, self.lambda_context, parse_event_source(event) + ) + + mock_extract.assert_called_once_with(dd_data) + mock_dsm_set_checkpoint.assert_called_once_with( + dd_data, "sqs", "arn:aws:sqs:us-east-1:123456789012:test-queue" + ) + self.assertEqual(result, mock_context) + + @patch("datadog_lambda.tracing._dsm_set_checkpoint") + @patch("datadog_lambda.tracing.propagator.extract") + def test_sqs_event_with_binary_datadog_message_attributes( + self, mock_extract, mock_dsm_set_checkpoint + ): + dd_data = {DSM_PROPAGATION_KEY_BASE_64: "12345"} + dd_json_data = json.dumps(dd_data) + encoded_data = base64.b64encode(dd_json_data.encode()).decode() + + event = { + "Records": [ + { + "eventSourceARN": "arn:aws:sqs:us-east-1:123456789012:test-queue", + "messageAttributes": { + "_datadog": {"dataType": "Binary", "binaryValue": encoded_data} + }, + "eventSource": "aws:sqs", + } + ] + } + + mock_context = Context(trace_id=12345, span_id=67890, sampling_priority=1) + mock_extract.return_value = mock_context + + result = extract_context_from_sqs_or_sns_event_or_context( + event, self.lambda_context, parse_event_source(event) + ) + + mock_extract.assert_called_once_with(dd_data) + mock_dsm_set_checkpoint.assert_called_once_with( + dd_data, "sqs", "arn:aws:sqs:us-east-1:123456789012:test-queue" + ) + self.assertEqual(result, mock_context) + + @patch("datadog_lambda.tracing._dsm_set_checkpoint") + @patch("datadog_lambda.tracing.propagator.extract") + def test_sns_event_with_datadog_message_attributes( + self, mock_extract, mock_dsm_set_checkpoint + ): + dd_data = {DSM_PROPAGATION_KEY_BASE_64: "12345"} + dd_json_data = json.dumps(dd_data) + + event = { + "Records": [ + { + "eventSourceARN": "", + "Sns": { + "TopicArn": "arn:aws:sns:us-east-1:123456789012:test-topic", + "MessageAttributes": { + "_datadog": {"Type": "String", "Value": dd_json_data} + }, + }, + "eventSource": "aws:sns", + } + ] + } + + mock_context = Context(trace_id=12345, span_id=67890, sampling_priority=1) + mock_extract.return_value = mock_context + + result = extract_context_from_sqs_or_sns_event_or_context( + event, self.lambda_context, parse_event_source(event) + ) + + mock_extract.assert_called_once_with(dd_data) + mock_dsm_set_checkpoint.assert_called_once_with( + dd_data, "sns", "arn:aws:sns:us-east-1:123456789012:test-topic" + ) + self.assertEqual(result, mock_context) + + @patch("datadog_lambda.tracing._dsm_set_checkpoint") + @patch("datadog_lambda.tracing.propagator.extract") + def test_sns_to_sqs_event_detection_and_processing( + self, mock_extract, mock_dsm_set_checkpoint + ): + """Test SNS->SQS case where SQS body contains SNS notification""" + dd_data = {DSM_PROPAGATION_KEY_BASE_64: "12345"} + dd_json_data = json.dumps(dd_data) + + sns_notification = { + "Type": "Notification", + "TopicArn": "arn:aws:sns:us-east-1:123456789012:test-topic", + "MessageAttributes": { + "_datadog": {"Type": "String", "Value": dd_json_data} + }, + "Message": "test message", + } + + event = { + "Records": [ + { + "eventSourceARN": "arn:aws:sqs:us-east-1:123456789012:test-queue", + "body": json.dumps(sns_notification), + "messageAttributes": {}, + "eventSource": "aws:sqs", + } + ] + } + + mock_context = Context(trace_id=12345, span_id=67890, sampling_priority=1) + mock_extract.return_value = mock_context + + result = extract_context_from_sqs_or_sns_event_or_context( + event, self.lambda_context, parse_event_source(event) + ) + + mock_extract.assert_called_once_with(dd_data) + mock_dsm_set_checkpoint.assert_called_once_with( + dd_data, "sqs", "arn:aws:sqs:us-east-1:123456789012:test-queue" + ) + self.assertEqual(result, mock_context) + + @patch("datadog_lambda.tracing.extract_context_from_lambda_context") + @patch("datadog_lambda.tracing._dsm_set_checkpoint") + def test_sqs_event_without_datadog_message_attributes( + self, mock_dsm_set_checkpoint, mock_extract_from_lambda_context + ): + event = { + "Records": [ + { + "eventSourceARN": "arn:aws:sqs:us-east-1:123456789012:test-queue", + "messageAttributes": {}, + "eventSource": "aws:sqs", + } + ] + } + + mock_context = Context(trace_id=123, span_id=456) + mock_extract_from_lambda_context.return_value = mock_context + + result = extract_context_from_sqs_or_sns_event_or_context( + event, self.lambda_context, parse_event_source(event) + ) + + mock_dsm_set_checkpoint.assert_called_once_with( + None, "sqs", "arn:aws:sqs:us-east-1:123456789012:test-queue" + ) + mock_extract_from_lambda_context.assert_called_once_with(self.lambda_context) + self.assertEqual(result, mock_context) + + @patch("datadog_lambda.tracing.extract_context_from_lambda_context") + @patch("datadog_lambda.tracing._dsm_set_checkpoint") + def test_sqs_event_with_malformed_datadog_message_attributes( + self, mock_dsm_set_checkpoint, mock_extract_from_lambda_context + ): + dd_json_data = "not-json" + + event = { + "Records": [ + { + "eventSourceARN": "arn:aws:sqs:us-east-1:123456789012:test-queue", + "messageAttributes": { + "_datadog": {"dataType": "String", "stringValue": dd_json_data} + }, + "eventSource": "aws:sqs", + } + ] + } + + mock_context = Context(trace_id=123, span_id=456) + mock_extract_from_lambda_context.return_value = mock_context + + result = extract_context_from_sqs_or_sns_event_or_context( + event, self.lambda_context, parse_event_source(event) + ) + + mock_dsm_set_checkpoint.assert_called_once_with( + None, "sqs", "arn:aws:sqs:us-east-1:123456789012:test-queue" + ) + mock_extract_from_lambda_context.assert_called_once_with(self.lambda_context) + self.assertEqual(result, mock_context) + + @patch("datadog_lambda.tracing.extract_context_from_lambda_context") + @patch("datadog_lambda.tracing._dsm_set_checkpoint") + def test_sns_event_without_datadog_message_attributes( + self, mock_dsm_set_checkpoint, mock_extract_from_lambda_context + ): + event = { + "Records": [ + { + "Sns": { + "TopicArn": "arn:aws:sns:us-east-1:123456789012:test-topic", + "MessageAttributes": {}, + }, + "eventSource": "aws:sns", + } + ] + } + + mock_context = Context(trace_id=123, span_id=456) + mock_extract_from_lambda_context.return_value = mock_context + + result = extract_context_from_sqs_or_sns_event_or_context( + event, self.lambda_context, parse_event_source(event) + ) + + mock_dsm_set_checkpoint.assert_called_once_with( + None, "sns", "arn:aws:sns:us-east-1:123456789012:test-topic" + ) + mock_extract_from_lambda_context.assert_called_once_with(self.lambda_context) + self.assertEqual(result, mock_context) + + @patch("datadog_lambda.tracing.extract_context_from_lambda_context") + @patch("datadog_lambda.tracing._dsm_set_checkpoint") + def test_sns_event_with_malformed_datadog_message_attributes( + self, mock_dsm_set_checkpoint, mock_extract_from_lambda_context + ): + dd_json_data = "not-json" + + event = { + "Records": [ + { + "Sns": { + "TopicArn": "arn:aws:sns:us-east-1:123456789012:test-topic", + "MessageAttributes": { + "_datadog": {"Type": "String", "Value": dd_json_data} + }, + }, + "eventSource": "aws:sns", + } + ] + } + + mock_context = Context(trace_id=123, span_id=456) + mock_extract_from_lambda_context.return_value = mock_context + + result = extract_context_from_sqs_or_sns_event_or_context( + event, self.lambda_context, parse_event_source(event) + ) + + mock_dsm_set_checkpoint.assert_called_once_with( + None, "sns", "arn:aws:sns:us-east-1:123456789012:test-topic" + ) + mock_extract_from_lambda_context.assert_called_once_with(self.lambda_context) + self.assertEqual(result, mock_context) + + @patch("datadog_lambda.tracing.extract_context_from_lambda_context") + @patch("datadog_lambda.tracing._dsm_set_checkpoint") + def test_sns_to_sqs_event_with_malformed_datadog_message_attributes( + self, mock_dsm_set_checkpoint, mock_extract_from_lambda_context + ): + """Test SNS->SQS case where SQS body contains SNS notification with malformed datadog ctx""" + dd_json_data = "not-json" + + sns_notification = { + "Type": "Notification", + "TopicArn": "arn:aws:sns:us-east-1:123456789012:test-topic", + "MessageAttributes": { + "_datadog": {"Type": "String", "Value": dd_json_data} + }, + "Message": "test message", + } + + event = { + "Records": [ + { + "eventSourceARN": "arn:aws:sqs:us-east-1:123456789012:test-queue", + "body": json.dumps(sns_notification), + "messageAttributes": {}, + "eventSource": "aws:sqs", + } + ] + } + + mock_context = Context(trace_id=123, span_id=456) + mock_extract_from_lambda_context.return_value = mock_context + + result = extract_context_from_sqs_or_sns_event_or_context( + event, self.lambda_context, parse_event_source(event) + ) + + mock_dsm_set_checkpoint.assert_called_once_with( + None, "sqs", "arn:aws:sqs:us-east-1:123456789012:test-queue" + ) + mock_extract_from_lambda_context.assert_called_once_with(self.lambda_context) + self.assertEqual(result, mock_context) + + @patch("datadog_lambda.tracing.extract_context_from_lambda_context") + @patch("datadog_lambda.tracing._dsm_set_checkpoint") + @patch("ddtrace.data_streams.set_consume_checkpoint") + def test_sqs_sns_event_with_exception_accessing_first_record( + self, + mock_set_consume_checkpoint, + mock_dsm_set_checkpoint, + mock_extract_from_lambda_context, + ): + event = {"Records": None, "eventSource": "aws:sqs"} + + mock_context = Context(trace_id=123, span_id=456) + mock_extract_from_lambda_context.return_value = mock_context + + result = extract_context_from_sqs_or_sns_event_or_context( + event, self.lambda_context, parse_event_source(event) + ) + mock_set_consume_checkpoint.assert_not_called() + mock_extract_from_lambda_context.assert_called_once_with(self.lambda_context) + self.assertEqual(result, mock_context) + + @patch("datadog_lambda.tracing.extract_context_from_lambda_context") + @patch("datadog_lambda.tracing._dsm_set_checkpoint") + @patch("ddtrace.data_streams.set_consume_checkpoint") + def test_sqs_event_with_empty_arn( + self, + mock_set_consume_checkpoint, + mock_dsm_set_checkpoint, + mock_extract_from_lambda_context, + ): + """Test SQS event with empty eventSourceARN""" + event = { + "Records": [ + { + "eventSourceARN": "", + "messageAttributes": {}, + "eventSource": "aws:sqs", + } + ] + } + + mock_context = Context(trace_id=123, span_id=456) + mock_extract_from_lambda_context.return_value = mock_context + + result = extract_context_from_sqs_or_sns_event_or_context( + event, self.lambda_context, parse_event_source(event) + ) + + mock_dsm_set_checkpoint.assert_called_once_with(None, "sqs", "") + mock_set_consume_checkpoint.assert_not_called() + mock_extract_from_lambda_context.assert_called_once_with(self.lambda_context) + self.assertEqual(result, mock_context) + + @patch("datadog_lambda.tracing.extract_context_from_lambda_context") + @patch("datadog_lambda.tracing._dsm_set_checkpoint") + @patch("ddtrace.data_streams.set_consume_checkpoint") + def test_sns_event_with_empty_arn( + self, + mock_set_consume_checkpoint, + mock_dsm_set_checkpoint, + mock_extract_from_lambda_context, + ): + """Test SNS event with empty TopicArn""" + event = { + "Records": [ + { + "Sns": { + "TopicArn": "", + "MessageAttributes": {}, + }, + "eventSource": "aws:sns", + } + ] + } + + mock_context = Context(trace_id=123, span_id=456) + mock_extract_from_lambda_context.return_value = mock_context + + result = extract_context_from_sqs_or_sns_event_or_context( + event, self.lambda_context, parse_event_source(event) + ) + + mock_dsm_set_checkpoint.assert_called_once_with(None, "sns", "") + mock_set_consume_checkpoint.assert_not_called() + mock_extract_from_lambda_context.assert_called_once_with(self.lambda_context) + self.assertEqual(result, mock_context) + + @patch("datadog_lambda.tracing.extract_context_from_lambda_context") + @patch("datadog_lambda.tracing._dsm_set_checkpoint") + @patch("ddtrace.data_streams.set_consume_checkpoint") + def test_sns_to_sqs_event_with_empty_arn( + self, + mock_set_consume_checkpoint, + mock_dsm_set_checkpoint, + mock_extract_from_lambda_context, + ): + """Test SNS->SQS event with empty eventSourceARN""" + sns_notification = { + "Type": "Notification", + "TopicArn": "arn:aws:sns:us-east-1:123456789012:test-topic", + "MessageAttributes": {}, + "Message": "test message", + } + + event = { + "Records": [ + { + "eventSourceARN": "", + "body": json.dumps(sns_notification), + "messageAttributes": {}, + "eventSource": "aws:sqs", + } + ] + } + + mock_context = Context(trace_id=123, span_id=456) + mock_extract_from_lambda_context.return_value = mock_context + + result = extract_context_from_sqs_or_sns_event_or_context( + event, self.lambda_context, parse_event_source(event) + ) + + mock_dsm_set_checkpoint.assert_called_once_with(None, "sqs", "") + mock_set_consume_checkpoint.assert_not_called() + mock_extract_from_lambda_context.assert_called_once_with(self.lambda_context) + self.assertEqual(result, mock_context) + + +class TestExtractContextFromKinesisEventWithDSMLogic(unittest.TestCase): + def setUp(self): + self.lambda_context = get_mock_context() + + @patch("datadog_lambda.tracing._dsm_set_checkpoint") + @patch("datadog_lambda.tracing.propagator.extract") + def test_kinesis_event_with_datadog_data( + self, mock_extract, mock_dsm_set_checkpoint + ): + dd_data = {DSM_PROPAGATION_KEY_BASE_64: "12345"} + kinesis_data = {"_datadog": dd_data, "message": "test"} + kinesis_data_str = json.dumps(kinesis_data) + encoded_data = base64.b64encode(kinesis_data_str.encode()).decode() + + event = { + "Records": [ + { + "eventSourceARN": "arn:aws:kinesis:us-east-1:123456789012:stream/test-stream", + "kinesis": {"data": encoded_data}, + } + ] + } + + mock_context = Context(trace_id=12345, span_id=67890, sampling_priority=1) + mock_extract.return_value = mock_context + + result = extract_context_from_kinesis_event(event, self.lambda_context) + + mock_extract.assert_called_once_with(dd_data) + mock_dsm_set_checkpoint.assert_called_once_with( + dd_data, + "kinesis", + "arn:aws:kinesis:us-east-1:123456789012:stream/test-stream", + ) + self.assertEqual(result, mock_context) + + @patch("datadog_lambda.tracing.extract_context_from_lambda_context") + @patch("datadog_lambda.tracing._dsm_set_checkpoint") + def test_kinesis_event_without_datadog_data( + self, mock_dsm_set_checkpoint, mock_extract_from_lambda_context + ): + kinesis_data = {"message": "test"} + kinesis_data_str = json.dumps(kinesis_data) + encoded_data = base64.b64encode(kinesis_data_str.encode()).decode() + + event = { + "Records": [ + { + "eventSourceARN": "arn:aws:kinesis:us-east-1:123456789012:stream/test-stream", + "kinesis": {"data": encoded_data}, + } + ] + } + + mock_context = Context(trace_id=123, span_id=456) + mock_extract_from_lambda_context.return_value = mock_context + + result = extract_context_from_kinesis_event(event, self.lambda_context) + + mock_dsm_set_checkpoint.assert_called_once_with( + None, + "kinesis", + "arn:aws:kinesis:us-east-1:123456789012:stream/test-stream", + ) + mock_extract_from_lambda_context.assert_called_once_with(self.lambda_context) + self.assertEqual(result, mock_context) + + @patch("datadog_lambda.tracing.extract_context_from_lambda_context") + @patch("datadog_lambda.tracing._dsm_set_checkpoint") + def test_kinesis_event_with_malformed_data( + self, mock_dsm_set_checkpoint, mock_extract_from_lambda_context + ): + encoded_data = "not-base64-or-json" + + event = { + "Records": [ + { + "eventSourceARN": "arn:aws:kinesis:us-east-1:123456789012:stream/test-stream", + "kinesis": {"data": encoded_data}, + } + ] + } + + mock_context = Context(trace_id=123, span_id=456) + mock_extract_from_lambda_context.return_value = mock_context + + result = extract_context_from_kinesis_event(event, self.lambda_context) + + mock_dsm_set_checkpoint.assert_called_once_with( + None, + "kinesis", + "arn:aws:kinesis:us-east-1:123456789012:stream/test-stream", + ) + mock_extract_from_lambda_context.assert_called_once_with(self.lambda_context) + self.assertEqual(result, mock_context) + + @patch("datadog_lambda.tracing.extract_context_from_lambda_context") + @patch("datadog_lambda.tracing._dsm_set_checkpoint") + @patch("ddtrace.data_streams.set_consume_checkpoint") + def test_kinesis_event_with_empty_arn( + self, + mock_set_consume_checkpoint, + mock_dsm_set_checkpoint, + mock_extract_from_lambda_context, + ): + """Test Kinesis event with empty eventSourceARN""" + kinesis_data = {"message": "test"} + kinesis_data_str = json.dumps(kinesis_data) + encoded_data = base64.b64encode(kinesis_data_str.encode()).decode() + + event = { + "Records": [ + { + "eventSourceARN": "", + "kinesis": {"data": encoded_data}, + } + ] + } + + mock_context = Context(trace_id=123, span_id=456) + mock_extract_from_lambda_context.return_value = mock_context + + result = extract_context_from_kinesis_event(event, self.lambda_context) + + mock_dsm_set_checkpoint.assert_called_once_with(None, "kinesis", "") + mock_set_consume_checkpoint.assert_not_called() + mock_extract_from_lambda_context.assert_called_once_with(self.lambda_context) + self.assertEqual(result, mock_context) + + @patch("datadog_lambda.tracing.extract_context_from_lambda_context") + @patch("ddtrace.data_streams.set_consume_checkpoint") + def test_kinesis_event_with_exception_accessing_first_record( + self, mock_set_consume_checkpoint, mock_extract_from_lambda_context + ): + event = {"Records": None} + + mock_context = Context(trace_id=123, span_id=456) + mock_extract_from_lambda_context.return_value = mock_context + + result = extract_context_from_kinesis_event(event, self.lambda_context) + mock_set_consume_checkpoint.assert_not_called() + self.assertEqual(result, mock_context) diff --git a/tests/test_wrapper.py b/tests/test_wrapper.py index 09f48c8a..fc081e90 100644 --- a/tests/test_wrapper.py +++ b/tests/test_wrapper.py @@ -73,10 +73,6 @@ def setUp(self): self.mock_dd_lambda_layer_tag = patcher.start() self.addCleanup(patcher.stop) - patcher = patch("datadog_lambda.wrapper.set_dsm_context") - self.mock_set_dsm_context = patcher.start() - self.addCleanup(patcher.stop) - @patch("datadog_lambda.config.Config.trace_enabled", False) def test_datadog_lambda_wrapper(self): @wrapper.datadog_lambda_wrapper @@ -556,62 +552,6 @@ def return_type_test(event, context): self.assertEqual(result, test_result) self.assertFalse(MockPrintExc.called) - def test_set_dsm_context_called_when_DSM_and_tracing_enabled(self): - os.environ["DD_DATA_STREAMS_ENABLED"] = "true" - os.environ["DD_TRACE_ENABLED"] = "true" - - @wrapper.datadog_lambda_wrapper - def lambda_handler(event, context): - return "ok" - - result = lambda_handler({}, get_mock_context()) - self.assertEqual(result, "ok") - self.mock_set_dsm_context.assert_called_once() - - del os.environ["DD_DATA_STREAMS_ENABLED"] - - def test_set_dsm_context_not_called_when_only_DSM_enabled(self): - os.environ["DD_DATA_STREAMS_ENABLED"] = "true" - os.environ["DD_TRACE_ENABLED"] = "false" - - @wrapper.datadog_lambda_wrapper - def lambda_handler(event, context): - return "ok" - - result = lambda_handler({}, get_mock_context()) - self.assertEqual(result, "ok") - self.mock_set_dsm_context.assert_not_called() - - del os.environ["DD_DATA_STREAMS_ENABLED"] - - def test_set_dsm_context_not_called_when_only_tracing_enabled(self): - os.environ["DD_DATA_STREAMS_ENABLED"] = "false" - os.environ["DD_TRACE_ENABLED"] = "true" - - @wrapper.datadog_lambda_wrapper - def lambda_handler(event, context): - return "ok" - - result = lambda_handler({}, get_mock_context()) - self.assertEqual(result, "ok") - self.mock_set_dsm_context.assert_not_called() - - del os.environ["DD_DATA_STREAMS_ENABLED"] - - def test_set_dsm_context_not_called_when_tracing_and_DSM_disabled(self): - os.environ["DD_DATA_STREAMS_ENABLED"] = "false" - os.environ["DD_TRACE_ENABLED"] = "false" - - @wrapper.datadog_lambda_wrapper - def lambda_handler(event, context): - return "ok" - - result = lambda_handler({}, get_mock_context()) - self.assertEqual(result, "ok") - self.mock_set_dsm_context.assert_not_called() - - del os.environ["DD_DATA_STREAMS_ENABLED"] - class TestLambdaWrapperWithTraceContext(unittest.TestCase): xray_root = "1-5e272390-8c398be037738dc042009320"