Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion kasa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
"""
from importlib_metadata import version # type: ignore
from kasa.discover import Discover
from kasa.exceptions import SmartDeviceException
from kasa.protocol import TPLinkSmartHomeProtocol
from kasa.smartbulb import SmartBulb
from kasa.smartdevice import DeviceType, EmeterStatus, SmartDevice, SmartDeviceException
from kasa.smartdevice import DeviceType, EmeterStatus, SmartDevice
from kasa.smartdimmer import SmartDimmer
from kasa.smartplug import SmartPlug
from kasa.smartstrip import SmartStrip
Expand Down
5 changes: 5 additions & 0 deletions kasa/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""python-kasa exceptions."""


class SmartDeviceException(Exception):
"""Base exception for device errors."""
79 changes: 49 additions & 30 deletions kasa/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from pprint import pformat as pf
from typing import Dict, Union

from .exceptions import SmartDeviceException

_LOGGER = logging.getLogger(__name__)


Expand All @@ -27,48 +29,65 @@ class TPLinkSmartHomeProtocol:
DEFAULT_TIMEOUT = 5

@staticmethod
async def query(host: str, request: Union[str, Dict]) -> Dict:
async def query(host: str, request: Union[str, Dict], retry_count: int = 3) -> Dict:
"""Request information from a TP-Link SmartHome Device.

:param str host: host name or ip address of the device
:param request: command to send to the device (can be either dict or
json string)
:param retry_count: how many retries to do in case of failure
:return: response dict
"""
if isinstance(request, dict):
request = json.dumps(request)

timeout = TPLinkSmartHomeProtocol.DEFAULT_TIMEOUT
writer = None
try:
task = asyncio.open_connection(host, TPLinkSmartHomeProtocol.DEFAULT_PORT)
reader, writer = await asyncio.wait_for(task, timeout=timeout)
_LOGGER.debug("> (%i) %s", len(request), request)
writer.write(TPLinkSmartHomeProtocol.encrypt(request))
await writer.drain()

buffer = bytes()
# Some devices send responses with a length header of 0 and
# terminate with a zero size chunk. Others send the length and
# will hang if we attempt to read more data.
length = -1
while True:
chunk = await reader.read(4096)
if length == -1:
length = struct.unpack(">I", chunk[0:4])[0]
buffer += chunk
if (length > 0 and len(buffer) >= length + 4) or not chunk:
break
finally:
if writer:
writer.close()
await writer.wait_closed()

response = TPLinkSmartHomeProtocol.decrypt(buffer[4:])
json_payload = json.loads(response)
_LOGGER.debug("< (%i) %s", len(response), pf(json_payload))

return json_payload
for retry in range(retry_count + 1):
try:
task = asyncio.open_connection(
host, TPLinkSmartHomeProtocol.DEFAULT_PORT
)
reader, writer = await asyncio.wait_for(task, timeout=timeout)
_LOGGER.debug("> (%i) %s", len(request), request)
writer.write(TPLinkSmartHomeProtocol.encrypt(request))
await writer.drain()

buffer = bytes()
# Some devices send responses with a length header of 0 and
# terminate with a zero size chunk. Others send the length and
# will hang if we attempt to read more data.
length = -1
while True:
chunk = await reader.read(4096)
if length == -1:
length = struct.unpack(">I", chunk[0:4])[0]
buffer += chunk
if (length > 0 and len(buffer) >= length + 4) or not chunk:
break

response = TPLinkSmartHomeProtocol.decrypt(buffer[4:])
json_payload = json.loads(response)
_LOGGER.debug("< (%i) %s", len(response), pf(json_payload))

return json_payload

except Exception as ex:
if retry >= retry_count:
_LOGGER.debug("Giving up after %s retries", retry)
raise SmartDeviceException(
"Unable to query the device: %s" % ex
) from ex

_LOGGER.debug("Unable to query the device, retrying: %s", ex)

finally:
if writer:
writer.close()
await writer.wait_closed()

# make mypy happy, this should never be reached..
raise SmartDeviceException("Query reached somehow to unreachable")

@staticmethod
def encrypt(request: str) -> bytes:
Expand Down
7 changes: 2 additions & 5 deletions kasa/smartdevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from enum import Enum
from typing import Any, Dict, List, Optional

