mirror of
https://github.com/Andre0512/pyhOn.git
synced 2025-03-08 06:09:41 +00:00
Add jwt token parsing to auth handler to get expiration date
This should increase the realiability of re-using tokens in case the expiration date changes on Haier's side
This commit is contained in:
parent
cb216eab59
commit
392f366c25
|
@ -3,9 +3,10 @@ import logging
|
||||||
import re
|
import re
|
||||||
import secrets
|
import secrets
|
||||||
import urllib
|
import urllib
|
||||||
|
import base64
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Dict, Optional, Any, List
|
from typing import Dict, Optional, Any, List
|
||||||
from urllib import parse
|
from urllib import parse
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
@ -37,10 +38,26 @@ class HonAuthData:
|
||||||
cognito_token: str = ""
|
cognito_token: str = ""
|
||||||
id_token: str = ""
|
id_token: str = ""
|
||||||
|
|
||||||
|
def decode_jwt(token: str) -> dict[str, str]:
|
||||||
|
if token == "":
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def base64url_decode(input_str: str) -> bytes:
|
||||||
|
# Add padding if necessary
|
||||||
|
input_str += '=' * (4 - len(input_str) % 4)
|
||||||
|
return base64.urlsafe_b64decode(input_str)
|
||||||
|
|
||||||
|
# Split the token into parts
|
||||||
|
_, payload_b64, _ = token.split('.')
|
||||||
|
|
||||||
|
if not payload_b64:
|
||||||
|
raise Exception("Invalid JWT token!")
|
||||||
|
|
||||||
|
return json.loads(base64url_decode(payload_b64))
|
||||||
|
|
||||||
|
|
||||||
class HonAuth:
|
class HonAuth:
|
||||||
_TOKEN_EXPIRES_AFTER_HOURS = 8
|
_TOKEN_EXPIRE_WARNING_HOURS = 1
|
||||||
_TOKEN_EXPIRE_WARNING_HOURS = 7
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -55,9 +72,15 @@ class HonAuth:
|
||||||
self._login_data.email = email
|
self._login_data.email = email
|
||||||
self._login_data.password = password
|
self._login_data.password = password
|
||||||
self._device = device
|
self._device = device
|
||||||
self._expires: datetime = datetime.utcnow()
|
self._expires: datetime = datetime.now(timezone.utc)
|
||||||
self._auth = HonAuthData()
|
self._auth = HonAuthData()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expires(self) -> datetime:
|
||||||
|
if self.id_token == "":
|
||||||
|
return datetime.fromtimestamp(0, timezone.utc)
|
||||||
|
return datetime.fromtimestamp(float(decode_jwt(self.id_token).get("exp", "0")), timezone.utc)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cognito_token(self) -> str:
|
def cognito_token(self) -> str:
|
||||||
return self._auth.cognito_token
|
return self._auth.cognito_token
|
||||||
|
@ -74,12 +97,12 @@ class HonAuth:
|
||||||
def refresh_token(self) -> str:
|
def refresh_token(self) -> str:
|
||||||
return self._auth.refresh_token
|
return self._auth.refresh_token
|
||||||
|
|
||||||
def _check_token_expiration(self, hours: int) -> bool:
|
def _check_token_expiration(self, hours: int = 0) -> bool:
|
||||||
return datetime.utcnow() >= self._expires + timedelta(hours=hours)
|
return datetime.now(timezone.utc) >= self.expires - timedelta(hours=hours)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def token_is_expired(self) -> bool:
|
def token_is_expired(self) -> bool:
|
||||||
return self._check_token_expiration(self._TOKEN_EXPIRES_AFTER_HOURS)
|
return self._check_token_expiration()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def token_expires_soon(self) -> bool:
|
def token_expires_soon(self) -> bool:
|
||||||
|
@ -119,7 +142,7 @@ class HonAuth:
|
||||||
url = f"{const.AUTH_API}/services/oauth2/authorize/expid_Login?{params_encode}"
|
url = f"{const.AUTH_API}/services/oauth2/authorize/expid_Login?{params_encode}"
|
||||||
async with self._request.get(url) as response:
|
async with self._request.get(url) as response:
|
||||||
text = await response.text()
|
text = await response.text()
|
||||||
self._expires = datetime.utcnow()
|
self._expires = datetime.now(timezone.utc)
|
||||||
login_url: List[str] = re.findall("(?:url|href) ?= ?'(.+?)'", text)
|
login_url: List[str] = re.findall("(?:url|href) ?= ?'(.+?)'", text)
|
||||||
if not login_url:
|
if not login_url:
|
||||||
if "oauth/done#access_token=" in text:
|
if "oauth/done#access_token=" in text:
|
||||||
|
|
|
@ -58,7 +58,7 @@ class HonConnectionHandler(ConnectionHandler):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def _check_headers(self, headers: Dict[str, str]) -> Dict[str, str]:
|
async def _check_headers(self, headers: Dict[str, str]) -> Dict[str, str]:
|
||||||
if self._refresh_token:
|
if self._refresh_token and self.auth.token_expires_soon:
|
||||||
await self.auth.refresh(self._refresh_token)
|
await self.auth.refresh(self._refresh_token)
|
||||||
if not (self.auth.cognito_token and self.auth.id_token):
|
if not (self.auth.cognito_token and self.auth.id_token):
|
||||||
await self.auth.authenticate()
|
await self.auth.authenticate()
|
||||||
|
|
Loading…
Reference in a new issue