Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions kasa/aestransport.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Based on the work of https://github.com/petretiandrea/plugp100
under compatible GNU GPL3 license.
"""

import asyncio
import base64
import hashlib
import logging
Expand Down Expand Up @@ -39,6 +39,7 @@

ONE_DAY_SECONDS = 86400
SESSION_EXPIRE_BUFFER_SECONDS = 60 * 20
BACKOFF_SECONDS_AFTER_LOGIN_ERROR = 1


def _sha1(payload: bytes) -> str:
Expand Down Expand Up @@ -184,8 +185,24 @@ async def send_secure_passthrough(self, request: str) -> Dict[str, Any]:
assert self._encryption_session is not None

raw_response: str = resp_dict["result"]["response"]
response = self._encryption_session.decrypt(raw_response.encode())
return json_loads(response) # type: ignore[return-value]

try:
response = self._encryption_session.decrypt(raw_response.encode())
ret_val = json_loads(response)
except Exception as ex:
try:
ret_val = json_loads(raw_response)
_LOGGER.debug(
"Received unencrypted response over secure passthrough from %s",
self._host,
)
except Exception:
raise SmartDeviceException(
f"Unable to decrypt response from {self._host}, "
+ f"error: {ex}, response: {raw_response}",
ex,
) from ex
return ret_val # type: ignore[return-value]

async def perform_login(self):
"""Login to the device."""
Expand All @@ -199,6 +216,7 @@ async def perform_login(self):
self._default_credentials = get_default_credentials(
DEFAULT_CREDENTIALS["TAPO"]
)
await asyncio.sleep(BACKOFF_SECONDS_AFTER_LOGIN_ERROR)
await self.perform_handshake()
await self.try_login(self._get_login_params(self._default_credentials))
_LOGGER.debug(
Expand Down
12 changes: 9 additions & 3 deletions kasa/smart/smartdevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ async def update(self, update_children: bool = True):
resp = await self.protocol.query("component_nego")
self._components_raw = resp["component_nego"]
self._components = {
comp["id"]: comp["ver_code"]
comp["id"]: int(comp["ver_code"])
for comp in self._components_raw["component_list"]
}
await self._initialize_modules()
Expand All @@ -86,18 +86,24 @@ async def update(self, update_children: bool = True):
"get_current_power": None,
}

if self._components["device"] >= 2:
extra_reqs = {
**extra_reqs,
"get_device_usage": None,
}

req = {
"get_device_info": None,
"get_device_usage": None,
"get_device_time": None,
**extra_reqs,
}

resp = await self.protocol.query(req)

self._info = resp["get_device_info"]
self._usage = resp["get_device_usage"]
self._time = resp["get_device_time"]
# Device usage is not available on older firmware versions
self._usage = resp.get("get_device_usage", {})
# Emeter is not always available, but we set them still for now.
self._energy = resp.get("get_energy_usage", {})
self._emeter = resp.get("get_current_power", {})
Expand Down
1 change: 1 addition & 0 deletions kasa/smartprotocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex
await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT)
continue
except TimeoutException as ex:
await self._transport.reset()
Expand Down
94 changes: 86 additions & 8 deletions kasa/tests/test_aestransport.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import base64
import json
import logging
import random
import string
import time
Expand Down Expand Up @@ -180,6 +181,67 @@ async def test_send(mocker, status_code, error_code, inner_error_code, expectati
assert "result" in res


async def test_unencrypted_response(mocker, caplog):
host = "127.0.0.1"
mock_aes_device = MockAesDevice(host, 200, 0, 0, do_not_encrypt_response=True)
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=mock_aes_device.post)

transport = AesTransport(
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
)
transport._state = TransportState.ESTABLISHED
transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session
transport._token_url = transport._app_url.with_query(
f"token={mock_aes_device.token}"
)

request = {
"method": "get_device_info",
"params": None,
"request_time_milis": round(time.time() * 1000),
"requestID": 1,
"terminal_uuid": "foobar",
}
caplog.set_level(logging.DEBUG)
res = await transport.send(json_dumps(request))
assert "result" in res
assert (
"Received unencrypted response over secure passthrough from 127.0.0.1"
in caplog.text
)


async def test_unencrypted_response_invalid_json(mocker, caplog):
host = "127.0.0.1"
mock_aes_device = MockAesDevice(
host, 200, 0, 0, do_not_encrypt_response=True, send_response=b"Foobar"
)
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=mock_aes_device.post)

transport = AesTransport(
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
)
transport._state = TransportState.ESTABLISHED
transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session
transport._token_url = transport._app_url.with_query(
f"token={mock_aes_device.token}"
)

request = {
"method": "get_device_info",
"params": None,
"request_time_milis": round(time.time() * 1000),
"requestID": 1,
"terminal_uuid": "foobar",
}
caplog.set_level(logging.DEBUG)
msg = f"Unable to decrypt response from {host}, error: Incorrect padding, response: Foobar"
with pytest.raises(SmartDeviceException, match=msg):
await transport.send(json_dumps(request))


ERRORS = [e for e in SmartErrorCode if e != 0]


Expand Down Expand Up @@ -233,15 +295,28 @@ async def __aexit__(self, exc_t, exc_v, exc_tb):
pass

async def read(self):
return json_dumps(self._json).encode()
if isinstance(self._json, dict):
return json_dumps(self._json).encode()
return self._json

encryption_session = AesEncyptionSession(KEY_IV[:16], KEY_IV[16:])

def __init__(self, host, status_code=200, error_code=0, inner_error_code=0):
def __init__(
self,
host,
status_code=200,
error_code=0,
inner_error_code=0,
*,
do_not_encrypt_response=False,
send_response=None,
):
self.host = host
self.status_code = status_code
self.error_code = error_code
self._inner_error_code = inner_error_code
self.do_not_encrypt_response = do_not_encrypt_response
self.send_response = send_response
self.http_client = HttpClient(DeviceConfig(self.host))
self.inner_call_count = 0
self.token = "".join(random.choices(string.ascii_uppercase, k=32)) # noqa: S311
Expand Down Expand Up @@ -289,13 +364,15 @@ async def _return_secure_passthrough_response(self, url: URL, json: Dict[str, An
decrypted_request_dict = json_loads(decrypted_request)
decrypted_response = await self._post(url, decrypted_request_dict)
async with decrypted_response:
response_data = await decrypted_response.read()
decrypted_response_dict = json_loads(response_data.decode())
encrypted_response = self.encryption_session.encrypt(
json_dumps(decrypted_response_dict).encode()
decrypted_response_data = await decrypted_response.read()
encrypted_response = self.encryption_session.encrypt(decrypted_response_data)
response = (
decrypted_response_data
if self.do_not_encrypt_response
else encrypted_response
)
result = {
"result": {"response": encrypted_response.decode()},
"result": {"response": response.decode()},
"error_code": self.error_code,
}
return self._mock_response(self.status_code, result)
Expand All @@ -310,5 +387,6 @@ async def _return_login_response(self, url: URL, json: Dict[str, Any]):

async def _return_send_response(self, url: URL, json: Dict[str, Any]):
result = {"result": {"method": None}, "error_code": self.inner_error_code}
response = self.send_response if self.send_response else result
self.inner_call_count += 1
return self._mock_response(self.status_code, result)
return self._mock_response(self.status_code, response)