Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions kasa/httpclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,23 @@
ssl=ssl,
)
async with resp:
if resp.status == 200:
response_data = await resp.read()
if return_json:
response_data = await resp.read()

if resp.status == 200:
if return_json:
response_data = json_loads(response_data.decode())
else:
_LOGGER.debug(
"Device %s received status code %s with response %s",
self._config.host,
resp.status,
str(response_data),
)
if response_data and return_json:
try:
response_data = json_loads(response_data.decode())
except Exception:
_LOGGER.debug("Device %s response could not be parsed as json")

Check warning on line 132 in kasa/httpclient.py

View check run for this annotation

Codecov / codecov/patch

kasa/httpclient.py#L131-L132

Added lines #L131 - L132 were not covered by tests

except (aiohttp.ServerDisconnectedError, aiohttp.ClientOSError) as ex:
if not self._wait_between_requests:
Expand Down
27 changes: 26 additions & 1 deletion kasa/transports/sslaestransport.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import logging
import secrets
import ssl
from contextlib import suppress
from enum import Enum, auto
from typing import TYPE_CHECKING, Any, cast

Expand Down Expand Up @@ -229,6 +230,31 @@ async def send_secure_passthrough(self, request: str) -> dict[str, Any]:
ssl=await self._get_ssl_context(),
)

if TYPE_CHECKING:
assert self._encryption_session is not None

# Devices can respond with 500 if another session is created from
# the same host. Decryption may not succeed after that
if status_code == 500:
msg = (
f"Device {self._host} replied with status 500 after handshake, "
f"response: "
)
decrypted = None
if isinstance(resp_dict, dict) and (
response := resp_dict.get("result", {}).get("response")
):
with suppress(Exception):
decrypted = self._encryption_session.decrypt(response.encode())

if decrypted:
msg += decrypted
else:
msg += str(resp_dict)

_LOGGER.debug(msg)
raise _RetryableError(msg)

if status_code != 200:
raise KasaException(
f"{self._host} responded with an unexpected "
Expand All @@ -241,7 +267,6 @@ async def send_secure_passthrough(self, request: str) -> dict[str, Any]:

if TYPE_CHECKING:
resp_dict = cast(dict[str, Any], resp_dict)
assert self._encryption_session is not None

if "result" in resp_dict and "response" in resp_dict["result"]:
raw_response: str = resp_dict["result"]["response"]
Expand Down
69 changes: 65 additions & 4 deletions tests/transports/test_sslaestransport.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
DeviceError,
KasaException,
SmartErrorCode,
_RetryableError,
)
from kasa.httpclient import HttpClient
from kasa.transports.aestransport import AesEncyptionSession
Expand Down Expand Up @@ -217,6 +218,48 @@ async def test_device_blocked_response(mocker):
await transport.perform_handshake()


@pytest.mark.parametrize(
("response", "expected_msg"),
[
pytest.param(
{"error_code": -1, "msg": "Check tapo tag failed"},
'{"error_code": -1, "msg": "Check tapo tag failed"}',
id="can-decrypt",
),
pytest.param(
b"12345678",
str({"result": {"response": "12345678"}, "error_code": 0}),
id="cannot-decrypt",
),
],
)
async def test_device_500_error(mocker, response, expected_msg):
"""Test 500 error raises retryable exception."""
host = "127.0.0.1"
mock_ssl_aes_device = MockSslAesDevice(host)
mocker.patch.object(
aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post
)

transport = SslAesTransport(
config=DeviceConfig(host, credentials=Credentials(MOCK_USER, MOCK_PWD))
)

request = {
"method": "getDeviceInfo",
"params": None,
}

await transport.perform_handshake()

mock_ssl_aes_device.put_next_response(response)
mock_ssl_aes_device.status_code = 500

msg = f"Device 127.0.0.1 replied with status 500 after handshake, response: {expected_msg}"
with pytest.raises(_RetryableError, match=msg):
await transport.send(json_dumps(request))


async def test_port_override():
"""Test that port override sets the app_url."""
host = "127.0.0.1"
Expand Down Expand Up @@ -302,6 +345,8 @@ def __init__(
self.digest_password_fail = digest_password_fail
self.device_blocked = device_blocked

self._next_responses: list[dict | bytes] = []

async def post(self, url: URL, params=None, json=None, data=None, *_, **__):
if data:
json = json_loads(data)
Expand Down Expand Up @@ -386,11 +431,24 @@ async def _return_secure_passthrough_response(self, url: URL, json: dict[str, An
assert self.encryption_session
decrypted_request = self.encryption_session.decrypt(encrypted_request.encode())
decrypted_request_dict = json_loads(decrypted_request)
decrypted_response = await self._post(url, decrypted_request_dict)
async with decrypted_response:
decrypted_response_data = await decrypted_response.read()

encrypted_response = self.encryption_session.encrypt(decrypted_response_data)
if self._next_responses:
next_response = self._next_responses.pop(0)
if isinstance(next_response, dict):
decrypted_response_data = json_dumps(next_response).encode()
encrypted_response = self.encryption_session.encrypt(
decrypted_response_data
)
else:
encrypted_response = next_response
else:
decrypted_response = await self._post(url, decrypted_request_dict)
async with decrypted_response:
decrypted_response_data = await decrypted_response.read()
encrypted_response = self.encryption_session.encrypt(
decrypted_response_data
)

response = (
decrypted_response_data
if self.do_not_encrypt_response
Expand All @@ -405,3 +463,6 @@ async def _return_secure_passthrough_response(self, url: URL, json: dict[str, An
async def _return_send_response(self, url: URL, json: dict[str, Any]):
result = {"result": {"method": None}, "error_code": self.send_error_code}
return self._mock_response(self.status_code, result)

def put_next_response(self, request: dict | bytes) -> None:
self._next_responses.append(request)
Loading