diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..23ee7ef --- /dev/null +++ b/.coveragerc @@ -0,0 +1,8 @@ +# SPDX-FileCopyrightText: 2024 Justin Myers for Adafruit Industries +# +# SPDX-License-Identifier: Unlicense + +[report] +exclude_lines = + # pragma: no cover + if not sys.implementation.name == "circuitpython": diff --git a/.gitignore b/.gitignore index db3d538..a06dc67 100644 --- a/.gitignore +++ b/.gitignore @@ -46,3 +46,10 @@ _build .idea .vscode *~ + +# tox-specific files +.tox +build + +# coverage-specific files +.coverage diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 70ade69..6699562 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,9 +4,14 @@ repos: - repo: https://github.com/python/black - rev: 23.3.0 + rev: 24.2.0 hooks: - id: black + - repo: https://github.com/PyCQA/isort + rev: 5.13.2 + hooks: + - id: isort + args: ["--profile", "black", "--filter-files"] - repo: https://github.com/fsfe/reuse-tool rev: v1.1.2 hooks: @@ -32,11 +37,11 @@ repos: types: [python] files: "^examples/" args: - - --disable=missing-docstring,invalid-name,consider-using-f-string,duplicate-code + - --disable=consider-using-f-string,duplicate-code,missing-docstring,invalid-name, - id: pylint name: pylint (test code) description: Run pylint rules on "tests/*.py" files types: [python] files: "^tests/" args: - - --disable=missing-docstring,consider-using-f-string,duplicate-code + - --disable=consider-using-f-string,duplicate-code,missing-docstring,invalid-name,protected-access diff --git a/LICENSE b/LICENSE index fa6ee38..87fc65e 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ The MIT License (MIT) -Copyright (c) 2023 Justin Myers for Adafruit Industries +Copyright (c) 2024 Justin Myers for Adafruit Industries Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.rst b/README.rst index 6d81dd9..25a7836 100644 --- a/README.rst +++ b/README.rst @@ -36,11 +36,6 @@ This is easily achieved by downloading or individual libraries can be installed using `circup `_. - - -.. todo:: Describe the Adafruit product this library works with. For PCBs, you can also add the -image from the assets folder in the PCB's GitHub repo. - `Purchase one from the Adafruit shop `_ Installing from PyPI @@ -48,8 +43,6 @@ Installing from PyPI .. note:: This library is not available on PyPI yet. Install documentation is included as a standard element. Stay tuned for PyPI availability! -.. todo:: Remove the above note if PyPI version is/will be available at time of release. - On supported GNU/Linux systems like the Raspberry Pi, you can install the driver locally `from PyPI `_. To install for current user: @@ -99,8 +92,11 @@ Or the following command to update an existing version: Usage Example ============= -.. todo:: Add a quick, simple example. It and other examples should live in the -examples folder and be included in docs/examples.rst. +This library is used internally by libraries like `Adafruit_CircuitPython_Requests +`_ and `Adafruit_CircuitPython_MiniMQTT +`_ + +Usage examples are within the `examples` subfolder of this library. Documentation ============= diff --git a/README.rst.license b/README.rst.license index 5963138..f69990d 100644 --- a/README.rst.license +++ b/README.rst.license @@ -1,3 +1,3 @@ SPDX-FileCopyrightText: 2017 Scott Shawcroft, written for Adafruit Industries -SPDX-FileCopyrightText: Copyright (c) 2023 Justin Myers for Adafruit Industries +SPDX-FileCopyrightText: 2024 Justin Myers for Adafruit Industries SPDX-License-Identifier: MIT diff --git a/adafruit_connection_manager.py b/adafruit_connection_manager.py new file mode 100644 index 0000000..cc70f3f --- /dev/null +++ b/adafruit_connection_manager.py @@ -0,0 +1,300 @@ +# SPDX-FileCopyrightText: 2017 Scott Shawcroft, written for Adafruit Industries +# SPDX-FileCopyrightText: 2024 Justin Myers for Adafruit Industries +# +# SPDX-License-Identifier: MIT +""" +`adafruit_connection_manager` +================================================================================ + +A urllib3.poolmanager/urllib3.connectionpool-like library for managing sockets and connections + + +* Author(s): Justin Myers + +Implementation Notes +-------------------- + +**Software and Dependencies:** + +* Adafruit CircuitPython firmware for the supported boards: + https://circuitpython.org/downloads + +""" + +# imports + +__version__ = "0.0.0+auto.0" +__repo__ = "https://github.com/adafruit/Adafruit_CircuitPython_ConnectionManager.git" + +import errno +import sys + +# typing + + +if not sys.implementation.name == "circuitpython": + from typing import Optional, Tuple + + from circuitpython_typing.socket import ( + CircuitPythonSocketType, + InterfaceType, + SocketpoolModuleType, + SocketType, + SSLContextType, + ) + + +# ssl and pool helpers + + +class _FakeSSLSocket: + def __init__(self, socket: CircuitPythonSocketType, tls_mode: int) -> None: + self._socket = socket + self._mode = tls_mode + self.settimeout = socket.settimeout + self.send = socket.send + self.recv = socket.recv + self.close = socket.close + self.recv_into = socket.recv_into + + def connect(self, address: Tuple[str, int]) -> None: + """Connect wrapper to add non-standard mode parameter""" + try: + return self._socket.connect(address, self._mode) + except RuntimeError as error: + raise OSError(errno.ENOMEM) from error + + +class _FakeSSLContext: + def __init__(self, iface: InterfaceType) -> None: + self._iface = iface + + # pylint: disable=unused-argument + def wrap_socket( + self, socket: CircuitPythonSocketType, server_hostname: Optional[str] = None + ) -> _FakeSSLSocket: + """Return the same socket""" + if hasattr(self._iface, "TLS_MODE"): + return _FakeSSLSocket(socket, self._iface.TLS_MODE) + + raise AttributeError("This radio does not support TLS/HTTPS") + + +def create_fake_ssl_context( + socket_pool: SocketpoolModuleType, iface: InterfaceType +) -> _FakeSSLContext: + """Method to return a fake SSL context for when ssl isn't available to import + + For example when using a: + + * `Adafruit Ethernet FeatherWing `_ + * `Adafruit AirLift – ESP32 WiFi Co-Processor Breakout Board + `_ + * `Adafruit AirLift FeatherWing – ESP32 WiFi Co-Processor + `_ + """ + socket_pool.set_interface(iface) + return _FakeSSLContext(iface) + + +_global_socketpool = {} +_global_ssl_contexts = {} + + +def get_radio_socketpool(radio): + """Helper to get a socket pool for common boards + + Currently supported: + + * Boards with onboard WiFi (ESP32S2, ESP32S3, Pico W, etc) + * Using the ESP32 WiFi Co-Processor (like the Adafruit AirLift) + * Using a WIZ5500 (Like the Adafruit Ethernet FeatherWing) + """ + class_name = radio.__class__.__name__ + if class_name not in _global_socketpool: + if class_name == "Radio": + import ssl # pylint: disable=import-outside-toplevel + + import socketpool # pylint: disable=import-outside-toplevel + + pool = socketpool.SocketPool(radio) + ssl_context = ssl.create_default_context() + + elif class_name == "ESP_SPIcontrol": + import adafruit_esp32spi.adafruit_esp32spi_socket as pool # pylint: disable=import-outside-toplevel + + ssl_context = create_fake_ssl_context(pool, radio) + + elif class_name == "WIZNET5K": + import adafruit_wiznet5k.adafruit_wiznet5k_socket as pool # pylint: disable=import-outside-toplevel + + # Note: SSL/TLS connections are not supported by the Wiznet5k library at this time + ssl_context = create_fake_ssl_context(pool, radio) + + else: + raise AttributeError(f"Unsupported radio class: {class_name}") + + _global_socketpool[class_name] = pool + _global_ssl_contexts[class_name] = ssl_context + + return _global_socketpool[class_name] + + +def get_radio_ssl_context(radio): + """Helper to get ssl_contexts for common boards + + Currently supported: + + * Boards with onboard WiFi (ESP32S2, ESP32S3, Pico W, etc) + * Using the ESP32 WiFi Co-Processor (like the Adafruit AirLift) + * Using a WIZ5500 (Like the Adafruit Ethernet FeatherWing) + """ + class_name = radio.__class__.__name__ + get_radio_socketpool(radio) + return _global_ssl_contexts[class_name] + + +# main class + + +class ConnectionManager: + """Connection manager for sharing open sockets (aka connections).""" + + def __init__( + self, + socket_pool: SocketpoolModuleType, + ) -> None: + self._socket_pool = socket_pool + # Hang onto open sockets so that we can reuse them. + self._available_socket = {} + self._open_sockets = {} + + def _free_sockets(self) -> None: + available_sockets = [] + for socket, free in self._available_socket.items(): + if free: + available_sockets.append(socket) + + for socket in available_sockets: + self.close_socket(socket) + + def _get_key_for_socket(self, socket): + try: + return next( + key for key, value in self._open_sockets.items() if value == socket + ) + except StopIteration: + return None + + def close_socket(self, socket: SocketType) -> None: + """Close a previously opened socket.""" + if socket not in self._open_sockets.values(): + raise RuntimeError("Socket not managed") + key = self._get_key_for_socket(socket) + socket.close() + del self._available_socket[socket] + del self._open_sockets[key] + + def free_socket(self, socket: SocketType) -> None: + """Mark a previously opened socket as available so it can be reused if needed.""" + if socket not in self._open_sockets.values(): + raise RuntimeError("Socket not managed") + self._available_socket[socket] = True + + # pylint: disable=too-many-branches,too-many-locals,too-many-statements + def get_socket( + self, + host: str, + port: int, + proto: str, + session_id: Optional[str] = None, + *, + timeout: float = 1, + is_ssl: bool = False, + ssl_context: Optional[SSLContextType] = None, + ) -> CircuitPythonSocketType: + """Get a new socket and connect""" + if session_id: + session_id = str(session_id) + key = (host, port, proto, session_id) + if key in self._open_sockets: + socket = self._open_sockets[key] + if self._available_socket[socket]: + self._available_socket[socket] = False + return socket + + raise RuntimeError(f"Socket already connected to {proto}//{host}:{port}") + + if proto == "https:": + is_ssl = True + if is_ssl and not ssl_context: + raise AttributeError( + "ssl_context must be set before using adafruit_requests for https" + ) + + addr_info = self._socket_pool.getaddrinfo( + host, port, 0, self._socket_pool.SOCK_STREAM + )[0] + + try_count = 0 + socket = None + last_exc = None + while try_count < 2 and socket is None: + try_count += 1 + if try_count > 1: + if any( + socket + for socket, free in self._available_socket.items() + if free is True + ): + self._free_sockets() + else: + break + + try: + socket = self._socket_pool.socket(addr_info[0], addr_info[1]) + except OSError as exc: + last_exc = exc + continue + except RuntimeError as exc: + last_exc = exc + continue + + if is_ssl: + socket = ssl_context.wrap_socket(socket, server_hostname=host) + connect_host = host + else: + connect_host = addr_info[-1][0] + socket.settimeout(timeout) # socket read timeout + + try: + socket.connect((connect_host, port)) + except MemoryError as exc: + last_exc = exc + socket.close() + socket = None + except OSError as exc: + last_exc = exc + socket.close() + socket = None + + if socket is None: + raise RuntimeError(f"Error connecting socket: {last_exc}") from last_exc + + self._available_socket[socket] = False + self._open_sockets[key] = socket + return socket + + +# global helpers + + +_global_connection_manager = None # pylint: disable=invalid-name + + +def get_connection_manager(socket_pool: SocketpoolModuleType) -> None: + """Get the ConnectionManager singleton""" + global _global_connection_manager # pylint: disable=global-statement + if _global_connection_manager is None: + _global_connection_manager = ConnectionManager(socket_pool) + return _global_connection_manager diff --git a/adafruit_connectionmanager.py b/adafruit_connectionmanager.py deleted file mode 100644 index b7fd70e..0000000 --- a/adafruit_connectionmanager.py +++ /dev/null @@ -1,37 +0,0 @@ -# SPDX-FileCopyrightText: 2017 Scott Shawcroft, written for Adafruit Industries -# SPDX-FileCopyrightText: Copyright (c) 2023 Justin Myers for Adafruit Industries -# -# SPDX-License-Identifier: MIT -""" -`adafruit_connectionmanager` -================================================================================ - -A urllib3.poolmanager/urllib3.connectionpool-like library for managing sockets and connections - - -* Author(s): Justin Myers - -Implementation Notes --------------------- - -**Hardware:** - -.. todo:: Add links to any specific hardware product page(s), or category page(s). - Use unordered list & hyperlink rST inline format: "* `Link Text `_" - -**Software and Dependencies:** - -* Adafruit CircuitPython firmware for the supported boards: - https://circuitpython.org/downloads - -.. todo:: Uncomment or remove the Bus Device and/or the Register library dependencies - based on the library's use of either. - -# * Adafruit's Bus Device library: https://github.com/adafruit/Adafruit_CircuitPython_BusDevice -# * Adafruit's Register library: https://github.com/adafruit/Adafruit_CircuitPython_Register -""" - -# imports - -__version__ = "0.0.0+auto.0" -__repo__ = "https://github.com/adafruit/Adafruit_CircuitPython_ConnectionManager.git" diff --git a/docs/api.rst b/docs/api.rst index 9916599..1112255 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -4,5 +4,5 @@ .. If your library file(s) are nested in a directory (e.g. /adafruit_foo/foo.py) .. use this format as the module name: "adafruit_foo.foo" -.. automodule:: adafruit_connectionmanager +.. automodule:: adafruit_connection_manager :members: diff --git a/docs/api.rst.license b/docs/api.rst.license index ddc59df..95c6363 100644 --- a/docs/api.rst.license +++ b/docs/api.rst.license @@ -1,4 +1,4 @@ SPDX-FileCopyrightText: 2017 Scott Shawcroft, written for Adafruit Industries -SPDX-FileCopyrightText: Copyright (c) 2023 Justin Myers for Adafruit Industries +SPDX-FileCopyrightText: 2024 Justin Myers for Adafruit Industries SPDX-License-Identifier: MIT diff --git a/docs/conf.py b/docs/conf.py index 56d656c..b184b10 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -4,9 +4,9 @@ # # SPDX-License-Identifier: MIT +import datetime import os import sys -import datetime sys.path.insert(0, os.path.abspath("..")) @@ -50,7 +50,7 @@ # General information about the project. project = "Adafruit CircuitPython ConnectionManager Library" -creation_year = "2023" +creation_year = "2024" current_year = str(datetime.datetime.now().year) year_duration = ( current_year diff --git a/docs/examples.rst b/docs/examples.rst index 8e31fae..0e1e53b 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -1,8 +1,25 @@ -Simple test ------------- +Examples +======== -Ensure your device works with this simple test. +Below are a few examples, there may be more in the examples folder of the library -.. literalinclude:: ../examples/connectionmanager_simpletest.py - :caption: examples/connectionmanager_simpletest.py +Helper example +-------------- + +This example shows you how to use the ``adafruit_connection_manager`` helpers to help +simplify code when writing it for multiple different boards + +.. literalinclude:: ../examples/connectionmanager_helpers.py + :caption: examples/connectionmanager_helpers.py + :linenos: + +SSL Test +-------- + +This test runs across the common hosts found in the +`Adafruit Learning System Guides `_ +as well as tests created by `badssl.com `_ + +.. literalinclude:: ../examples/connectionmanager_ssltest.py + :caption: examples/connectionmanager_ssltest.py :linenos: diff --git a/docs/examples.rst.license b/docs/examples.rst.license index ddc59df..95c6363 100644 --- a/docs/examples.rst.license +++ b/docs/examples.rst.license @@ -1,4 +1,4 @@ SPDX-FileCopyrightText: 2017 Scott Shawcroft, written for Adafruit Industries -SPDX-FileCopyrightText: Copyright (c) 2023 Justin Myers for Adafruit Industries +SPDX-FileCopyrightText: 2024 Justin Myers for Adafruit Industries SPDX-License-Identifier: MIT diff --git a/docs/index.rst b/docs/index.rst index 78525b6..235fd37 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -24,15 +24,9 @@ Table of Contents .. toctree:: :caption: Tutorials -.. todo:: Add any Learn guide links here. If there are none, then simply delete this todo and leave - the toctree above for use later. - .. toctree:: :caption: Related Products -.. todo:: Add any product links here. If there are none, then simply delete this todo and leave - the toctree above for use later. - .. toctree:: :caption: Other Links diff --git a/docs/index.rst.license b/docs/index.rst.license index ddc59df..95c6363 100644 --- a/docs/index.rst.license +++ b/docs/index.rst.license @@ -1,4 +1,4 @@ SPDX-FileCopyrightText: 2017 Scott Shawcroft, written for Adafruit Industries -SPDX-FileCopyrightText: Copyright (c) 2023 Justin Myers for Adafruit Industries +SPDX-FileCopyrightText: 2024 Justin Myers for Adafruit Industries SPDX-License-Identifier: MIT diff --git a/examples/connectionmanager_helpers.py b/examples/connectionmanager_helpers.py new file mode 100644 index 0000000..36f4af6 --- /dev/null +++ b/examples/connectionmanager_helpers.py @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: 2024 Justin Myers for Adafruit Industries +# +# SPDX-License-Identifier: Unlicense + +import os + +import adafruit_requests +import wifi + +import adafruit_connection_manager + +TEXT_URL = "http://wifitest.adafruit.com/testwifi/index.html" + +wifi_ssid = os.getenv("CIRCUITPY_WIFI_SSID") +wifi_password = os.getenv("CIRCUITPY_WIFI_PASSWORD") + +radio = wifi.radio +while not radio.connected: + radio.connect(wifi_ssid, wifi_password) + +# get the pool and ssl_context from the helpers: +pool = adafruit_connection_manager.get_radio_socketpool(radio) +ssl_context = adafruit_connection_manager.get_radio_ssl_context(radio) + +# get request session +requests = adafruit_requests.Session(pool, ssl_context) + +# make request +print("-" * 40) +print(f"Fetching from {TEXT_URL}") + +response = requests.get(TEXT_URL) +response_text = response.text +response.close() + +print(f"Text Response {response_text}") +print("-" * 40) diff --git a/examples/connectionmanager_simpletest.py b/examples/connectionmanager_simpletest.py deleted file mode 100644 index 8110262..0000000 --- a/examples/connectionmanager_simpletest.py +++ /dev/null @@ -1,4 +0,0 @@ -# SPDX-FileCopyrightText: 2017 Scott Shawcroft, written for Adafruit Industries -# SPDX-FileCopyrightText: Copyright (c) 2023 Justin Myers for Adafruit Industries -# -# SPDX-License-Identifier: Unlicense diff --git a/examples/connectionmanager_ssltest.py b/examples/connectionmanager_ssltest.py new file mode 100644 index 0000000..6fa707d --- /dev/null +++ b/examples/connectionmanager_ssltest.py @@ -0,0 +1,314 @@ +# SPDX-FileCopyrightText: 2024 Justin Myers for Adafruit Industries +# +# SPDX-License-Identifier: Unlicense + +import os +import time + +import adafruit_connection_manager + +try: + import wifi + + radio = wifi.radio + onboard_wifi = True +except ImportError: + import board + import busio + from adafruit_esp32spi import adafruit_esp32spi + from digitalio import DigitalInOut + + # esp32spi pins set based on Adafruit AirLift FeatherWing + # if using a different setup, please change appropriately + spi = busio.SPI(board.SCK, board.MOSI, board.MISO) + esp32_cs = DigitalInOut(board.D13) + esp32_ready = DigitalInOut(board.D11) + esp32_reset = DigitalInOut(board.D12) + radio = adafruit_esp32spi.ESP_SPIcontrol(spi, esp32_cs, esp32_ready, esp32_reset) + onboard_wifi = False + + +# built from: +# https://github.com/adafruit/Adafruit_Learning_System_Guides +ADAFRUIT_GROUPS = [ + { + "heading": "API hosts", + "description": "These are common API hosts users hit.", + "success": "yes", + "fail": "no", + "subdomains": [ + {"host": "api.coindesk.com"}, + {"host": "api.covidtracking.com"}, + {"host": "api.developer.lifx.com"}, + {"host": "api.fitbit.com"}, + {"host": "api.github.com"}, + {"host": "api.hackaday.io"}, + {"host": "api.hackster.io"}, + {"host": "api.met.no"}, + {"host": "api.nasa.gov"}, + {"host": "api.nytimes.com"}, + {"host": "api.open-meteo.com"}, + {"host": "api.openai.com"}, + {"host": "api.openweathermap.org"}, + {"host": "api.purpleair.com"}, + {"host": "api.spacexdata.com"}, + {"host": "api.thecatapi.com"}, + {"host": "api.thingiverse.com"}, + {"host": "api.thingspeak.com"}, + {"host": "api.tidesandcurrents.noaa.gov"}, + {"host": "api.twitter.com"}, + {"host": "api.wordnik.com"}, + ], + }, + { + "heading": "Common hosts", + "description": "These are other common hosts users hit.", + "success": "yes", + "fail": "no", + "subdomains": [ + {"host": "admiraltyapi.azure-api.net"}, + {"host": "aeroapi.flightaware.com"}, + {"host": "airnowapi.org"}, + {"host": "certification.oshwa.org"}, + {"host": "certificationapi.oshwa.org"}, + {"host": "chat.openai.com"}, + {"host": "covidtracking.com"}, + {"host": "discord.com"}, + {"host": "enviro.epa.gov"}, + {"host": "flightaware.com"}, + {"host": "hosted.weblate.org"}, + {"host": "id.twitch.tv"}, + {"host": "io.adafruit.com"}, + {"host": "jwst.nasa.gov"}, + {"host": "management.azure.com"}, + {"host": "na1.api.riotgames.com"}, + {"host": "oauth2.googleapis.com"}, + {"host": "opensky-network.org"}, + {"host": "opentdb.com"}, + {"host": "raw.githubusercontent.com"}, + {"host": "site.api.espn.com"}, + {"host": "spreadsheets.google.com"}, + {"host": "twitrss.me"}, + {"host": "www.adafruit.com"}, + {"host": "www.alphavantage.co"}, + {"host": "www.googleapis.com"}, + {"host": "www.nhc.noaa.gov"}, + {"host": "www.reddit.com"}, + {"host": "youtube.googleapis.com"}, + ], + }, + { + "heading": "Known problem hosts", + "description": "These are hosts we have run into problems in the past.", + "success": "yes", + "fail": "no", + "subdomains": [ + {"host": "valid-isrgrootx2.letsencrypt.org"}, + {"host": "openaccess-api.clevelandart.org"}, + ], + }, +] + +# pulled from: +# https://github.com/chromium/badssl.com/blob/master/domains/misc/badssl.com/dashboard/sets.js +BADSSL_GROUPS = [ + { + "heading": "Certificate Validation (High Risk)", + "description": ( + "If your browser connects to one of these sites, it could be very easy for an attacker " + "to see and modify everything on web sites that you visit." + ), + "success": "no", + "fail": "yes", + "subdomains": [ + {"subdomain": "expired"}, + {"subdomain": "wrong.host"}, + {"subdomain": "self-signed"}, + {"subdomain": "untrusted-root"}, + ], + }, + { + "heading": "Interception Certificates (High Risk)", + "description": ( + "If your browser connects to one of these sites, it could be very easy for an attacker " + "to see and modify everything on web sites that you visit. This may be due to " + "interception software installed on your device." + ), + "success": "no", + "fail": "yes", + "subdomains": [ + {"subdomain": "superfish"}, + {"subdomain": "edellroot"}, + {"subdomain": "dsdtestprovider"}, + {"subdomain": "preact-cli"}, + {"subdomain": "webpack-dev-server"}, + ], + }, + { + "heading": "Broken Cryptography (Medium Risk)", + "description": ( + "If your browser connects to one of these sites, an attacker with enough resources may " + "be able to see and/or modify everything on web sites that you visit. This is because " + "your browser supports connections settings that are outdated and known to have " + "significant security flaws." + ), + "success": "no", + "fail": "yes", + "subdomains": [ + {"subdomain": "rc4"}, + {"subdomain": "rc4-md5"}, + {"subdomain": "dh480"}, + {"subdomain": "dh512"}, + {"subdomain": "dh1024"}, + {"subdomain": "null"}, + ], + }, + { + "heading": "Legacy Cryptography (Moderate Risk)", + "description": ( + "If your browser connects to one of these sites, your web traffic is probably safe " + "from attackers in the near future. However, your connections to some sites might " + "not be using the strongest possible security. Your browser may use these settings in " + "order to connect to some older sites." + ), + "success": "maybe", + "fail": "yes", + "subdomains": [ + {"subdomain": "tls-v1-0", "port": 1010}, + {"subdomain": "tls-v1-1", "port": 1011}, + {"subdomain": "cbc"}, + {"subdomain": "3des"}, + {"subdomain": "dh2048"}, + ], + }, + { + "heading": "Domain Security Policies", + "description": ( + "These are special tests for some specific browsers. These tests may be able to tell " + "whether your browser uses advanced domain security policy mechanisms (HSTS, HPKP, SCT" + ") to detect illegitimate certificates." + ), + "success": "maybe", + "fail": "yes", + "subdomains": [ + {"subdomain": "revoked"}, + {"subdomain": "pinning-test"}, + {"subdomain": "no-sct"}, + ], + }, + { + "heading": "Secure (Uncommon)", + "description": ( + "These settings are secure. However, they are less common and even if your browser " + "doesn't support them you probably won't have issues with most sites." + ), + "success": "yes", + "fail": "maybe", + "subdomains": [ + {"subdomain": "1000-sans"}, + {"subdomain": "10000-sans"}, + {"subdomain": "sha384"}, + {"subdomain": "sha512"}, + {"subdomain": "rsa8192"}, + {"subdomain": "no-subject"}, + {"subdomain": "no-common-name"}, + {"subdomain": "incomplete-chain"}, + ], + }, + { + "heading": "Secure (Common)", + "description": ( + "These settings are secure and commonly used by sites. Your browser will need to " + "support most of these in order to connect to sites securely." + ), + "success": "yes", + "fail": "no", + "subdomains": [ + {"subdomain": "tls-v1-2", "port": 1012}, + {"subdomain": "sha256"}, + {"subdomain": "rsa2048"}, + {"subdomain": "ecc256"}, + {"subdomain": "ecc384"}, + {"subdomain": "extended-validation"}, + {"subdomain": "mozilla-modern"}, + ], + }, +] + +COMMON_FAILURE_CODES = [ + "Expected 01 but got 00", # AirLift + "Failed SSL handshake", # Espressif + "MBEDTLS_ERR_SSL_BAD_HS_SERVER_KEY_EXCHANG", # mbedtls + "MBEDTLS_ERR_SSL_FATAL_ALERT_MESSAGE", # mbedtls + "MBEDTLS_ERR_X509_CERT_VERIFY_FAILED", # mbedtls + "MBEDTLS_ERR_X509_FATAL_ERROR", # mbedtls +] + + +pool = adafruit_connection_manager.get_radio_socketpool(radio) +ssl_context = adafruit_connection_manager.get_radio_ssl_context(radio) +connection_manager = adafruit_connection_manager.get_connection_manager(pool) + +wifi_ssid = os.getenv("CIRCUITPY_WIFI_SSID") +wifi_password = os.getenv("CIRCUITPY_WIFI_PASSWORD") + +if onboard_wifi: + while not radio.connected: + radio.connect(wifi_ssid, wifi_password) +else: + while not radio.is_connected: + try: + radio.connect_AP(wifi_ssid, wifi_password) + except OSError as os_exc: + print(f"could not connect to AP, retrying: {os_exc}") + continue + + +def common_failure(exc): + text_value = str(exc) + for common_failures_code in COMMON_FAILURE_CODES: + if common_failures_code in text_value: + return True + return False + + +def check_group(groups, group_name): + print(f"\nRunning {group_name}") + for group in groups: + print(f'\n - {group["heading"]}') + success = group["success"] + fail = group["fail"] + for subdomain in group["subdomains"]: + if "host" in subdomain: + host = subdomain["host"] + else: + host = f'{subdomain["subdomain"]}.badssl.com' + port = subdomain.get("port", 443) + exc = None + start_time = time.monotonic() + try: + socket = connection_manager.get_socket( + host, + port, + "https:", + is_ssl=True, + ssl_context=ssl_context, + timeout=10, + ) + connection_manager.close_socket(socket) + except RuntimeError as e: + exc = e + duration = time.monotonic() - start_time + + if fail == "yes" and exc and common_failure(exc): + result = "passed" + elif success == "yes" and exc is None: + result = "passed" + else: + result = f"error - success:{success}, fail:{fail}, exc:{exc}" + + print(f" - {host}:{port} took {duration:.2f} seconds | {result}") + + +check_group(ADAFRUIT_GROUPS, "Adafruit") +check_group(BADSSL_GROUPS, "BadSSL") diff --git a/pyproject.toml b/pyproject.toml index 9a52e4c..cb04b0c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ # SPDX-FileCopyrightText: 2022 Alec Delaney, written for Adafruit Industries -# SPDX-FileCopyrightText: Copyright (c) 2023 Justin Myers for Adafruit Industries +# SPDX-FileCopyrightText: 2024 Justin Myers for Adafruit Industries # # SPDX-License-Identifier: MIT @@ -42,7 +42,7 @@ dynamic = ["dependencies", "optional-dependencies"] [tool.setuptools] # TODO: IF LIBRARY FILES ARE A PACKAGE FOLDER, # CHANGE `py_modules = ['...']` TO `packages = ['...']` -py-modules = ["adafruit_connectionmanager"] +py-modules = ["adafruit_connection_manager"] [tool.setuptools.dynamic] dependencies = {file = ["requirements.txt"]} diff --git a/requirements.txt b/requirements.txt index c3539cf..e64002a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ # SPDX-FileCopyrightText: 2017 Scott Shawcroft, written for Adafruit Industries -# SPDX-FileCopyrightText: Copyright (c) 2023 Justin Myers for Adafruit Industries +# SPDX-FileCopyrightText: 2024 Justin Myers for Adafruit Industries # # SPDX-License-Identifier: MIT diff --git a/tests/close_socket_test.py b/tests/close_socket_test.py new file mode 100644 index 0000000..957cb94 --- /dev/null +++ b/tests/close_socket_test.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: 2024 Justin Myers for Adafruit Industries +# +# SPDX-License-Identifier: Unlicense + +""" Close Socket Tests """ + +import mocket +import pytest + +import adafruit_connection_manager + + +def test_close_socket(): + mock_pool = mocket.MocketPool() + mock_socket_1 = mocket.Mocket() + mock_pool.socket.return_value = mock_socket_1 + + connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) + + # validate socket is tracked + socket = connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") + key = (mocket.MOCK_HOST_1, 80, "http:", None) + assert socket == mock_socket_1 + assert socket in connection_manager._available_socket + assert key in connection_manager._open_sockets + + # validate socket is no longer tracked + connection_manager.close_socket(socket) + assert socket not in connection_manager._available_socket + assert key not in connection_manager._open_sockets + + +def test_close_socket_not_managed(): + mock_pool = mocket.MocketPool() + mock_socket_1 = mocket.Mocket() + + connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) + + # validate not managed socket errors + with pytest.raises(RuntimeError) as context: + connection_manager.close_socket(mock_socket_1) + assert "Socket not managed" in str(context) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..2d9bb0a --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: 2024 Justin Myers for Adafruit Industries +# +# SPDX-License-Identifier: Unlicense + +""" Setup Tests """ + +import sys + +import mocket + + +# pylint: disable=unused-argument +def set_interface(iface): + """Helper to set the global internet interface""" + + +socketpool_module = type(sys)("socketpool") +socketpool_module.SocketPool = mocket.MocketPool +sys.modules["socketpool"] = socketpool_module + +esp32spi_module = type(sys)("adafruit_esp32spi") +esp32spi_socket_module = type(sys)("adafruit_esp32spi_socket") +esp32spi_socket_module.set_interface = set_interface +sys.modules["adafruit_esp32spi"] = esp32spi_module +sys.modules["adafruit_esp32spi.adafruit_esp32spi_socket"] = esp32spi_socket_module + +wiznet5k_module = type(sys)("adafruit_wiznet5k") +wiznet5k_socket_module = type(sys)("adafruit_wiznet5k_socket") +wiznet5k_socket_module.set_interface = set_interface +sys.modules["adafruit_wiznet5k"] = wiznet5k_module +sys.modules["adafruit_wiznet5k.adafruit_wiznet5k_socket"] = wiznet5k_socket_module diff --git a/tests/fake_ssl_context_test.py b/tests/fake_ssl_context_test.py new file mode 100644 index 0000000..fc566ea --- /dev/null +++ b/tests/fake_ssl_context_test.py @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: 2024 Justin Myers for Adafruit Industries +# +# SPDX-License-Identifier: Unlicense + +""" FakeSLLSocket Tests """ + +import mocket +import pytest + +import adafruit_connection_manager + + +def test_connect_https(): + mock_pool = mocket.MocketPool() + mock_socket_1 = mocket.Mocket() + mock_pool.socket.return_value = mock_socket_1 + + radio = mocket.MockRadio.ESP_SPIcontrol() + ssl_context = adafruit_connection_manager.get_radio_ssl_context(radio) + connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) + + # verify a HTTPS call for a board without built in WiFi gets a _FakeSSLSocket + socket = connection_manager.get_socket( + mocket.MOCK_HOST_1, 443, "https:", ssl_context=ssl_context + ) + assert socket != mock_socket_1 + assert socket._socket == mock_socket_1 + assert isinstance(socket, adafruit_connection_manager._FakeSSLSocket) + + +def test_connect_https_not_supported(): + mock_pool = mocket.MocketPool() + mock_socket_1 = mocket.Mocket() + mock_pool.socket.return_value = mock_socket_1 + + radio = mocket.MockRadio.WIZNET5K() + ssl_context = adafruit_connection_manager.get_radio_ssl_context(radio) + connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) + + # verify a HTTPS call for a board without built in WiFi and SSL support errors + with pytest.raises(AttributeError) as context: + connection_manager.get_socket( + mocket.MOCK_HOST_1, 443, "https:", ssl_context=ssl_context + ) + assert "This radio does not support TLS/HTTPS" in str(context) diff --git a/tests/free_socket_test.py b/tests/free_socket_test.py new file mode 100644 index 0000000..93f34eb --- /dev/null +++ b/tests/free_socket_test.py @@ -0,0 +1,99 @@ +# SPDX-FileCopyrightText: 2024 Justin Myers for Adafruit Industries +# +# SPDX-License-Identifier: Unlicense + +""" Free Socket Tests """ + +import mocket +import pytest + +import adafruit_connection_manager + + +def test_free_socket(): + mock_pool = mocket.MocketPool() + mock_socket_1 = mocket.Mocket() + mock_pool.socket.return_value = mock_socket_1 + + connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) + + # validate socket is tracked and not available + socket = connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") + key = (mocket.MOCK_HOST_1, 80, "http:", None) + assert socket == mock_socket_1 + assert socket in connection_manager._available_socket + assert connection_manager._available_socket[socket] is False + assert key in connection_manager._open_sockets + + # validate socket is tracked and is available + connection_manager.free_socket(socket) + assert socket in connection_manager._available_socket + assert connection_manager._available_socket[socket] is True + assert key in connection_manager._open_sockets + + +def test_free_socket_not_managed(): + mock_pool = mocket.MocketPool() + mock_socket_1 = mocket.Mocket() + + connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) + + # validate not managed socket errors + with pytest.raises(RuntimeError) as context: + connection_manager.free_socket(mock_socket_1) + assert "Socket not managed" in str(context) + + +def test_free_sockets(): + mock_pool = mocket.MocketPool() + mock_socket_1 = mocket.Mocket() + mock_socket_2 = mocket.Mocket() + mock_pool.socket.side_effect = [ + mock_socket_1, + mock_socket_2, + ] + + connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) + + # validate socket is tracked and not available + socket_1 = connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") + assert socket_1 == mock_socket_1 + assert socket_1 in connection_manager._available_socket + assert connection_manager._available_socket[socket_1] is False + + socket_2 = connection_manager.get_socket(mocket.MOCK_HOST_2, 80, "http:") + assert socket_2 == mock_socket_2 + + # validate socket is tracked and is available + connection_manager.free_socket(socket_1) + assert socket_1 in connection_manager._available_socket + assert connection_manager._available_socket[socket_1] is True + + # validate socket is no longer tracked + connection_manager._free_sockets() + assert socket_1 not in connection_manager._available_socket + assert socket_2 in connection_manager._available_socket + mock_socket_1.close.assert_called_once() + + +def test_get_key_for_socket(): + mock_pool = mocket.MocketPool() + mock_socket_1 = mocket.Mocket() + mock_pool.socket.return_value = mock_socket_1 + + connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) + + # validate tracked socket has correct key + socket = connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") + key = (mocket.MOCK_HOST_1, 80, "http:", None) + assert connection_manager._get_key_for_socket(socket) == key + + +def test_get_key_for_socket_not_managed(): + mock_pool = mocket.MocketPool() + mock_socket_1 = mocket.Mocket() + + connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) + + # validate untracked socket has no key + assert connection_manager._get_key_for_socket(mock_socket_1) is None diff --git a/tests/get_connection_manager_test.py b/tests/get_connection_manager_test.py new file mode 100644 index 0000000..0efdbfd --- /dev/null +++ b/tests/get_connection_manager_test.py @@ -0,0 +1,18 @@ +# SPDX-FileCopyrightText: 2024 Justin Myers for Adafruit Industries +# +# SPDX-License-Identifier: Unlicense + +""" Get Connection Manager Tests """ + +import mocket + +import adafruit_connection_manager + + +def test_get_connection_manager(): + mock_pool = mocket.MocketPool() + + connection_manager_1 = adafruit_connection_manager.get_connection_manager(mock_pool) + connection_manager_2 = adafruit_connection_manager.get_connection_manager(mock_pool) + + assert connection_manager_1 == connection_manager_2 diff --git a/tests/get_radio_test.py b/tests/get_radio_test.py new file mode 100644 index 0000000..ea80f7e --- /dev/null +++ b/tests/get_radio_test.py @@ -0,0 +1,76 @@ +# SPDX-FileCopyrightText: 2024 Justin Myers for Adafruit Industries +# +# SPDX-License-Identifier: Unlicense + +""" Get socketpool and ssl_context Tests """ + +import ssl + +import mocket +import pytest + +import adafruit_connection_manager + + +def test_get_radio_socketpool_wifi(): + radio = mocket.MockRadio.Radio() + socket_pool = adafruit_connection_manager.get_radio_socketpool(radio) + assert isinstance(socket_pool, mocket.MocketPool) + + +def test_get_radio_socketpool_esp32spi(): + radio = mocket.MockRadio.ESP_SPIcontrol() + socket_pool = adafruit_connection_manager.get_radio_socketpool(radio) + assert socket_pool.__name__ == "adafruit_esp32spi_socket" + + +def test_get_radio_socketpool_wiznet5k(): + radio = mocket.MockRadio.WIZNET5K() + socket_pool = adafruit_connection_manager.get_radio_socketpool(radio) + assert socket_pool.__name__ == "adafruit_wiznet5k_socket" + + +def test_get_radio_socketpool_unsupported(): + radio = mocket.MockRadio.Unsupported() + with pytest.raises(AttributeError) as context: + adafruit_connection_manager.get_radio_socketpool(radio) + assert "Unsupported radio class" in str(context) + + +def test_get_radio_socketpool_returns_same_one(): + radio = mocket.MockRadio.Radio() + socket_pool_1 = adafruit_connection_manager.get_radio_socketpool(radio) + socket_pool_2 = adafruit_connection_manager.get_radio_socketpool(radio) + assert socket_pool_1 == socket_pool_2 + + +def test_get_radio_ssl_context_wifi(): + radio = mocket.MockRadio.Radio() + ssl_contexts = adafruit_connection_manager.get_radio_ssl_context(radio) + assert isinstance(ssl_contexts, ssl.SSLContext) + + +def test_get_radio_ssl_context_esp32spi(): + radio = mocket.MockRadio.ESP_SPIcontrol() + ssl_contexts = adafruit_connection_manager.get_radio_ssl_context(radio) + assert isinstance(ssl_contexts, adafruit_connection_manager._FakeSSLContext) + + +def test_get_radio_ssl_context_wiznet5k(): + radio = mocket.MockRadio.WIZNET5K() + ssl_contexts = adafruit_connection_manager.get_radio_ssl_context(radio) + assert isinstance(ssl_contexts, adafruit_connection_manager._FakeSSLContext) + + +def test_get_radio_ssl_context_unsupported(): + radio = mocket.MockRadio.Unsupported() + with pytest.raises(AttributeError) as context: + adafruit_connection_manager.get_radio_ssl_context(radio) + assert "Unsupported radio class" in str(context) + + +def test_get_radio_ssl_context_returns_same_one(): + radio = mocket.MockRadio.Radio() + ssl_contexts_1 = adafruit_connection_manager.get_radio_ssl_context(radio) + ssl_contexts_2 = adafruit_connection_manager.get_radio_ssl_context(radio) + assert ssl_contexts_1 == ssl_contexts_2 diff --git a/tests/get_socket_test.py b/tests/get_socket_test.py new file mode 100644 index 0000000..cee223d --- /dev/null +++ b/tests/get_socket_test.py @@ -0,0 +1,251 @@ +# SPDX-FileCopyrightText: 2024 Justin Myers for Adafruit Industries +# +# SPDX-License-Identifier: Unlicense + +""" Get Socket Tests """ + +from unittest import mock + +import mocket +import pytest + +import adafruit_connection_manager + + +def test_get_socket(): + mock_pool = mocket.MocketPool() + mock_socket_1 = mocket.Mocket() + mock_socket_2 = mocket.Mocket() + mock_pool.socket.side_effect = [ + mock_socket_1, + mock_socket_2, + ] + + connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) + + # get socket + socket = connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") + assert socket == mock_socket_1 + + +def test_get_socket_different_session(): + mock_pool = mocket.MocketPool() + mock_socket_1 = mocket.Mocket() + mock_socket_2 = mocket.Mocket() + mock_pool.socket.side_effect = [ + mock_socket_1, + mock_socket_2, + ] + + connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) + + # get socket + socket = connection_manager.get_socket( + mocket.MOCK_HOST_1, 80, "http:", session_id="1" + ) + assert socket == mock_socket_1 + + # get socket on different session + socket = connection_manager.get_socket( + mocket.MOCK_HOST_1, 80, "http:", session_id="2" + ) + assert socket == mock_socket_2 + + +def test_get_socket_flagged_free(): + mock_pool = mocket.MocketPool() + mock_socket_1 = mocket.Mocket() + mock_socket_2 = mocket.Mocket() + mock_pool.socket.side_effect = [ + mock_socket_1, + mock_socket_2, + ] + + connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) + + # get a socket and then mark as free + socket = connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") + assert socket == mock_socket_1 + connection_manager.free_socket(socket) + + # get a socket for the same host, should be the same one + socket = connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") + assert socket == mock_socket_1 + + +def test_get_socket_not_flagged_free(): + mock_pool = mocket.MocketPool() + mock_socket_1 = mocket.Mocket() + mock_socket_2 = mocket.Mocket() + mock_pool.socket.side_effect = [ + mock_socket_1, + mock_socket_2, + ] + + connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) + + # get a socket but don't mark as free + socket = connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") + assert socket == mock_socket_1 + + # get a socket for the same host, should be a different one + with pytest.raises(RuntimeError) as context: + socket = connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") + assert "Socket already connected" in str(context) + + +def test_get_socket_os_error(): + mock_pool = mocket.MocketPool() + mock_socket_1 = mocket.Mocket() + mock_pool.socket.side_effect = [ + OSError("OSError"), + mock_socket_1, + ] + + connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) + + # try to get a socket that returns a OSError + with pytest.raises(RuntimeError) as context: + connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") + assert "Error connecting socket: OSError" in str(context) + + +def test_get_socket_runtime_error(): + mock_pool = mocket.MocketPool() + mock_socket_1 = mocket.Mocket() + mock_pool.socket.side_effect = [ + RuntimeError("RuntimeError"), + mock_socket_1, + ] + + connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) + + # try to get a socket that returns a RuntimeError + with pytest.raises(RuntimeError) as context: + connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") + assert "Error connecting socket: RuntimeError" in str(context) + + +def test_get_socket_connect_memory_error(): + mock_pool = mocket.MocketPool() + mock_socket_1 = mocket.Mocket() + mock_socket_2 = mocket.Mocket() + mock_pool.socket.side_effect = [ + mock_socket_1, + mock_socket_2, + ] + mock_socket_1.connect.side_effect = MemoryError("MemoryError") + + connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) + + # try to connect a socket that returns a MemoryError + with pytest.raises(RuntimeError) as context: + connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") + assert "Error connecting socket: MemoryError" in str(context) + + +def test_get_socket_connect_os_error(): + mock_pool = mocket.MocketPool() + mock_socket_1 = mocket.Mocket() + mock_socket_2 = mocket.Mocket() + mock_pool.socket.side_effect = [ + mock_socket_1, + mock_socket_2, + ] + mock_socket_1.connect.side_effect = OSError("OSError") + + connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) + + # try to connect a socket that returns a OSError + with pytest.raises(RuntimeError) as context: + connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") + assert "Error connecting socket: OSError" in str(context) + + +def test_get_socket_runtime_error_ties_again_at_least_one_free(): + mock_pool = mocket.MocketPool() + mock_socket_1 = mocket.Mocket() + mock_socket_2 = mocket.Mocket() + mock_pool.socket.side_effect = [ + mock_socket_1, + RuntimeError(), + mock_socket_2, + ] + + free_sockets_mock = mock.Mock() + connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) + connection_manager._free_sockets = free_sockets_mock + + # get a socket and then mark as free + socket = connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") + assert socket == mock_socket_1 + connection_manager.free_socket(socket) + free_sockets_mock.assert_not_called() + + # try to get a socket that returns a RuntimeError and at least one is flagged as free + socket = connection_manager.get_socket(mocket.MOCK_HOST_2, 80, "http:") + assert socket == mock_socket_2 + free_sockets_mock.assert_called_once() + + +def test_get_socket_runtime_error_ties_again_only_once(): + mock_pool = mocket.MocketPool() + mock_socket_1 = mocket.Mocket() + mock_socket_2 = mocket.Mocket() + mock_pool.socket.side_effect = [ + mock_socket_1, + RuntimeError("error 1"), + RuntimeError("error 2"), + RuntimeError("error 3"), + mock_socket_2, + ] + + free_sockets_mock = mock.Mock() + connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) + connection_manager._free_sockets = free_sockets_mock + + # get a socket and then mark as free + socket = connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") + assert socket == mock_socket_1 + connection_manager.free_socket(socket) + free_sockets_mock.assert_not_called() + + # try to get a socket that returns a RuntimeError twice + with pytest.raises(RuntimeError) as context: + connection_manager.get_socket(mocket.MOCK_HOST_2, 80, "http:") + assert "Error connecting socket: error 2" in str(context) + free_sockets_mock.assert_called_once() + + +def test_fake_ssl_context_connect(): + mock_pool = mocket.MocketPool() + mock_socket_1 = mocket.Mocket() + mock_pool.socket.return_value = mock_socket_1 + + radio = mocket.MockRadio.ESP_SPIcontrol() + ssl_context = adafruit_connection_manager.get_radio_ssl_context(radio) + connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) + + # verify a HTTPS call gets a _FakeSSLSocket + socket = connection_manager.get_socket( + mocket.MOCK_HOST_1, 443, "https:", ssl_context=ssl_context + ) + assert socket != mock_socket_1 + socket._socket.connect.assert_called_once() + + +def test_fake_ssl_context_connect_error(): + mock_pool = mocket.MocketPool() + mock_socket_1 = mocket.Mocket() + mock_pool.socket.return_value = mock_socket_1 + mock_socket_1.connect.side_effect = RuntimeError("RuntimeError ") + + radio = mocket.MockRadio.ESP_SPIcontrol() + ssl_context = adafruit_connection_manager.get_radio_ssl_context(radio) + connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) + + with pytest.raises(RuntimeError) as context: + connection_manager.get_socket( + mocket.MOCK_HOST_1, 443, "https:", ssl_context=ssl_context + ) + assert "Error connecting socket: 12" in str(context) diff --git a/tests/mocket.py b/tests/mocket.py new file mode 100644 index 0000000..6740a1a --- /dev/null +++ b/tests/mocket.py @@ -0,0 +1,95 @@ +# SPDX-FileCopyrightText: 2021 ladyada for Adafruit Industries +# SPDX-FileCopyrightText: 2024 Justin Myers for Adafruit Industries +# +# SPDX-License-Identifier: Unlicense + +""" Mock Socket """ + +from unittest import mock + +MOCK_POOL_IP = "10.10.10.10" +MOCK_HOST_1 = "wifitest.adafruit.com" +MOCK_HOST_2 = "wifitest2.adafruit.com" + + +class MocketPool: # pylint: disable=too-few-public-methods + """Mock SocketPool""" + + SOCK_STREAM = 0 + + # pylint: disable=unused-argument + def __init__(self, radio=None): + self.getaddrinfo = mock.Mock() + self.getaddrinfo.return_value = ((None, None, None, None, (MOCK_POOL_IP, 80)),) + self.socket = mock.Mock() + + +class Mocket: # pylint: disable=too-few-public-methods + """Mock Socket""" + + def __init__(self, response=None): + self.settimeout = mock.Mock() + self.close = mock.Mock() + self.connect = mock.Mock() + self.send = mock.Mock(side_effect=self._send) + self.readline = mock.Mock(side_effect=self._readline) + self.recv = mock.Mock(side_effect=self._recv) + self.recv_into = mock.Mock(side_effect=self._recv_into) + self._response = response + self._position = 0 + self.fail_next_send = False + + def _send(self, data): + if self.fail_next_send: + self.fail_next_send = False + return 0 + return len(data) + + def _readline(self): + i = self._response.find(b"\r\n", self._position) + response = self._response[self._position : i + 2] + self._position = i + 2 + return response + + def _recv(self, count): + end = self._position + count + response = self._response[self._position : end] + self._position = end + return response + + def _recv_into(self, buf, nbytes=0): + assert isinstance(nbytes, int) and nbytes >= 0 + read = nbytes if nbytes > 0 else len(buf) + remaining = len(self._response) - self._position + read = min(read, remaining) + end = self._position + read + buf[:read] = self._response[self._position : end] + self._position = end + return read + + +class SSLContext: # pylint: disable=too-few-public-methods + """Mock SSL Context""" + + def __init__(self): + self.wrap_socket = mock.Mock(side_effect=self._wrap_socket) + + def _wrap_socket( + self, sock, server_hostname=None + ): # pylint: disable=no-self-use,unused-argument + return sock + + +# pylint: disable=too-few-public-methods +class MockRadio: + class Radio: + pass + + class ESP_SPIcontrol: + TLS_MODE = 2 + + class WIZNET5K: + pass + + class Unsupported: + pass diff --git a/tests/protocol_test.py b/tests/protocol_test.py new file mode 100644 index 0000000..98b5296 --- /dev/null +++ b/tests/protocol_test.py @@ -0,0 +1,51 @@ +# SPDX-FileCopyrightText: 2021 ladyada for Adafruit Industries +# +# SPDX-License-Identifier: Unlicense + +""" Protocol Tests """ + +import mocket +import pytest + +import adafruit_connection_manager + + +def test_get_https_no_ssl(): + mock_pool = mocket.MocketPool() + mock_socket_1 = mocket.Mocket() + mock_pool.socket.return_value = mock_socket_1 + + connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) + + # verify not sending in a SSL context for a HTTPS call errors + with pytest.raises(AttributeError) as context: + connection_manager.get_socket(mocket.MOCK_HOST_1, 443, "https:") + assert "ssl_context must be set" in str(context) + + +def test_connect_https(): + mock_pool = mocket.MocketPool() + mock_socket_1 = mocket.Mocket() + mock_pool.socket.return_value = mock_socket_1 + + mock_ssl_context = mocket.SSLContext() + connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) + + # verify a HTTPS call changes the port to 443 + connection_manager.get_socket( + mocket.MOCK_HOST_1, 443, "https:", ssl_context=mock_ssl_context + ) + mock_socket_1.connect.assert_called_once_with((mocket.MOCK_HOST_1, 443)) + mock_ssl_context.wrap_socket.assert_called_once() + + +def test_connect_http(): + mock_pool = mocket.MocketPool() + mock_socket_1 = mocket.Mocket() + mock_pool.socket.return_value = mock_socket_1 + + connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) + + # verify a HTTP call does not change the port to 443 + connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") + mock_socket_1.connect.assert_called_once_with((mocket.MOCK_POOL_IP, 80)) diff --git a/tox.ini b/tox.ini new file mode 100644 index 0000000..74ae4fe --- /dev/null +++ b/tox.ini @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: 2024 Justin Myers for Adafruit Industries +# +# SPDX-License-Identifier: MIT + +[tox] +envlist = py311 + +[testenv] +description = run tests +deps = + pytest==7.4.3 +commands = pytest + +[testenv:coverage] +description = run coverage +deps = + pytest==7.4.3 + pytest-cov==4.1.0 +package = editable +commands = + coverage run --source=. --omit=tests/* --branch {posargs} -m pytest + coverage report + coverage html + +[testenv:lint] +description = run linters +deps = + pre-commit==3.6.0 +skip_install = true +commands = pre-commit run {posargs} + +[testenv:docs] +description = build docs +deps = + -r requirements.txt + -r docs/requirements.txt +skip_install = true +commands = sphinx-build -E -W -b html docs/. _build/html