1
+ from copy import deepcopy
1
2
import logging
2
3
from typing import Any , Dict , List , Optional
3
4
15
16
logger = logging .getLogger (__name__ )
16
17
17
18
18
- def _to_single_value_headers (request_headers : Dict [str , List [str ]]) -> Dict [str , str ]:
19
+ def _to_single_value_headers (headers : Dict [str , List [str ]]) -> Dict [str , str ]:
19
20
"""
20
21
Convert multi-value headers to single-value headers.
21
- If a header has multiple values, the first value is used .
22
+ If a header has multiple values, join them with commas .
22
23
"""
23
24
single_value_headers = {}
24
- for key , values in request_headers .items ():
25
- if len (values ) >= 1 :
26
- single_value_headers [key ] = values [0 ]
25
+ for key , values in headers .items ():
26
+ single_value_headers [key ] = ", " .join (values )
27
27
return single_value_headers
28
28
29
29
30
+ def _merge_single_and_multi_value_headers (
31
+ single_value_headers : Dict [str , str ],
32
+ multi_value_headers : Dict [str , List [str ]],
33
+ ):
34
+ """
35
+ Merge single-value headers with multi-value headers.
36
+ If a header exists in both, we merge them removing duplicates
37
+ """
38
+ merged_headers = deepcopy (multi_value_headers )
39
+ for key , value in single_value_headers .items ():
40
+ if key not in merged_headers :
41
+ merged_headers [key ] = [value ]
42
+ elif value not in merged_headers [key ]:
43
+ merged_headers [key ].append (value )
44
+ return _to_single_value_headers (merged_headers )
45
+
46
+
30
47
def asm_start_request (
31
48
span : Span ,
32
49
event : Dict [str , Any ],
@@ -36,6 +53,7 @@ def asm_start_request(
36
53
request_headers : Dict [str , str ] = {}
37
54
peer_ip : Optional [str ] = None
38
55
request_path_parameters : Optional [Dict [str , Any ]] = None
56
+ route : Optional [str ] = None
39
57
40
58
if event_source .event_type == EventTypes .ALB :
41
59
headers = event .get ("headers" )
@@ -59,11 +77,10 @@ def asm_start_request(
59
77
elif event_source .event_type == EventTypes .API_GATEWAY :
60
78
request_context = event .get ("requestContext" , {})
61
79
request_path_parameters = event .get ("pathParameters" )
80
+ route = trigger_tags .get ("http.route" )
62
81
63
82
if event_source .subtype == EventSubtypes .API_GATEWAY :
64
- request_headers = _to_single_value_headers (
65
- event .get ("multiValueHeaders" , {})
66
- )
83
+ request_headers = event .get ("headers" , {})
67
84
peer_ip = request_context .get ("identity" , {}).get ("sourceIp" )
68
85
raw_uri = event .get ("path" )
69
86
parsed_query = event .get ("multiValueQueryStringParameters" )
@@ -105,7 +122,7 @@ def asm_start_request(
105
122
body ,
106
123
is_base64_encoded ,
107
124
raw_uri ,
108
- trigger_tags . get ( "http. route" ) ,
125
+ route ,
109
126
trigger_tags .get ("http.method" ),
110
127
parsed_query ,
111
128
request_path_parameters ,
@@ -122,9 +139,14 @@ def asm_start_response(
122
139
if event_source .event_type not in _http_event_types :
123
140
return
124
141
125
- response_headers = response .get ("headers" , {})
126
- if not isinstance (response_headers , dict ):
127
- response_headers = {}
142
+ headers = response .get ("headers" , {})
143
+ multi_value_request_headers = response .get ("multiValueHeaders" )
144
+ if multi_value_request_headers :
145
+ response_headers = _merge_single_and_multi_value_headers (
146
+ headers , multi_value_request_headers
147
+ )
148
+ else :
149
+ response_headers = headers
128
150
129
151
core .dispatch (
130
152
"aws_lambda.start_response" ,
0 commit comments