Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 45 additions & 4 deletions kasa/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def _device_to_serializable(val: SmartDevice):
"--port",
envvar="KASA_PORT",
required=False,
type=int,
help="The port of the device to connect to.",
)
@click.option(
Expand Down Expand Up @@ -138,7 +139,17 @@ def _device_to_serializable(val: SmartDevice):
)
@click.version_option(package_name="python-kasa")
@click.pass_context
async def cli(ctx, host, port, alias, target, debug, type, json, discovery_timeout):
async def cli(
ctx,
host,
port,
alias,
target,
debug,
type,
json,
discovery_timeout,
):
"""A tool for controlling TP-Link smart home devices.""" # noqa
# no need to perform any checks if we are just displaying the help
if sys.argv[-1] == "--help":
Expand Down Expand Up @@ -238,13 +249,29 @@ async def join(dev: SmartDevice, ssid, password, keytype):

@cli.command()
@click.option("--timeout", default=3, required=False)
@click.option(
"--show-unsupported",
envvar="KASA_SHOW_UNSUPPORTED",
required=False,
default=False,
is_flag=True,
help="Print out discovered unsupported devices",
)
@click.pass_context
async def discover(ctx, timeout):
async def discover(ctx, timeout, show_unsupported):
"""Discover devices in the network."""
target = ctx.parent.params["target"]
echo(f"Discovering devices on {target} for {timeout} seconds")
sem = asyncio.Semaphore()
discovered = dict()
unsupported = []

async def print_unsupported(data: Dict):
unsupported.append(data)
if show_unsupported:
echo(f"Found unsupported device (tapo/unknown encryption): {data}")
echo()

echo(f"Discovering devices on {target} for {timeout} seconds")

async def print_discovered(dev: SmartDevice):
await dev.update()
Expand All @@ -255,9 +282,23 @@ async def print_discovered(dev: SmartDevice):
echo()

await Discover.discover(
target=target, timeout=timeout, on_discovered=print_discovered
target=target,
timeout=timeout,
on_discovered=print_discovered,
on_unsupported=print_unsupported,
)

echo(f"Found {len(discovered)} devices")
if unsupported:
echo(
f"Found {len(unsupported)} unsupported devices"
+ (
""
if show_unsupported
else ", to show them use: kasa discover --show-unsupported"
)
)

return discovered


Expand Down
99 changes: 86 additions & 13 deletions kasa/discover.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
"""Discovery module for TP-Link Smart Home devices."""
import asyncio
import binascii
import logging
import socket
from typing import Awaitable, Callable, Dict, Optional, Type, cast

# When support for cpython older than 3.11 is dropped
# async_timeout can be replaced with asyncio.timeout
from async_timeout import timeout as asyncio_timeout

from kasa.exceptions import UnsupportedDeviceException
from kasa.json import dumps as json_dumps
from kasa.json import loads as json_loads
from kasa.protocol import TPLinkSmartHomeProtocol
Expand Down Expand Up @@ -36,13 +42,22 @@ def __init__(
target: str = "255.255.255.255",
discovery_packets: int = 3,
interface: Optional[str] = None,
on_unsupported: Optional[Callable[[Dict], Awaitable[None]]] = None,
port: Optional[int] = None,
discovered_event: Optional[asyncio.Event] = None,
):
self.transport = None
self.discovery_packets = discovery_packets
self.interface = interface
self.on_discovered = on_discovered
self.target = (target, Discover.DISCOVERY_PORT)
self.discovery_port = port or Discover.DISCOVERY_PORT
self.target = (target, self.discovery_port)
self.target_2 = (target, Discover.DISCOVERY_PORT_2)
self.discovered_devices = {}
self.unsupported_devices: Dict = {}
self.invalid_device_exceptions: Dict = {}
self.on_unsupported = on_unsupported
self.discovered_event = discovered_event

def connection_made(self, transport) -> None:
"""Set socket options for broadcasting."""
Expand All @@ -69,30 +84,58 @@ def do_discover(self) -> None:
encrypted_req = TPLinkSmartHomeProtocol.encrypt(req)
for i in range(self.discovery_packets):
self.transport.sendto(encrypted_req[4:], self.target) # type: ignore
self.transport.sendto(Discover.DISCOVERY_QUERY_2, self.target_2) # type: ignore

def datagram_received(self, data, addr) -> None:
"""Handle discovery responses."""
ip, port = addr
if ip in self.discovered_devices:
if (
ip in self.discovered_devices
or ip in self.unsupported_devices
or ip in self.invalid_device_exceptions
):
return

