Migrate to aiomqtt

Refresh AWS tokens on connect after expiration
This commit is contained in:
Kuba Sawulski 2024-08-13 18:41:12 +02:00
parent d15976d9f7
commit f9602a2b1c
5 changed files with 100 additions and 103 deletions

View file

@ -1,29 +1,40 @@
import functools import asyncio
from functools import cached_property, partial
import json import json
import logging import logging
import pprint
import secrets import secrets
import ssl import ssl
from typing import TYPE_CHECKING from typing import Any, TYPE_CHECKING, cast
from urllib.parse import urlencode from urllib.parse import urlencode
from paho.mqtt.client import Client, MQTTv5 from aiomqtt import Client as aiomqttClient, MqttError, ProtocolVersion, Topic
from pyhon import const from pyhon import const
if TYPE_CHECKING: if TYPE_CHECKING:
from paho.mqtt.client import MQTTMessage, _UserData, ReasonCodes, Properties from aiomqtt import Message
from pyhon import Hon, HonAPI from pyhon import Hon, HonAPI
from pyhon.appliance import HonAppliance from pyhon.appliance import HonAppliance
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
class _Payload(dict[Any, Any]):
def __str__(self) -> str:
return pprint.pformat(self)
class Client(aiomqttClient):
def set_username(self, username: str) -> None:
self._client.username_pw_set(username=username)
class MQTTClient: class MQTTClient:
def __init__(self, hon: "Hon", mobile_id: str) -> None: def __init__(self, hon: "Hon", mobile_id: str = const.MOBILE_ID) -> None:
self._client: Client | None = None self._task: asyncio.Task[None] | None = None
self._hon = hon self._hon = hon
self._mobile_id = mobile_id or const.MOBILE_ID self._mobile_id = mobile_id
self._connected_event = asyncio.Event()
@property @property
def _appliances(self) -> list["HonAppliance"]: def _appliances(self) -> list["HonAppliance"]:
@ -33,105 +44,93 @@ class MQTTClient:
def _api(self) -> "HonAPI": def _api(self) -> "HonAPI":
return self._hon.api return self._hon.api
@property
def client(self) -> Client:
if self._client is not None:
return self._client
raise AttributeError("Client is not set")
async def create(self) -> "MQTTClient": async def create(self) -> "MQTTClient":
await self._start() self._task = asyncio.create_task(self.loop())
await self._connected_event.wait()
return self return self
async def _start(self) -> None: @cached_property
self._client = Client( def _subscription_handlers(self) -> dict[Topic, partial[None]]:
client_id=f"{self._mobile_id}_{secrets.token_hex(8)}",
protocol=MQTTv5,
reconnect_on_failure=True,
)
ssl_context = ssl.create_default_context() handlers = {}
ssl_context.set_alpn_protocols([const.ALPN_PROTOCOL])
self.client.tls_set_context(ssl_context)
self.client.enable_logger(_LOGGER)
self.client.on_connect = self._subscribe_appliances
query_params = urlencode(
{
"x-amz-customauthorizer-name": const.AWS_AUTHORIZER,
"x-amz-customauthorizer-signature": await self._api.load_aws_token(),
"token": self._api.auth.id_token,
}
)
self.client.username_pw_set("?" + query_params)
self.client.connect_async(const.AWS_ENDPOINT, 443)
self.client.loop_start()
def _subscribe_appliances(
self,
client: Client,
userdata: "_UserData",
flags: dict[str, int],
rc: "ReasonCodes",
properties: "Properties|None",
) -> None:
del client, userdata, flags, rc, properties
for appliance in self._appliances: for appliance in self._appliances:
self._subscribe(appliance)
def _appliance_status_callback( handler_protos = {
self, "appliancestatus": partial(self._status_handler, appliance),
appliance: "HonAppliance", "disconnected": partial(self._connection_handler, appliance, False),
client: Client, "connected": partial(self._connection_handler, appliance, True),
userdata: "_UserData", }
message: "MQTTMessage",
) -> None: for topic in appliance.info.get("topics", {}).get("subscribe", []):
del client, userdata topic_parts = topic.split("/")
payload = json.loads(message.payload) for topic_part, callback in handler_protos.items():
if topic_part in topic_parts:
handlers[Topic(topic)] = callback
return handlers
async def _get_mqtt_username(self) -> str:
query_params = {
"x-amz-customauthorizer-name": const.AWS_AUTHORIZER,
"x-amz-customauthorizer-signature": await self._api.load_aws_token(),
"token": self._api.auth.id_token,
}
return "?" + urlencode(query_params)
@staticmethod
def _status_handler(appliance: "HonAppliance", message: "Message") -> None:
payload = _Payload(json.loads(cast(str | bytes | bytearray, message.payload)))
for parameter in payload["parameters"]: for parameter in payload["parameters"]:
appliance.attributes["parameters"][parameter["parName"]].update(parameter) appliance.attributes["parameters"][parameter["parName"]].update(parameter)
appliance.sync_params_to_command("settings") appliance.sync_params_to_command("settings")
self._hon.notify() _LOGGER.debug("On topic '%s' received: \n %s", message.topic, payload)
def _appliance_disconnected_callback( @staticmethod
self, def _connection_handler(
appliance: "HonAppliance", appliance: "HonAppliance", connection_status: bool, __message: "Message"
client: Client,
userdata: "_UserData",
message: "MQTTMessage",
) -> None: ) -> None:
del client, userdata, message appliance.connection = connection_status
appliance.connection = False
self._hon.notify() async def loop(self) -> None:
delay_min, delay_max = 5, 120
def _appliance_connected_callback( tls_context = ssl.create_default_context()
self, tls_context.set_alpn_protocols([const.ALPN_PROTOCOL])
appliance: "HonAppliance",
client: Client,
userdata: "_UserData",
message: "MQTTMessage",
) -> None:
del client, userdata, message
appliance.connection = True
self._hon.notify() client = Client(
hostname=const.AWS_ENDPOINT,
port=const.AWS_PORT,
identifier=f"{self._mobile_id}_{secrets.token_hex(8)}",
protocol=ProtocolVersion.V5,
logger=logging.getLogger(f"{__name__}.paho"),
tls_context=tls_context,
)
def _subscribe(self, appliance: "HonAppliance") -> None: reconnect_interval = delay_min
topic_part_to_callback_mapping = {
"appliancestatus": self._appliance_status_callback, while True:
"disconnected": self._appliance_disconnected_callback, client.set_username(await self._get_mqtt_username())
"connected": self._appliance_connected_callback, try:
} async with client:
for topic in appliance.info.get("topics", {}).get("subscribe", []): self._connected_event.set()
for topic_part, callback in topic_part_to_callback_mapping.items(): reconnect_interval = delay_min
if topic_part in topic:
self.client.message_callback_add( for topic in self._subscription_handlers:
topic, functools.partial(callback, appliance) await client.subscribe(str(topic))
)
self.client.subscribe(topic) async for message in client.messages:
_LOGGER.info("Subscribed to topic %s", topic) handler = self._subscription_handlers[message.topic]
handler(message)
self._hon.notify()
except MqttError:
self._connected_event.clear()
_LOGGER.warning(
"Connection to MQTT broker lost. Reconnecting in %s seconds",
reconnect_interval,
)
await asyncio.sleep(reconnect_interval)
reconnect_interval = min(reconnect_interval * 2, delay_max)

