diff --git a/firebase_admin/_token_gen.py b/firebase_admin/_token_gen.py index 562e77fa5..135573c01 100644 --- a/firebase_admin/_token_gen.py +++ b/firebase_admin/_token_gen.py @@ -29,6 +29,7 @@ from firebase_admin import exceptions from firebase_admin import _auth_utils +from firebase_admin import _http_client # ID token constants @@ -231,12 +232,37 @@ def create_session_cookie(self, id_token, expires_in): return body.get('sessionCookie') +class CertificateFetchRequest(transport.Request): + """A google-auth transport that supports HTTP cache-control. + + Also injects a timeout to each outgoing HTTP request. + """ + + def __init__(self, timeout_seconds=None): + self._session = cachecontrol.CacheControl(requests.Session()) + self._delegate = transport.requests.Request(self.session) + self._timeout_seconds = timeout_seconds + + @property + def session(self): + return self._session + + @property + def timeout_seconds(self): + return self._timeout_seconds + + def __call__(self, url, method='GET', body=None, headers=None, timeout=None, **kwargs): + timeout = timeout or self.timeout_seconds + return self._delegate( + url, method=method, body=body, headers=headers, timeout=timeout, **kwargs) + + class TokenVerifier: """Verifies ID tokens and session cookies.""" def __init__(self, app): - session = cachecontrol.CacheControl(requests.Session()) - self.request = transport.requests.Request(session=session) + timeout = app.options.get('httpTimeout', _http_client.DEFAULT_TIMEOUT_SECONDS) + self.request = CertificateFetchRequest(timeout) self.id_token_verifier = _JWTVerifier( project_id=app.project_id, short_name='ID token', operation='verify_id_token()', diff --git a/tests/test_token_gen.py b/tests/test_token_gen.py index 29c70da80..d8450c59c 100644 --- a/tests/test_token_gen.py +++ b/tests/test_token_gen.py @@ -31,6 +31,7 @@ from firebase_admin import auth from firebase_admin import credentials from firebase_admin import exceptions +from firebase_admin import _http_client from firebase_admin import _token_gen from tests import testutils @@ -702,3 +703,52 @@ def test_certificate_caching(self, user_mgt_app, httpserver): assert len(httpserver.requests) == request_count verifier.verify_id_token(TEST_ID_TOKEN) assert len(httpserver.requests) == request_count + + +class TestCertificateFetchTimeout: + + timeout_configs = [ + ({'httpTimeout': 4}, 4), + ({'httpTimeout': None}, None), + ({}, _http_client.DEFAULT_TIMEOUT_SECONDS), + ] + + @pytest.mark.parametrize('options, timeout', timeout_configs) + def test_init_request(self, options, timeout): + app = firebase_admin.initialize_app(MOCK_CREDENTIAL, options=options) + + client = auth._get_client(app) + request = client._token_verifier.request + + assert isinstance(request, _token_gen.CertificateFetchRequest) + assert request.timeout_seconds == timeout + + @pytest.mark.parametrize('options, timeout', timeout_configs) + def test_verify_id_token_timeout(self, options, timeout): + app = firebase_admin.initialize_app(MOCK_CREDENTIAL, options=options) + recorder = self._instrument_session(app) + + auth.verify_id_token(TEST_ID_TOKEN) + + assert len(recorder) == 1 + assert recorder[0]._extra_kwargs['timeout'] == timeout + + @pytest.mark.parametrize('options, timeout', timeout_configs) + def test_verify_session_cookie_timeout(self, options, timeout): + app = firebase_admin.initialize_app(MOCK_CREDENTIAL, options=options) + recorder = self._instrument_session(app) + + auth.verify_session_cookie(TEST_SESSION_COOKIE) + + assert len(recorder) == 1 + assert recorder[0]._extra_kwargs['timeout'] == timeout + + def _instrument_session(self, app): + client = auth._get_client(app) + request = client._token_verifier.request + recorder = [] + request.session.mount('https://', testutils.MockAdapter(MOCK_PUBLIC_CERTS, 200, recorder)) + return recorder + + def teardown(self): + testutils.cleanup_apps()