pyhOn/pyhon/connection/mqtt.py
Kuba Sawulski f9602a2b1c Migrate to aiomqtt
Refresh AWS tokens on connect after expiration
2024-08-13 18:41:12 +02:00

137 lines
4.5 KiB
Python

import asyncio
from functools import cached_property, partial
import json
import logging
import pprint
import secrets
import ssl
from typing import Any, TYPE_CHECKING, cast
from urllib.parse import urlencode
from aiomqtt import Client as aiomqttClient, MqttError, ProtocolVersion, Topic
from pyhon import const
if TYPE_CHECKING:
from aiomqtt import Message
from pyhon import Hon, HonAPI
from pyhon.appliance import HonAppliance
_LOGGER = logging.getLogger(__name__)
class _Payload(dict[Any, Any]):
def __str__(self) -> str:
return pprint.pformat(self)
class Client(aiomqttClient):
def set_username(self, username: str) -> None:
self._client.username_pw_set(username=username)
class MQTTClient:
def __init__(self, hon: "Hon", mobile_id: str = const.MOBILE_ID) -> None:
self._task: asyncio.Task[None] | None = None
self._hon = hon
self._mobile_id = mobile_id
self._connected_event = asyncio.Event()
@property
def _appliances(self) -> list["HonAppliance"]:
return self._hon.appliances
@property
def _api(self) -> "HonAPI":
return self._hon.api
async def create(self) -> "MQTTClient":
self._task = asyncio.create_task(self.loop())
await self._connected_event.wait()
return self
@cached_property
def _subscription_handlers(self) -> dict[Topic, partial[None]]:
handlers = {}
for appliance in self._appliances:
handler_protos = {
"appliancestatus": partial(self._status_handler, appliance),
"disconnected": partial(self._connection_handler, appliance, False),
"connected": partial(self._connection_handler, appliance, True),
}
for topic in appliance.info.get("topics", {}).get("subscribe", []):
topic_parts = topic.split("/")
for topic_part, callback in handler_protos.items():
if topic_part in topic_parts:
handlers[Topic(topic)] = callback
return handlers
async def _get_mqtt_username(self) -> str:
query_params = {
"x-amz-customauthorizer-name": const.AWS_AUTHORIZER,
"x-amz-customauthorizer-signature": await self._api.load_aws_token(),
"token": self._api.auth.id_token,
}
return "?" + urlencode(query_params)
@staticmethod
def _status_handler(appliance: "HonAppliance", message: "Message") -> None:
payload = _Payload(json.loads(cast(str | bytes | bytearray, message.payload)))
for parameter in payload["parameters"]:
appliance.attributes["parameters"][parameter["parName"]].update(parameter)
appliance.sync_params_to_command("settings")
_LOGGER.debug("On topic '%s' received: \n %s", message.topic, payload)
@staticmethod
def _connection_handler(
appliance: "HonAppliance", connection_status: bool, __message: "Message"
) -> None:
appliance.connection = connection_status
async def loop(self) -> None:
delay_min, delay_max = 5, 120
tls_context = ssl.create_default_context()
tls_context.set_alpn_protocols([const.ALPN_PROTOCOL])
client = Client(
hostname=const.AWS_ENDPOINT,
port=const.AWS_PORT,
identifier=f"{self._mobile_id}_{secrets.token_hex(8)}",
protocol=ProtocolVersion.V5,
logger=logging.getLogger(f"{__name__}.paho"),
tls_context=tls_context,
)
reconnect_interval = delay_min
while True:
client.set_username(await self._get_mqtt_username())
try:
async with client:
self._connected_event.set()
reconnect_interval = delay_min
for topic in self._subscription_handlers:
await client.subscribe(str(topic))
async for message in client.messages:
handler = self._subscription_handlers[message.topic]
handler(message)
self._hon.notify()
except MqttError:
self._connected_event.clear()
_LOGGER.warning(
"Connection to MQTT broker lost. Reconnecting in %s seconds",
reconnect_interval,
)
await asyncio.sleep(reconnect_interval)
reconnect_interval = min(reconnect_interval * 2, delay_max)