mirror of
https://github.com/Andre0512/pyhOn.git
synced 2025-05-04 20:42:09 +00:00
Migrate to aiomqtt
Refresh AWS tokens on connect after expiration
This commit is contained in:
parent
d15976d9f7
commit
f9602a2b1c
|
@ -1,29 +1,40 @@
|
||||||
import functools
|
import asyncio
|
||||||
|
from functools import cached_property, partial
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import pprint
|
||||||
import secrets
|
import secrets
|
||||||
import ssl
|
import ssl
|
||||||
from typing import TYPE_CHECKING
|
from typing import Any, TYPE_CHECKING, cast
|
||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
from paho.mqtt.client import Client, MQTTv5
|
from aiomqtt import Client as aiomqttClient, MqttError, ProtocolVersion, Topic
|
||||||
|
|
||||||
from pyhon import const
|
from pyhon import const
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from paho.mqtt.client import MQTTMessage, _UserData, ReasonCodes, Properties
|
from aiomqtt import Message
|
||||||
|
|
||||||
from pyhon import Hon, HonAPI
|
from pyhon import Hon, HonAPI
|
||||||
from pyhon.appliance import HonAppliance
|
from pyhon.appliance import HonAppliance
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class _Payload(dict[Any, Any]):
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return pprint.pformat(self)
|
||||||
|
|
||||||
|
|
||||||
|
class Client(aiomqttClient):
|
||||||
|
def set_username(self, username: str) -> None:
|
||||||
|
self._client.username_pw_set(username=username)
|
||||||
|
|
||||||
|
|
||||||
class MQTTClient:
|
class MQTTClient:
|
||||||
def __init__(self, hon: "Hon", mobile_id: str) -> None:
|
def __init__(self, hon: "Hon", mobile_id: str = const.MOBILE_ID) -> None:
|
||||||
self._client: Client | None = None
|
self._task: asyncio.Task[None] | None = None
|
||||||
self._hon = hon
|
self._hon = hon
|
||||||
self._mobile_id = mobile_id or const.MOBILE_ID
|
self._mobile_id = mobile_id
|
||||||
|
self._connected_event = asyncio.Event()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _appliances(self) -> list["HonAppliance"]:
|
def _appliances(self) -> list["HonAppliance"]:
|
||||||
|
@ -33,105 +44,93 @@ class MQTTClient:
|
||||||
def _api(self) -> "HonAPI":
|
def _api(self) -> "HonAPI":
|
||||||
return self._hon.api
|
return self._hon.api
|
||||||
|
|
||||||
@property
|
|
||||||
def client(self) -> Client:
|
|
||||||
if self._client is not None:
|
|
||||||
return self._client
|
|
||||||
raise AttributeError("Client is not set")
|
|
||||||
|
|
||||||
async def create(self) -> "MQTTClient":
|
async def create(self) -> "MQTTClient":
|
||||||
await self._start()
|
self._task = asyncio.create_task(self.loop())
|
||||||
|
await self._connected_event.wait()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def _start(self) -> None:
|
@cached_property
|
||||||
self._client = Client(
|
def _subscription_handlers(self) -> dict[Topic, partial[None]]:
|
||||||
client_id=f"{self._mobile_id}_{secrets.token_hex(8)}",
|
|
||||||
protocol=MQTTv5,
|
|
||||||
reconnect_on_failure=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
ssl_context = ssl.create_default_context()
|
handlers = {}
|
||||||
ssl_context.set_alpn_protocols([const.ALPN_PROTOCOL])
|
|
||||||
|
|
||||||
self.client.tls_set_context(ssl_context)
|
|
||||||
self.client.enable_logger(_LOGGER)
|
|
||||||
self.client.on_connect = self._subscribe_appliances
|
|
||||||
|
|
||||||
query_params = urlencode(
|
|
||||||
{
|
|
||||||
"x-amz-customauthorizer-name": const.AWS_AUTHORIZER,
|
|
||||||
"x-amz-customauthorizer-signature": await self._api.load_aws_token(),
|
|
||||||
"token": self._api.auth.id_token,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
self.client.username_pw_set("?" + query_params)
|
|
||||||
|
|
||||||
self.client.connect_async(const.AWS_ENDPOINT, 443)
|
|
||||||
self.client.loop_start()
|
|
||||||
|
|
||||||
def _subscribe_appliances(
|
|
||||||
self,
|
|
||||||
client: Client,
|
|
||||||
userdata: "_UserData",
|
|
||||||
flags: dict[str, int],
|
|
||||||
rc: "ReasonCodes",
|
|
||||||
properties: "Properties|None",
|
|
||||||
) -> None:
|
|
||||||
del client, userdata, flags, rc, properties
|
|
||||||
for appliance in self._appliances:
|
for appliance in self._appliances:
|
||||||
self._subscribe(appliance)
|
|
||||||
|
|
||||||
def _appliance_status_callback(
|
handler_protos = {
|
||||||
self,
|
"appliancestatus": partial(self._status_handler, appliance),
|
||||||
appliance: "HonAppliance",
|
"disconnected": partial(self._connection_handler, appliance, False),
|
||||||
client: Client,
|
"connected": partial(self._connection_handler, appliance, True),
|
||||||
userdata: "_UserData",
|
}
|
||||||
message: "MQTTMessage",
|
|
||||||
) -> None:
|
for topic in appliance.info.get("topics", {}).get("subscribe", []):
|
||||||
del client, userdata
|
topic_parts = topic.split("/")
|
||||||
payload = json.loads(message.payload)
|
for topic_part, callback in handler_protos.items():
|
||||||
|
if topic_part in topic_parts:
|
||||||
|
handlers[Topic(topic)] = callback
|
||||||
|
|
||||||
|
return handlers
|
||||||
|
|
||||||
|
async def _get_mqtt_username(self) -> str:
|
||||||
|
query_params = {
|
||||||
|
"x-amz-customauthorizer-name": const.AWS_AUTHORIZER,
|
||||||
|
"x-amz-customauthorizer-signature": await self._api.load_aws_token(),
|
||||||
|
"token": self._api.auth.id_token,
|
||||||
|
}
|
||||||
|
return "?" + urlencode(query_params)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _status_handler(appliance: "HonAppliance", message: "Message") -> None:
|
||||||
|
payload = _Payload(json.loads(cast(str | bytes | bytearray, message.payload)))
|
||||||
for parameter in payload["parameters"]:
|
for parameter in payload["parameters"]:
|
||||||
appliance.attributes["parameters"][parameter["parName"]].update(parameter)
|
appliance.attributes["parameters"][parameter["parName"]].update(parameter)
|
||||||
appliance.sync_params_to_command("settings")
|
appliance.sync_params_to_command("settings")
|
||||||
|
|
||||||
self._hon.notify()
|
_LOGGER.debug("On topic '%s' received: \n %s", message.topic, payload)
|
||||||
|
|
||||||
def _appliance_disconnected_callback(
|
@staticmethod
|
||||||
self,
|
def _connection_handler(
|
||||||
appliance: "HonAppliance",
|
appliance: "HonAppliance", connection_status: bool, __message: "Message"
|
||||||
client: Client,
|
|
||||||
userdata: "_UserData",
|
|
||||||
message: "MQTTMessage",
|
|
||||||
) -> None:
|
) -> None:
|
||||||
del client, userdata, message
|
appliance.connection = connection_status
|
||||||
appliance.connection = False
|
|
||||||
|
|
||||||
self._hon.notify()
|
async def loop(self) -> None:
|
||||||
|
delay_min, delay_max = 5, 120
|
||||||
|
|
||||||
def _appliance_connected_callback(
|
tls_context = ssl.create_default_context()
|
||||||
self,
|
tls_context.set_alpn_protocols([const.ALPN_PROTOCOL])
|
||||||
appliance: "HonAppliance",
|
|
||||||
client: Client,
|
|
||||||
userdata: "_UserData",
|
|
||||||
message: "MQTTMessage",
|
|
||||||
) -> None:
|
|
||||||
del client, userdata, message
|
|
||||||
appliance.connection = True
|
|
||||||
|
|
||||||
self._hon.notify()
|
client = Client(
|
||||||
|
hostname=const.AWS_ENDPOINT,
|
||||||
|
port=const.AWS_PORT,
|
||||||
|
identifier=f"{self._mobile_id}_{secrets.token_hex(8)}",
|
||||||
|
protocol=ProtocolVersion.V5,
|
||||||
|
logger=logging.getLogger(f"{__name__}.paho"),
|
||||||
|
tls_context=tls_context,
|
||||||
|
)
|
||||||
|
|
||||||
def _subscribe(self, appliance: "HonAppliance") -> None:
|
reconnect_interval = delay_min
|
||||||
topic_part_to_callback_mapping = {
|
|
||||||
"appliancestatus": self._appliance_status_callback,
|
while True:
|
||||||
"disconnected": self._appliance_disconnected_callback,
|
client.set_username(await self._get_mqtt_username())
|
||||||
"connected": self._appliance_connected_callback,
|
try:
|
||||||
}
|
async with client:
|
||||||
for topic in appliance.info.get("topics", {}).get("subscribe", []):
|
self._connected_event.set()
|
||||||
for topic_part, callback in topic_part_to_callback_mapping.items():
|
reconnect_interval = delay_min
|
||||||
if topic_part in topic:
|
|
||||||
self.client.message_callback_add(
|
for topic in self._subscription_handlers:
|
||||||
topic, functools.partial(callback, appliance)
|
await client.subscribe(str(topic))
|
||||||
)
|
|
||||||
self.client.subscribe(topic)
|
async for message in client.messages:
|
||||||
_LOGGER.info("Subscribed to topic %s", topic)
|
handler = self._subscription_handlers[message.topic]
|
||||||
|
|
||||||
|
handler(message)
|
||||||
|
|
||||||
|
self._hon.notify()
|
||||||
|
except MqttError:
|
||||||
|
self._connected_event.clear()
|
||||||
|
_LOGGER.warning(
|
||||||
|
"Connection to MQTT broker lost. Reconnecting in %s seconds",
|
||||||
|
reconnect_interval,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(reconnect_interval)
|
||||||
|
reconnect_interval = min(reconnect_interval * 2, delay_max)
|
||||||
|
|
|
@ -3,6 +3,7 @@ AUTH_API = "https://account2.hon-smarthome.com"
|
||||||
API_URL = "https://api-iot.he.services"
|
API_URL = "https://api-iot.he.services"
|
||||||
API_KEY = "GRCqFhC6Gk@ikWXm1RmnSmX1cm,MxY-configuration"
|
API_KEY = "GRCqFhC6Gk@ikWXm1RmnSmX1cm,MxY-configuration"
|
||||||
AWS_ENDPOINT = "a30f6tqw0oh1x0-ats.iot.eu-west-1.amazonaws.com"
|
AWS_ENDPOINT = "a30f6tqw0oh1x0-ats.iot.eu-west-1.amazonaws.com"
|
||||||
|
AWS_PORT = 443
|
||||||
AWS_AUTHORIZER = "candy-iot-authorizer"
|
AWS_AUTHORIZER = "candy-iot-authorizer"
|
||||||
APP = "hon"
|
APP = "hon"
|
||||||
CLIENT_ID = (
|
CLIENT_ID = (
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
aiohttp>=3.8.6
|
aiohttp>=3.8.6
|
||||||
yarl>=1.8
|
yarl>=1.8
|
||||||
typing-extensions>=4.8
|
typing-extensions>=4.8
|
||||||
paho-mqtt==1.6.1
|
aiomqtt==2.0.1
|
|
@ -2,5 +2,4 @@ black>=22.12
|
||||||
flake8>=6.0
|
flake8>=6.0
|
||||||
mypy>=0.991
|
mypy>=0.991
|
||||||
pylint>=2.15
|
pylint>=2.15
|
||||||
setuptools>=62.3
|
setuptools>=62.3
|
||||||
types-paho-mqtt
|
|
10
setup.py
10
setup.py
|
@ -5,6 +5,9 @@ from setuptools import setup, find_packages
|
||||||
with open("README.md", "r", encoding="utf-8") as f:
|
with open("README.md", "r", encoding="utf-8") as f:
|
||||||
long_description = f.read()
|
long_description = f.read()
|
||||||
|
|
||||||
|
with open("requirements.txt", "r", encoding="utf-8") as f:
|
||||||
|
install_requires = f.read().splitlines()
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="pyhOn",
|
name="pyhOn",
|
||||||
version="0.17.5",
|
version="0.17.5",
|
||||||
|
@ -21,12 +24,7 @@ setup(
|
||||||
packages=find_packages(),
|
packages=find_packages(),
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
python_requires=">=3.10",
|
python_requires=">=3.10",
|
||||||
install_requires=[
|
install_requires=install_requires,
|
||||||
"aiohttp>=3.8.6",
|
|
||||||
"typing-extensions>=4.8",
|
|
||||||
"yarl>=1.8",
|
|
||||||
"paho-mqtt==1.6.1",
|
|
||||||
],
|
|
||||||
classifiers=[
|
classifiers=[
|
||||||
"Development Status :: 4 - Beta",
|
"Development Status :: 4 - Beta",
|
||||||
"Environment :: Console",
|
"Environment :: Console",
|
||||||
|
|
Loading…
Reference in a new issue