Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions kasa/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Mapping, Sequence
from typing import Any, Mapping, Sequence, overload

from .credentials import Credentials
from .device_type import DeviceType
Expand All @@ -15,7 +15,7 @@
from .exceptions import KasaException
from .feature import Feature
from .iotprotocol import IotProtocol
from .module import Module
from .module import Module, ModuleT
from .protocol import BaseProtocol
from .xortransport import XorTransport

Expand Down Expand Up @@ -116,6 +116,18 @@ async def disconnect(self):
def modules(self) -> Mapping[str, Module]:
"""Return the device modules."""

@overload
@abstractmethod
def get_module(self, module_type: type[ModuleT]) -> ModuleT | None: ...

@overload
@abstractmethod
def get_module(self, module_type: str) -> Module | None: ...

@abstractmethod
def get_module(self, module_type: type[ModuleT] | str) -> ModuleT | Module | None:
"""Return the module from the device modules or None if not present."""

@property
@abstractmethod
def is_on(self) -> bool:
Expand Down
23 changes: 22 additions & 1 deletion kasa/iot/iotdevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
import inspect
import logging
from datetime import datetime, timedelta
from typing import Any, Mapping, Sequence, cast
from typing import Any, Mapping, Sequence, cast, overload

from ..device import Device, WifiNetwork
from ..deviceconfig import DeviceConfig
from ..emeterstatus import EmeterStatus
from ..exceptions import KasaException
from ..feature import Feature
from ..module import ModuleT
from ..protocol import BaseProtocol
from .iotmodule import IotModule
from .modules import Emeter, Time
Expand Down Expand Up @@ -201,6 +202,26 @@ def modules(self) -> dict[str, IotModule]:
"""Return the device modules."""
return self._modules

@overload
def get_module(self, module_type: type[ModuleT]) -> ModuleT | None: ...

@overload
def get_module(self, module_type: str) -> IotModule | None: ...

def get_module(
self, module_type: type[ModuleT] | str
) -> ModuleT | IotModule | None:
"""Return the module from the device modules or None if not present."""
if isinstance(module_type, str):
module_name = module_type.lower()
elif issubclass(module_type, IotModule):
module_name = module_type.__name__.lower()
else:
return None
if module_name in self.modules:
return self.modules[module_name]
return None

def add_module(self, name: str, module: IotModule):
"""Register a module."""
if name in self.modules:
Expand Down
7 changes: 6 additions & 1 deletion kasa/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from typing import (
TYPE_CHECKING,
TypeVar,
)

from .exceptions import KasaException
from .feature import Feature
Expand All @@ -14,6 +17,8 @@

_LOGGER = logging.getLogger(__name__)

ModuleT = TypeVar("ModuleT", bound="Module")


