Skip to content
80 changes: 44 additions & 36 deletions kasa/aestransport.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import hashlib
import logging
import time
from typing import TYPE_CHECKING, AsyncGenerator, Dict, Optional, cast
from enum import Enum, auto
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Optional, Tuple, cast

from cryptography.hazmat.primitives import padding, serialization
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
Expand Down Expand Up @@ -41,6 +42,14 @@ def _sha1(payload: bytes) -> str:
return sha1_algo.hexdigest()


class TransportState(Enum):
"""Enum for AES state."""

HANDSHAKE_REQUIRED = auto() # Handshake needed
LOGIN_REQUIRED = auto() # Login needed
ESTABLISHED = auto() # Ready to send requests


class AesTransport(BaseTransport):
"""Implementation of the AES encryption protocol.

Expand Down Expand Up @@ -79,21 +88,21 @@ def __init__(
self._default_credentials: Optional[Credentials] = None
self._http_client: HttpClient = HttpClient(config)

self._handshake_done = False
self._state = TransportState.HANDSHAKE_REQUIRED

self._encryption_session: Optional[AesEncyptionSession] = None
self._session_expire_at: Optional[float] = None

self._session_cookie: Optional[Dict[str, str]] = None

self._login_token = None
self._login_token: Optional[str] = None

self._key_pair: Optional[KeyPair] = None

_LOGGER.debug("Created AES transport for %s", self._host)

@property
def default_port(self):
def default_port(self) -> int:
"""Default port for the transport."""
return self.DEFAULT_PORT

Expand All @@ -102,30 +111,25 @@ def credentials_hash(self) -> str:
"""The hashed credentials used by the transport."""
return base64.b64encode(json_dumps(self._login_params).encode()).decode()

def _get_login_params(self, credentials):
def _get_login_params(self, credentials: Credentials) -> Dict[str, str]:
"""Get the login parameters based on the login_version."""
un, pw = self.hash_credentials(self._login_version == 2, credentials)
password_field_name = "password2" if self._login_version == 2 else "password"
return {password_field_name: pw, "username": un}

@staticmethod
def hash_credentials(login_v2, credentials):
def hash_credentials(login_v2: bool, credentials: Credentials) -> Tuple[str, str]:
"""Hash the credentials."""
un = base64.b64encode(_sha1(credentials.username.encode()).encode()).decode()
if login_v2:
un = base64.b64encode(
_sha1(credentials.username.encode()).encode()
).decode()
pw = base64.b64encode(
_sha1(credentials.password.encode()).encode()
).decode()
else:
un = base64.b64encode(
_sha1(credentials.username.encode()).encode()
).decode()
pw = base64.b64encode(credentials.password.encode()).decode()
return un, pw

def _handle_response_error_code(self, resp_dict: dict, msg: str):
def _handle_response_error_code(self, resp_dict: Any, msg: str) -> None:
error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
if error_code == SmartErrorCode.SUCCESS:
return
Expand All @@ -135,12 +139,11 @@ def _handle_response_error_code(self, resp_dict: dict, msg: str):
if error_code in SMART_RETRYABLE_ERRORS:
raise RetryableException(msg, error_code=error_code)
if error_code in SMART_AUTHENTICATION_ERRORS:
self._handshake_done = False
self._login_token = None
self._state = TransportState.HANDSHAKE_REQUIRED
raise AuthenticationException(msg, error_code=error_code)
raise SmartDeviceException(msg, error_code=error_code)

async def send_secure_passthrough(self, request: str):
async def send_secure_passthrough(self, request: str) -> Dict[str, Any]:
"""Send encrypted message as passthrough."""
url = f"http://{self._host}/app"
if self._login_token:
Expand All @@ -165,24 +168,25 @@ async def send_secure_passthrough(self, request: str):
+ f"status code {status_code} to passthrough"
)

resp_dict = cast(Dict, resp_dict)
self._handle_response_error_code(
resp_dict, "Error sending secure_passthrough message"
)

response = self._encryption_session.decrypt( # type: ignore
resp_dict["result"]["response"].encode()
)
resp_dict = json_loads(response)
return resp_dict
if TYPE_CHECKING:
resp_dict = cast(Dict[str, Any], resp_dict)
assert self._encryption_session is not None

raw_response: str = resp_dict["result"]["response"]
response = self._encryption_session.decrypt(raw_response.encode())
return json_loads(response) # type: ignore[return-value]

async def perform_login(self):
"""Login to the device."""
try:
await self.try_login(self._login_params)
except AuthenticationException as aex:
try:
if aex.error_code != SmartErrorCode.LOGIN_ERROR:
if aex.error_code is not SmartErrorCode.LOGIN_ERROR:
raise aex
if self._default_credentials is None:
self._default_credentials = get_default_credentials(
Expand All @@ -203,9 +207,8 @@ async def perform_login(self):
ex,
) from ex

async def try_login(self, login_params):
async def try_login(self, login_params: Dict[str, Any]) -> None:
"""Try to login with supplied login_params."""
self._login_token = None
login_request = {
"method": "login_device",
"params": login_params,
Expand All @@ -216,6 +219,7 @@ async def try_login(self, login_params):
resp_dict = await self.send_secure_passthrough(request)
self._handle_response_error_code(resp_dict, "Error logging in")
self._login_token = resp_dict["result"]["token"]
self._state = TransportState.ESTABLISHED

async def _generate_key_pair_payload(self) -> AsyncGenerator:
"""Generate the request body and return an ascyn_generator.
Expand All @@ -236,12 +240,11 @@ async def _generate_key_pair_payload(self) -> AsyncGenerator:
_LOGGER.debug(f"Request {request_body}")
yield json_dumps(request_body).encode()

