diff --git a/datadog_lambda/tracing.py b/datadog_lambda/tracing.py index 19ef8c04..67836e46 100644 --- a/datadog_lambda/tracing.py +++ b/datadog_lambda/tracing.py @@ -2,7 +2,7 @@ # under the Apache License Version 2.0. # This product includes software developed at Datadog (https://www.datadoghq.com/). # Copyright 2019 Datadog, Inc. - +import hashlib import logging import os import json @@ -328,6 +328,39 @@ def extract_context_from_kinesis_event(event, lambda_context): return extract_context_from_lambda_context(lambda_context) +def _deterministic_md5_hash(s: str) -> str: + """MD5 here is to generate trace_id, not for any encryption.""" + hex_number = hashlib.md5(s.encode("ascii")).hexdigest() + binary = bin(int(hex_number, 16)) + binary_str = str(binary) + binary_str_remove_0b = binary_str[2:].rjust(128, "0") + most_significant_64_bits_without_leading_1 = "0" + binary_str_remove_0b[1:-64] + result = str(int(most_significant_64_bits_without_leading_1, 2)) + if result == "0" * 64: + return "1" + return result + + +def extract_context_from_step_functions(event, lambda_context): + """ + Only extract datadog trace context when Step Functions Context Object is injected + into lambda's event dict. + """ + try: + execution_id = event.get("Execution").get("Id") + state_name = event.get("State").get("Name") + state_entered_time = event.get("State").get("EnteredTime") + trace_id = _deterministic_md5_hash(execution_id) + parent_id = _deterministic_md5_hash( + execution_id + "#" + state_name + "#" + state_entered_time + ) + sampling_priority = SamplingPriority.AUTO_KEEP + return trace_id, parent_id, sampling_priority + except Exception as e: + logger.debug("The Step Functions trace extractor returned with error %s", e) + return extract_context_from_lambda_context(lambda_context) + + def extract_context_custom_extractor(extractor, event, lambda_context): """ Extract Datadog trace context using a custom trace extractor function @@ -440,6 +473,12 @@ def extract_dd_trace_context( parent_id, sampling_priority, ) = extract_context_from_kinesis_event(event, lambda_context) + elif event_source.equals(EventTypes.STEPFUNCTIONS): + ( + trace_id, + parent_id, + sampling_priority, + ) = extract_context_from_step_functions(event, lambda_context) else: trace_id, parent_id, sampling_priority = extract_context_from_lambda_context( lambda_context diff --git a/datadog_lambda/trigger.py b/datadog_lambda/trigger.py index 0576e3f9..bbb44b30 100644 --- a/datadog_lambda/trigger.py +++ b/datadog_lambda/trigger.py @@ -34,12 +34,13 @@ class EventTypes(_stringTypedEnum): CLOUDWATCH_EVENTS = "cloudwatch-events" CLOUDFRONT = "cloudfront" DYNAMODB = "dynamodb" + EVENTBRIDGE = "eventbridge" KINESIS = "kinesis" LAMBDA_FUNCTION_URL = "lambda-function-url" S3 = "s3" SNS = "sns" SQS = "sqs" - EVENTBRIDGE = "eventbridge" + STEPFUNCTIONS = "states" class EventSubtypes(_stringTypedEnum): @@ -145,6 +146,9 @@ def parse_event_source(event: dict) -> _EventSource: if event.get("source") == "aws.events" or has_event_categories: event_source = _EventSource(EventTypes.CLOUDWATCH_EVENTS) + if "Execution" in event and "StateMachine" in event and "State" in event: + event_source = _EventSource(EventTypes.STEPFUNCTIONS) + event_record = get_first_record(event) if event_record: aws_event_source = event_record.get( diff --git a/tests/test_tracing.py b/tests/test_tracing.py index 25865d5e..e19c66aa 100644 --- a/tests/test_tracing.py +++ b/tests/test_tracing.py @@ -5,7 +5,7 @@ from unittest.mock import MagicMock, Mock, patch, call import ddtrace -from ddtrace.constants import ERROR_MSG, ERROR_TYPE + from ddtrace import tracer from ddtrace.context import Context @@ -15,6 +15,7 @@ XraySubsegment, ) from datadog_lambda.tracing import ( + _deterministic_md5_hash, create_inferred_span, extract_dd_trace_context, create_dd_dummy_metadata_subsegment, @@ -1334,9 +1335,7 @@ def test_create_inferred_span_from_api_gateway_event_no_apiid(self): event = json.load(event) ctx = get_mock_context() ctx.aws_request_id = "123" - print(event) span = create_inferred_span(event, ctx) - print(span) self.assertEqual(span.get_tag("operation_name"), "aws.apigateway.rest") self.assertEqual( span.service, @@ -1389,3 +1388,24 @@ def test_no_error_with_nonetype_headers(self): lambda_ctx, ) self.assertEqual(ctx, None) + + +class TestStepFunctionsTraceContext(unittest.TestCase): + def test_deterministic_m5_hash(self): + result = _deterministic_md5_hash("some_testing_random_string") + self.assertEqual("2251275791555400689", result) + + def test_deterministic_m5_hash__result_the_same_as_backend(self): + result = _deterministic_md5_hash( + "arn:aws:states:sa-east-1:601427271234:express:DatadogStateMachine:acaf1a67-336a-e854-1599-2a627eb2dd8a" + ":c8baf081-31f1-464d-971f-70cb17d01111#step-one#2022-12-08T21:08:19.224Z" + ) + self.assertEqual("8034507082463708833", result) + + def test_deterministic_m5_hash__always_leading_with_zero(self): + for i in range(100): + result = _deterministic_md5_hash(str(i)) + result_in_binary = bin(int(result)) + # Leading zeros will be omitted, so only test for full 64 bits present + if len(result_in_binary) == 66: # "0b" + 64 bits. + self.assertTrue(result_in_binary.startswith("0b0"))