Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,6 @@
myst_heading_anchors = 3


def setup(app):
def setup(app): # noqa: ANN201,ANN001
# add copybutton to hide the >>> prompts, see https://github.com/readthedocs/sphinx_rtd_theme/issues/167
app.add_js_file("copybutton.js")
6 changes: 3 additions & 3 deletions kasa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"""

from importlib.metadata import version
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from warnings import warn

from kasa.credentials import Credentials
Expand Down Expand Up @@ -101,7 +101,7 @@
}


def __getattr__(name):
def __getattr__(name: str) -> Any:
if name in deprecated_names:
warn(f"{name} is deprecated", DeprecationWarning, stacklevel=2)
return globals()[f"_deprecated_{name}"]
Expand All @@ -117,7 +117,7 @@ def __getattr__(name):
)
return new_class
if name in deprecated_classes:
new_class = deprecated_classes[name]
new_class = deprecated_classes[name] # type: ignore[assignment]
msg = f"{name} is deprecated, use {new_class.__name__} instead"
warn(msg, DeprecationWarning, stacklevel=2)
return new_class
Expand Down
38 changes: 22 additions & 16 deletions kasa/aestransport.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def hash_credentials(login_v2: bool, credentials: Credentials) -> tuple[str, str
pw = base64.b64encode(credentials.password.encode()).decode()
return un, pw

def _handle_response_error_code(self, resp_dict: Any, msg: str) -> None:
def _handle_response_error_code(self, resp_dict: dict, msg: str) -> None:
error_code_raw = resp_dict.get("error_code")
try:
error_code = SmartErrorCode.from_int(error_code_raw)
Expand Down Expand Up @@ -191,14 +191,14 @@ async def send_secure_passthrough(self, request: str) -> dict[str, Any]:
+ f"status code {status_code} to passthrough"
)

self._handle_response_error_code(
resp_dict, "Error sending secure_passthrough message"
)

if TYPE_CHECKING:
resp_dict = cast(Dict[str, Any], resp_dict)
assert self._encryption_session is not None

self._handle_response_error_code(
resp_dict, "Error sending secure_passthrough message"
)

raw_response: str = resp_dict["result"]["response"]

try:
Expand All @@ -219,7 +219,7 @@ async def send_secure_passthrough(self, request: str) -> dict[str, Any]:
) from ex
return ret_val # type: ignore[return-value]

async def perform_login(self):
async def perform_login(self) -> None:
"""Login to the device."""
try:
await self.try_login(self._login_params)
Expand Down Expand Up @@ -324,11 +324,11 @@ async def perform_handshake(self) -> None:
+ f"status code {status_code} to handshake"
)

self._handle_response_error_code(resp_dict, "Unable to complete handshake")

if TYPE_CHECKING:
resp_dict = cast(Dict[str, Any], resp_dict)

self._handle_response_error_code(resp_dict, "Unable to complete handshake")

handshake_key = resp_dict["result"]["key"]

if (
Expand All @@ -355,7 +355,7 @@ async def perform_handshake(self) -> None:

_LOGGER.debug("Handshake with %s complete", self._host)

def _handshake_session_expired(self):
def _handshake_session_expired(self) -> bool:
"""Return true if session has expired."""
return (
self._session_expire_at is None
Expand Down Expand Up @@ -394,7 +394,9 @@ class AesEncyptionSession:
"""Class for an AES encryption session."""

@staticmethod
def create_from_keypair(handshake_key: str, keypair: KeyPair):
def create_from_keypair(
handshake_key: str, keypair: KeyPair
) -> AesEncyptionSession:
"""Create the encryption session."""
handshake_key_bytes: bytes = base64.b64decode(handshake_key.encode())

Expand All @@ -404,19 +406,19 @@ def create_from_keypair(handshake_key: str, keypair: KeyPair):

return AesEncyptionSession(key_and_iv[:16], key_and_iv[16:])

def __init__(self, key, iv):
def __init__(self, key: bytes, iv: bytes) -> None:
self.cipher = Cipher(algorithms.AES(key), modes.CBC(iv))
self.padding_strategy = padding.PKCS7(algorithms.AES.block_size)

def encrypt(self, data) -> bytes:
def encrypt(self, data: bytes) -> bytes:
"""Encrypt the message."""
encryptor = self.cipher.encryptor()
padder = self.padding_strategy.padder()
padded_data = padder.update(data) + padder.finalize()
encrypted = encryptor.update(padded_data) + encryptor.finalize()
return base64.b64encode(encrypted)

def decrypt(self, data) -> str:
def decrypt(self, data: str | bytes) -> str:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be cleaner if we'd accept only bytes here (and leave base64-decoding to the caller), but I feel this is fine as it is.

"""Decrypt the message."""
decryptor = self.cipher.decryptor()
unpadder = self.padding_strategy.unpadder()
Expand All @@ -429,14 +431,16 @@ class KeyPair:
"""Class for generating key pairs."""

@staticmethod
def create_key_pair(key_size: int = 1024):
def create_key_pair(key_size: int = 1024) -> KeyPair:
"""Create a key pair."""
private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size)
public_key = private_key.public_key()
return KeyPair(private_key, public_key)

@staticmethod
def create_from_der_keys(private_key_der_b64: str, public_key_der_b64: str):
def create_from_der_keys(
private_key_der_b64: str, public_key_der_b64: str
) -> KeyPair:
"""Create a key pair."""
key_bytes = base64.b64decode(private_key_der_b64.encode())
private_key = cast(
Expand All @@ -449,7 +453,9 @@ def create_from_der_keys(private_key_der_b64: str, public_key_der_b64: str):

return KeyPair(private_key, public_key)

def __init__(self, private_key: rsa.RSAPrivateKey, public_key: rsa.RSAPublicKey):
def __init__(
self, private_key: rsa.RSAPrivateKey, public_key: rsa.RSAPublicKey
) -> None:
self.private_key = private_key
self.public_key = public_key
self.private_key_der_bytes = self.private_key.private_bytes(
Expand Down
23 changes: 15 additions & 8 deletions kasa/cli/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import sys
from contextlib import contextmanager
from functools import singledispatch, update_wrapper, wraps
from typing import Final
from typing import TYPE_CHECKING, Any, Callable, Final

import asyncclick as click

Expand Down Expand Up @@ -37,7 +37,7 @@ def _strip_rich_formatting(echo_func):
"""Strip rich formatting from messages."""

@wraps(echo_func)
def wrapper(message=None, *args, **kwargs):
def wrapper(message=None, *args, **kwargs) -> None:
if message is not None:
message = rich_formatting.sub("", message)
echo_func(message, *args, **kwargs)
Expand All @@ -47,20 +47,20 @@ def wrapper(message=None, *args, **kwargs):
_echo = _strip_rich_formatting(click.echo)


def echo(*args, **kwargs):
def echo(*args, **kwargs) -> None:
"""Print a message."""
ctx = click.get_current_context().find_root()
if "json" not in ctx.params or ctx.params["json"] is False:
_echo(*args, **kwargs)


def error(msg: str):
def error(msg: str) -> None:
"""Print an error and exit."""
echo(f"[bold red]{msg}[/bold red]")
sys.exit(1)


def json_formatter_cb(result, **kwargs):
def json_formatter_cb(result: Any, **kwargs) -> None:
"""Format and output the result as JSON, if requested."""
if not kwargs.get("json"):
return
Expand All @@ -82,7 +82,7 @@ def _device_to_serializable(val: Device):
print(json_content)


def pass_dev_or_child(wrapped_function):
def pass_dev_or_child(wrapped_function: Callable) -> Callable:
"""Pass the device or child to the click command based on the child options."""
child_help = (
"Child ID or alias for controlling sub-devices. "
Expand Down Expand Up @@ -133,7 +133,10 @@ async def wrapper(ctx: click.Context, dev, *args, child, child_index, **kwargs):


async def _get_child_device(
device: Device, child_option, child_index_option, info_command
device: Device,
child_option: str | None,
child_index_option: int | None,
info_command: str | None,
) -> Device | None:
def _list_children():
return "\n".join(
Expand Down Expand Up @@ -178,11 +181,15 @@ def _list_children():
f"{child_option} children are:\n{_list_children()}"
)

if TYPE_CHECKING:
assert isinstance(child_index_option, int)

if child_index_option + 1 > len(device.children) or child_index_option < 0:
error(
f"Invalid index {child_index_option}, "
f"device has {len(device.children)} children"
)

child_by_index = device.children[child_index_option]
echo(f"Targeting child device {child_by_index.alias}")
return child_by_index
Expand All @@ -195,7 +202,7 @@ def CatchAllExceptions(cls):
https://stackoverflow.com/questions/52213375
"""

