Add jwt token parsing to auth handler to get expiration date

This should increase the realiability of re-using tokens in case the expiration date changes on Haier's side
This commit is contained in:
Niek Schoemaker 2024-12-10 20:36:25 +01:00
parent cb216eab59
commit 392f366c25
No known key found for this signature in database
GPG key ID: BDF9404CFECB0006
2 changed files with 32 additions and 9 deletions

View file

@ -3,9 +3,10 @@ import logging
import re import re
import secrets import secrets
import urllib import urllib
import base64
from contextlib import suppress from contextlib import suppress
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timedelta from datetime import datetime, timedelta, timezone
from typing import Dict, Optional, Any, List from typing import Dict, Optional, Any, List
from urllib import parse from urllib import parse
from urllib.parse import quote from urllib.parse import quote
@ -37,10 +38,26 @@ class HonAuthData:
cognito_token: str = "" cognito_token: str = ""
id_token: str = "" id_token: str = ""
def decode_jwt(token: str) -> dict[str, str]:
if token == "":
return {}
def base64url_decode(input_str: str) -> bytes:
# Add padding if necessary
input_str += '=' * (4 - len(input_str) % 4)
return base64.urlsafe_b64decode(input_str)
# Split the token into parts
_, payload_b64, _ = token.split('.')
if not payload_b64:
raise Exception("Invalid JWT token!")
return json.loads(base64url_decode(payload_b64))
class HonAuth: class HonAuth:
_TOKEN_EXPIRES_AFTER_HOURS = 8 _TOKEN_EXPIRE_WARNING_HOURS = 1
_TOKEN_EXPIRE_WARNING_HOURS = 7
def __init__( def __init__(
self, self,
@ -55,9 +72,15 @@ class HonAuth:
self._login_data.email = email self._login_data.email = email
self._login_data.password = password self._login_data.password = password
self._device = device self._device = device
self._expires: datetime = datetime.utcnow() self._expires: datetime = datetime.now(timezone.utc)
self._auth = HonAuthData() self._auth = HonAuthData()
@property
def expires(self) -> datetime:
if self.id_token == "":
return datetime.fromtimestamp(0, timezone.utc)
return datetime.fromtimestamp(float(decode_jwt(self.id_token).get("exp", "0")), timezone.utc)
@property @property
def cognito_token(self) -> str: def cognito_token(self) -> str:
return self._auth.cognito_token return self._auth.cognito_token
@ -74,12 +97,12 @@ class HonAuth:
def refresh_token(self) -> str: def refresh_token(self) -> str:
return self._auth.refresh_token return self._auth.refresh_token
def _check_token_expiration(self, hours: int) -> bool: def _check_token_expiration(self, hours: int = 0) -> bool:
return datetime.utcnow() >= self._expires + timedelta(hours=hours) return datetime.now(timezone.utc) >= self.expires - timedelta(hours=hours)
@property @property
def token_is_expired(self) -> bool: def token_is_expired(self) -> bool:
return self._check_token_expiration(self._TOKEN_EXPIRES_AFTER_HOURS) return self._check_token_expiration()
@property @property
def token_expires_soon(self) -> bool: def token_expires_soon(self) -> bool:
@ -119,7 +142,7 @@ class HonAuth:
url = f"{const.AUTH_API}/services/oauth2/authorize/expid_Login?{params_encode}" url = f"{const.AUTH_API}/services/oauth2/authorize/expid_Login?{params_encode}"
async with self._request.get(url) as response: async with self._request.get(url) as response:
text = await response.text() text = await response.text()
self._expires = datetime.utcnow() self._expires = datetime.now(timezone.utc)
login_url: List[str] = re.findall("(?:url|href) ?= ?'(.+?)'", text) login_url: List[str] = re.findall("(?:url|href) ?= ?'(.+?)'", text)
if not login_url: if not login_url:
if "oauth/done#access_token=" in text: if "oauth/done#access_token=" in text:

View file

@ -58,7 +58,7 @@ class HonConnectionHandler(ConnectionHandler):
return self return self
async def _check_headers(self, headers: Dict[str, str]) -> Dict[str, str]: async def _check_headers(self, headers: Dict[str, str]) -> Dict[str, str]:
if self._refresh_token: if self._refresh_token and self.auth.token_expires_soon:
await self.auth.refresh(self._refresh_token) await self.auth.refresh(self._refresh_token)
if not (self.auth.cognito_token and self.auth.id_token): if not (self.auth.cognito_token and self.auth.id_token):
await self.auth.authenticate() await self.auth.authenticate()