Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions kasa/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"""
import asyncio
import contextlib
import errno
import json
import logging
import struct
Expand All @@ -20,6 +21,7 @@
from .exceptions import SmartDeviceException

_LOGGER = logging.getLogger(__name__)
_NO_RETRY_ERRORS = {errno.EHOSTDOWN, errno.EHOSTUNREACH, errno.ECONNREFUSED}


class TPLinkSmartHomeProtocol:
Expand Down Expand Up @@ -115,9 +117,30 @@ def _reset(self) -> None:

async def _query(self, request: str, retry_count: int, timeout: int) -> Dict:
"""Try to query a device."""
#
# Most of the time we will already be connected if the device is online
# and the connect call will do nothing and return right away
#
# However, if we get an unrecoverable error (_NO_RETRY_ERRORS and ConnectionRefusedError)
# we do not want to keep trying since many connection open/close operations
# in the same time frame can block the event loop. This is especially
# import when there are multiple tplink devices being polled.
#
for retry in range(retry_count + 1):
try:
await self._connect(timeout)
except ConnectionRefusedError as ex:
await self.close()
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {ex}"
)
except OSError as ex:
await self.close()
if ex.errno in _NO_RETRY_ERRORS or retry >= retry_count:
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {ex}"
)
continue
except Exception as ex:
await self.close()
if retry >= retry_count:
Expand Down
34 changes: 34 additions & 0 deletions kasa/tests/test_protocol.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import errno
import json
import logging
import struct
Expand Down Expand Up @@ -29,6 +30,39 @@ def aio_mock_writer(_, __):
assert conn.call_count == retry_count + 1


async def test_protocol_no_retry_on_unreachable(mocker):
conn = mocker.patch(
"asyncio.open_connection",
side_effect=OSError(errno.EHOSTUNREACH, "No route to host"),
)
with pytest.raises(SmartDeviceException):
await TPLinkSmartHomeProtocol("127.0.0.1").query({}, retry_count=5)

assert conn.call_count == 1


async def test_protocol_no_retry_connection_refused(mocker):
conn = mocker.patch(
"asyncio.open_connection",
side_effect=ConnectionRefusedError,
)
with pytest.raises(SmartDeviceException):
await TPLinkSmartHomeProtocol("127.0.0.1").query({}, retry_count=5)

assert conn.call_count == 1


async def test_protocol_retry_recoverable_error(mocker):
conn = mocker.patch(
"asyncio.open_connection",
side_effect=OSError(errno.ECONNRESET, "Connection reset by peer"),
)
with pytest.raises(SmartDeviceException):
await TPLinkSmartHomeProtocol("127.0.0.1").query({}, retry_count=5)

assert conn.call_count == 6


@pytest.mark.skipif(sys.version_info < (3, 8), reason="3.8 is first one with asyncmock")
@pytest.mark.parametrize("retry_count", [1, 3, 5])
async def test_protocol_reconnect(mocker, retry_count):
Expand Down