diff --git a/Dockerfile b/Dockerfile index 0e79d884..c5824528 100644 --- a/Dockerfile +++ b/Dockerfile @@ -21,7 +21,6 @@ RUN pip install --no-cache-dir . -t ./python/lib/$runtime/site-packages RUN rm -rf ./python/lib/$runtime/site-packages/botocore* RUN rm -rf ./python/lib/$runtime/site-packages/setuptools RUN rm -rf ./python/lib/$runtime/site-packages/jsonschema/tests -RUN find . -name 'libddwaf.so' -delete RUN rm -f ./python/lib/$runtime/site-packages/ddtrace/appsec/_iast/_taint_tracking/*.so RUN rm -f ./python/lib/$runtime/site-packages/ddtrace/appsec/_iast/_stacktrace*.so RUN rm -f ./python/lib/$runtime/site-packages/ddtrace/internal/datadog/profiling/libdd_wrapper*.so diff --git a/datadog_lambda/asm.py b/datadog_lambda/asm.py new file mode 100644 index 00000000..aab0f1e9 --- /dev/null +++ b/datadog_lambda/asm.py @@ -0,0 +1,174 @@ +from copy import deepcopy +import logging +from typing import Any, Dict, List, Optional, Union + +from ddtrace.contrib.internal.trace_utils import _get_request_header_client_ip +from ddtrace.internal import core +from ddtrace.trace import Span + +from datadog_lambda.trigger import ( + EventSubtypes, + EventTypes, + _EventSource, + _http_event_types, +) + +logger = logging.getLogger(__name__) + + +def _to_single_value_headers(headers: Dict[str, List[str]]) -> Dict[str, str]: + """ + Convert multi-value headers to single-value headers. + If a header has multiple values, join them with commas. + """ + single_value_headers = {} + for key, values in headers.items(): + single_value_headers[key] = ", ".join(values) + return single_value_headers + + +def _merge_single_and_multi_value_headers( + single_value_headers: Dict[str, str], + multi_value_headers: Dict[str, List[str]], +): + """ + Merge single-value headers with multi-value headers. + If a header exists in both, we merge them removing duplicates + """ + merged_headers = deepcopy(multi_value_headers) + for key, value in single_value_headers.items(): + if key not in merged_headers: + merged_headers[key] = [value] + elif value not in merged_headers[key]: + merged_headers[key].append(value) + return _to_single_value_headers(merged_headers) + + +def asm_start_request( + span: Span, + event: Dict[str, Any], + event_source: _EventSource, + trigger_tags: Dict[str, str], +): + if event_source.event_type not in _http_event_types: + return + + request_headers: Dict[str, str] = {} + peer_ip: Optional[str] = None + request_path_parameters: Optional[Dict[str, Any]] = None + route: Optional[str] = None + + if event_source.event_type == EventTypes.ALB: + headers = event.get("headers") + multi_value_request_headers = event.get("multiValueHeaders") + if multi_value_request_headers: + request_headers = _to_single_value_headers(multi_value_request_headers) + else: + request_headers = headers or {} + + raw_uri = event.get("path") + parsed_query = event.get("multiValueQueryStringParameters") or event.get( + "queryStringParameters" + ) + + elif event_source.event_type == EventTypes.LAMBDA_FUNCTION_URL: + request_headers = event.get("headers", {}) + peer_ip = event.get("requestContext", {}).get("http", {}).get("sourceIp") + raw_uri = event.get("rawPath") + parsed_query = event.get("queryStringParameters") + + elif event_source.event_type == EventTypes.API_GATEWAY: + request_context = event.get("requestContext", {}) + request_path_parameters = event.get("pathParameters") + route = trigger_tags.get("http.route") + + if event_source.subtype == EventSubtypes.API_GATEWAY: + request_headers = event.get("headers", {}) + peer_ip = request_context.get("identity", {}).get("sourceIp") + raw_uri = event.get("path") + parsed_query = event.get("multiValueQueryStringParameters") + + elif event_source.subtype == EventSubtypes.HTTP_API: + request_headers = event.get("headers", {}) + peer_ip = request_context.get("http", {}).get("sourceIp") + raw_uri = event.get("rawPath") + parsed_query = event.get("queryStringParameters") + + elif event_source.subtype == EventSubtypes.WEBSOCKET: + request_headers = _to_single_value_headers( + event.get("multiValueHeaders", {}) + ) + peer_ip = request_context.get("identity", {}).get("sourceIp") + raw_uri = event.get("path") + parsed_query = event.get("multiValueQueryStringParameters") + + else: + return + + else: + return + + body = event.get("body") + is_base64_encoded = event.get("isBase64Encoded", False) + + request_ip = _get_request_header_client_ip(request_headers, peer_ip, True) + if request_ip is not None: + span.set_tag_str("http.client_ip", request_ip) + span.set_tag_str("network.client.ip", request_ip) + + core.dispatch( + # The matching listener is registered in ddtrace.appsec._handlers + "aws_lambda.start_request", + ( + span, + request_headers, + request_ip, + body, + is_base64_encoded, + raw_uri, + route, + trigger_tags.get("http.method"), + parsed_query, + request_path_parameters, + ), + ) + + +def asm_start_response( + span: Span, + status_code: str, + event_source: _EventSource, + response: Union[Dict[str, Any], str, None], +): + if event_source.event_type not in _http_event_types: + return + + if isinstance(response, dict) and ( + "headers" in response or "multiValueHeaders" in response + ): + headers = response.get("headers", {}) + multi_value_request_headers = response.get("multiValueHeaders") + if isinstance(multi_value_request_headers, dict) and isinstance(headers, dict): + response_headers = _merge_single_and_multi_value_headers( + headers, multi_value_request_headers + ) + elif isinstance(headers, dict): + response_headers = headers + else: + response_headers = { + "content-type": "application/json", + } + else: + response_headers = { + "content-type": "application/json", + } + + core.dispatch( + # The matching listener is registered in ddtrace.appsec._handlers + "aws_lambda.start_response", + ( + span, + status_code, + response_headers, + ), + ) diff --git a/datadog_lambda/config.py b/datadog_lambda/config.py index 7a08d8a7..aaa1af5e 100644 --- a/datadog_lambda/config.py +++ b/datadog_lambda/config.py @@ -95,6 +95,7 @@ def _resolve_env(self, key, default=None, cast=None, depends_on_tracing=False): data_streams_enabled = _get_env( "DD_DATA_STREAMS_ENABLED", "false", as_bool, depends_on_tracing=True ) + appsec_enabled = _get_env("DD_APPSEC_ENABLED", "false", as_bool) is_gov_region = _get_env("AWS_REGION", "", lambda x: x.startswith("us-gov-")) diff --git a/datadog_lambda/tracing.py b/datadog_lambda/tracing.py index 3d5f671e..89a4126b 100644 --- a/datadog_lambda/tracing.py +++ b/datadog_lambda/tracing.py @@ -859,7 +859,7 @@ def create_inferred_span_from_lambda_function_url_event(event, context): InferredSpanInfo.set_tags(tags, tag_source="self", synchronicity="sync") if span: span.set_tags(tags) - span.start_ns = int(request_time_epoch) * 1e6 + span.start_ns = int(request_time_epoch * 1e6) return span diff --git a/datadog_lambda/wrapper.py b/datadog_lambda/wrapper.py index 87063411..29972bf4 100644 --- a/datadog_lambda/wrapper.py +++ b/datadog_lambda/wrapper.py @@ -9,6 +9,7 @@ from importlib import import_module from time import time_ns +from datadog_lambda.asm import asm_start_response, asm_start_request from datadog_lambda.dsm import set_dsm_context from datadog_lambda.extension import should_use_extension, flush_extension from datadog_lambda.cold_start import ( @@ -253,6 +254,8 @@ def _before(self, event, context): parent_span=self.inferred_span, span_pointers=calculate_span_pointers(event_source, event), ) + if config.appsec_enabled: + asm_start_request(self.span, event, event_source, self.trigger_tags) else: set_correlation_ids() if config.profiling_enabled and is_new_sandbox(): @@ -285,6 +288,15 @@ def _after(self, event, context): if status_code: self.span.set_tag("http.status_code", status_code) + + if config.appsec_enabled: + asm_start_response( + self.span, + status_code, + self.event_source, + response=self.response, + ) + self.span.finish() if self.inferred_span: diff --git a/scripts/check_layer_size.sh b/scripts/check_layer_size.sh index 90d5861b..626f9d31 100755 --- a/scripts/check_layer_size.sh +++ b/scripts/check_layer_size.sh @@ -8,8 +8,8 @@ # Compares layer size to threshold, and fails if below that threshold set -e -MAX_LAYER_COMPRESSED_SIZE_KB=$(expr 5 \* 1024) -MAX_LAYER_UNCOMPRESSED_SIZE_KB=$(expr 13 \* 1024) +MAX_LAYER_COMPRESSED_SIZE_KB=$(expr 6 \* 1024) +MAX_LAYER_UNCOMPRESSED_SIZE_KB=$(expr 15 \* 1024) LAYER_FILES_PREFIX="datadog_lambda_py" diff --git a/tests/event_samples/application-load-balancer-mutivalue-headers.json b/tests/event_samples/application-load-balancer-mutivalue-headers.json new file mode 100644 index 00000000..6d446d15 --- /dev/null +++ b/tests/event_samples/application-load-balancer-mutivalue-headers.json @@ -0,0 +1,31 @@ +{ + "requestContext": { + "elb": { + "targetGroupArn": "arn:aws:elasticloadbalancing:us-east-2:123456789012:targetgroup/lambda-xyz/123abc" + } + }, + "httpMethod": "GET", + "path": "/lambda", + "queryStringParameters": { + "query": "1234ABCD" + }, + "multiValueHeaders": { + "accept": ["text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8"], + "accept-encoding": ["gzip"], + "accept-language": ["en-US,en;q=0.9"], + "connection": ["keep-alive"], + "host": ["lambda-alb-123578498.us-east-2.elb.amazonaws.com"], + "upgrade-insecure-requests": ["1"], + "user-agent": ["Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/71.0.3578.98 Safari/537.36"], + "x-amzn-trace-id": ["Root=1-5c536348-3d683b8b04734faae651f476"], + "x-forwarded-for": ["72.12.164.125"], + "x-forwarded-port": ["80"], + "x-forwarded-proto": ["http"], + "x-imforwards": ["20"], + "x-datadog-trace-id": ["12345"], + "x-datadog-parent-id": ["67890"], + "x-datadog-sampling-priority": ["2"] + }, + "body": "", + "isBase64Encoded": false +} diff --git a/tests/test_asm.py b/tests/test_asm.py new file mode 100644 index 00000000..e57c289f --- /dev/null +++ b/tests/test_asm.py @@ -0,0 +1,329 @@ +import json +import pytest +from unittest.mock import MagicMock, patch + +from datadog_lambda.asm import asm_start_request, asm_start_response +from datadog_lambda.trigger import parse_event_source, extract_trigger_tags +from tests.utils import get_mock_context + +event_samples = "tests/event_samples/" + + +# Test cases for ASM start request +ASM_START_REQUEST_TEST_CASES = [ + ( + "application_load_balancer", + "application-load-balancer.json", + "72.12.164.125", + "/lambda", + "GET", + "", + False, + {"query": "1234ABCD"}, + None, + None, + ), + ( + "application_load_balancer_multivalue_headers", + "application-load-balancer-mutivalue-headers.json", + "72.12.164.125", + "/lambda", + "GET", + "", + False, + {"query": "1234ABCD"}, + None, + None, + ), + ( + "lambda_function_url", + "lambda-url.json", + "71.195.30.42", + "/", + "GET", + None, + False, + None, + None, + None, + ), + ( + "api_gateway", + "api-gateway.json", + "127.0.0.1", + "/path/to/resource", + "POST", + "eyJ0ZXN0IjoiYm9keSJ9", + True, + {"foo": ["bar"]}, + {"proxy": "/path/to/resource"}, + "/{proxy+}", + ), + ( + "api_gateway_v2_parametrized", + "api-gateway-v2-parametrized.json", + "76.115.124.192", + "/user/42", + "GET", + None, + False, + None, + {"id": "42"}, + "/user/{id}", + ), + ( + "api_gateway_websocket", + "api-gateway-websocket-default.json", + "38.122.226.210", + None, + None, + '"What\'s good in the hood?"', + False, + None, + None, + None, + ), +] + + +# Test cases for ASM start response +ASM_START_RESPONSE_TEST_CASES = [ + ( + "application_load_balancer", + "application-load-balancer.json", + { + "statusCode": 200, + "headers": {"Content-Type": "text/html"}, + }, + "200", + {"Content-Type": "text/html"}, + True, + ), + ( + "application_load_balancer_multivalue_headers", + "application-load-balancer-mutivalue-headers.json", + { + "statusCode": 404, + "multiValueHeaders": { + "Content-Type": ["text/plain"], + "X-Error": ["Not Found"], + }, + }, + "404", + { + "Content-Type": "text/plain", + "X-Error": "Not Found", + }, + True, + ), + ( + "lambda_function_url", + "lambda-url.json", + { + "statusCode": 201, + "headers": { + "Location": "/user/123", + "Content-Type": "application/json", + }, + }, + "201", + { + "Location": "/user/123", + "Content-Type": "application/json", + }, + True, + ), + ( + "api_gateway", + "api-gateway.json", + { + "statusCode": 200, + "headers": { + "Content-Type": "application/json", + "X-Custom-Header": "test-value", + }, + "body": '{"message": "success"}', + }, + "200", + { + "Content-Type": "application/json", + "X-Custom-Header": "test-value", + }, + True, + ), + ( + "api_gateway_v2_parametrized", + "api-gateway-v2-parametrized.json", + { + "statusCode": 200, + "headers": {"Content-Type": "application/json"}, + }, + "200", + {"Content-Type": "application/json"}, + True, + ), + ( + "api_gateway_websocket", + "api-gateway-websocket-default.json", + { + "statusCode": 200, + "headers": {"Content-Type": "text/plain"}, + }, + "200", + {"Content-Type": "text/plain"}, + True, + ), + ( + "non_http_event_s3", + "s3.json", + {"statusCode": 200}, + "200", + {}, + False, # Should not dispatch for non-HTTP events + ), + ( + "api_gateway_v2_string_response", + "api-gateway-v2-parametrized.json", + "Hello, World!", + "200", + {"content-type": "application/json"}, + True, + ), + ( + "api_gateway_v2_dict_response", + "api-gateway-v2-parametrized.json", + {"message": "Hello, World!"}, + "200", + {"content-type": "application/json"}, + True, + ), +] + + +@pytest.mark.parametrize( + "name,file,expected_ip,expected_uri,expected_method,expected_body,expected_base64,expected_query,expected_path_params,expected_route", + ASM_START_REQUEST_TEST_CASES, +) +@patch("datadog_lambda.asm.core") +def test_asm_start_request_parametrized( + mock_core, + name, + file, + expected_ip, + expected_uri, + expected_method, + expected_body, + expected_base64, + expected_query, + expected_path_params, + expected_route, +): + """Test ASM start request for various HTTP event types using parametrization""" + mock_span = MagicMock() + ctx = get_mock_context() + + # Reset mock for each test + mock_core.reset_mock() + mock_span.reset_mock() + + test_file = event_samples + file + with open(test_file, "r") as f: + event = json.load(f) + + event_source = parse_event_source(event) + trigger_tags = extract_trigger_tags(event, ctx) + + asm_start_request(mock_span, event, event_source, trigger_tags) + + # Verify core.dispatch was called + mock_core.dispatch.assert_called_once() + call_args = mock_core.dispatch.call_args + dispatch_args = call_args[0][1] + ( + span, + request_headers, + request_ip, + body, + is_base64_encoded, + raw_uri, + http_route, + http_method, + parsed_query, + request_path_parameters, + ) = dispatch_args + + # Common assertions + assert span == mock_span + assert isinstance(request_headers, dict) + + # Specific assertions based on test case + assert request_ip == expected_ip + assert raw_uri == expected_uri + assert http_method == expected_method + assert body == expected_body + assert is_base64_encoded == expected_base64 + + if expected_query is not None: + assert parsed_query == expected_query + else: + assert parsed_query is None + + if expected_path_params is not None: + assert request_path_parameters == expected_path_params + else: + assert request_path_parameters is None + + # Check route is correctly extracted and passed + assert http_route == expected_route + + # Check IP tags were set if IP is present + if expected_ip: + mock_span.set_tag_str.assert_any_call("http.client_ip", expected_ip) + mock_span.set_tag_str.assert_any_call("network.client.ip", expected_ip) + + +@pytest.mark.parametrize( + "name,event_file,response,status_code,expected_headers,should_dispatch", + ASM_START_RESPONSE_TEST_CASES, +) +@patch("datadog_lambda.asm.core") +def test_asm_start_response_parametrized( + mock_core, + name, + event_file, + response, + status_code, + expected_headers, + should_dispatch, +): + """Test ASM start response for various HTTP event types using parametrization""" + mock_span = MagicMock() + + # Reset mock for each test + mock_core.reset_mock() + mock_span.reset_mock() + + test_file = event_samples + event_file + with open(test_file, "r") as f: + event = json.load(f) + + event_source = parse_event_source(event) + + asm_start_response(mock_span, status_code, event_source, response) + + if should_dispatch: + # Verify core.dispatch was called + mock_core.dispatch.assert_called_once() + call_args = mock_core.dispatch.call_args + assert call_args[0][0] == "aws_lambda.start_response" + + # Extract the dispatched arguments + dispatch_args = call_args[0][1] + span, response_status_code, response_headers = dispatch_args + + assert span == mock_span + assert response_status_code == status_code + assert response_headers == expected_headers + else: + # Verify core.dispatch was not called for non-HTTP events + mock_core.dispatch.assert_not_called()