Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 1 addition & 10 deletions kasa/smart/modules/childdevicemodule.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Implementation for child devices."""
from typing import Dict

from ..smartmodule import SmartModule

Expand All @@ -8,12 +7,4 @@ class ChildDeviceModule(SmartModule):
"""Implementation for child devices."""

REQUIRED_COMPONENT = "child_device"

def query(self) -> Dict:
"""Query to execute during the update cycle."""
# TODO: There is no need to fetch the component list every time,
# so this should be optimized only for the init.
return {
"get_child_device_list": None,
"get_child_device_component_list": None,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_child_device_component_list might be needed back in if we want to support device pairing w/o restarts in the future.

}
QUERY_GETTER_NAME = "get_child_device_list"
43 changes: 29 additions & 14 deletions kasa/smart/smartdevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,18 @@ def __init__(
self.modules: Dict[str, "SmartModule"] = {}
self._parent: Optional["SmartDevice"] = None
self._children: Mapping[str, "SmartDevice"] = {}
self._last_update = {}

async def _initialize_children(self):
"""Initialize children for power strips."""
children = self.internal_state["child_info"]["child_device_list"]
child_info_query = {
"get_child_device_component_list": None,
"get_child_device_list": None,
}
resp = await self.protocol.query(child_info_query)
self.internal_state.update(resp)

children = self.internal_state["get_child_device_list"]["child_device_list"]
children_components = {
child["device_id"]: {
comp["id"]: int(comp["ver_code"]) for comp in child["component_list"]
Expand Down Expand Up @@ -88,13 +96,30 @@ def _try_get_response(self, responses: dict, request: str, default=None) -> dict
)

async def _negotiate(self):
resp = await self.protocol.query("component_nego")
"""Perform initialization.

We fetch the device info and the available components as early as possible.
If the device reports supporting child devices, they are also initialized.
"""
initial_query = {"component_nego": None, "get_device_info": None}
resp = await self.protocol.query(initial_query)

# Save the initial state to allow modules access the device info already
# during the initialization, which is necessary as some information like the
# supported color temperature range is contained within the response.
self._last_update.update(resp)
self._info = self._try_get_response(resp, "get_device_info")

# Create our internal presentation of available components
self._components_raw = resp["component_nego"]
self._components = {
comp["id"]: int(comp["ver_code"])
for comp in self._components_raw["component_list"]
}

if "child_device" in self._components and not self.children:
await self._initialize_children()

async def update(self, update_children: bool = True):
"""Update the device."""
if self.credentials is None and self.credentials_hash is None:
Expand All @@ -110,20 +135,10 @@ async def update(self, update_children: bool = True):
for module in self.modules.values():
req.update(module.query())

resp = await self.protocol.query(req)
self._last_update = resp = await self.protocol.query(req)

self._info = self._try_get_response(resp, "get_device_info")

self._last_update = {
"components": self._components_raw,
**resp,
"child_info": self._try_get_response(resp, "get_child_device_list", {}),
}

if child_info := self._last_update.get("child_info"):
if not self.children:
await self._initialize_children()

if child_info := self._try_get_response(resp, "get_child_device_list", {}):
# TODO: we don't currently perform queries on children based on modules,
# but just update the information that is returned in the main query.
for info in child_info["child_device_list"]:
Expand Down
2 changes: 1 addition & 1 deletion kasa/tests/test_childdevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_childdevice_init(dev, dummy_protocol, mocker):
@strip_smart
async def test_childdevice_update(dev, dummy_protocol, mocker):
"""Test that parent update updates children."""
child_info = dev._last_update["child_info"]
child_info = dev.internal_state["get_child_device_list"]
child_list = child_info["child_device_list"]

assert len(dev.children) == child_info["sum"]
Expand Down
79 changes: 73 additions & 6 deletions kasa/tests/test_smartdevice.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Tests for SMART devices."""
import logging
from unittest.mock import patch
from typing import Any, Dict

import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
import pytest
from pytest_mock import MockerFixture

from kasa import KasaException
from kasa.exceptions import SmartErrorCode
Expand All @@ -25,13 +26,79 @@ async def test_try_get_response(dev: SmartDevice, caplog):


@device_smart
async def test_update_no_device_info(dev: SmartDevice):
async def test_update_no_device_info(dev: SmartDevice, mocker: MockerFixture):
mock_response: dict = {
"get_device_usage": {},
"get_device_time": {},
}
msg = f"get_device_info not found in {mock_response} for device 127.0.0.123"
with patch.object(dev.protocol, "query", return_value=mock_response), pytest.raises(
KasaException, match=msg
):
with mocker.patch.object(
dev.protocol, "query", return_value=mock_response
), pytest.raises(KasaException, match=msg):
await dev.update()


@device_smart
async def test_initial_update(dev: SmartDevice, mocker: MockerFixture):
"""Test the initial update cycle."""
# As the fixture data is already initialized, we reset the state for testing
dev._components_raw = None
dev._features = {}

negotiate = mocker.spy(dev, "_negotiate")
initialize_modules = mocker.spy(dev, "_initialize_modules")
initialize_features = mocker.spy(dev, "_initialize_features")

# Perform two updates and verify that initialization is only done once
await dev.update()
await dev.update()

negotiate.assert_called_once()
assert dev._components_raw is not None
initialize_modules.assert_called_once()
assert dev.modules
initialize_features.assert_called_once()
assert dev.features


@device_smart
async def test_negotiate(dev: SmartDevice, mocker: MockerFixture):
"""Test that the initial negotiation performs expected steps."""
# As the fixture data is already initialized, we reset the state for testing
dev._components_raw = None
dev._children = {}

query = mocker.spy(dev.protocol, "query")
initialize_children = mocker.spy(dev, "_initialize_children")
await dev._negotiate()

# Check that we got the initial negotiation call
query.assert_any_call({"component_nego": None, "get_device_info": None})
assert dev._components_raw

# Check the children are created, if device supports them
if "child_device" in dev._components:
initialize_children.assert_called_once()
query.assert_any_call(
{
"get_child_device_component_list": None,
"get_child_device_list": None,
}
)
assert len(dev.children) == dev.internal_state["get_child_device_list"]["sum"]


@device_smart
async def test_update_module_queries(dev: SmartDevice, mocker: MockerFixture):
"""Test that the regular update uses queries from all supported modules."""
query = mocker.spy(dev.protocol, "query")

# We need to have some modules initialized by now
assert dev.modules

await dev.update()
full_query: Dict[str, Any] = {}
for mod in dev.modules.values():
full_query |= mod.query()

query.assert_called_with(full_query)