Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions kasa/cli/hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""Hub-specific commands."""

import asyncio

import asyncclick as click

from kasa import DeviceType, Module, SmartDevice
from kasa.smart import SmartChildDevice

from .common import (
echo,
error,
pass_dev,
)


def pretty_category(cat: str):
"""Return pretty category for paired devices."""
return SmartChildDevice.CHILD_DEVICE_TYPE_MAP.get(cat)


@click.group()
@pass_dev
async def hub(dev: SmartDevice):
"""Commands controlling hub child device pairing."""
if dev.device_type is not DeviceType.Hub:
error(f"{dev} is not a hub.")

if dev.modules.get(Module.ChildSetup) is None:
error(f"{dev} does not have child setup module.")


@hub.command(name="list")
@pass_dev
async def hub_list(dev: SmartDevice):
"""List hub paired child devices."""
for c in dev.children:
echo(f"{c.device_id}: {c}")


@hub.command(name="supported")
@pass_dev
async def hub_supported(dev: SmartDevice):
"""List supported hub child device categories."""
cs = dev.modules[Module.ChildSetup]

cats = [cat["category"] for cat in await cs.get_supported_device_categories()]
for cat in cats:
echo(f"Supports: {cat}")


@hub.command(name="pair")
@click.option("--timeout", default=10)
@pass_dev
async def hub_pair(dev: SmartDevice, timeout: int):
"""Pair all pairable device.

This will pair any child devices currently in pairing mode.
"""
cs = dev.modules[Module.ChildSetup]

echo(f"Finding new devices for {timeout} seconds...")

pair_res = await cs.pair(timeout=timeout)
if not pair_res:
echo("No devices found.")

for child in pair_res:
echo(
f'Paired {child["name"]} ({child["device_model"]}, '
f'{pretty_category(child["category"])}) with id {child["device_id"]}'
)


@hub.command(name="unpair")
@click.argument("device_id")
@pass_dev
async def hub_unpair(dev, device_id: str):
"""Unpair given device."""
cs = dev.modules[Module.ChildSetup]

# Accessing private here, as the property exposes only values
if device_id not in dev._children:
error(f"{dev} does not have children with identifier {device_id}")

res = await cs.unpair(device_id=device_id)
# Give the device some time to update its internal state, just in case.
await asyncio.sleep(1)
await dev.update()

if device_id not in dev._children:
echo(f"Unpaired {device_id}")
else:
error(f"Failed to unpair {device_id}")

return res
1 change: 1 addition & 0 deletions kasa/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def _legacy_type_to_class(_type: str) -> Any:
"hsv": "light",
"temperature": "light",
"effect": "light",
"hub": "hub",
},
result_callback=json_formatter_cb,
)
Expand Down
3 changes: 2 additions & 1 deletion kasa/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@

if TYPE_CHECKING:
from .device import Device
from .module import Module

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -142,7 +143,7 @@ class Category(Enum):
#: Callable coroutine or name of the method that allows changing the value
attribute_setter: str | Callable[..., Coroutine[Any, Any, Any]] | None = None
#: Container storing the data, this overrides 'device' for getters
container: Any = None
container: Device | Module | None = None
#: Icon suggestion
icon: str | None = None
#: Attribute containing the name of the unit getter property.
Expand Down
1 change: 1 addition & 0 deletions kasa/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ class Module(ABC):
)
ChildLock: Final[ModuleName[smart.ChildLock]] = ModuleName("ChildLock")
TriggerLogs: Final[ModuleName[smart.TriggerLogs]] = ModuleName("TriggerLogs")
ChildSetup: Final[ModuleName[smart.ChildSetup]] = ModuleName("ChildSetup")

HomeKit: Final[ModuleName[smart.HomeKit]] = ModuleName("HomeKit")
Matter: Final[ModuleName[smart.Matter]] = ModuleName("Matter")
Expand Down
2 changes: 2 additions & 0 deletions kasa/smart/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .childdevice import ChildDevice
from .childlock import ChildLock
from .childprotection import ChildProtection
from .childsetup import ChildSetup
from .clean import Clean
from .cloud import Cloud
from .color import Color
Expand Down Expand Up @@ -47,6 +48,7 @@
"DeviceModule",
"ChildDevice",
"ChildLock",
"ChildSetup",
"BatterySensor",
"HumiditySensor",
"TemperatureSensor",
Expand Down
84 changes: 84 additions & 0 deletions kasa/smart/modules/childsetup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""Implementation for child device setup.

This module allows pairing and disconnecting child devices.
"""