class Module(ABC):
"""Base class implemention for all modules.
Expand Down
22 changes: 18 additions & 4 deletions kasa/smart/smartdevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import base64
import logging
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, Mapping, Sequence, cast
from typing import Any, Mapping, Sequence, cast, overload

from ..aestransport import AesTransport
from ..bulb import HSV, Bulb, BulbPreset, ColorTempRange
Expand All @@ -16,6 +16,7 @@
from ..exceptions import AuthenticationError, DeviceError, KasaException, SmartErrorCode
from ..fan import Fan
from ..feature import Feature
from ..module import ModuleT
from ..smartprotocol import SmartProtocol
from .modules import (
Brightness,
Expand All @@ -28,11 +29,10 @@
Firmware,
TimeModule,
)
from .smartmodule import SmartModule

_LOGGER = logging.getLogger(__name__)

if TYPE_CHECKING:
from .smartmodule import SmartModule

# List of modules that wall switches with children, i.e. ks240 report on
# the child but only work on the parent. See longer note below in _initialize_modules.
Expand Down Expand Up @@ -305,8 +305,22 @@ async def _initialize_features(self):
for feat in module._module_features.values():
self._add_feature(feat)

def get_module(self, module_name) -> SmartModule | None:
@overload
def get_module(self, module_type: type[ModuleT]) -> ModuleT | None: ...

@overload
def get_module(self, module_type: str) -> SmartModule | None: ...

def get_module(
self, module_type: type[ModuleT] | str
) -> ModuleT | SmartModule | None:
"""Return the module from the device modules or None if not present."""
if isinstance(module_type, str):
module_name = module_type
elif issubclass(module_type, SmartModule):
module_name = module_type.__name__
else:
return None
if module_name in self.modules:
return self.modules[module_name]
elif self._exposes_child_modules:
Expand Down
2 changes: 1 addition & 1 deletion kasa/tests/smart/features/test_brightness.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async def test_brightness_component(dev: SmartDevice):


@dimmable
async def test_brightness_dimmable(dev: SmartDevice):
async def test_brightness_dimmable(dev: IotDevice):
"""Test brightness feature."""
assert isinstance(dev, IotDevice)
assert "brightness" in dev.sys_info or bool(dev.sys_info["is_dimmable"])
Expand Down
9 changes: 4 additions & 5 deletions kasa/tests/smart/modules/test_fan.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import cast

import pytest
from pytest_mock import MockerFixture

Expand All @@ -13,7 +11,7 @@
@fan
async def test_fan_speed(dev: SmartDevice, mocker: MockerFixture):
"""Test fan speed feature."""
fan = cast(FanModule, dev.get_module("FanModule"))
fan = dev.get_module(FanModule)
assert fan

level_feature = fan._module_features["fan_speed_level"]
Expand All @@ -38,7 +36,7 @@ async def test_fan_speed(dev: SmartDevice, mocker: MockerFixture):
@fan
async def test_sleep_mode(dev: SmartDevice, mocker: MockerFixture):
"""Test sleep mode feature."""
fan = cast(FanModule, dev.get_module("FanModule"))
fan = dev.get_module(FanModule)
assert fan
sleep_feature = fan._module_features["fan_sleep_mode"]
assert isinstance(sleep_feature.value, bool)
Expand All @@ -57,7 +55,8 @@ async def test_sleep_mode(dev: SmartDevice, mocker: MockerFixture):
async def test_fan_interface(dev: SmartDevice, mocker: MockerFixture):
"""Test fan speed on device interface."""
assert isinstance(dev, SmartDevice)
fan = cast(FanModule, dev.get_module("FanModule"))
fan = dev.get_module(FanModule)
assert fan
device = fan._device
assert device.is_fan

Expand Down
29 changes: 28 additions & 1 deletion kasa/tests/test_iotdevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from kasa import KasaException
from kasa.iot import IotDevice

from .conftest import handle_turn_on, turn_on
from .conftest import get_device_for_fixture_protocol, handle_turn_on, turn_on
from .device_fixtures import device_iot, has_emeter_iot, no_emeter_iot
from .fakeprotocol_iot import FakeIotProtocol

Expand Down Expand Up @@ -258,3 +258,30 @@ async def test_modules_not_supported(dev: IotDevice):
await dev.update()
for module in dev.modules.values():
assert module.is_supported is not None


async def test_get_modules():
"""Test get_modules for child and parent modules."""
dummy_device = await get_device_for_fixture_protocol(
"HS100(US)_2.0_1.5.6.json", "IOT"
)
from kasa.iot.modules import Cloud
from kasa.smart.modules import CloudModule

# Modules on device
module = dummy_device.get_module("Cloud")
assert module
assert module._device == dummy_device
assert isinstance(module, Cloud)

module = dummy_device.get_module(Cloud)
assert module
assert module._device == dummy_device
assert isinstance(module, Cloud)

# Invalid modules
module = dummy_device.get_module("DummyModule")
assert module is None

module = dummy_device.get_module(CloudModule)
assert module is None
22 changes: 21 additions & 1 deletion kasa/tests/test_smartdevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,23 +122,43 @@ async def test_update_module_queries(dev: SmartDevice, mocker: MockerFixture):
spies[device].assert_not_called()


async def test_get_modules(mocker):
async def test_get_modules():
"""Test get_modules for child and parent modules."""
dummy_device = await get_device_for_fixture_protocol(
"KS240(US)_1.0_1.0.5.json", "SMART"
)
from kasa.iot.modules import AmbientLight
from kasa.smart.modules import CloudModule, FanModule

# Modules on device
module = dummy_device.get_module("CloudModule")
assert module
assert module._device == dummy_device
assert isinstance(module, CloudModule)

module = dummy_device.get_module(CloudModule)
assert module
assert module._device == dummy_device
assert isinstance(module, CloudModule)

# Modules on child
module = dummy_device.get_module("FanModule")
assert module
assert module._device != dummy_device
assert module._device._parent == dummy_device

module = dummy_device.get_module(FanModule)
assert module
assert module._device != dummy_device
assert module._device._parent == dummy_device

# Invalid modules
module = dummy_device.get_module("DummyModule")
assert module is None

module = dummy_device.get_module(AmbientLight)
assert module is None


@bulb_smart
async def test_smartdevice_brightness(dev: SmartDevice):
Expand Down