Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions devtools/dump_devinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ async def cli(
click.echo("Host and discovery info given, trying connect on %s." % host)

di = json.loads(discovery_info)
dr = DiscoveryResult(**di)
dr = DiscoveryResult.from_dict(di)
connection_type = DeviceConnectionParameters.from_values(
dr.device_type,
dr.mgt_encrypt_schm.encrypt_type,
Expand All @@ -336,7 +336,7 @@ async def cli(
basedir,
autosave,
device.protocol,
discovery_info=dr.get_dict(),
discovery_info=dr.to_dict(),
batch_size=batch_size,
)
elif device_family and encrypt_type:
Expand Down Expand Up @@ -443,7 +443,7 @@ async def get_legacy_fixture(protocol, *, discovery_info):
if discovery_info and not discovery_info.get("system"):
# Need to recreate a DiscoverResult here because we don't want the aliases
# in the fixture, we want the actual field names as returned by the device.
dr = DiscoveryResult(**protocol._discovery_info)
dr = DiscoveryResult.from_dict(protocol._discovery_info)
final["discovery_result"] = dr.dict(
by_alias=False, exclude_unset=True, exclude_none=True, exclude_defaults=True
)
Expand Down Expand Up @@ -960,10 +960,8 @@ async def get_smart_fixtures(
# Need to recreate a DiscoverResult here because we don't want the aliases
# in the fixture, we want the actual field names as returned by the device.
if discovery_info:
dr = DiscoveryResult(**discovery_info) # type: ignore
final["discovery_result"] = dr.dict(
by_alias=False, exclude_unset=True, exclude_none=True, exclude_defaults=True
)
dr = DiscoveryResult.from_dict(discovery_info) # type: ignore
final["discovery_result"] = dr.to_dict()

click.echo("Got %s successes" % len(successes))
click.echo(click.style("## device info file ##", bold=True))
Expand Down
2 changes: 1 addition & 1 deletion kasa/cli/discover.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def _echo_discovery_info(discovery_info) -> None:
return

try:
dr = DiscoveryResult(**discovery_info)
dr = DiscoveryResult.from_dict(discovery_info)
except ValidationError:
_echo_dictionary(discovery_info)
return
Expand Down
73 changes: 45 additions & 28 deletions kasa/discover.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
import socket
import struct
from asyncio.transports import DatagramTransport
from dataclasses import dataclass, field
from pprint import pformat as pf
from typing import (
TYPE_CHECKING,
Expand All @@ -108,7 +109,8 @@
# When support for cpython older than 3.11 is dropped
# async_timeout can be replaced with asyncio.timeout
from async_timeout import timeout as asyncio_timeout
from pydantic.v1 import BaseModel, ValidationError
from mashumaro import field_options
from mashumaro.config import BaseConfig

from kasa import Device
from kasa.credentials import Credentials
Expand All @@ -130,6 +132,7 @@
from kasa.experimental import Experimental
from kasa.iot.iotdevice import IotDevice
from kasa.iotprotocol import REDACTORS as IOT_REDACTORS
from kasa.json import DataClassJSONMixin
from kasa.json import dumps as json_dumps
from kasa.json import loads as json_loads
from kasa.protocol import mask_mac, redact_data
Expand Down Expand Up @@ -647,7 +650,7 @@
def _get_device_class(info: dict) -> type[Device]:
"""Find SmartDevice subclass for device described by passed data."""
if "result" in info:
discovery_result = DiscoveryResult(**info["result"])
discovery_result = DiscoveryResult.from_dict(info["result"])
https = discovery_result.mgt_encrypt_schm.is_support_https
dev_class = get_device_class_from_family(
discovery_result.device_type, https=https
Expand Down Expand Up @@ -721,12 +724,8 @@
f"Unable to read response from device: {config.host}: {ex}"
) from ex
try:
discovery_result = DiscoveryResult(**info["result"])
if (
encrypt_info := discovery_result.encrypt_info
) and encrypt_info.sym_schm == "AES":
Discover._decrypt_discovery_data(discovery_result)
except ValidationError as ex:
discovery_result = DiscoveryResult.from_dict(info["result"])
except Exception as ex:
if debug_enabled:
data = (
redact_data(info, NEW_DISCOVERY_REDACTORS)
Expand All @@ -742,6 +741,16 @@
f"Unable to parse discovery from device: {config.host}: {ex}",
host=config.host,
) from ex
# Decrypt the data
if (
encrypt_info := discovery_result.encrypt_info
) and encrypt_info.sym_schm == "AES":
try:
Discover._decrypt_discovery_data(discovery_result)
except Exception:
_LOGGER.exception(

Check warning on line 751 in kasa/discover.py

View check run for this annotation

Codecov / codecov/patch

kasa/discover.py#L748-L751

Added lines #L748 - L751 were not covered by tests
"Unable to decrypt discovery data %s: %s", config.host, data
)

type_ = discovery_result.device_type
encrypt_schm = discovery_result.mgt_encrypt_schm
Expand All @@ -754,7 +763,7 @@
raise UnsupportedDeviceError(
f"Unsupported device {config.host} of type {type_} "
+ "with no encryption type",
discovery_result=discovery_result.get_dict(),
discovery_result=discovery_result.to_dict(),
host=config.host,
)
config.connection_type = DeviceConnectionParameters.from_values(
Expand All @@ -767,7 +776,7 @@
raise UnsupportedDeviceError(
f"Unsupported device {config.host} of type {type_} "
+ f"with encrypt_type {discovery_result.mgt_encrypt_schm.encrypt_type}",
discovery_result=discovery_result.get_dict(),
discovery_result=discovery_result.to_dict(),
host=config.host,
) from ex
if (
Expand All @@ -778,7 +787,7 @@
_LOGGER.warning("Got unsupported device type: %s", type_)
raise UnsupportedDeviceError(
f"Unsupported device {config.host} of type {type_}: {info}",
discovery_result=discovery_result.get_dict(),
discovery_result=discovery_result.to_dict(),
host=config.host,
)
if (protocol := get_protocol(config)) is None:
Expand All @@ -788,7 +797,7 @@
raise UnsupportedDeviceError(
f"Unsupported encryption scheme {config.host} of "
+ f"type {config.connection_type.to_dict()}: {info}",
discovery_result=discovery_result.get_dict(),
discovery_result=discovery_result.to_dict(),
host=config.host,
)

Expand All @@ -801,42 +810,59 @@
_LOGGER.debug("[DISCOVERY] %s << %s", config.host, pf(data))
device = device_class(config.host, protocol=protocol)

di = discovery_result.get_dict()
di = discovery_result.to_dict()
di["model"], _, _ = discovery_result.device_model.partition("(")
device.update_from_discover_info(di)
return device


class EncryptionScheme(BaseModel):
class _DiscoveryBaseMixin(DataClassJSONMixin):
"""Base class for serialization mixin."""

class Config(BaseConfig):
"""Serialization config."""

omit_none = True
omit_default = True
serialize_by_alias = True


@dataclass
class EncryptionScheme(_DiscoveryBaseMixin):
"""Base model for encryption scheme of discovery result."""

is_support_https: bool
encrypt_type: Optional[str] # noqa: UP007
encrypt_type: Optional[str] = None # noqa: UP007
http_port: Optional[int] = None # noqa: UP007
lv: Optional[int] = None # noqa: UP007


class EncryptionInfo(BaseModel):
@dataclass
class EncryptionInfo(_DiscoveryBaseMixin):
"""Base model for encryption info of discovery result."""

sym_schm: str
key: str
data: str


class DiscoveryResult(BaseModel):
@dataclass
class DiscoveryResult(_DiscoveryBaseMixin):
"""Base model for discovery result."""

device_type: str
device_model: str
device_name: Optional[str] # noqa: UP007
device_id: str
ip: str
mac: str
mgt_encrypt_schm: EncryptionScheme
device_name: Optional[str] = None # noqa: UP007
encrypt_info: Optional[EncryptionInfo] = None # noqa: UP007
encrypt_type: Optional[list[str]] = None # noqa: UP007
decrypted_data: Optional[dict] = None # noqa: UP007
device_id: str
is_reset_wifi: Optional[bool] = field( # noqa: UP007
metadata=field_options(alias="isResetWiFi"), default=None
)

firmware_version: Optional[str] = None # noqa: UP007
hardware_version: Optional[str] = None # noqa: UP007
Expand All @@ -845,12 +871,3 @@
is_support_iot_cloud: Optional[bool] = None # noqa: UP007
obd_src: Optional[str] = None # noqa: UP007
factory_default: Optional[bool] = None # noqa: UP007

def get_dict(self) -> dict:
"""Return a dict for this discovery result.

containing only the values actually set and with aliases as field names.
"""
return self.dict(
by_alias=False, exclude_unset=True, exclude_none=True, exclude_defaults=True
)
10 changes: 10 additions & 0 deletions kasa/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,13 @@ def dumps(obj: Any, *, default: Callable | None = None) -> str:
return json.dumps(obj, separators=(",", ":"))

loads = json.loads


try:
from mashumaro.mixins.orjson import DataClassORJSONMixin

DataClassJSONMixin = DataClassORJSONMixin
except ImportError:
from mashumaro.mixins.json import DataClassJSONMixin as JSONMixin

DataClassJSONMixin = JSONMixin # type: ignore[assignment, misc]
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ dependencies = [
"aiohttp>=3",
"typing-extensions>=4.12.2,<5.0",
"tzdata>=2024.2 ; platform_system == 'Windows'",
"mashumaro>=3.14",
]

classifiers = [
Expand Down
2 changes: 1 addition & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ async def _state(dev: Device):

mocker.patch("kasa.cli.device.state", new=_state)

dr = DiscoveryResult(**discovery_mock.discovery_data["result"])
dr = DiscoveryResult.from_dict(discovery_mock.discovery_data["result"])
res = await runner.invoke(
cli,
[
Expand Down
2 changes: 1 addition & 1 deletion tests/test_device_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
def _get_connection_type_device_class(discovery_info):
if "result" in discovery_info:
device_class = Discover._get_device_class(discovery_info)
dr = DiscoveryResult(**discovery_info["result"])
dr = DiscoveryResult.from_dict(discovery_info["result"])

connection_type = DeviceConnectionParameters.from_values(
dr.device_type, dr.mgt_encrypt_schm.encrypt_type
Expand Down
6 changes: 3 additions & 3 deletions tests/test_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,8 +391,8 @@ async def test_device_update_from_new_discovery_info(discovery_mock):
discovery_data = discovery_mock.discovery_data
device_class = Discover._get_device_class(discovery_data)
device = device_class("127.0.0.1")
discover_info = DiscoveryResult(**discovery_data["result"])
discover_dump = discover_info.get_dict()
discover_info = DiscoveryResult.from_dict(discovery_data["result"])
discover_dump = discover_info.to_dict()
model, _, _ = discover_dump["device_model"].partition("(")
discover_dump["model"] = model
device.update_from_discover_info(discover_dump)
Expand Down Expand Up @@ -652,7 +652,7 @@ async def test_discovery_decryption():
"sym_schm": "AES",
}
info = {**UNSUPPORTED["result"], "encrypt_info": encrypt_info}
dr = DiscoveryResult(**info)
dr = DiscoveryResult.from_dict(info)
Discover._decrypt_discovery_data(dr)
assert dr.decrypted_data == data_dict

Expand Down
14 changes: 14 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading