Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion kasa/cli/discover.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)
from kasa.discover import DiscoveryResult

from .common import echo
from .common import echo, error


@click.group(invoke_without_command=True)
Expand Down Expand Up @@ -145,6 +145,41 @@ async def _discover(ctx, print_discovered, print_unsupported, *, do_echo=True):
return discovered_devices


@discover.command()
@click.pass_context
async def config(ctx):
"""Bypass udp discovery and try to show connection config for a device.

Bypasses udp discovery and shows the parameters required to connect
directly to the device.
"""
params = ctx.parent.parent.params
username = params["username"]
password = params["password"]
timeout = params["timeout"]
host = params["host"]
port = params["port"]

if not host:
error("--host option must be supplied to discover config")

credentials = Credentials(username, password) if username and password else None

dev = await Discover.try_connect_all(
host, credentials=credentials, timeout=timeout, port=port
)
if dev:
cparams = dev.config.connection_type
echo("Managed to connect, cli options to connect are:")
echo(
f"--device-family {cparams.device_family.value} "
f"--encrypt-type {cparams.encryption_type.value} "
f"{'--https' if cparams.https else '--no-https'}"
)
else:
error(f"Unable to connect to {host}")


def _echo_dictionary(discovery_info: dict):
echo("\t[bold]== Discovery information ==[/bold]")
for key, value in discovery_info.items():
Expand Down
6 changes: 5 additions & 1 deletion kasa/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
]

ENCRYPT_TYPES = [encrypt_type.value for encrypt_type in DeviceEncryptionType]
DEFAULT_TARGET = "255.255.255.255"


