Skip to content

Commit fe07265

Browse files
authored
Simplify get_protocol to prevent clashes with smartcam and robovac (#1377)
1 parent 5918e4d commit fe07265

File tree

3 files changed

+111
-18
lines changed

3 files changed

+111
-18
lines changed

kasa/device_factory.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from .device import Device
1010
from .device_type import DeviceType
11-
from .deviceconfig import DeviceConfig
11+
from .deviceconfig import DeviceConfig, DeviceFamily
1212
from .exceptions import KasaException, UnsupportedDeviceError
1313
from .iot import (
1414
IotBulb,
@@ -179,20 +179,29 @@ def get_device_class_from_family(
179179
def get_protocol(
180180
config: DeviceConfig,
181181
) -> BaseProtocol | None:
182-
"""Return the protocol from the connection name."""
183-
protocol_name = config.connection_type.device_family.value.split(".")[0]
182+
"""Return the protocol from the connection name.
183+
184+
For cameras and vacuums the device family is a simple mapping to
185+
the protocol/transport. For other device types the transport varies
186+
based on the discovery information.
187+
"""
184188
ctype = config.connection_type
189+
protocol_name = ctype.device_family.value.split(".")[0]
190+
191+
if ctype.device_family is DeviceFamily.SmartIpCamera:
192+
return SmartCamProtocol(transport=SslAesTransport(config=config))
193+
194+
if ctype.device_family is DeviceFamily.IotIpCamera:
195+
return IotProtocol(transport=LinkieTransportV2(config=config))
196+
197+
if ctype.device_family is DeviceFamily.SmartTapoRobovac:
198+
return SmartProtocol(transport=SslTransport(config=config))
185199

186200
protocol_transport_key = (
187201
protocol_name
188202
+ "."
189203
+ ctype.encryption_type.value
190204
+ (".HTTPS" if ctype.https else "")
191-
+ (
192-
f".{ctype.login_version}"
193-
if ctype.login_version and ctype.login_version > 1
194-
else ""
195-
)
196205
)
197206

198207
_LOGGER.debug("Finding transport for %s", protocol_transport_key)
@@ -201,12 +210,11 @@ def get_protocol(
201210
] = {
202211
"IOT.XOR": (IotProtocol, XorTransport),
203212
"IOT.KLAP": (IotProtocol, KlapTransport),
204-
"IOT.XOR.HTTPS.2": (IotProtocol, LinkieTransportV2),
205213
"SMART.AES": (SmartProtocol, AesTransport),
206-
"SMART.AES.2": (SmartProtocol, AesTransport),
207-
"SMART.KLAP.2": (SmartProtocol, KlapTransportV2),
208-
"SMART.AES.HTTPS.2": (SmartCamProtocol, SslAesTransport),
209-
"SMART.AES.HTTPS": (SmartProtocol, SslTransport),
214+
"SMART.KLAP": (SmartProtocol, KlapTransportV2),
215+
# H200 is device family SMART.TAPOHUB and uses SmartCamProtocol so use
216+
# https to distuingish from SmartProtocol devices
217+
"SMART.AES.HTTPS": (SmartCamProtocol, SslAesTransport),
210218
}
211219
if not (prot_tran_cls := supported_device_protocols.get(protocol_transport_key)):
212220
return None

kasa/discover.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -847,12 +847,12 @@ def _get_device_instance(
847847
):
848848
encrypt_type = encrypt_info.sym_schm
849849

850-
if (
851-
not (login_version := encrypt_schm.lv)
852-
and (et := discovery_result.encrypt_type)
853-
and et == ["3"]
850+
if not (login_version := encrypt_schm.lv) and (
851+
et := discovery_result.encrypt_type
854852
):
855-
login_version = 2
853+
# Known encrypt types are ["1","2"] and ["3"]
854+
# Reuse the login_version attribute to pass the max to transport
855+
login_version = max([int(i) for i in et])
856856

857857
if not encrypt_type:
858858
raise UnsupportedDeviceError(

tests/test_device_factory.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,13 @@
1313
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
1414

1515
from kasa import (
16+
BaseProtocol,
1617
Credentials,
1718
Discover,
19+
IotProtocol,
1820
KasaException,
21+
SmartCamProtocol,
22+
SmartProtocol,
1923
)
2024
from kasa.device_factory import (
2125
Device,
@@ -33,6 +37,16 @@
3337
DeviceFamily,
3438
)
3539
from kasa.discover import DiscoveryResult
40+
from kasa.transports import (
41+
AesTransport,
42+
BaseTransport,
43+
KlapTransport,
44+
KlapTransportV2,
45+
LinkieTransportV2,
46+
SslAesTransport,
47+
SslTransport,
48+
XorTransport,
49+
)
3650

3751
from .conftest import DISCOVERY_MOCK_IP
3852

@@ -203,3 +217,74 @@ async def test_device_class_from_unknown_family(caplog):
203217
with caplog.at_level(logging.DEBUG):
204218
assert get_device_class_from_family(dummy_name, https=False) == SmartDevice
205219
assert f"Unknown SMART device with {dummy_name}" in caplog.text
220+
221+
222+
# Aliases to make the test params more readable
223+
CP = DeviceConnectionParameters
224+
DF = DeviceFamily
225+
ET = DeviceEncryptionType
226+
227+
228+
@pytest.mark.parametrize(
229+
("conn_params", "expected_protocol", "expected_transport"),
230+
[
231+
pytest.param(
232+
CP(DF.SmartIpCamera, ET.Aes, https=True),
233+
SmartCamProtocol,
234+
SslAesTransport,
235+
id="smartcam",
236+
),
237+
pytest.param(
238+
CP(DF.SmartTapoHub, ET.Aes, https=True),
239+
SmartCamProtocol,
240+
SslAesTransport,
241+
id="smartcam-hub",
242+
),
243+
pytest.param(
244+
CP(DF.IotIpCamera, ET.Aes, https=True),
245+
IotProtocol,
246+
LinkieTransportV2,
247+
id="kasacam",
248+
),
249+
pytest.param(
250+
CP(DF.SmartTapoRobovac, ET.Aes, https=True),
251+
SmartProtocol,
252+
SslTransport,
253+
id="robovac",
254+
),
255+
pytest.param(
256+
CP(DF.IotSmartPlugSwitch, ET.Klap, https=False),
257+
IotProtocol,
258+
KlapTransport,
259+
id="iot-klap",
260+
),
261+
pytest.param(
262+
CP(DF.IotSmartPlugSwitch, ET.Xor, https=False),
263+
IotProtocol,
264+
XorTransport,
265+
id="iot-xor",
266+
),
267+
pytest.param(
268+
CP(DF.SmartTapoPlug, ET.Aes, https=False),
269+
SmartProtocol,
270+
AesTransport,
271+
id="smart-aes",
272+
),
273+
pytest.param(
274+
CP(DF.SmartTapoPlug, ET.Klap, https=False),
275+
SmartProtocol,
276+
KlapTransportV2,
277+
id="smart-klap",
278+
),
279+
],
280+
)
281+
async def test_get_protocol(
282+
conn_params: DeviceConnectionParameters,
283+
expected_protocol: type[BaseProtocol],
284+
expected_transport: type[BaseTransport],
285+
):
286+
"""Test get_protocol returns the right protocol."""
287+
config = DeviceConfig("127.0.0.1", connection_type=conn_params)
288+
protocol = get_protocol(config)
289+
assert isinstance(protocol, expected_protocol)
290+
assert isinstance(protocol._transport, expected_transport)

0 commit comments

Comments
 (0)