Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
3eaa7e2
Keep connection open and lock to prevent duplicate requests
bdraco Sep 24, 2021
67b40ec
option to not update children
bdraco Sep 24, 2021
d85aa01
tweaks
bdraco Sep 24, 2021
8fb60a7
typing
bdraco Sep 24, 2021
b51a655
tweaks
bdraco Sep 24, 2021
42bf715
run tests in the same event loop
bdraco Sep 24, 2021
ef8259e
memorize model
bdraco Sep 24, 2021
c116d46
Update kasa/protocol.py
bdraco Sep 24, 2021
56b5883
Update kasa/protocol.py
bdraco Sep 24, 2021
1d9cb92
Update kasa/protocol.py
bdraco Sep 24, 2021
fbb7268
Update kasa/protocol.py
bdraco Sep 24, 2021
37d0745
dry
bdraco Sep 24, 2021
5ada950
tweaks
bdraco Sep 24, 2021
fd2f388
warn when the event loop gets switched out from under us
bdraco Sep 24, 2021
5b9616d
raise on unable to connect multiple times
bdraco Sep 24, 2021
e617d16
fix patch target
bdraco Sep 24, 2021
7a4df1a
tweaks
bdraco Sep 24, 2021
07438d2
isrot
bdraco Sep 24, 2021
0d1c2a3
reconnect test
bdraco Sep 24, 2021
58fda47
prune
bdraco Sep 24, 2021
c4c533e
fix mocking
bdraco Sep 24, 2021
4a9279b
fix mocking
bdraco Sep 24, 2021
a2b38b4
fix test under python 3.7
bdraco Sep 24, 2021
053ccaa
fix test under python 3.7
bdraco Sep 24, 2021
3c453cf
less patching
bdraco Sep 24, 2021
2ba83c0
isort
bdraco Sep 24, 2021
0d690f6
use mocker to patch
bdraco Sep 24, 2021
22ba405
disable on old python since mocking doesnt work
bdraco Sep 24, 2021
4b369c7
avoid disconnect/reconnect cycles
bdraco Sep 24, 2021
fc2b637
isort
bdraco Sep 24, 2021
e11f813
Fix hue validation
bdraco Sep 24, 2021
a4b92b2
Fix latitude_i/longitude_i units
bdraco Sep 24, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions devtools/dump_devinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,17 @@ def cli(host, debug):
),
]

protocol = TPLinkSmartHomeProtocol()

successes = []

for test_call in items:

async def _run_query():
protocol = TPLinkSmartHomeProtocol(host)
return await protocol.query({test_call.module: {test_call.method: None}})

try:
click.echo(f"Testing {test_call}..", nl=False)
info = asyncio.run(
protocol.query(host, {test_call.module: {test_call.method: None}})
)
info = asyncio.run(_run_query())
resp = info[test_call.module]
except Exception as ex:
click.echo(click.style(f"FAIL {ex}", fg="red"))
Expand All @@ -107,8 +108,12 @@ def cli(host, debug):

final = default_to_regular(final)

async def _run_final_query():
protocol = TPLinkSmartHomeProtocol(host)
return await protocol.query(final_query)

