Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 9 additions & 14 deletions kasa/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,20 +70,13 @@ async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
async with self.query_lock:
return await self._query(request, retry_count, timeout)

async def _connect(self, timeout: int) -> bool:
async def _connect(self, timeout: int) -> None:
"""Try to connect or reconnect to the device."""
if self.writer:
return True

with contextlib.suppress(Exception):
self.reader = self.writer = None
task = asyncio.open_connection(
self.host, TPLinkSmartHomeProtocol.DEFAULT_PORT
)
self.reader, self.writer = await asyncio.wait_for(task, timeout=timeout)
return True

return False
return
self.reader = self.writer = None
task = asyncio.open_connection(self.host, TPLinkSmartHomeProtocol.DEFAULT_PORT)
self.reader, self.writer = await asyncio.wait_for(task, timeout=timeout)

async def _execute_query(self, request: str) -> Dict:
"""Execute a query on the device and wait for the response."""
Expand Down Expand Up @@ -123,12 +116,14 @@ def _reset(self) -> None:
async def _query(self, request: str, retry_count: int, timeout: int) -> Dict:
"""Try to query a device."""
for retry in range(retry_count + 1):
if not await self._connect(timeout):
try:
await self._connect(timeout)
except Exception as ex:
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}"
f"Unable to connect to the device: {self.host}: {ex}"
)
continue

Expand Down
8 changes: 6 additions & 2 deletions kasa/smartdevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,11 @@ async def update(self, update_children: bool = True):
self._last_update = await self.protocol.query(req)
self._sys_info = self._last_update["system"]["get_sysinfo"]

await self._modular_update(req)
self._sys_info = self._last_update["system"]["get_sysinfo"]

async def _modular_update(self, req: dict) -> None:
"""Execute an update query."""
if self.has_emeter:
_LOGGER.debug(
"The device has emeter, querying its information along sysinfo"
Expand All @@ -326,10 +331,9 @@ async def update(self, update_children: bool = True):
continue
q = module.query()
_LOGGER.debug("Adding query for %s: %s", module, q)
req = merge(req, module.query())
req = merge(req, q)

self._last_update = await self.protocol.query(req)
self._sys_info = self._last_update["system"]["get_sysinfo"]

def update_from_discover_info(self, info):
"""Update state from info from the discover call."""
Expand Down
11 changes: 6 additions & 5 deletions kasa/smartstrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Any, DefaultDict, Dict, Optional

import asyncio
from kasa.smartdevice import (
DeviceType,
EmeterStatus,
SmartDevice,
SmartDeviceException,
requires_update,
merge,
)
from kasa.smartplug import SmartPlug

Expand Down Expand Up @@ -250,16 +251,16 @@ def __init__(self, host: str, parent: "SmartStrip", child_id: str) -> None:
self._last_update = parent._last_update
self._sys_info = parent._sys_info
self._device_type = DeviceType.StripSocket
self.modules = {}
self.protocol = parent.protocol # Must use the same connection as the parent
self.add_module("time", Time(self, "time"))

async def update(self, update_children: bool = True):
"""Query the device to update the data.

Needed for properties that are decorated with `requires_update`.
"""
# TODO: it needs to be checked if this still works after modularization
self._last_update = await self.parent.protocol.query(
self._create_emeter_request()
)
await self._modular_update({})

def _create_emeter_request(self, year: int = None, month: int = None):
"""Create a request for requesting all emeter statistics at once."""
Expand Down
5 changes: 3 additions & 2 deletions kasa/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,14 @@ def _generate_type_class_pairs():


@pytest.mark.parametrize("type_class", _generate_type_class_pairs())
async def test_deprecated_type(dev, type_class):
async def test_deprecated_type(dev, type_class, mocker):
"""Make sure that using deprecated types yields a warning."""
type, cls = type_class
if type == "dimmer":
return
runner = CliRunner()
res = await runner.invoke(cli, ["--host", "127.0.0.2", f"--{type}"])
with mocker.patch("kasa.SmartDevice.update"):
res = await runner.invoke(cli, ["--host", "127.0.0.2", f"--{type}"])
assert "Using --bulb, --plug, --strip, and --lightstrip is deprecated" in res.output


Expand Down
4 changes: 3 additions & 1 deletion kasa/tests/test_smartdevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ async def test_initial_update_no_emeter(dev, mocker):
dev._last_update = None
spy = mocker.spy(dev.protocol, "query")
await dev.update()
assert spy.call_count == 1
# 2 calls are necessary as some devices crash on unexpected modules
# See #105, #120, #161
assert spy.call_count == 2


async def test_query_helper(dev):
Expand Down