View file

@ -3,6 +3,7 @@ AUTH_API = "https://account2.hon-smarthome.com"
API_URL = "https://api-iot.he.services" API_URL = "https://api-iot.he.services"
API_KEY = "GRCqFhC6Gk@ikWXm1RmnSmX1cm,MxY-configuration" API_KEY = "GRCqFhC6Gk@ikWXm1RmnSmX1cm,MxY-configuration"
AWS_ENDPOINT = "a30f6tqw0oh1x0-ats.iot.eu-west-1.amazonaws.com" AWS_ENDPOINT = "a30f6tqw0oh1x0-ats.iot.eu-west-1.amazonaws.com"
AWS_PORT = 443
AWS_AUTHORIZER = "candy-iot-authorizer" AWS_AUTHORIZER = "candy-iot-authorizer"
APP = "hon" APP = "hon"
CLIENT_ID = ( CLIENT_ID = (

View file

@ -1,4 +1,4 @@
aiohttp>=3.8.6 aiohttp>=3.8.6
yarl>=1.8 yarl>=1.8
typing-extensions>=4.8 typing-extensions>=4.8
paho-mqtt==1.6.1 aiomqtt==2.0.1

View file

@ -2,5 +2,4 @@ black>=22.12
flake8>=6.0 flake8>=6.0
mypy>=0.991 mypy>=0.991
pylint>=2.15 pylint>=2.15
setuptools>=62.3 setuptools>=62.3
types-paho-mqtt

View file

@ -5,6 +5,9 @@ from setuptools import setup, find_packages
with open("README.md", "r", encoding="utf-8") as f: with open("README.md", "r", encoding="utf-8") as f:
long_description = f.read() long_description = f.read()
with open("requirements.txt", "r", encoding="utf-8") as f:
install_requires = f.read().splitlines()
setup( setup(
name="pyhOn", name="pyhOn",
version="0.17.5", version="0.17.5",
@ -21,12 +24,7 @@ setup(
packages=find_packages(), packages=find_packages(),
include_package_data=True, include_package_data=True,
python_requires=">=3.10", python_requires=">=3.10",
install_requires=[ install_requires=install_requires,
"aiohttp>=3.8.6",
"typing-extensions>=4.8",
"yarl>=1.8",
"paho-mqtt==1.6.1",
],
classifiers=[ classifiers=[
"Development Status :: 4 - Beta", "Development Status :: 4 - Beta",
"Environment :: Console", "Environment :: Console",