info = json_loads(TPLinkSmartHomeProtocol.decrypt(data))
_LOGGER.debug("[DISCOVERY] %s << %s", ip, info)
if port == self.discovery_port:
info = json_loads(TPLinkSmartHomeProtocol.decrypt(data))
_LOGGER.debug("[DISCOVERY] %s << %s", ip, info)

elif port == Discover.DISCOVERY_PORT_2:
info = json_loads(data[16:])
self.unsupported_devices[ip] = info
if self.on_unsupported is not None:
asyncio.ensure_future(self.on_unsupported(info))
_LOGGER.debug("[DISCOVERY] Unsupported device found at %s << %s", ip, info)
if self.discovered_event is not None and "255" not in self.target[0].split(
"."
):
self.discovered_event.set()
return

try:
device_class = Discover._get_device_class(info)
except SmartDeviceException as ex:
_LOGGER.debug("Unable to find device type from %s: %s", info, ex)
_LOGGER.debug(
"[DISCOVERY] Unable to find device type from %s: %s", info, ex
)
self.invalid_device_exceptions[ip] = ex
if self.discovered_event is not None and "255" not in self.target[0].split(
"."
):
self.discovered_event.set()
return

device = device_class(ip)
device = device_class(ip, port=port)
device.update_from_discover_info(info)

self.discovered_devices[ip] = device

if self.on_discovered is not None:
asyncio.ensure_future(self.on_discovered(device))

if self.discovered_event is not None and "255" not in self.target[0].split("."):
self.discovered_event.set()

def error_received(self, ex):
"""Handle asyncio.Protocol errors."""
_LOGGER.error("Got error: %s", ex)
Expand Down Expand Up @@ -142,6 +185,9 @@ class Discover:
"system": {"get_sysinfo": None},
}

DISCOVERY_PORT_2 = 20002
DISCOVERY_QUERY_2 = binascii.unhexlify("020000010000000000000000463cb5d3")

@staticmethod
async def discover(
*,
Expand All @@ -150,6 +196,7 @@ async def discover(
timeout=5,
discovery_packets=3,
interface=None,
on_unsupported=None,
) -> DeviceDict:
"""Discover supported devices.

Expand Down Expand Up @@ -177,6 +224,7 @@ async def discover(
on_discovered=on_discovered,
discovery_packets=discovery_packets,
interface=interface,
on_unsupported=on_unsupported,
),
local_addr=("0.0.0.0", 0),
)
Expand All @@ -193,22 +241,47 @@ async def discover(
return protocol.discovered_devices

@staticmethod
async def discover_single(host: str, *, port: Optional[int] = None) -> SmartDevice:
async def discover_single(
host: str, *, port: Optional[int] = None, timeout=5
) -> SmartDevice:
"""Discover a single device by the given IP address.

:param host: Hostname of device to query
:rtype: SmartDevice
:return: Object for querying/controlling found device.
"""
protocol = TPLinkSmartHomeProtocol(host, port=port)
loop = asyncio.get_event_loop()
event = asyncio.Event()
transport, protocol = await loop.create_datagram_endpoint(
lambda: _DiscoverProtocol(target=host, port=port, discovered_event=event),
local_addr=("0.0.0.0", 0),
)
protocol = cast(_DiscoverProtocol, protocol)

info = await protocol.query(Discover.DISCOVERY_QUERY)
try:
_LOGGER.debug("Waiting a total of %s seconds for responses...", timeout)

device_class = Discover._get_device_class(info)
dev = device_class(host, port=port)
await dev.update()
async with asyncio_timeout(timeout):
await event.wait()
except asyncio.TimeoutError:
raise SmartDeviceException(
f"Timed out getting discovery response for {host}"
)
finally:
transport.close()

return dev
if host in protocol.discovered_devices:
dev = protocol.discovered_devices[host]
await dev.update()
return dev
elif host in protocol.unsupported_devices:
raise UnsupportedDeviceException(
f"Unsupported device {host}: {protocol.unsupported_devices[host]}"
)
elif host in protocol.invalid_device_exceptions:
raise protocol.invalid_device_exceptions[host]
else:
raise SmartDeviceException(f"Unable to get discovery response for {host}")

@staticmethod
def _get_device_class(info: dict) -> Type[SmartDevice]:
Expand Down
4 changes: 4 additions & 0 deletions kasa/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@

class SmartDeviceException(Exception):
"""Base exception for device errors."""


class UnsupportedDeviceException(SmartDeviceException):
"""Exception for trying to connect to unsupported devices."""
Loading