diff --git a/pyhon/connection/mqtt.py b/pyhon/connection/mqtt.py index d6a04da..e1809ce 100644 --- a/pyhon/connection/mqtt.py +++ b/pyhon/connection/mqtt.py @@ -1,29 +1,40 @@ -import functools +import asyncio +from functools import cached_property, partial import json import logging +import pprint import secrets import ssl -from typing import TYPE_CHECKING +from typing import Any, TYPE_CHECKING, cast from urllib.parse import urlencode -from paho.mqtt.client import Client, MQTTv5 - +from aiomqtt import Client as aiomqttClient, MqttError, ProtocolVersion, Topic from pyhon import const if TYPE_CHECKING: - from paho.mqtt.client import MQTTMessage, _UserData, ReasonCodes, Properties - + from aiomqtt import Message from pyhon import Hon, HonAPI from pyhon.appliance import HonAppliance _LOGGER = logging.getLogger(__name__) +class _Payload(dict[Any, Any]): + def __str__(self) -> str: + return pprint.pformat(self) + + +class Client(aiomqttClient): + def set_username(self, username: str) -> None: + self._client.username_pw_set(username=username) + + class MQTTClient: - def __init__(self, hon: "Hon", mobile_id: str) -> None: - self._client: Client | None = None + def __init__(self, hon: "Hon", mobile_id: str = const.MOBILE_ID) -> None: + self._task: asyncio.Task[None] | None = None self._hon = hon - self._mobile_id = mobile_id or const.MOBILE_ID + self._mobile_id = mobile_id + self._connected_event = asyncio.Event() @property def _appliances(self) -> list["HonAppliance"]: @@ -33,105 +44,93 @@ class MQTTClient: def _api(self) -> "HonAPI": return self._hon.api - @property - def client(self) -> Client: - if self._client is not None: - return self._client - raise AttributeError("Client is not set") - async def create(self) -> "MQTTClient": - await self._start() + self._task = asyncio.create_task(self.loop()) + await self._connected_event.wait() return self - async def _start(self) -> None: - self._client = Client( - client_id=f"{self._mobile_id}_{secrets.token_hex(8)}", - protocol=MQTTv5, - reconnect_on_failure=True, - ) + @cached_property + def _subscription_handlers(self) -> dict[Topic, partial[None]]: - ssl_context = ssl.create_default_context() - ssl_context.set_alpn_protocols([const.ALPN_PROTOCOL]) + handlers = {} - self.client.tls_set_context(ssl_context) - self.client.enable_logger(_LOGGER) - self.client.on_connect = self._subscribe_appliances - - query_params = urlencode( - { - "x-amz-customauthorizer-name": const.AWS_AUTHORIZER, - "x-amz-customauthorizer-signature": await self._api.load_aws_token(), - "token": self._api.auth.id_token, - } - ) - - self.client.username_pw_set("?" + query_params) - - self.client.connect_async(const.AWS_ENDPOINT, 443) - self.client.loop_start() - - def _subscribe_appliances( - self, - client: Client, - userdata: "_UserData", - flags: dict[str, int], - rc: "ReasonCodes", - properties: "Properties|None", - ) -> None: - del client, userdata, flags, rc, properties for appliance in self._appliances: - self._subscribe(appliance) - def _appliance_status_callback( - self, - appliance: "HonAppliance", - client: Client, - userdata: "_UserData", - message: "MQTTMessage", - ) -> None: - del client, userdata - payload = json.loads(message.payload) + handler_protos = { + "appliancestatus": partial(self._status_handler, appliance), + "disconnected": partial(self._connection_handler, appliance, False), + "connected": partial(self._connection_handler, appliance, True), + } + + for topic in appliance.info.get("topics", {}).get("subscribe", []): + topic_parts = topic.split("/") + for topic_part, callback in handler_protos.items(): + if topic_part in topic_parts: + handlers[Topic(topic)] = callback + + return handlers + + async def _get_mqtt_username(self) -> str: + query_params = { + "x-amz-customauthorizer-name": const.AWS_AUTHORIZER, + "x-amz-customauthorizer-signature": await self._api.load_aws_token(), + "token": self._api.auth.id_token, + } + return "?" + urlencode(query_params) + + @staticmethod + def _status_handler(appliance: "HonAppliance", message: "Message") -> None: + payload = _Payload(json.loads(cast(str | bytes | bytearray, message.payload))) for parameter in payload["parameters"]: appliance.attributes["parameters"][parameter["parName"]].update(parameter) appliance.sync_params_to_command("settings") - self._hon.notify() + _LOGGER.debug("On topic '%s' received: \n %s", message.topic, payload) - def _appliance_disconnected_callback( - self, - appliance: "HonAppliance", - client: Client, - userdata: "_UserData", - message: "MQTTMessage", + @staticmethod + def _connection_handler( + appliance: "HonAppliance", connection_status: bool, __message: "Message" ) -> None: - del client, userdata, message - appliance.connection = False + appliance.connection = connection_status - self._hon.notify() + async def loop(self) -> None: + delay_min, delay_max = 5, 120 - def _appliance_connected_callback( - self, - appliance: "HonAppliance", - client: Client, - userdata: "_UserData", - message: "MQTTMessage", - ) -> None: - del client, userdata, message - appliance.connection = True + tls_context = ssl.create_default_context() + tls_context.set_alpn_protocols([const.ALPN_PROTOCOL]) - self._hon.notify() + client = Client( + hostname=const.AWS_ENDPOINT, + port=const.AWS_PORT, + identifier=f"{self._mobile_id}_{secrets.token_hex(8)}", + protocol=ProtocolVersion.V5, + logger=logging.getLogger(f"{__name__}.paho"), + tls_context=tls_context, + ) - def _subscribe(self, appliance: "HonAppliance") -> None: - topic_part_to_callback_mapping = { - "appliancestatus": self._appliance_status_callback, - "disconnected": self._appliance_disconnected_callback, - "connected": self._appliance_connected_callback, - } - for topic in appliance.info.get("topics", {}).get("subscribe", []): - for topic_part, callback in topic_part_to_callback_mapping.items(): - if topic_part in topic: - self.client.message_callback_add( - topic, functools.partial(callback, appliance) - ) - self.client.subscribe(topic) - _LOGGER.info("Subscribed to topic %s", topic) + reconnect_interval = delay_min + + while True: + client.set_username(await self._get_mqtt_username()) + try: + async with client: + self._connected_event.set() + reconnect_interval = delay_min + + for topic in self._subscription_handlers: + await client.subscribe(str(topic)) + + async for message in client.messages: + handler = self._subscription_handlers[message.topic] + + handler(message) + + self._hon.notify() + except MqttError: + self._connected_event.clear() + _LOGGER.warning( + "Connection to MQTT broker lost. Reconnecting in %s seconds", + reconnect_interval, + ) + await asyncio.sleep(reconnect_interval) + reconnect_interval = min(reconnect_interval * 2, delay_max) diff --git a/pyhon/const.py b/pyhon/const.py index 8cc8c7f..727dd16 100644 --- a/pyhon/const.py +++ b/pyhon/const.py @@ -3,6 +3,7 @@ AUTH_API = "https://account2.hon-smarthome.com" API_URL = "https://api-iot.he.services" API_KEY = "GRCqFhC6Gk@ikWXm1RmnSmX1cm,MxY-configuration" AWS_ENDPOINT = "a30f6tqw0oh1x0-ats.iot.eu-west-1.amazonaws.com" +AWS_PORT = 443 AWS_AUTHORIZER = "candy-iot-authorizer" APP = "hon" CLIENT_ID = ( diff --git a/requirements.txt b/requirements.txt index b40476a..f909246 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ aiohttp>=3.8.6 yarl>=1.8 typing-extensions>=4.8 -paho-mqtt==1.6.1 \ No newline at end of file +aiomqtt==2.0.1 \ No newline at end of file diff --git a/requirements_dev.txt b/requirements_dev.txt index 991f88d..09b0696 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -2,5 +2,4 @@ black>=22.12 flake8>=6.0 mypy>=0.991 pylint>=2.15 -setuptools>=62.3 -types-paho-mqtt \ No newline at end of file +setuptools>=62.3 \ No newline at end of file diff --git a/setup.py b/setup.py index 83610d9..8b4b3d6 100644 --- a/setup.py +++ b/setup.py @@ -5,6 +5,9 @@ from setuptools import setup, find_packages with open("README.md", "r", encoding="utf-8") as f: long_description = f.read() +with open("requirements.txt", "r", encoding="utf-8") as f: + install_requires = f.read().splitlines() + setup( name="pyhOn", version="0.17.5", @@ -21,12 +24,7 @@ setup( packages=find_packages(), include_package_data=True, python_requires=">=3.10", - install_requires=[ - "aiohttp>=3.8.6", - "typing-extensions>=4.8", - "yarl>=1.8", - "paho-mqtt==1.6.1", - ], + install_requires=install_requires, classifiers=[ "Development Status :: 4 - Beta", "Environment :: Console",