Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions kasa/discover.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ async def discover_single(
port: Optional[int] = None,
timeout=5,
credentials: Optional[Credentials] = None,
update_parent_devices: bool = True,
) -> SmartDevice:
"""Discover a single device by the given IP address.

Expand All @@ -271,8 +272,9 @@ async def discover_single(
:param host: Hostname of device to query
:param port: Optionally set a different port for the device
:param timeout: Timeout for discovery
:param credentials: Optionally provide credentials for
devices requiring them
:param credentials: Credentials for devices that require authentication
:param update_parent_devices: Automatically call device.update() on
devices that have children
:rtype: SmartDevice
:return: Object for querying/controlling found device.
"""
Expand Down Expand Up @@ -330,7 +332,9 @@ async def discover_single(
if ip in protocol.discovered_devices:
dev = protocol.discovered_devices[ip]
dev.host = host
await dev.update()
# Call device update on devices that have children
if update_parent_devices and dev.has_children:
await dev.update()
return dev
elif ip in protocol.unsupported_devices:
raise UnsupportedDeviceException(
Expand Down
9 changes: 9 additions & 0 deletions kasa/smartdevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,15 @@ def model(self) -> str:
sys_info = self._sys_info
return str(sys_info["model"])

@property
def has_children(self) -> bool:
"""Return true if the device has children devices."""
# Ideally we would check for the 'child_num' key in sys_info,
# but devices that speak klap do not populate this key via
# update_from_discover_info so we check for the devices
# we know have children instead.
return self.is_strip

@property # type: ignore
@requires_update
def alias(self) -> str:
Expand Down
32 changes: 21 additions & 11 deletions kasa/tests/newfakes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import logging
import re

Expand Down Expand Up @@ -289,7 +290,7 @@ def __init__(self, info):
self.discovery_data = info
self.writer = None
self.reader = None
proto = FakeTransportProtocol.baseproto
proto = copy.deepcopy(FakeTransportProtocol.baseproto)

for target in info:
# print("target %s" % target)
Expand All @@ -298,16 +299,23 @@ def __init__(self, info):
proto[target][cmd] = info[target][cmd]
# if we have emeter support, we need to add the missing pieces
for module in ["emeter", "smartlife.iot.common.emeter"]:
for etype in ["get_realtime", "get_daystat", "get_monthstat"]:
if (
module in info and etype in info[module]
): # if the fixture has the data, use it
# print("got %s %s from fixture: %s" % (module, etype, info[module][etype]))
proto[module][etype] = info[module][etype]
else: # otherwise fall back to the static one
dummy_data = emeter_commands[module][etype]
# print("got %s %s from dummy: %s" % (module, etype, dummy_data))
proto[module][etype] = dummy_data
if (
module in info
and "err_code" in info[module]
and info[module]["err_code"] != 0
):
proto[module] = info[module]
else:
for etype in ["get_realtime", "get_daystat", "get_monthstat"]:
if (
module in info and etype in info[module]
): # if the fixture has the data, use it
# print("got %s %s from fixture: %s" % (module, etype, info[module][etype]))
proto[module][etype] = info[module][etype]
else: # otherwise fall back to the static one
dummy_data = emeter_commands[module][etype]
# print("got %s %s from dummy: %s" % (module, etype, dummy_data))
proto[module][etype] = dummy_data

# print("initialized: %s" % proto[module])

Expand Down Expand Up @@ -471,6 +479,8 @@ async def query(self, request, port=9999):
def get_response_for_module(target):
if target not in proto:
return error(msg="target not found")
if "err_code" in proto[target] and proto[target]["err_code"] != 0:
return {target: proto[target]}

def get_response_for_command(cmd):
if cmd not in proto[target]:
Expand Down
33 changes: 24 additions & 9 deletions kasa/tests/test_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@

import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342

from kasa import DeviceType, Discover, SmartDevice, SmartDeviceException, protocol
from kasa import (
DeviceType,
Discover,
SmartDevice,
SmartDeviceException,
SmartStrip,
protocol,
)
from kasa.discover import DiscoveryResult, _DiscoverProtocol, json_dumps
from kasa.exceptions import AuthenticationException, UnsupportedDeviceException

Expand Down Expand Up @@ -59,41 +66,45 @@ async def test_type_unknown():
async def test_discover_single(discovery_data: dict, mocker, custom_port):
"""Make sure that discover_single returns an initialized SmartDevice instance."""
host = "127.0.0.1"
info = {"system": {"get_sysinfo": discovery_data["system"]["get_sysinfo"]}}
query_mock = mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=info)

def mock_discover(self):
self.datagram_received(
protocol.TPLinkSmartHomeProtocol.encrypt(json_dumps(discovery_data))[4:],
protocol.TPLinkSmartHomeProtocol.encrypt(json_dumps(info))[4:],
(host, custom_port or 9999),
)

mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover)
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)