try:
final = asyncio.run(protocol.query(host, final_query))
final = asyncio.run(_run_final_query())
except Exception as ex:
click.echo(
click.style(
Expand Down
9 changes: 4 additions & 5 deletions kasa/discover.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def __init__(
self.discovery_packets = discovery_packets
self.interface = interface
self.on_discovered = on_discovered
self.protocol = TPLinkSmartHomeProtocol()
self.target = (target, Discover.DISCOVERY_PORT)
self.discovered_devices = {}

Expand All @@ -61,7 +60,7 @@ def do_discover(self) -> None:
"""Send number of discovery datagrams."""
req = json.dumps(Discover.DISCOVERY_QUERY)
_LOGGER.debug("[DISCOVERY] %s >> %s", self.target, Discover.DISCOVERY_QUERY)
encrypted_req = self.protocol.encrypt(req)
encrypted_req = TPLinkSmartHomeProtocol.encrypt(req)
for i in range(self.discovery_packets):
self.transport.sendto(encrypted_req[4:], self.target) # type: ignore

Expand All @@ -71,7 +70,7 @@ def datagram_received(self, data, addr) -> None:
if ip in self.discovered_devices:
return

info = json.loads(self.protocol.decrypt(data))
info = json.loads(TPLinkSmartHomeProtocol.decrypt(data))
_LOGGER.debug("[DISCOVERY] %s << %s", ip, info)

device_class = Discover._get_device_class(info)
Expand Down Expand Up @@ -190,9 +189,9 @@ async def discover_single(host: str) -> SmartDevice:
:rtype: SmartDevice
:return: Object for querying/controlling found device.
"""
protocol = TPLinkSmartHomeProtocol()
protocol = TPLinkSmartHomeProtocol(host)

info = await protocol.query(host, Discover.DISCOVERY_QUERY)
info = await protocol.query(Discover.DISCOVERY_QUERY)

device_class = Discover._get_device_class(info)
dev = device_class(host)
Expand Down
140 changes: 104 additions & 36 deletions kasa/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
http://www.apache.org/licenses/LICENSE-2.0
"""
import asyncio
import contextlib
import json
import logging
import struct
from pprint import pformat as pf
from typing import Dict, Union
from typing import Dict, Optional, Union

from .exceptions import SmartDeviceException

Expand All @@ -28,8 +29,26 @@ class TPLinkSmartHomeProtocol:
DEFAULT_PORT = 9999
DEFAULT_TIMEOUT = 5

@staticmethod
async def query(host: str, request: Union[str, Dict], retry_count: int = 3) -> Dict:
BLOCK_SIZE = 4

def __init__(self, host: str) -> None:
"""Create a protocol object."""
self.host = host
self.reader: Optional[asyncio.StreamReader] = None
self.writer: Optional[asyncio.StreamWriter] = None
self.query_lock: Optional[asyncio.Lock] = None
self.loop: Optional[asyncio.AbstractEventLoop] = None

def _detect_event_loop_change(self) -> None:
"""Check if this object has been reused betwen event loops."""
loop = asyncio.get_running_loop()
if not self.loop:
self.loop = loop
elif self.loop != loop:
_LOGGER.warning("Detected protocol reuse between different event loop")
self._reset()

async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
"""Request information from a TP-Link SmartHome Device.

:param str host: host name or ip address of the device
Expand All @@ -38,57 +57,106 @@ async def query(host: str, request: Union[str, Dict], retry_count: int = 3) -> D
:param retry_count: how many retries to do in case of failure
:return: response dict
"""
self._detect_event_loop_change()

if not self.query_lock:
self.query_lock = asyncio.Lock()

if isinstance(request, dict):
request = json.dumps(request)
assert isinstance(request, str)

timeout = TPLinkSmartHomeProtocol.DEFAULT_TIMEOUT
writer = None

async with self.query_lock:
return await self._query(request, retry_count, timeout)

async def _connect(self, timeout: int) -> bool:
"""Try to connect or reconnect to the device."""
if self.writer:
return True

with contextlib.suppress(Exception):
self.reader = self.writer = None
task = asyncio.open_connection(
self.host, TPLinkSmartHomeProtocol.DEFAULT_PORT
)
self.reader, self.writer = await asyncio.wait_for(task, timeout=timeout)
return True

return False

async def _execute_query(self, request: str) -> Dict:
"""Execute a query on the device and wait for the response."""
assert self.writer is not None
assert self.reader is not None

_LOGGER.debug("> (%i) %s", len(request), request)
self.writer.write(TPLinkSmartHomeProtocol.encrypt(request))
await self.writer.drain()

packed_block_size = await self.reader.readexactly(self.BLOCK_SIZE)
length = struct.unpack(">I", packed_block_size)[0]

buffer = await self.reader.readexactly(length)
response = TPLinkSmartHomeProtocol.decrypt(buffer)
json_payload = json.loads(response)
_LOGGER.debug("< (%i) %s", len(response), pf(json_payload))
return json_payload

async def close(self):
"""Close the connection."""
writer = self.writer
self._reset()
if writer:
writer.close()
with contextlib.suppress(Exception):
await writer.wait_closed()

def _reset(self):
"""Clear any varibles that should not survive between loops."""
self.writer = None
self.reader = None
self.query_lock = None
self.loop = None

async def _query(self, request: str, retry_count: int, timeout: int) -> Dict:
"""Try to query a device."""
for retry in range(retry_count + 1):
if not await self._connect(timeout):
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up after %s retries", retry)
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}"
)
continue

try:
task = asyncio.open_connection(
host, TPLinkSmartHomeProtocol.DEFAULT_PORT
assert self.reader is not None
assert self.writer is not None
return await asyncio.wait_for(
self._execute_query(request), timeout=timeout
)
reader, writer = await asyncio.wait_for(task, timeout=timeout)
_LOGGER.debug("> (%i) %s", len(request), request)
writer.write(TPLinkSmartHomeProtocol.encrypt(request))
await writer.drain()

buffer = bytes()
# Some devices send responses with a length header of 0 and
# terminate with a zero size chunk. Others send the length and
# will hang if we attempt to read more data.
length = -1
while True:
chunk = await reader.read(4096)
if length == -1:
length = struct.unpack(">I", chunk[0:4])[0]
buffer += chunk
if (length > 0 and len(buffer) >= length + 4) or not chunk:
break

response = TPLinkSmartHomeProtocol.decrypt(buffer[4:])
json_payload = json.loads(response)
_LOGGER.debug("< (%i) %s", len(response), pf(json_payload))

return json_payload

except Exception as ex:
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up after %s retries", retry)
raise SmartDeviceException(
"Unable to query the device: %s" % ex
f"Unable to query the device: {ex}"
) from ex

_LOGGER.debug("Unable to query the device, retrying: %s", ex)

finally:
if writer:
writer.close()
await writer.wait_closed()

# make mypy happy, this should never be reached..
await self.close()
raise SmartDeviceException("Query reached somehow to unreachable")

def __del__(self):
if self.writer and self.loop and self.loop.is_running():
self.writer.close()
self._reset()

@staticmethod
def _xor_payload(unencrypted):
key = TPLinkSmartHomeProtocol.INITIALIZATION_VECTOR
Expand Down
14 changes: 7 additions & 7 deletions kasa/smartdevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def __init__(self, host: str) -> None:
"""
self.host = host

self.protocol = TPLinkSmartHomeProtocol()
self.protocol = TPLinkSmartHomeProtocol(host)
self.emeter_type = "emeter"
_LOGGER.debug("Initializing %s of type %s", self.host, type(self))
self._device_type = DeviceType.Unknown
Expand Down Expand Up @@ -234,7 +234,7 @@ async def _query_helper(
request = self._create_request(target, cmd, arg, child_ids)

try:
response = await self.protocol.query(host=self.host, request=request)
response = await self.protocol.query(request=request)
except Exception as ex:
raise SmartDeviceException(f"Communication error on {target}:{cmd}") from ex

Expand Down Expand Up @@ -272,7 +272,7 @@ async def get_sys_info(self) -> Dict[str, Any]:
"""Retrieve system information."""
return await self._query_helper("system", "get_sysinfo")

async def update(self):
async def update(self, update_children: bool = True):
"""Query the device to update the data.

Needed for properties that are decorated with `requires_update`.
Expand All @@ -285,7 +285,7 @@ async def update(self):
# See #105, #120, #161
if self._last_update is None:
_LOGGER.debug("Performing the initial update to obtain sysinfo")
self._last_update = await self.protocol.query(self.host, req)
self._last_update = await self.protocol.query(req)
self._sys_info = self._last_update["system"]["get_sysinfo"]
# If the device has no emeter, we are done for the initial update
# Otherwise we will follow the regular code path to also query
Expand All @@ -299,7 +299,7 @@ async def update(self):
)
req.update(self._create_emeter_request())

self._last_update = await self.protocol.query(self.host, req)
self._last_update = await self.protocol.query(req)
self._sys_info = self._last_update["system"]["get_sysinfo"]

def update_from_discover_info(self, info):
Expand Down Expand Up @@ -383,8 +383,8 @@ def location(self) -> Dict:
loc["latitude"] = sys_info["latitude"]
loc["longitude"] = sys_info["longitude"]
elif "latitude_i" in sys_info and "longitude_i" in sys_info:
loc["latitude"] = sys_info["latitude_i"]
loc["longitude"] = sys_info["longitude_i"]
loc["latitude"] = sys_info["latitude_i"] / 10000
loc["longitude"] = sys_info["longitude_i"] / 10000
else:
_LOGGER.warning("Unsupported device location.")

Expand Down
10 changes: 5 additions & 5 deletions kasa/smartstrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,12 @@ def is_on(self) -> bool:
"""Return if any of the outlets are on."""
return any(plug.is_on for plug in self.children)

async def update(self):
async def update(self, update_children: bool = True):
"""Update some of the attributes.

Needed for methods that are decorated with `requires_update`.
"""
await super().update()
await super().update(update_children)

# Initialize the child devices during the first update.
if not self.children:
Expand All @@ -103,7 +103,7 @@ async def update(self):
SmartStripPlug(self.host, parent=self, child_id=child["id"])
)

if self.has_emeter:
if update_children and self.has_emeter:
for plug in self.children:
await plug.update()

Expand Down Expand Up @@ -243,13 +243,13 @@ def __init__(self, host: str, parent: "SmartStrip", child_id: str) -> None:
self._sys_info = parent._sys_info
self._device_type = DeviceType.StripSocket

async def update(self):
async def update(self, update_children: bool = True):
"""Query the device to update the data.

Needed for properties that are decorated with `requires_update`.
"""
self._last_update = await self.parent.protocol.query(
self.host, self._create_emeter_request()
self._create_emeter_request()
)

def _create_request(
Expand Down
Loading