Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions kasa/aestransport.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def _handle_response_error_code(self, resp_dict: Any, msg: str) -> None:
async def send_secure_passthrough(self, request: str) -> Dict[str, Any]:
"""Send encrypted message as passthrough."""
url = f"http://{self._host}/app"
if self._login_token:
if self._state is TransportState.ESTABLISHED and self._login_token:
url += f"?token={self._login_token}"

encrypted_payload = self._encryption_session.encrypt(request.encode()) # type: ignore
Expand Down Expand Up @@ -250,6 +250,7 @@ async def perform_handshake(self) -> None:
_LOGGER.debug("Will perform handshaking...")

self._key_pair = None
self._login_token = None
self._session_expire_at = None
self._session_cookie = None

Expand Down Expand Up @@ -284,9 +285,7 @@ async def perform_handshake(self) -> None:
handshake_key = resp_dict["result"]["key"]

if (
cookie := http_client.get_cookie( # type: ignore
self.SESSION_COOKIE_NAME
)
cookie := http_client.get_cookie(self.SESSION_COOKIE_NAME) # type: ignore
) or (
cookie := http_client.get_cookie("SESSIONID") # type: ignore
):
Expand Down
18 changes: 12 additions & 6 deletions kasa/tests/test_aestransport.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import base64
import json
import random
import string
import time
from contextlib import nullcontext as does_not_raise
from json import dumps as json_dumps
from json import loads as json_loads
from typing import Any, Dict, Optional

import aiohttp
import pytest
Expand Down Expand Up @@ -219,7 +222,6 @@ async def read(self):
return json_dumps(self._json).encode()

encryption_session = AesEncyptionSession(KEY_IV[:16], KEY_IV[16:])
token = "test_token" # noqa

def __init__(self, host, status_code=200, error_code=0, inner_error_code=0):
self.host = host
Expand All @@ -228,6 +230,7 @@ def __init__(self, host, status_code=200, error_code=0, inner_error_code=0):
self._inner_error_code = inner_error_code
self.http_client = HttpClient(DeviceConfig(self.host))
self.inner_call_count = 0
self.token = "".join(random.choices(string.ascii_uppercase, k=32)) # noqa: S311

@property
def inner_error_code(self):
Expand All @@ -242,7 +245,7 @@ async def post(self, url, params=None, json=None, data=None, *_, **__):
json = json_loads(item.decode())
return await self._post(url, json)

async def _post(self, url, json):
async def _post(self, url: str, json: Dict[str, Any]):
if json["method"] == "handshake":
return await self._return_handshake_response(url, json)
elif json["method"] == "securePassthrough":
Expand All @@ -253,7 +256,7 @@ async def _post(self, url, json):
assert url == f"http://{self.host}/app?token={self.token}"
return await self._return_send_response(url, json)

async def _return_handshake_response(self, url, json):
async def _return_handshake_response(self, url: str, json: Dict[str, Any]):
start = len("-----BEGIN PUBLIC KEY-----\n")
end = len("\n-----END PUBLIC KEY-----\n")
client_pub_key = json["params"]["key"][start:-end]
Expand All @@ -266,7 +269,7 @@ async def _return_handshake_response(self, url, json):
self.status_code, {"result": {"key": key_64}, "error_code": self.error_code}
)

async def _return_secure_passthrough_response(self, url, json):
async def _return_secure_passthrough_response(self, url: str, json: Dict[str, Any]):
encrypted_request = json["params"]["request"]
decrypted_request = self.encryption_session.decrypt(encrypted_request.encode())
decrypted_request_dict = json_loads(decrypted_request)
Expand All @@ -283,12 +286,15 @@ async def _return_secure_passthrough_response(self, url, json):
}
return self._mock_response(self.status_code, result)

async def _return_login_response(self, url, json):
async def _return_login_response(self, url: str, json: Dict[str, Any]):
if "token=" in url:
raise Exception("token should not be in url for a login request")
self.token = "".join(random.choices(string.ascii_uppercase, k=32)) # noqa: S311
result = {"result": {"token": self.token}, "error_code": self.inner_error_code}
self.inner_call_count += 1
return self._mock_response(self.status_code, result)

async def _return_send_response(self, url, json):
async def _return_send_response(self, url: str, json: Dict[str, Any]):
result = {"result": {"method": None}, "error_code": self.inner_error_code}
self.inner_call_count += 1
return self._mock_response(self.status_code, result)