from __future__ import annotations

import asyncio
import logging

from ...feature import Feature
from ..smartmodule import SmartModule

_LOGGER = logging.getLogger(__name__)


class ChildSetup(SmartModule):
"""Implementation for child device setup."""

REQUIRED_COMPONENT = "child_quick_setup"
QUERY_GETTER_NAME = "get_support_child_device_category"

def _initialize_features(self) -> None:
"""Initialize features."""
self._add_feature(
Feature(
self._device,
id="pair",
name="Pair",
container=self,
attribute_setter="pair",
category=Feature.Category.Config,
type=Feature.Type.Action,
)
)

async def get_supported_device_categories(self) -> list[dict]:
"""Get supported device categories."""
categories = await self.call("get_support_child_device_category")
return categories["get_support_child_device_category"]["device_category_list"]

async def pair(self, *, timeout: int = 10) -> list[dict]:
"""Scan for new devices and pair after discovering first new device."""
await self.call("begin_scanning_child_device")

_LOGGER.info("Waiting %s seconds for discovering new devices", timeout)
await asyncio.sleep(timeout)
detected = await self._get_detected_devices()

if not detected["child_device_list"]:
_LOGGER.info("No devices found.")
return []

_LOGGER.info(
"Discovery done, found %s devices: %s",
len(detected["child_device_list"]),
detected,
)

await self._add_devices(detected)

return detected["child_device_list"]

async def unpair(self, device_id: str) -> dict:
"""Remove device from the hub."""
_LOGGER.debug("Going to unpair %s from %s", device_id, self)

payload = {"child_device_list": [{"device_id": device_id}]}
return await self.call("remove_child_device_list", payload)

async def _add_devices(self, devices: dict) -> dict:
"""Add devices based on get_detected_device response.

Pass the output from :ref:_get_detected_devices: as a parameter.
"""
res = await self.call("add_child_device_list", devices)
return res

async def _get_detected_devices(self) -> dict:
"""Return list of devices detected during scanning."""
param = {"scan_list": await self.get_supported_device_categories()}
res = await self.call("get_scan_child_device_list", param)
_LOGGER.debug("Scan status: %s", res)
return res["get_scan_child_device_list"]
15 changes: 15 additions & 0 deletions kasa/smart/smartdevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,21 @@ async def _initialize_features(self) -> None:
)
)

if self.parent is not None and (
cs := self.parent.modules.get(Module.ChildSetup)
):
self._add_feature(
Feature(
device=self,
id="unpair",
name="Unpair device",
container=cs,
attribute_setter=lambda: cs.unpair(self.device_id),
category=Feature.Category.Debug,
type=Feature.Type.Action,
)
)

for module in self.modules.values():
module._initialize_features()
for feat in module._module_features.values():
Expand Down
Empty file added tests/cli/__init__.py
Empty file.
53 changes: 53 additions & 0 deletions tests/cli/test_hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pytest
from pytest_mock import MockerFixture

from kasa import DeviceType, Module
from kasa.cli.hub import hub

from ..device_fixtures import HUBS_SMART, hubs_smart, parametrize, plug_iot


@hubs_smart
async def test_hub_pair(dev, mocker: MockerFixture, runner, caplog):
"""Test that pair calls the expected methods."""
cs = dev.modules.get(Module.ChildSetup)
# Patch if the device supports the module
if cs is not None:
mock_pair = mocker.patch.object(cs, "pair")

