Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion kasa/cli/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@

from kasa.cli.main import cli

cli()
if __name__ == "__main__":
cli()
231 changes: 231 additions & 0 deletions kasa/cli/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
"""Common cli module."""

from __future__ import annotations

import json
import re
import sys
from contextlib import contextmanager
from functools import singledispatch, update_wrapper, wraps
from typing import Final

import asyncclick as click

from kasa import (
Device,
)

# Value for optional options if passed without a value
OPTIONAL_VALUE_FLAG: Final = "_FLAG_"

# Block list of commands which require no update
SKIP_UPDATE_COMMANDS = ["raw-command", "command"]

pass_dev = click.make_pass_decorator(Device) # type: ignore[type-abstract]


try:
from rich import print as _echo
except ImportError:
# Strip out rich formatting if rich is not installed
# but only lower case tags to avoid stripping out
# raw data from the device that is printed from
# the device state.
rich_formatting = re.compile(r"\[/?[a-z]+]")

def _strip_rich_formatting(echo_func):
"""Strip rich formatting from messages."""

@wraps(echo_func)
def wrapper(message=None, *args, **kwargs):
if message is not None:
message = rich_formatting.sub("", message)
echo_func(message, *args, **kwargs)

return wrapper

_echo = _strip_rich_formatting(click.echo)


def echo(*args, **kwargs):
"""Print a message."""
ctx = click.get_current_context().find_root()
if "json" not in ctx.params or ctx.params["json"] is False:
_echo(*args, **kwargs)


def error(msg: str):
"""Print an error and exit."""
echo(f"[bold red]{msg}[/bold red]")
sys.exit(1)


def json_formatter_cb(result, **kwargs):
"""Format and output the result as JSON, if requested."""
if not kwargs.get("json"):
return

@singledispatch
def to_serializable(val):
"""Regular obj-to-string for json serialization.

The singledispatch trick is from hynek: https://hynek.me/articles/serialization/
"""
return str(val)

@to_serializable.register(Device)
def _device_to_serializable(val: Device):
"""Serialize smart device data, just using the last update raw payload."""
return val.internal_state

json_content = json.dumps(result, indent=4, default=to_serializable)
print(json_content)


def pass_dev_or_child(wrapped_function):
"""Pass the device or child to the click command based on the child options."""
child_help = (
"Child ID or alias for controlling sub-devices. "
"If no value provided will show an interactive prompt allowing you to "
"select a child."
)
child_index_help = "Child index controlling sub-devices"

@contextmanager
def patched_device_update(parent: Device, child: Device):
try:
orig_update = child.update
# patch child update method. Can be removed once update can be called
# directly on child devices
child.update = parent.update # type: ignore[method-assign]
yield child
finally:
child.update = orig_update # type: ignore[method-assign]

@click.pass_obj
@click.pass_context
@click.option(
"--child",
"--name",
is_flag=False,
flag_value=OPTIONAL_VALUE_FLAG,
default=None,
required=False,
type=click.STRING,
help=child_help,
)
@click.option(
"--child-index",
"--index",
required=False,
default=None,
type=click.INT,
help=child_index_help,
)
async def wrapper(ctx: click.Context, dev, *args, child, child_index, **kwargs):
if child := await _get_child_device(dev, child, child_index, ctx.info_name):
ctx.obj = ctx.with_resource(patched_device_update(dev, child))
dev = child
return await ctx.invoke(wrapped_function, dev, *args, **kwargs)

# Update wrapper function to look like wrapped function
return update_wrapper(wrapper, wrapped_function)


async def _get_child_device(
device: Device, child_option, child_index_option, info_command
) -> Device | None:
def _list_children():
return "\n".join(
[
f"{idx}: {child.device_id} ({child.alias})"
for idx, child in enumerate(device.children)
]
)

if child_option is None and child_index_option is None:
return None

if info_command in SKIP_UPDATE_COMMANDS:
# The device hasn't had update called (e.g. for cmd_command)
# The way child devices are accessed requires a ChildDevice to
# wrap the communications. Doing this properly would require creating
# a common interfaces for both IOT and SMART child devices.
# As a stop-gap solution, we perform an update instead.
await device.update()

if not device.children:
error(f"Device: {device.host} does not have children")

if child_option is not None and child_index_option is not None:
raise click.BadOptionUsage(
"child", "Use either --child or --child-index, not both."
)

if child_option is not None:
if child_option is OPTIONAL_VALUE_FLAG:
msg = _list_children()
child_index_option = click.prompt(
f"\n{msg}\nEnter the index number of the child device",
type=click.IntRange(0, len(device.children) - 1),
)
elif child := device.get_child_device(child_option):
echo(f"Targeting child device {child.alias}")
return child
else:
error(
"No child device found with device_id or name: "
f"{child_option} children are:\n{_list_children()}"
)

if child_index_option + 1 > len(device.children) or child_index_option < 0:
error(
f"Invalid index {child_index_option}, "
f"device has {len(device.children)} children"
)
child_by_index = device.children[child_index_option]
echo(f"Targeting child device {child_by_index.alias}")
return child_by_index


def CatchAllExceptions(cls):
"""Capture all exceptions and prints them nicely.

Idea from https://stackoverflow.com/a/44347763 and
https://stackoverflow.com/questions/52213375
"""

def _handle_exception(debug, exc):
if isinstance(exc, click.ClickException):
raise
# Handle exit request from click.
if isinstance(exc, click.exceptions.Exit):
sys.exit(exc.exit_code)

echo(f"Raised error: {exc}")
if debug:
raise
echo("Run with --debug enabled to see stacktrace")
sys.exit(1)

class _CommandCls(cls):
_debug = False

async def make_context(self, info_name, args, parent=None, **extra):
self._debug = any(
[arg for arg in args if arg in ["--debug", "-d", "--verbose", "-v"]]
)
try:
return await super().make_context(
info_name, args, parent=parent, **extra
)
except Exception as exc:
_handle_exception(self._debug, exc)

async def invoke(self, ctx):
try:
return await super().invoke(ctx)
except Exception as exc:
_handle_exception(self._debug, exc)

return _CommandCls
Loading