diff --git a/firebase_admin/db.py b/firebase_admin/db.py index b82a327ed..d42370317 100644 --- a/firebase_admin/db.py +++ b/firebase_admin/db.py @@ -373,8 +373,7 @@ def listen(self, callback): Raises: FirebaseError: If an error occurs while starting the initial HTTP connection. """ - session = _sseclient.KeepAuthSession(self._client.credential) - return self._listen_with_session(callback, session) + return self._listen_with_session(callback) def transaction(self, transaction_update): """Atomically modifies the data at this location. @@ -463,8 +462,11 @@ def order_by_value(self): def _add_suffix(self, suffix='.json'): return self._pathurl + suffix - def _listen_with_session(self, callback, session): + def _listen_with_session(self, callback, session=None): url = self._client.base_url + self._add_suffix() + if not session: + session = self._client.create_listener_session() + try: sse = _sseclient.SSEClient(url, session) return ListenerRegistration(callback, sse) @@ -907,6 +909,7 @@ def __init__(self, credential, base_url, timeout, params=None): super().__init__( credential=credential, base_url=base_url, timeout=timeout, headers={'User-Agent': _USER_AGENT}) + self.credential = credential self.params = params if params else {} def request(self, method, url, **kwargs): @@ -941,6 +944,9 @@ def request(self, method, url, **kwargs): except requests.exceptions.RequestException as error: raise _Client.handle_rtdb_error(error) + def create_listener_session(self): + return _sseclient.KeepAuthSession(self.credential) + @classmethod def handle_rtdb_error(cls, error): """Converts an error encountered while calling RTDB into a FirebaseError.""" diff --git a/tests/test_db.py b/tests/test_db.py index 1743347c5..2989fc030 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -551,6 +551,17 @@ def callback(_): finally: testutils.cleanup_apps() + def test_listener_session(self): + firebase_admin.initialize_app(testutils.MockCredential(), { + 'databaseURL' : 'https://test.firebaseio.com', + }) + try: + ref = db.reference() + session = ref._client.create_listener_session() + assert isinstance(session, _sseclient.KeepAuthSession) + finally: + testutils.cleanup_apps() + def test_single_event(self): self.events = [] def callback(event):