Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 34 additions & 7 deletions kasa/discover.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Discovery module for TP-Link Smart Home devices."""
import asyncio
import binascii
import ipaddress
import logging
import socket
from typing import Awaitable, Callable, Dict, Optional, Type, cast
Expand Down Expand Up @@ -273,9 +274,34 @@ async def discover_single(
"""
loop = asyncio.get_event_loop()
event = asyncio.Event()

try:
ipaddress.ip_address(host)
ip = host
except ValueError:
try:
adrrinfo = await loop.getaddrinfo(
host,
0,
type=socket.SOCK_DGRAM,
family=socket.AF_INET,
)
# getaddrinfo returns a list of 5 tuples with the following structure:
# (family, type, proto, canonname, sockaddr)
# where sockaddr is 2 tuple (ip, port).
# hence [0][4][0] is a stable array access because if no socket
# address matches the host for SOCK_DGRAM AF_INET the gaierror
# would be raised.
# https://docs.python.org/3/library/socket.html#socket.getaddrinfo
ip = adrrinfo[0][4][0]
except socket.gaierror as gex:
raise SmartDeviceException(
f"Could not resolve hostname {host}"
) from gex

transport, protocol = await loop.create_datagram_endpoint(
lambda: _DiscoverProtocol(
target=host,
target=ip,
port=port,
discovered_event=event,
credentials=credentials,
Expand All @@ -297,16 +323,17 @@ async def discover_single(
finally:
transport.close()

if host in protocol.discovered_devices:
dev = protocol.discovered_devices[host]
if ip in protocol.discovered_devices:
dev = protocol.discovered_devices[ip]
dev.host = host
await dev.update()
return dev
elif host in protocol.unsupported_devices:
elif ip in protocol.unsupported_devices:
raise UnsupportedDeviceException(
f"Unsupported device {host}: {protocol.unsupported_devices[host]}"
f"Unsupported device {host}: {protocol.unsupported_devices[ip]}"
)
elif host in protocol.invalid_device_exceptions:
raise protocol.invalid_device_exceptions[host]
elif ip in protocol.invalid_device_exceptions:
raise protocol.invalid_device_exceptions[ip]
else:
raise SmartDeviceException(f"Unable to get discovery response for {host}")

Expand Down
26 changes: 26 additions & 0 deletions kasa/tests/test_discovery.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# type: ignore
import re
import socket
import sys

import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
Expand Down Expand Up @@ -74,6 +75,31 @@ def mock_discover(self):
assert x.port == custom_port or x.port == 9999


async def test_discover_single_hostname(discovery_data: dict, mocker):
"""Make sure that discover_single returns an initialized SmartDevice instance."""
host = "foobar"
ip = "127.0.0.1"

def mock_discover(self):
self.datagram_received(
protocol.TPLinkSmartHomeProtocol.encrypt(json_dumps(discovery_data))[4:],
(ip, 9999),
)

mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover)
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
mocker.patch("socket.getaddrinfo", return_value=[(None, None, None, None, (ip, 0))])

x = await Discover.discover_single(host)
assert issubclass(x.__class__, SmartDevice)
assert x._sys_info is not None
assert x.host == host

mocker.patch("socket.getaddrinfo", side_effect=socket.gaierror())
with pytest.raises(SmartDeviceException):
x = await Discover.discover_single(host)


@pytest.mark.parametrize("custom_port", [123, None])
async def test_connect_single(discovery_data: dict, mocker, custom_port):
"""Make sure that connect_single returns an initialized SmartDevice instance."""
Expand Down