1
1
import unittest
2
2
import json
3
+ import os
3
4
4
5
from unittest .mock import MagicMock , patch
5
6
6
- from datadog_lambda .xray import get_xray_host_port , build_segment_payload , build_segment
7
+ from datadog_lambda .xray import (
8
+ get_xray_host_port ,
9
+ build_segment_payload ,
10
+ build_segment ,
11
+ send_segment ,
12
+ )
7
13
8
14
9
15
class TestXRay (unittest .TestCase ):
16
+ def tearDown (self ):
17
+ if os .environ .get ("_X_AMZN_TRACE_ID" ):
18
+ os .environ .pop ("_X_AMZN_TRACE_ID" )
19
+ if os .environ .get ("AWS_XRAY_DAEMON_ADDRESS" ):
20
+ os .environ .pop ("AWS_XRAY_DAEMON_ADDRESS" )
21
+ return super ().tearDown ()
22
+
10
23
def test_get_xray_host_port_empty_ (self ):
11
24
result = get_xray_host_port ("" )
12
25
self .assertIsNone (result )
@@ -20,6 +33,31 @@ def test_get_xray_host_port_success(self):
20
33
self .assertEqual ("mySuperHost" , result [0 ])
21
34
self .assertEqual (1000 , result [1 ])
22
35
36
+ def test_send_segment_sampled_out (self ):
37
+ os .environ ["AWS_XRAY_DAEMON_ADDRESS" ] = "fake-agent.com:8080"
38
+ os .environ [
39
+ "_X_AMZN_TRACE_ID"
40
+ ] = "Root=1-5e272390-8c398be037738dc042009320;Parent=94ae789b969f1cc5;Sampled=0"
41
+
42
+ with patch (
43
+ "datadog_lambda.xray.send" , MagicMock (return_value = None )
44
+ ) as mock_send :
45
+ # XRay trace won't be sampled according to the trace header.
46
+ send_segment ("my_key" , {"data" : "value" })
47
+ self .assertFalse (mock_send .called )
48
+
49
+ def test_send_segment_sampled (self ):
50
+ os .environ ["AWS_XRAY_DAEMON_ADDRESS" ] = "fake-agent.com:8080"
51
+ os .environ [
52
+ "_X_AMZN_TRACE_ID"
53
+ ] = "Root=1-5e272390-8c398be037738dc042009320;Parent=94ae789b969f1cc5;Sampled=1"
54
+ with patch (
55
+ "datadog_lambda.xray.send" , MagicMock (return_value = None )
56
+ ) as mock_send :
57
+ # X-Ray trace will be sampled according to the trace header.
58
+ send_segment ("my_key" , {"data" : "value" })
59
+ self .assertTrue (mock_send .called )
60
+
23
61
def test_build_segment_payload_ok (self ):
24
62
exected_text = '{"format": "json", "version": 1}\n myPayload'
25
63
self .assertEqual (exected_text , build_segment_payload ("myPayload" ))
0 commit comments