async def perform_handshake(self):
async def perform_handshake(self) -> None:
"""Perform the handshake."""
_LOGGER.debug("Will perform handshaking...")

self._key_pair = None
self._handshake_done = False
self._session_expire_at = None
self._session_cookie = None

Expand All @@ -258,7 +261,7 @@ async def perform_handshake(self):
cookies_dict=self._session_cookie,
)

_LOGGER.debug(f"Device responded with: {resp_dict}")
_LOGGER.debug("Device responded with: %s", resp_dict)

if status_code != 200:
raise SmartDeviceException(
Expand All @@ -268,6 +271,9 @@ async def perform_handshake(self):

self._handle_response_error_code(resp_dict, "Unable to complete handshake")

if TYPE_CHECKING:
resp_dict = cast(Dict[str, Any], resp_dict)

handshake_key = resp_dict["result"]["key"]

if (
Expand All @@ -283,12 +289,12 @@ async def perform_handshake(self):

self._session_expire_at = time.time() + 86400
if TYPE_CHECKING:
assert self._key_pair is not None # pragma: no cover
assert self._key_pair is not None
self._encryption_session = AesEncyptionSession.create_from_keypair(
handshake_key, self._key_pair
)

self._handshake_done = True
self._state = TransportState.LOGIN_REQUIRED

_LOGGER.debug("Handshake with %s complete", self._host)

Expand All @@ -299,17 +305,20 @@ def _handshake_session_expired(self):
or self._session_expire_at - time.time() <= 0
)

async def send(self, request: str):
async def send(self, request: str) -> Dict[str, Any]:
"""Send the request."""
if not self._handshake_done or self._handshake_session_expired():
if (
self._state is TransportState.HANDSHAKE_REQUIRED
or self._handshake_session_expired()
):
await self.perform_handshake()
if not self._login_token:
if self._state is not TransportState.ESTABLISHED:
try:
await self.perform_login()
# After a login failure handshake needs to
# be redone or a 9999 error is received.
except AuthenticationException as ex:
self._handshake_done = False
self._state = TransportState.HANDSHAKE_REQUIRED
raise ex

return await self.send_secure_passthrough(request)
Expand All @@ -321,8 +330,7 @@ async def close(self) -> None:

async def reset(self) -> None:
"""Reset internal handshake and login state."""
self._handshake_done = False
self._login_token = None
self._state = TransportState.HANDSHAKE_REQUIRED


class AesEncyptionSession:
Expand Down
10 changes: 5 additions & 5 deletions kasa/tests/test_aestransport.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding

from ..aestransport import AesEncyptionSession, AesTransport
from ..aestransport import AesEncyptionSession, AesTransport, TransportState
from ..credentials import Credentials
from ..deviceconfig import DeviceConfig
from ..exceptions import (
Expand Down Expand Up @@ -66,11 +66,11 @@ async def test_handshake(
)

assert transport._encryption_session is None
assert transport._handshake_done is False
assert transport._state is TransportState.HANDSHAKE_REQUIRED
with expectation:
await transport.perform_handshake()
assert transport._encryption_session is not None
assert transport._handshake_done is True
assert transport._state is TransportState.LOGIN_REQUIRED


@status_parameters
Expand All @@ -82,7 +82,7 @@ async def test_login(mocker, status_code, error_code, inner_error_code, expectat
transport = AesTransport(
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
)
transport._handshake_done = True
transport._state = TransportState.LOGIN_REQUIRED
transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session

Expand Down Expand Up @@ -129,7 +129,7 @@ async def test_login_errors(mocker, inner_error_codes, expectation, call_count):
transport = AesTransport(
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
)
transport._handshake_done = True
transport._state = TransportState.LOGIN_REQUIRED
transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session

Expand Down
11 changes: 9 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,16 @@ omit = ["kasa/tests/*"]

[tool.coverage.report]
exclude_lines = [
# ignore abstract methods
# Don't complain if tests don't hit defensive assertion code:
"raise AssertionError",
"raise NotImplementedError",
"def __repr__"
# Don't complain about missing debug-only code:
"def __repr__",
# Have to re-enable the standard pragma
"pragma: no cover",
# TYPE_CHECKING and @overload blocks are never executed during pytest run
"if TYPE_CHECKING:",
"@overload"
]

[tool.pytest.ini_options]
Expand Down