Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 21 additions & 13 deletions kasa/device_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from .device import Device
from .device_type import DeviceType
from .deviceconfig import DeviceConfig
from .deviceconfig import DeviceConfig, DeviceFamily
from .exceptions import KasaException, UnsupportedDeviceError
from .iot import (
IotBulb,
Expand Down Expand Up @@ -179,20 +179,29 @@ def get_device_class_from_family(
def get_protocol(
config: DeviceConfig,
) -> BaseProtocol | None:
"""Return the protocol from the connection name."""
protocol_name = config.connection_type.device_family.value.split(".")[0]
"""Return the protocol from the connection name.

For cameras and vacuums the device family is a simple mapping to
the protocol/transport. For other device types the transport varies
based on the discovery information.
"""
ctype = config.connection_type
protocol_name = ctype.device_family.value.split(".")[0]

if ctype.device_family is DeviceFamily.SmartIpCamera:
return SmartCamProtocol(transport=SslAesTransport(config=config))

if ctype.device_family is DeviceFamily.IotIpCamera:
return IotProtocol(transport=LinkieTransportV2(config=config))

if ctype.device_family is DeviceFamily.SmartTapoRobovac:
return SmartProtocol(transport=SslTransport(config=config))

protocol_transport_key = (
protocol_name
+ "."
+ ctype.encryption_type.value
+ (".HTTPS" if ctype.https else "")
+ (
f".{ctype.login_version}"
if ctype.login_version and ctype.login_version > 1
else ""
)
)

_LOGGER.debug("Finding transport for %s", protocol_transport_key)
Expand All @@ -201,12 +210,11 @@ def get_protocol(
] = {
"IOT.XOR": (IotProtocol, XorTransport),
"IOT.KLAP": (IotProtocol, KlapTransport),
"IOT.XOR.HTTPS.2": (IotProtocol, LinkieTransportV2),
"SMART.AES": (SmartProtocol, AesTransport),
"SMART.AES.2": (SmartProtocol, AesTransport),
"SMART.KLAP.2": (SmartProtocol, KlapTransportV2),
"SMART.AES.HTTPS.2": (SmartCamProtocol, SslAesTransport),
"SMART.AES.HTTPS": (SmartProtocol, SslTransport),
"SMART.KLAP": (SmartProtocol, KlapTransportV2),
# H200 is device family SMART.TAPOHUB and uses SmartCamProtocol so use
# https to distuingish from SmartProtocol devices
"SMART.AES.HTTPS": (SmartCamProtocol, SslAesTransport),
}
if not (prot_tran_cls := supported_device_protocols.get(protocol_transport_key)):
return None
Expand Down
10 changes: 5 additions & 5 deletions kasa/discover.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,12 +847,12 @@ def _get_device_instance(
):
encrypt_type = encrypt_info.sym_schm

if (
not (login_version := encrypt_schm.lv)
and (et := discovery_result.encrypt_type)
and et == ["3"]
if not (login_version := encrypt_schm.lv) and (
et := discovery_result.encrypt_type
):
login_version = 2
# Known encrypt types are ["1","2"] and ["3"]
# Reuse the login_version attribute to pass the max to transport
login_version = max([int(i) for i in et])

if not encrypt_type:
raise UnsupportedDeviceError(
Expand Down
85 changes: 85 additions & 0 deletions tests/test_device_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342

from kasa import (
BaseProtocol,
Credentials,
Discover,
IotProtocol,
KasaException,
SmartCamProtocol,
SmartProtocol,
)
from kasa.device_factory import (
Device,
Expand All @@ -33,6 +37,16 @@
DeviceFamily,
)
from kasa.discover import DiscoveryResult
from kasa.transports import (
AesTransport,
BaseTransport,
KlapTransport,
KlapTransportV2,
LinkieTransportV2,
SslAesTransport,
SslTransport,
XorTransport,
)

from .conftest import DISCOVERY_MOCK_IP

Expand Down Expand Up @@ -203,3 +217,74 @@ async def test_device_class_from_unknown_family(caplog):
with caplog.at_level(logging.DEBUG):
assert get_device_class_from_family(dummy_name, https=False) == SmartDevice
assert f"Unknown SMART device with {dummy_name}" in caplog.text


# Aliases to make the test params more readable
CP = DeviceConnectionParameters
DF = DeviceFamily
ET = DeviceEncryptionType


@pytest.mark.parametrize(
("conn_params", "expected_protocol", "expected_transport"),
[
pytest.param(
CP(DF.SmartIpCamera, ET.Aes, https=True),
SmartCamProtocol,
SslAesTransport,
id="smartcam",
),
pytest.param(
CP(DF.SmartTapoHub, ET.Aes, https=True),
SmartCamProtocol,
SslAesTransport,
id="smartcam-hub",
),
pytest.param(
CP(DF.IotIpCamera, ET.Aes, https=True),
IotProtocol,
LinkieTransportV2,
id="kasacam",
),
pytest.param(
CP(DF.SmartTapoRobovac, ET.Aes, https=True),
SmartProtocol,
SslTransport,
id="robovac",
),
pytest.param(
CP(DF.IotSmartPlugSwitch, ET.Klap, https=False),
IotProtocol,
KlapTransport,
id="iot-klap",
),
pytest.param(
CP(DF.IotSmartPlugSwitch, ET.Xor, https=False),
IotProtocol,
XorTransport,
id="iot-xor",
),
pytest.param(
CP(DF.SmartTapoPlug, ET.Aes, https=False),
SmartProtocol,
AesTransport,
id="smart-aes",
),
pytest.param(
CP(DF.SmartTapoPlug, ET.Klap, https=False),
SmartProtocol,
KlapTransportV2,
id="smart-klap",
),
],
)
async def test_get_protocol(
conn_params: DeviceConnectionParameters,
expected_protocol: type[BaseProtocol],
expected_transport: type[BaseTransport],
):
"""Test get_protocol returns the right protocol."""
config = DeviceConfig("127.0.0.1", connection_type=conn_params)
protocol = get_protocol(config)
assert isinstance(protocol, expected_protocol)
assert isinstance(protocol._transport, expected_transport)
Loading