Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion kasa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
_deprecated_TPLinkSmartHomeProtocol, # noqa: F401
)
from kasa.module import Module
from kasa.protocol import BaseProtocol
from kasa.protocol import BaseProtocol, BaseTransport
from kasa.smartprotocol import SmartProtocol

__version__ = version("python-kasa")
Expand All @@ -50,6 +50,7 @@
__all__ = [
"Discover",
"BaseProtocol",
"BaseTransport",
"IotProtocol",
"SmartProtocol",
"LightState",
Expand Down
13 changes: 11 additions & 2 deletions kasa/cli/discover.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
Discover,
UnsupportedDeviceError,
)
from kasa.discover import DiscoveryResult
from kasa.discover import ConnectAttempt, DiscoveryResult

from .common import echo, error

Expand Down Expand Up @@ -165,8 +165,17 @@ async def config(ctx):

credentials = Credentials(username, password) if username and password else None

host_port = host + (f":{port}" if port else "")

def on_attempt(connect_attempt: ConnectAttempt, success: bool) -> None:
prot, tran, dev = connect_attempt
key_str = f"{prot.__name__} + {tran.__name__} + {dev.__name__}"
result = "succeeded" if success else "failed"
msg = f"Attempt to connect to {host_port} with {key_str} {result}"
echo(msg)

dev = await Discover.try_connect_all(
host, credentials=credentials, timeout=timeout, port=port
host, credentials=credentials, timeout=timeout, port=port, on_attempt=on_attempt
)
if dev:
cparams = dev.config.connection_type
Expand Down
8 changes: 5 additions & 3 deletions kasa/device_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def get_device_class_from_sys_info(sysinfo: dict[str, Any]) -> type[IotDevice]:


def get_device_class_from_family(
device_type: str, *, https: bool
device_type: str, *, https: bool, require_exact: bool = False
) -> type[Device] | None:
"""Return the device class from the type name."""
supported_device_types: dict[str, type[Device]] = {
Expand All @@ -185,8 +185,10 @@ def get_device_class_from_family(
}
lookup_key = f"{device_type}{'.HTTPS' if https else ''}"
if (
cls := supported_device_types.get(lookup_key)
) is None and device_type.startswith("SMART."):
(cls := supported_device_types.get(lookup_key)) is None
and device_type.startswith("SMART.")
and not require_exact
):
_LOGGER.warning("Unknown SMART device with %s, using SmartDevice", device_type)
cls = SmartDevice

Expand Down
46 changes: 38 additions & 8 deletions kasa/discover.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
import struct
from collections.abc import Awaitable
from pprint import pformat as pf
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast
from typing import TYPE_CHECKING, Any, Callable, Dict, NamedTuple, Optional, Type, cast

from aiohttp import ClientSession

Expand All @@ -118,6 +118,7 @@
TimeoutError,
UnsupportedDeviceError,
)
from kasa.experimental import Experimental
from kasa.iot.iotdevice import IotDevice
from kasa.iotprotocol import REDACTORS as IOT_REDACTORS
from kasa.json import dumps as json_dumps
Expand All @@ -127,9 +128,21 @@

_LOGGER = logging.getLogger(__name__)

if TYPE_CHECKING:
from kasa import BaseProtocol, BaseTransport


class ConnectAttempt(NamedTuple):
"""Try to connect attempt."""

protocol: type
transport: type
device: type


OnDiscoveredCallable = Callable[[Device], Awaitable[None]]
OnUnsupportedCallable = Callable[[UnsupportedDeviceError], Awaitable[None]]
OnConnectAttemptCallable = Callable[[ConnectAttempt, bool], None]
DeviceDict = Dict[str, Device]

NEW_DISCOVERY_REDACTORS: dict[str, Callable[[Any], Any] | None] = {
Expand Down Expand Up @@ -535,6 +548,7 @@
timeout: int | None = None,
credentials: Credentials | None = None,
http_client: ClientSession | None = None,
on_attempt: OnConnectAttemptCallable | None = None,
) -> Device | None:
"""Try to connect directly to a device with all possible parameters.

Expand All @@ -551,13 +565,22 @@
"""
from .device_factory import _connect

candidates = {
main_device_families = {
Device.Family.SmartTapoPlug,
Device.Family.IotSmartPlugSwitch,
}
if Experimental.enabled():
main_device_families.add(Device.Family.SmartIpCamera)

Check warning on line 573 in kasa/discover.py

View check run for this annotation

Codecov / codecov/patch

kasa/discover.py#L573

Added line #L573 was not covered by tests
candidates: dict[
tuple[type[BaseProtocol], type[BaseTransport], type[Device]],
tuple[BaseProtocol, DeviceConfig],
] = {
(type(protocol), type(protocol._transport), device_class): (
protocol,
config,
)
for encrypt in Device.EncryptionType
for device_family in Device.Family
for device_family in main_device_families
for https in (True, False)
if (
conn_params := DeviceConnectionParameters(
Expand All @@ -580,19 +603,26 @@
and (protocol := get_protocol(config))
and (
device_class := get_device_class_from_family(
device_family.value, https=https
device_family.value, https=https, require_exact=True
)
)
}
for protocol, config in candidates.values():
for key, val in candidates.items():
try:
dev = await _connect(config, protocol)
prot, config = val
dev = await _connect(config, prot)
except Exception:
_LOGGER.debug("Unable to connect with %s", protocol)
_LOGGER.debug("Unable to connect with %s", prot)
if on_attempt:
ca = tuple.__new__(ConnectAttempt, key)
on_attempt(ca, False)
else:
if on_attempt:
ca = tuple.__new__(ConnectAttempt, key)
on_attempt(ca, True)
return dev
finally:
await protocol.close()
await prot.close()
return None

@staticmethod
Expand Down
10 changes: 9 additions & 1 deletion kasa/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,7 +1162,7 @@ async def test_cli_child_commands(
async def test_discover_config(dev: Device, mocker, runner):
"""Test that device config is returned."""
host = "127.0.0.1"
mocker.patch("kasa.discover.Discover.try_connect_all", return_value=dev)
mocker.patch("kasa.device_factory._connect", side_effect=[Exception, dev])

res = await runner.invoke(
cli,
Expand All @@ -1182,6 +1182,14 @@ async def test_discover_config(dev: Device, mocker, runner):
cparam = dev.config.connection_type
expected = f"--device-family {cparam.device_family.value} --encrypt-type {cparam.encryption_type.value} {'--https' if cparam.https else '--no-https'}"
assert expected in res.output
assert re.search(
r"Attempt to connect to 127\.0\.0\.1 with \w+ \+ \w+ \+ \w+ failed",
res.output.replace("\n", ""),
)
assert re.search(
r"Attempt to connect to 127\.0\.0\.1 with \w+ \+ \w+ \+ \w+ succeeded",
res.output.replace("\n", ""),
)


async def test_discover_config_invalid(mocker, runner):
Expand Down
Loading