def _handle_exception(debug, exc):
def _handle_exception(debug, exc) -> None:
if isinstance(exc, click.ClickException):
raise
# Handle exit request from click.
Expand Down
2 changes: 1 addition & 1 deletion kasa/cli/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

@click.group()
@pass_dev_or_child
def device(dev):
def device(dev) -> None:
"""Commands to control basic device settings."""


Expand Down
8 changes: 4 additions & 4 deletions kasa/cli/discover.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async def detail(ctx):
auth_failed = []
sem = asyncio.Semaphore()

async def print_unsupported(unsupported_exception: UnsupportedDeviceError):
async def print_unsupported(unsupported_exception: UnsupportedDeviceError) -> None:
unsupported.append(unsupported_exception)
async with sem:
if unsupported_exception.discovery_result:
Expand All @@ -50,7 +50,7 @@ async def print_unsupported(unsupported_exception: UnsupportedDeviceError):

from .device import state

async def print_discovered(dev: Device):
async def print_discovered(dev: Device) -> None:
async with sem:
try:
await dev.update()
Expand Down Expand Up @@ -189,15 +189,15 @@ def on_attempt(connect_attempt: ConnectAttempt, success: bool) -> None:
error(f"Unable to connect to {host}")


def _echo_dictionary(discovery_info: dict):
def _echo_dictionary(discovery_info: dict) -> None:
echo("\t[bold]== Discovery information ==[/bold]")
for key, value in discovery_info.items():
key_name = " ".join(x.capitalize() or "_" for x in key.split("_"))
key_name_and_spaces = "{:<15}".format(key_name + ":")
echo(f"\t{key_name_and_spaces}{value}")