from kasa.protocol import TPLinkSmartHomeProtocol
from .exceptions import SmartDeviceException
from .protocol import TPLinkSmartHomeProtocol

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -47,10 +48,6 @@ class WifiNetwork:
rssi: Optional[int] = None


class SmartDeviceException(Exception):
"""Base exception for device errors."""


class EmeterStatus(dict):
"""Container for converting different representations of emeter data.

Expand Down
12 changes: 12 additions & 0 deletions kasa/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import os
from os.path import basename
from unittest.mock import MagicMock

import pytest # type: ignore # see https://github.com/pytest-dev/pytest/issues/3342

Expand Down Expand Up @@ -151,3 +152,14 @@ def pytest_collection_modifyitems(config, items):
return
else:
print("Running against ip %s" % config.getoption("--ip"))


# allow mocks to be awaited
# https://stackoverflow.com/questions/51394411/python-object-magicmock-cant-be-used-in-await-expression/51399767#51399767


async def async_magic():
pass


MagicMock.__await__ = lambda x: async_magic().__await__()
150 changes: 86 additions & 64 deletions kasa/tests/test_protocol.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,95 @@
import json
from unittest import TestCase

import pytest

from ..exceptions import SmartDeviceException
from ..protocol import TPLinkSmartHomeProtocol


class TestTPLinkSmartHomeProtocol(TestCase):
def test_encrypt(self):
d = json.dumps({"foo": 1, "bar": 2})
encrypted = TPLinkSmartHomeProtocol.encrypt(d)
# encrypt adds a 4 byte header
encrypted = encrypted[4:]
self.assertEqual(d, TPLinkSmartHomeProtocol.decrypt(encrypted))

def test_encrypt_unicode(self):
d = "{'snowman': '\u2603'}"

e = bytes(
[
208,
247,
132,
234,
133,
242,
159,
254,
144,
183,
141,
173,
138,
104,
240,
115,
84,
41,
]
)
@pytest.mark.parametrize("retry_count", [1, 3, 5])
async def test_protocol_retries(mocker, retry_count):
def aio_mock_writer(_, __):
reader = mocker.patch("asyncio.StreamReader")
writer = mocker.patch("asyncio.StreamWriter")

encrypted = TPLinkSmartHomeProtocol.encrypt(d)
# encrypt adds a 4 byte header
encrypted = encrypted[4:]

self.assertEqual(e, encrypted)

def test_decrypt_unicode(self):
e = bytes(
[
208,
247,
132,
234,
133,
242,
159,
254,
144,
183,
141,
173,
138,
104,
240,
115,
84,
41,
]
mocker.patch(
"asyncio.StreamWriter.write", side_effect=Exception("dummy exception")
)

d = "{'snowman': '\u2603'}"
return reader, writer

conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
with pytest.raises(SmartDeviceException):
await TPLinkSmartHomeProtocol.query("127.0.0.1", {}, retry_count=retry_count)

assert conn.call_count == retry_count + 1


def test_encrypt():
d = json.dumps({"foo": 1, "bar": 2})
encrypted = TPLinkSmartHomeProtocol.encrypt(d)
# encrypt adds a 4 byte header
encrypted = encrypted[4:]
assert d == TPLinkSmartHomeProtocol.decrypt(encrypted)


def test_encrypt_unicode():
d = "{'snowman': '\u2603'}"

e = bytes(
[
208,
247,
132,
234,
133,
242,
159,
254,
144,
183,
141,
173,
138,
104,
240,
115,
84,
41,
]
)

encrypted = TPLinkSmartHomeProtocol.encrypt(d)
# encrypt adds a 4 byte header
encrypted = encrypted[4:]

assert e == encrypted


def test_decrypt_unicode():
e = bytes(
[
208,
247,
132,
234,
133,
242,
159,
254,
144,
183,
141,
173,
138,
104,
240,
115,
84,
41,
]
)

d = "{'snowman': '\u2603'}"

self.assertEqual(d, TPLinkSmartHomeProtocol.decrypt(e))
assert d == TPLinkSmartHomeProtocol.decrypt(e)