Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions kasa/cli/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from contextlib import contextmanager
from functools import singledispatch, update_wrapper, wraps
from gettext import gettext
from typing import TYPE_CHECKING, Any, Final
from typing import TYPE_CHECKING, Any, Final, NoReturn

import asyncclick as click

Expand Down Expand Up @@ -57,7 +57,7 @@
_echo(*args, **kwargs)


def error(msg: str) -> None:
def error(msg: str) -> NoReturn:
"""Print an error and exit."""
echo(f"[bold red]{msg}[/bold red]")
sys.exit(1)
Expand All @@ -68,6 +68,16 @@
if not kwargs.get("json"):
return

# Calling the discover command directly always returns a DeviceDict so if host
# was specified just format the device json
if (
(host := kwargs.get("host"))
and isinstance(result, dict)
and (dev := result.get(host))
and isinstance(dev, Device)
):
result = dev

Check warning on line 79 in kasa/cli/common.py

View check run for this annotation

Codecov / codecov/patch

kasa/cli/common.py#L79

Added line #L79 was not covered by tests

@singledispatch
def to_serializable(val):
"""Regular obj-to-string for json serialization.
Expand All @@ -85,6 +95,25 @@
print(json_content)


async def invoke_subcommand(
command: click.BaseCommand,
ctx: click.Context,
args: list[str] | None = None,
**extra: Any,
) -> Any:
"""Invoke a click subcommand.

Calling ctx.Invoke() treats the command like a simple callback and doesn't
process any result_callbacks so we use this pattern from the click docs
https://click.palletsprojects.com/en/stable/exceptions/#what-if-i-don-t-want-that.
"""
if args is None:
args = []
sub_ctx = await command.make_context(command.name, args, parent=ctx, **extra)
async with sub_ctx:
return await command.invoke(sub_ctx)


def pass_dev_or_child(wrapped_function: Callable) -> Callable:
"""Pass the device or child to the click command based on the child options."""
child_help = (
Expand Down
3 changes: 3 additions & 0 deletions kasa/cli/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from pprint import pformat as pf
from typing import TYPE_CHECKING

import asyncclick as click

Expand Down Expand Up @@ -82,6 +83,8 @@ async def state(ctx, dev: Device):
echo()
from .discover import _echo_discovery_info

if TYPE_CHECKING:
assert dev._discovery_info
_echo_discovery_info(dev._discovery_info)

return dev.internal_state
Expand Down
74 changes: 57 additions & 17 deletions kasa/cli/discover.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import asyncio
from pprint import pformat as pf
from typing import TYPE_CHECKING, cast

import asyncclick as click

Expand All @@ -17,8 +18,12 @@
from kasa.discover import (
NEW_DISCOVERY_REDACTORS,
ConnectAttempt,
DeviceDict,
DiscoveredRaw,
DiscoveryResult,
OnDiscoveredCallable,
OnDiscoveredRawCallable,
OnUnsupportedCallable,
)
from kasa.iot.iotdevice import _extract_sys_info
from kasa.protocols.iotprotocol import REDACTORS as IOT_REDACTORS
Expand All @@ -30,15 +35,33 @@

@click.group(invoke_without_command=True)
@click.pass_context
async def discover(ctx):
async def discover(ctx: click.Context):
"""Discover devices in the network."""
if ctx.invoked_subcommand is None:
return await ctx.invoke(detail)


@discover.result_callback()
@click.pass_context
async def _close_protocols(ctx: click.Context, discovered: DeviceDict):
"""Close all the device protocols if discover was invoked directly by the user."""
Comment on lines +44 to +47
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch on this! We should probably also do something similar when error()ing out (as we currently get those warnings from aiohttp), but that's something for a separate PR.

if _discover_is_root_cmd(ctx):
for dev in discovered.values():
await dev.disconnect()
return discovered


def _discover_is_root_cmd(ctx: click.Context) -> bool:
"""Will return true if discover was invoked directly by the user."""
root_ctx = ctx.find_root()
return (
root_ctx.invoked_subcommand is None or root_ctx.invoked_subcommand == "discover"
)


@discover.command()
@click.pass_context
async def detail(ctx):
async def detail(ctx: click.Context) -> DeviceDict:
"""Discover devices in the network using udp broadcasts."""
unsupported = []
auth_failed = []
Expand All @@ -59,10 +82,14 @@ async def print_unsupported(unsupported_exception: UnsupportedDeviceError) -> No
from .device import state

async def print_discovered(dev: Device) -> None:
if TYPE_CHECKING:
assert ctx.parent
async with sem:
try:
await dev.update()
except AuthenticationError:
if TYPE_CHECKING:
assert dev._discovery_info
auth_failed.append(dev._discovery_info)
echo("== Authentication failed for device ==")
_echo_discovery_info(dev._discovery_info)
Expand All @@ -73,9 +100,11 @@ async def print_discovered(dev: Device) -> None:
echo()

discovered = await _discover(
ctx, print_discovered=print_discovered, print_unsupported=print_unsupported
ctx,
print_discovered=print_discovered if _discover_is_root_cmd(ctx) else None,
print_unsupported=print_unsupported,
)
if ctx.parent.parent.params["host"]:
if ctx.find_root().params["host"]:
return discovered

echo(f"Found {len(discovered)} devices")
Expand All @@ -96,7 +125,7 @@ async def print_discovered(dev: Device) -> None:
help="Set flag to redact sensitive data from raw output.",
)
@click.pass_context
async def raw(ctx, redact: bool):
async def raw(ctx: click.Context, redact: bool) -> DeviceDict:
"""Return raw discovery data returned from devices."""

def print_raw(discovered: DiscoveredRaw):
Expand All @@ -116,7 +145,7 @@ def print_raw(discovered: DiscoveredRaw):

@discover.command()
@click.pass_context
async def list(ctx):
async def list(ctx: click.Context) -> DeviceDict:
"""List devices in the network in a table using udp broadcasts."""
sem = asyncio.Semaphore()

Expand Down Expand Up @@ -147,18 +176,24 @@ async def print_unsupported(unsupported_exception: UnsupportedDeviceError):
f"{'HOST':<15} {'MODEL':<9} {'DEVICE FAMILY':<20} {'ENCRYPT':<7} "
f"{'HTTPS':<5} {'LV':<3} {'ALIAS'}"
)
return await _discover(
discovered = await _discover(
ctx,
print_discovered=print_discovered,
print_unsupported=print_unsupported,
do_echo=False,
)
return discovered


async def _discover(
ctx, *, print_discovered=None, print_unsupported=None, print_raw=None, do_echo=True
):
params = ctx.parent.parent.params
ctx: click.Context,
*,
print_discovered: OnDiscoveredCallable | None = None,
print_unsupported: OnUnsupportedCallable | None = None,
print_raw: OnDiscoveredRawCallable | None = None,
do_echo=True,
) -> DeviceDict:
params = ctx.find_root().params
target = params["target"]
username = params["username"]
password = params["password"]
Expand All @@ -170,8 +205,9 @@ async def _discover(
credentials = Credentials(username, password) if username and password else None

if host:
host = cast(str, host)
echo(f"Discovering device {host} for {discovery_timeout} seconds")
return await Discover.discover_single(
dev = await Discover.discover_single(
host,
port=port,
credentials=credentials,
Expand All @@ -180,6 +216,12 @@ async def _discover(
on_unsupported=print_unsupported,
on_discovered_raw=print_raw,
)
if dev:
if print_discovered:
await print_discovered(dev)
return {host: dev}
else:
return {}
if do_echo:
echo(f"Discovering devices on {target} for {discovery_timeout} seconds")
discovered_devices = await Discover.discover(
Expand All @@ -193,21 +235,18 @@ async def _discover(
on_discovered_raw=print_raw,
)

for device in discovered_devices.values():
await device.protocol.close()

return discovered_devices


@discover.command()
@click.pass_context
async def config(ctx):
async def config(ctx: click.Context) -> DeviceDict:
"""Bypass udp discovery and try to show connection config for a device.

