diff --git a/pyhon/connection/auth.py b/pyhon/connection/auth.py index 3ab491c..ba89afd 100644 --- a/pyhon/connection/auth.py +++ b/pyhon/connection/auth.py @@ -3,9 +3,10 @@ import logging import re import secrets import urllib +import base64 from contextlib import suppress from dataclasses import dataclass -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from typing import Dict, Optional, Any, List from urllib import parse from urllib.parse import quote @@ -37,10 +38,26 @@ class HonAuthData: cognito_token: str = "" id_token: str = "" +def decode_jwt(token: str) -> dict[str, str]: + if token == "": + return {} + + def base64url_decode(input_str: str) -> bytes: + # Add padding if necessary + input_str += '=' * (4 - len(input_str) % 4) + return base64.urlsafe_b64decode(input_str) + + # Split the token into parts + _, payload_b64, _ = token.split('.') + + if not payload_b64: + raise Exception("Invalid JWT token!") + + return json.loads(base64url_decode(payload_b64)) + class HonAuth: - _TOKEN_EXPIRES_AFTER_HOURS = 8 - _TOKEN_EXPIRE_WARNING_HOURS = 7 + _TOKEN_EXPIRE_WARNING_HOURS = 1 def __init__( self, @@ -55,9 +72,15 @@ class HonAuth: self._login_data.email = email self._login_data.password = password self._device = device - self._expires: datetime = datetime.utcnow() + self._expires: datetime = datetime.now(timezone.utc) self._auth = HonAuthData() + @property + def expires(self) -> datetime: + if self.id_token == "": + return datetime.fromtimestamp(0, timezone.utc) + return datetime.fromtimestamp(float(decode_jwt(self.id_token).get("exp", "0")), timezone.utc) + @property def cognito_token(self) -> str: return self._auth.cognito_token @@ -74,12 +97,12 @@ class HonAuth: def refresh_token(self) -> str: return self._auth.refresh_token - def _check_token_expiration(self, hours: int) -> bool: - return datetime.utcnow() >= self._expires + timedelta(hours=hours) + def _check_token_expiration(self, hours: int = 0) -> bool: + return datetime.now(timezone.utc) >= self.expires - timedelta(hours=hours) @property def token_is_expired(self) -> bool: - return self._check_token_expiration(self._TOKEN_EXPIRES_AFTER_HOURS) + return self._check_token_expiration() @property def token_expires_soon(self) -> bool: @@ -119,7 +142,7 @@ class HonAuth: url = f"{const.AUTH_API}/services/oauth2/authorize/expid_Login?{params_encode}" async with self._request.get(url) as response: text = await response.text() - self._expires = datetime.utcnow() + self._expires = datetime.now(timezone.utc) login_url: List[str] = re.findall("(?:url|href) ?= ?'(.+?)'", text) if not login_url: if "oauth/done#access_token=" in text: diff --git a/pyhon/connection/handler/hon.py b/pyhon/connection/handler/hon.py index 9cce8ee..0f5060a 100644 --- a/pyhon/connection/handler/hon.py +++ b/pyhon/connection/handler/hon.py @@ -58,7 +58,7 @@ class HonConnectionHandler(ConnectionHandler): return self async def _check_headers(self, headers: Dict[str, str]) -> Dict[str, str]: - if self._refresh_token: + if self._refresh_token and self.auth.token_expires_soon: await self.auth.refresh(self._refresh_token) if not (self.auth.cognito_token and self.auth.id_token): await self.auth.authenticate()