def _echo_discovery_info(discovery_info):
def _echo_discovery_info(discovery_info) -> None:
# We don't have discovery info when all connection params are passed manually
if discovery_info is None:
return
Expand Down
6 changes: 4 additions & 2 deletions kasa/cli/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def _echo_features(
category: Feature.Category | None = None,
verbose: bool = False,
indent: str = "\t",
):
) -> None:
"""Print out a listing of features and their values."""
if category is not None:
features = {
Expand All @@ -43,7 +43,9 @@ def _echo_features(
echo(f"{indent}{feat.name} ({feat.id}): [red]got exception ({ex})[/red]")


def _echo_all_features(features, *, verbose=False, title_prefix=None, indent=""):
def _echo_all_features(
features, *, verbose=False, title_prefix=None, indent=""
) -> None:
"""Print out all features by category."""
if title_prefix is not None:
echo(f"[bold]\n{indent}== {title_prefix} ==[/bold]")
Expand Down
8 changes: 5 additions & 3 deletions kasa/cli/lazygroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
Taken from the click help files.
"""

from __future__ import annotations

import importlib

import asyncclick as click
Expand All @@ -11,7 +13,7 @@
class LazyGroup(click.Group):
"""Lazy group class."""

def __init__(self, *args, lazy_subcommands=None, **kwargs):
def __init__(self, *args, lazy_subcommands=None, **kwargs) -> None:
super().__init__(*args, **kwargs)
# lazy_subcommands is a map of the form:
#
Expand All @@ -31,9 +33,9 @@ def get_command(self, ctx, cmd_name):
return self._lazy_load(cmd_name)
return super().get_command(ctx, cmd_name)

def format_commands(self, ctx, formatter):
def format_commands(self, ctx, formatter) -> None:
"""Format the top level help output."""
sections = {}
sections: dict[str, list] = {}
for cmd, parent in self.lazy_subcommands.items():
sections.setdefault(parent, [])
cmd_obj = self.get_command(ctx, cmd)
Expand Down
2 changes: 1 addition & 1 deletion kasa/cli/light.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

@click.group()
@pass_dev_or_child
def light(dev):
def light(dev) -> None:
"""Commands to control light settings."""


Expand Down
6 changes: 3 additions & 3 deletions kasa/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
DEFAULT_TARGET = "255.255.255.255"


def _legacy_type_to_class(_type):
def _legacy_type_to_class(_type: str) -> Any:
from kasa.iot import (
IotBulb,
IotDimmer,
Expand Down Expand Up @@ -396,9 +396,9 @@

@cli.command()
@pass_dev_or_child
async def shell(dev: Device):
async def shell(dev: Device) -> None:
"""Open interactive shell."""
echo("Opening shell for %s" % dev)
echo(f"Opening shell for {dev}")

Check warning on line 401 in kasa/cli/main.py

View check run for this annotation

Codecov / codecov/patch

kasa/cli/main.py#L401

Added line #L401 was not covered by tests
from ptpython.repl import embed

logging.getLogger("parso").setLevel(logging.WARNING) # prompt parsing
Expand Down
2 changes: 1 addition & 1 deletion kasa/cli/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

@click.group()
@pass_dev
async def schedule(dev):
async def schedule(dev) -> None:
"""Scheduling commands."""


Expand Down
2 changes: 1 addition & 1 deletion kasa/cli/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

@click.group(invoke_without_command=True)
@click.pass_context
async def time(ctx: click.Context):
async def time(ctx: click.Context) -> None:
"""Get and set time."""
if ctx.invoked_subcommand is None:
await ctx.invoke(time_get)
Expand Down
Loading
Loading