Bypasses udp discovery and shows the parameters required to connect
directly to the device.
"""
params = ctx.parent.parent.params
params = ctx.find_root().params
username = params["username"]
password = params["password"]
timeout = params["timeout"]
Expand Down Expand Up @@ -239,6 +278,7 @@ def on_attempt(connect_attempt: ConnectAttempt, success: bool) -> None:
f"--encrypt-type {cparams.encryption_type.value} "
f"{'--https' if cparams.https else '--no-https'}"
)
return {host: dev}
else:
error(f"Unable to connect to {host}")

Expand All @@ -251,7 +291,7 @@ def _echo_dictionary(discovery_info: dict) -> None:
echo(f"\t{key_name_and_spaces}{value}")


def _echo_discovery_info(discovery_info) -> None:
def _echo_discovery_info(discovery_info: dict) -> None:
# We don't have discovery info when all connection params are passed manually
if discovery_info is None:
return
Expand Down
17 changes: 12 additions & 5 deletions kasa/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
CatchAllExceptions,
echo,
error,
invoke_subcommand,
json_formatter_cb,
pass_dev_or_child,
)
Expand Down Expand Up @@ -295,9 +296,10 @@ async def cli(
echo("No host name given, trying discovery..")
from .discover import discover

return await ctx.invoke(discover)
return await invoke_subcommand(discover, ctx)

device_updated = False
device_discovered = False

if type is not None and type not in {"smart", "camera"}:
from kasa.deviceconfig import DeviceConfig
Expand Down Expand Up @@ -351,12 +353,14 @@ async def cli(
return
echo(f"Found hostname by alias: {dev.host}")
device_updated = True
else:
else: # host will be set
from .discover import discover

dev = await ctx.invoke(discover)
if not dev:
discovered = await invoke_subcommand(discover, ctx)
if not discovered:
error(f"Unable to create device for {host}")
dev = discovered[host]
device_discovered = True

# Skip update on specific commands, or if device factory,
# that performs an update was used for the device.
Expand All @@ -372,11 +376,14 @@ async def async_wrapped_device(device: Device):

ctx.obj = await ctx.with_async_resource(async_wrapped_device(dev))

if ctx.invoked_subcommand is None:
# discover command has already invoked state
if ctx.invoked_subcommand is None and not device_discovered:
from .device import state

return await ctx.invoke(state)

return dev


@cli.command()
@pass_dev_or_child
Expand Down
Loading