res = await runner.invoke(hub, ["pair"], obj=dev, catch_exceptions=False)
if cs is None:
assert "is not a hub" in res.output
return

mock_pair.assert_awaited()
assert "Finding new devices for 10 seconds" in res.output
assert res.exit_code == 0


@parametrize("hubs smart", model_filter=HUBS_SMART, protocol_filter={"SMART"})
async def test_hub_unpair(dev, mocker: MockerFixture, runner):
"""Test that unpair calls the expected method."""
if not dev.children:
pytest.skip("Cannot test without child devices")

id_ = next(iter(dev.children)).device_id

cs = dev.modules.get(Module.ChildSetup)
mock_unpair = mocker.spy(cs, "unpair")

res = await runner.invoke(hub, ["unpair", id_], obj=dev, catch_exceptions=False)

mock_unpair.assert_awaited()
assert f"Unpaired {id_}" in res.output
assert res.exit_code == 0


@plug_iot
async def test_non_hub(dev, mocker: MockerFixture, runner):
"""Test that hub commands return an error if executed on a non-hub."""
assert dev.device_type is not DeviceType.Hub
res = await runner.invoke(
hub, ["unpair", "dummy_id"], obj=dev, catch_exceptions=False
)
assert "is not a hub" in res.output
13 changes: 13 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from __future__ import annotations

import asyncio
import os
import sys
import warnings
from pathlib import Path
from unittest.mock import MagicMock, patch

import pytest

# TODO: this and runner fixture could be moved to tests/cli/conftest.py
from asyncclick.testing import CliRunner

from kasa import (
DeviceConfig,
SmartProtocol,
Expand Down Expand Up @@ -149,3 +153,12 @@ async def _create_datagram_endpoint(protocol_factory, *_, **__):
side_effect=_create_datagram_endpoint,
):
yield


@pytest.fixture
def runner():
"""Runner fixture that unsets the KASA_ environment variables for tests."""
KASA_VARS = {k: None for k, v in os.environ.items() if k.startswith("KASA_")}
runner = CliRunner(env=KASA_VARS)

return runner
32 changes: 30 additions & 2 deletions tests/fakeprotocol_smart.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,16 @@ def credentials_hash(self):
"setup_payload": "00:0000000-0000.00.000",
},
),
# child setup
"get_support_child_device_category": (
"child_quick_setup",
{"device_category_list": [{"category": "subg.trv"}]},
),
# no devices found
"get_scan_child_device_list": (
"child_quick_setup",
{"child_device_list": [{"dummy": "response"}], "scan_status": "idle"},
),
}

def _missing_result(self, method):
Expand Down Expand Up @@ -548,6 +558,17 @@ def _update_sysinfo_key(self, info: dict, key: str, value: str) -> dict:

return {"error_code": 0}

def _hub_remove_device(self, info, params):
"""Remove hub device."""
items_to_remove = [dev["device_id"] for dev in params["child_device_list"]]
children = info["get_child_device_list"]["child_device_list"]
new_children = [
dev for dev in children if dev["device_id"] not in items_to_remove
]
info["get_child_device_list"]["child_device_list"] = new_children

return {"error_code": 0}

def get_child_device_queries(self, method, params):
return self._get_method_from_info(method, params)

Expand Down Expand Up @@ -658,8 +679,15 @@ async def _send_request(self, request_dict: dict):
return self._set_on_off_gradually_info(info, params)
elif method == "set_child_protection":
return self._update_sysinfo_key(info, "child_protection", params["enable"])
# Vacuum special actions
elif method in ["playSelectAudio"]:
elif method == "remove_child_device_list":
return self._hub_remove_device(info, params)
# actions
elif method in [
"begin_scanning_child_device", # hub pairing
"add_child_device_list", # hub pairing
"remove_child_device_list", # hub pairing
"playSelectAudio", # vacuum special actions
]:
return {"error_code": 0}
elif method[:3] == "set":
target_method = f"get{method[3:]}"
Expand Down
Loading
Loading