diff --git a/datadog_lambda/tracing.py b/datadog_lambda/tracing.py index 7352ff55..efe153b2 100644 --- a/datadog_lambda/tracing.py +++ b/datadog_lambda/tracing.py @@ -165,7 +165,20 @@ def extract_context_from_sqs_event_or_context(event, lambda_context): return extract_context_from_lambda_context(lambda_context) -def extract_dd_trace_context(event, lambda_context): +def extract_context_custom_extractor(extractor, event, lambda_context): + """ + Extract Datadog trace context using a custom trace extractor function + """ + try: + (trace_id, parent_id, sampling_priority,) = extractor(event, lambda_context) + return trace_id, parent_id, sampling_priority + except Exception as e: + logger.debug("The trace extractor returned with error %s", e) + + return None, None, None + + +def extract_dd_trace_context(event, lambda_context, extractor=None): """ Extract Datadog trace context from the Lambda `event` object. @@ -175,7 +188,11 @@ def extract_dd_trace_context(event, lambda_context): global dd_trace_context trace_context_source = None - if "headers" in event: + if extractor is not None: + (trace_id, parent_id, sampling_priority,) = extract_context_custom_extractor( + extractor, event, lambda_context + ) + elif "headers" in event: ( trace_id, parent_id, diff --git a/datadog_lambda/wrapper.py b/datadog_lambda/wrapper.py index 2b969f43..52390dd0 100644 --- a/datadog_lambda/wrapper.py +++ b/datadog_lambda/wrapper.py @@ -6,6 +6,7 @@ import os import logging import traceback +from importlib import import_module from datadog_lambda.extension import should_use_extension, flush_extension from datadog_lambda.cold_start import set_cold_start, is_cold_start @@ -15,6 +16,7 @@ submit_invocations_metric, submit_errors_metric, ) +from datadog_lambda.module_name import modify_module_name from datadog_lambda.patch import patch_all from datadog_lambda.tracing import ( extract_dd_trace_context, @@ -91,6 +93,16 @@ def __init__(self, func): os.environ.get("DD_MERGE_XRAY_TRACES", "false").lower() == "true" ) self.function_name = os.environ.get("AWS_LAMBDA_FUNCTION_NAME", "function") + self.extractor_env = os.environ.get("DD_TRACE_EXTRACTOR", None) + self.trace_extractor = None + + if self.extractor_env: + extractor_parts = self.extractor_env.rsplit(".", 1) + if len(extractor_parts) == 2: + (mod_name, extractor_name) = extractor_parts + modified_extractor_name = modify_module_name(mod_name) + extractor_module = import_module(modified_extractor_name) + self.trace_extractor = getattr(extractor_module, extractor_name) # Inject trace correlation ids to logs if self.logs_injection: @@ -125,7 +137,9 @@ def _before(self, event, context): set_cold_start() submit_invocations_metric(context) # Extract Datadog trace context and source from incoming requests - dd_context, trace_context_source = extract_dd_trace_context(event, context) + dd_context, trace_context_source = extract_dd_trace_context( + event, context, extractor=self.trace_extractor + ) # Create a Datadog X-Ray subsegment with the trace context if dd_context and trace_context_source == TraceContextSource.EVENT: create_dd_dummy_metadata_subsegment( diff --git a/tests/integration/serverless-plugin.yml b/tests/integration/serverless-plugin.yml index 2a69fa05..386429e1 100644 --- a/tests/integration/serverless-plugin.yml +++ b/tests/integration/serverless-plugin.yml @@ -6,6 +6,7 @@ provider: DD_INTEGRATION_TEST: true DD_API_KEY: ${env:DD_API_KEY} WITH_PLUGIN: true + lambdaHashingVersion: 20201221 layers: python27: diff --git a/tests/integration/serverless.yml b/tests/integration/serverless.yml index e7badc55..cc605397 100644 --- a/tests/integration/serverless.yml +++ b/tests/integration/serverless.yml @@ -8,6 +8,7 @@ provider: environment: DD_INTEGRATION_TEST: true DD_API_KEY: ${env:DD_API_KEY} + lambdaHashingVersion: 20201221 layers: python27: diff --git a/tests/test_tracing.py b/tests/test_tracing.py index 92b6e68b..bf8ebba1 100644 --- a/tests/test_tracing.py +++ b/tests/test_tracing.py @@ -136,6 +136,70 @@ def test_with_complete_datadog_trace_headers(self): XraySubsegment.NAMESPACE, ) + def test_with_extractor_function(self): + def extractor_foo(event, context): + foo = event.get("foo", {}) + lowercase_foo = {k.lower(): v for k, v in foo.items()} + + trace_id = lowercase_foo.get(TraceHeader.TRACE_ID) + parent_id = lowercase_foo.get(TraceHeader.PARENT_ID) + sampling_priority = lowercase_foo.get(TraceHeader.SAMPLING_PRIORITY) + return trace_id, parent_id, sampling_priority + + lambda_ctx = get_mock_context() + ctx, ctx_source = extract_dd_trace_context( + { + "foo": { + TraceHeader.TRACE_ID: "123", + TraceHeader.PARENT_ID: "321", + TraceHeader.SAMPLING_PRIORITY: "1", + } + }, + lambda_ctx, + extractor=extractor_foo, + ) + self.assertEquals(ctx_source, "event") + self.assertDictEqual( + ctx, {"trace-id": "123", "parent-id": "321", "sampling-priority": "1",}, + ) + self.assertDictEqual( + get_dd_trace_context(), + { + TraceHeader.TRACE_ID: "123", + TraceHeader.PARENT_ID: "65535", + TraceHeader.SAMPLING_PRIORITY: "1", + }, + ) + + def test_graceful_fail_of_extractor_function(self): + def extractor_raiser(event, context): + raise Exception("kreator") + + lambda_ctx = get_mock_context() + ctx, ctx_source = extract_dd_trace_context( + { + "foo": { + TraceHeader.TRACE_ID: "123", + TraceHeader.PARENT_ID: "321", + TraceHeader.SAMPLING_PRIORITY: "1", + } + }, + lambda_ctx, + extractor=extractor_raiser, + ) + self.assertEquals(ctx_source, "xray") + self.assertDictEqual( + ctx, {"trace-id": "4369", "parent-id": "65535", "sampling-priority": "2",}, + ) + self.assertDictEqual( + get_dd_trace_context(), + { + TraceHeader.TRACE_ID: "4369", + TraceHeader.PARENT_ID: "65535", + TraceHeader.SAMPLING_PRIORITY: "2", + }, + ) + def test_with_sqs_distributed_datadog_trace_data(self): lambda_ctx = get_mock_context() sqs_event = { diff --git a/tests/test_wrapper.py b/tests/test_wrapper.py index 2522702b..6370c47b 100644 --- a/tests/test_wrapper.py +++ b/tests/test_wrapper.py @@ -100,7 +100,7 @@ def lambda_handler(event, context): ) self.mock_wrapper_lambda_stats.flush.assert_called() self.mock_extract_dd_trace_context.assert_called_with( - lambda_event, lambda_context + lambda_event, lambda_context, extractor=None ) self.mock_set_correlation_ids.assert_called() self.mock_inject_correlation_ids.assert_called()