x = await Discover.discover_single(host, port=custom_port)
assert issubclass(x.__class__, SmartDevice)
assert x._sys_info is not None
assert x.port == custom_port or x.port == 9999
assert (query_mock.call_count > 0) == isinstance(x, SmartStrip)


async def test_discover_single_hostname(discovery_data: dict, mocker):
"""Make sure that discover_single returns an initialized SmartDevice instance."""
host = "foobar"
ip = "127.0.0.1"
info = {"system": {"get_sysinfo": discovery_data["system"]["get_sysinfo"]}}
query_mock = mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=info)

def mock_discover(self):
self.datagram_received(
protocol.TPLinkSmartHomeProtocol.encrypt(json_dumps(discovery_data))[4:],
protocol.TPLinkSmartHomeProtocol.encrypt(json_dumps(info))[4:],
(ip, 9999),
)

mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover)
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
mocker.patch("socket.getaddrinfo", return_value=[(None, None, None, None, (ip, 0))])

x = await Discover.discover_single(host)
assert issubclass(x.__class__, SmartDevice)
assert x._sys_info is not None
assert x.host == host
assert (query_mock.call_count > 0) == isinstance(x, SmartStrip)

mocker.patch("socket.getaddrinfo", side_effect=socket.gaierror())
with pytest.raises(SmartDeviceException):
Expand All @@ -104,14 +115,15 @@ def mock_discover(self):
async def test_connect_single(discovery_data: dict, mocker, custom_port):
"""Make sure that connect_single returns an initialized SmartDevice instance."""
host = "127.0.0.1"
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
info = {"system": {"get_sysinfo": discovery_data["system"]["get_sysinfo"]}}
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=info)

dev = await Discover.connect_single(host, port=custom_port)
assert issubclass(dev.__class__, SmartDevice)
assert dev.port == custom_port or dev.port == 9999


async def test_connect_single_query_fails(discovery_data: dict, mocker):
async def test_connect_single_query_fails(mocker):
"""Make sure that connect_single fails when query fails."""
host = "127.0.0.1"
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", side_effect=SmartDeviceException)
Expand Down Expand Up @@ -211,7 +223,8 @@ async def test_discover_send(mocker):
async def test_discover_datagram_received(mocker, discovery_data):
"""Verify that datagram received fills discovered_devices."""
proto = _DiscoverProtocol()
mocker.patch("kasa.discover.json_loads", return_value=discovery_data)
info = {"system": {"get_sysinfo": discovery_data["system"]["get_sysinfo"]}}
mocker.patch("kasa.discover.json_loads", return_value=info)
mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "encrypt")
mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "decrypt")

Expand Down Expand Up @@ -287,10 +300,12 @@ def mock_discover(self):
AuthenticationException,
match="Failed to authenticate",
):
await Discover.discover_single(host)
device = await Discover.discover_single(host)
await device.update()

mocker.patch.object(SmartDevice, "update")
device = await Discover.discover_single(host)
await device.update()
assert device.device_type == DeviceType.Plug


Expand Down
19 changes: 19 additions & 0 deletions kasa/tests/test_smartdevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,16 @@ async def test_childrens(dev):
assert len(dev.children) == 0


async def test_children(dev):
"""Make sure that children property is exposed by every device."""
if dev.is_strip:
assert len(dev.children) > 0
assert dev.has_children is True
else:
assert len(dev.children) == 0
assert dev.has_children is False


async def test_internal_state(dev):
"""Make sure the internal state returns the last update results."""
assert dev.internal_state == dev._last_update
Expand Down Expand Up @@ -203,3 +213,12 @@ async def test_create_smart_device_with_timeout():
"""Make sure timeout is passed to the protocol."""
dev = SmartDevice(host="127.0.0.1", timeout=100)
assert dev.protocol.timeout == 100


async def test_modules_not_supported(dev: SmartDevice):
"""Test that unsupported modules do not break the device."""
for module in dev.modules.values():
assert module.is_supported is not None
await dev.update()
for module in dev.modules.values():
assert module.is_supported is not None