def _legacy_type_to_class(_type):
Expand Down Expand Up @@ -115,7 +116,7 @@ def _legacy_type_to_class(_type):
@click.option(
"--target",
envvar="KASA_TARGET",
default="255.255.255.255",
default=DEFAULT_TARGET,
required=False,
show_default=True,
help="The broadcast address to be used for discovery.",
Expand Down Expand Up @@ -256,6 +257,9 @@ async def cli(
ctx.obj = object()
return

if target != DEFAULT_TARGET and host:
error("--target is not a valid option for single host discovery")

if experimental:
from kasa.experimental.enabled import Enabled

Expand Down
60 changes: 60 additions & 0 deletions kasa/discover.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,66 @@ async def discover_single(
else:
raise TimeoutError(f"Timed out getting discovery response for {host}")

@staticmethod
async def try_connect_all(
host: str,
*,
port: int | None = None,
timeout: int | None = None,
credentials: Credentials | None = None,
) -> Device | None:
"""Try to connect directly to a device with all possible parameters.

This method can be used when udp is not working due to network issues.
After succesfully connecting use the device config and
:meth:`Device.connect()` for future connections.

:param host: Hostname of device to query
:param port: Optionally set a different port for legacy devices using port 9999
:param timeout: Timeout in seconds device for devices queries
:param credentials: Credentials for devices that require authentication.
username and password are ignored if provided.
"""
from .device_factory import _connect

candidates = {
(type(protocol), type(protocol._transport), device_class): (
protocol,
config,
)
for encrypt in Device.EncryptionType
for device_family in Device.Family
for https in (True, False)
if (
conn_params := DeviceConnectionParameters(
device_family=device_family,
encryption_type=encrypt,
https=https,
)
)
and (
config := DeviceConfig(
host=host,
connection_type=conn_params,
timeout=timeout,
port_override=port,
credentials=credentials,
)
)
and (protocol := get_protocol(config))
and (device_class := get_device_class_from_family(device_family.value))
}
for protocol, config in candidates.values():
try:
dev = await _connect(config, protocol)
except Exception:
_LOGGER.debug("Unable to connect with %s", protocol)
else:
return dev
finally:
await protocol.close()
return None

@staticmethod
def _get_device_class(info: dict) -> type[Device]:
"""Find SmartDevice subclass for device described by passed data."""
Expand Down
75 changes: 75 additions & 0 deletions kasa/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1158,3 +1158,78 @@ async def test_cli_child_commands(
assert res.exit_code == 0
parent_update_spy.assert_called_once()
assert dev.children[0].update == child_update_method


async def test_discover_config(dev: Device, mocker, runner):
"""Test that device config is returned."""
host = "127.0.0.1"
mocker.patch("kasa.discover.Discover.try_connect_all", return_value=dev)

res = await runner.invoke(
cli,
[
"--username",
"foo",
"--password",
"bar",
"--host",
host,
"discover",
"config",
],
catch_exceptions=False,
)
assert res.exit_code == 0
cparam = dev.config.connection_type
expected = f"--device-family {cparam.device_family.value} --encrypt-type {cparam.encryption_type.value} {'--https' if cparam.https else '--no-https'}"
assert expected in res.output


async def test_discover_config_invalid(mocker, runner):
"""Test the device config command with invalids."""
host = "127.0.0.1"
mocker.patch("kasa.discover.Discover.try_connect_all", return_value=None)

res = await runner.invoke(
cli,
[
"--username",
"foo",
"--password",
"bar",
"--host",
host,
"discover",
"config",
],
catch_exceptions=False,
)
assert res.exit_code == 1
assert f"Unable to connect to {host}" in res.output

res = await runner.invoke(
cli,
["--username", "foo", "--password", "bar", "discover", "config"],
catch_exceptions=False,
)
assert res.exit_code == 1
assert "--host option must be supplied to discover config" in res.output

res = await runner.invoke(
cli,
[
"--username",
"foo",
"--password",
"bar",
"--host",
host,
"--target",
"127.0.0.2",
"discover",
"config",
],
catch_exceptions=False,
)
assert res.exit_code == 1
assert "--target is not a valid option for single host discovery" in res.output
56 changes: 55 additions & 1 deletion kasa/tests/test_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,15 @@
Device,
DeviceType,
Discover,
IotProtocol,
KasaException,
)
from kasa.aestransport import AesEncyptionSession
from kasa.device_factory import (
get_device_class_from_family,
get_device_class_from_sys_info,
get_protocol,
)
from kasa.deviceconfig import (
DeviceConfig,
DeviceConnectionParameters,
Expand All @@ -35,7 +41,7 @@
)
from kasa.exceptions import AuthenticationError, UnsupportedDeviceError
from kasa.iot import IotDevice
from kasa.xortransport import XorEncryption
from kasa.xortransport import XorEncryption, XorTransport

from .conftest import (
bulb_iot,
Expand Down Expand Up @@ -647,3 +653,51 @@ async def test_discovery_decryption():
dr = DiscoveryResult(**info)
Discover._decrypt_discovery_data(dr)
assert dr.decrypted_data == data_dict


async def test_discover_try_connect_all(discovery_mock, mocker):
"""Test that device update is called on main."""
if "result" in discovery_mock.discovery_data:
dev_class = get_device_class_from_family(discovery_mock.device_type)
cparams = DeviceConnectionParameters.from_values(
discovery_mock.device_type,
discovery_mock.encrypt_type,
discovery_mock.login_version,
False,
)
protocol = get_protocol(
DeviceConfig(discovery_mock.ip, connection_type=cparams)
)
protocol_class = protocol.__class__
transport_class = protocol._transport.__class__
else:
dev_class = get_device_class_from_sys_info(discovery_mock.discovery_data)
protocol_class = IotProtocol
transport_class = XorTransport

async def _query(self, *args, **kwargs):
if (
self.__class__ is protocol_class
and self._transport.__class__ is transport_class
):
return discovery_mock.query_data
raise KasaException()

async def _update(self, *args, **kwargs):
if (
self.protocol.__class__ is protocol_class
and self.protocol._transport.__class__ is transport_class
):
return
raise KasaException()

mocker.patch("kasa.IotProtocol.query", new=_query)
mocker.patch("kasa.SmartProtocol.query", new=_query)
mocker.patch.object(dev_class, "update", new=_update)

dev = await Discover.try_connect_all(discovery_mock.ip)

assert dev
assert isinstance(dev, dev_class)
assert isinstance(dev.protocol, protocol_class)
assert isinstance(dev.protocol._transport, transport_class)