Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions kasa/discover.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@
from pprint import pformat as pf
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast

from aiohttp import ClientSession

# When support for cpython older than 3.11 is dropped
# async_timeout can be replaced with asyncio.timeout
from async_timeout import timeout as asyncio_timeout
Expand Down Expand Up @@ -533,6 +535,7 @@ async def try_connect_all(
port: int | None = None,
timeout: int | None = None,
credentials: Credentials | None = None,
http_client: ClientSession | None = None,
) -> Device | None:
"""Try to connect directly to a device with all possible parameters.

Expand All @@ -544,6 +547,7 @@ async def try_connect_all(
:param port: Optionally set a different port for legacy devices using port 9999
:param timeout: Timeout in seconds device for devices queries
:param credentials: Credentials for devices that require authentication.
:param http_client: Optional client session for devices that use http.
username and password are ignored if provided.
"""
from .device_factory import _connect
Expand All @@ -570,6 +574,8 @@ async def try_connect_all(
timeout=timeout,
port_override=port,
credentials=credentials,
http_client=http_client,
uses_http=encrypt is not Device.EncryptionType.Xor,
)
)
and (protocol := get_protocol(config))
Expand Down
6 changes: 5 additions & 1 deletion kasa/tests/test_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,9 +697,13 @@ async def _update(self, *args, **kwargs):
mocker.patch("kasa.SmartProtocol.query", new=_query)
mocker.patch.object(dev_class, "update", new=_update)

dev = await Discover.try_connect_all(discovery_mock.ip)
session = aiohttp.ClientSession()
dev = await Discover.try_connect_all(discovery_mock.ip, http_client=session)

assert dev
assert isinstance(dev, dev_class)
assert isinstance(dev.protocol, protocol_class)
assert isinstance(dev.protocol._transport, transport_class)
assert dev.config.uses_http is (transport_class != XorTransport)
if transport_class != XorTransport:
assert dev.protocol._transport._http_client.client == session