Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 55 additions & 87 deletions kasa/deviceconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,32 @@
>>> config_dict = device.config.to_dict()
>>> # DeviceConfig.to_dict() can be used to store for later
>>> print(config_dict)
{'host': '127.0.0.3', 'timeout': 5, 'credentials': Credentials(), 'connection_type'\
: {'device_family': 'SMART.TAPOBULB', 'encryption_type': 'KLAP', 'https': False, \
'login_version': 2}, 'uses_http': True}
{'host': '127.0.0.3', 'timeout': 5, 'credentials': {'username': 'user@example.com', \
'password': 'great_password'}, 'connection_type'\
: {'device_family': 'SMART.TAPOBULB', 'encryption_type': 'KLAP', 'login_version': 2, \
'https': False}, 'uses_http': True}
Comment on lines -20 to +23
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Identified a bug here whereby credentials were not actually being fully serialized previously.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's sort of expected with any custom (de)serialization code, so thanks a lot for doing the conversion! :-)


>>> later_device = await Device.connect(config=Device.Config.from_dict(config_dict))
>>> print(later_device.alias) # Alias is available as connect() calls update()
Living Room Bulb

"""

# Module cannot use from __future__ import annotations until migrated to mashumaru
# as dataclass.fields() will not resolve the type.
from __future__ import annotations

import logging
from dataclasses import asdict, dataclass, field, fields, is_dataclass
from dataclasses import dataclass, field, replace
from enum import Enum
from typing import TYPE_CHECKING, Any, Optional, TypedDict
from typing import TYPE_CHECKING, Any, Self, TypedDict

from aiohttp import ClientSession
from mashumaro import field_options
from mashumaro.config import BaseConfig
from mashumaro.types import SerializationStrategy

from .credentials import Credentials
from .exceptions import KasaException
from .json import DataClassJSONMixin

if TYPE_CHECKING:
from aiohttp import ClientSession
Expand Down Expand Up @@ -73,45 +80,17 @@ class DeviceFamily(Enum):
SmartIpCamera = "SMART.IPCAMERA"


def _dataclass_from_dict(klass: Any, in_val: dict) -> Any:
if is_dataclass(klass):
fieldtypes = {f.name: f.type for f in fields(klass)}
val = {}
for dict_key in in_val:
if dict_key in fieldtypes:
if hasattr(fieldtypes[dict_key], "from_dict"):
val[dict_key] = fieldtypes[dict_key].from_dict(in_val[dict_key]) # type: ignore[union-attr]
else:
val[dict_key] = _dataclass_from_dict(
fieldtypes[dict_key], in_val[dict_key]
)
else:
raise KasaException(
f"Cannot create dataclass from dict, unknown key: {dict_key}"
)
return klass(**val) # type: ignore[operator]
else:
return in_val


def _dataclass_to_dict(in_val: Any) -> dict:
fieldtypes = {f.name: f.type for f in fields(in_val) if f.compare}
out_val = {}
for field_name in fieldtypes:
val = getattr(in_val, field_name)
if val is None:
continue
elif hasattr(val, "to_dict"):
out_val[field_name] = val.to_dict()
elif is_dataclass(fieldtypes[field_name]):
out_val[field_name] = asdict(val)
else:
out_val[field_name] = val
return out_val
class _DeviceConfigBaseMixin(DataClassJSONMixin):
"""Base class for serialization mixin."""

class Config(BaseConfig):
"""Serialization config."""

omit_none = True


@dataclass
class DeviceConnectionParameters:
class DeviceConnectionParameters(_DeviceConfigBaseMixin):
"""Class to hold the the parameters determining connection type."""

device_family: DeviceFamily
Expand All @@ -125,7 +104,7 @@ def from_values(
encryption_type: str,
login_version: int | None = None,
https: bool | None = None,
) -> "DeviceConnectionParameters":
) -> DeviceConnectionParameters:
"""Return connection parameters from string values."""
try:
if https is None:
Expand All @@ -142,39 +121,17 @@ def from_values(
+ f"{encryption_type}.{login_version}"
) from ex

@staticmethod
def from_dict(connection_type_dict: dict[str, Any]) -> "DeviceConnectionParameters":
"""Return connection parameters from dict."""
if (
isinstance(connection_type_dict, dict)
and (device_family := connection_type_dict.get("device_family"))
and (encryption_type := connection_type_dict.get("encryption_type"))
):
if login_version := connection_type_dict.get("login_version"):
login_version = int(login_version) # type: ignore[assignment]
return DeviceConnectionParameters.from_values(
device_family,
encryption_type,
login_version, # type: ignore[arg-type]
connection_type_dict.get("https", False),
)

raise KasaException(f"Invalid connection type data for {connection_type_dict}")
class _DoNotSerialize(SerializationStrategy):
def serialize(self, value: Any) -> None:
return None # pragma: no cover

def to_dict(self) -> dict[str, str | int | bool]:
"""Convert connection params to dict."""
result: dict[str, str | int] = {
"device_family": self.device_family.value,
"encryption_type": self.encryption_type.value,
"https": self.https,
}
if self.login_version:
result["login_version"] = self.login_version
return result
def deserialize(self, value: Any) -> None:
return None # pragma: no cover


@dataclass
class DeviceConfig:
class DeviceConfig(_DeviceConfigBaseMixin):
"""Class to represent paramaters that determine how to connect to devices."""

DEFAULT_TIMEOUT = 5
Expand Down Expand Up @@ -202,9 +159,12 @@ class DeviceConfig:
#: in order to determine whether they should pass a custom http client if desired.
uses_http: bool = False

# compare=False will be excluded from the serialization and object comparison.
#: Set a custom http_client for the device to use.
http_client: Optional["ClientSession"] = field(default=None, compare=False)
http_client: ClientSession | None = field(
default=None,
compare=False,
metadata=field_options(serialization_strategy=_DoNotSerialize()),
)

aes_keys: KeyPairDict | None = None

Expand All @@ -214,22 +174,30 @@ def __post_init__(self) -> None:
DeviceFamily.IotSmartPlugSwitch, DeviceEncryptionType.Xor
)

def to_dict(
def __pre_serialize__(self) -> Self:
return replace(self, http_client=None)

def to_dict_control_credentials(
self,
*,
credentials_hash: str | None = None,
exclude_credentials: bool = False,
) -> dict[str, dict[str, str]]:
"""Convert device config to dict."""
if credentials_hash is not None or exclude_credentials:
self.credentials = None
if credentials_hash:
self.credentials_hash = credentials_hash
return _dataclass_to_dict(self)
"""Convert deviceconfig to dict controlling how to serialize credentials.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a real use case for not using the default and avoid having extra code to strip this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't really have a case ourselves anymore as we don't serialize the whole deviceconfig in HA anymore. Maybe we could remove this in a subsequent PR or mark it for deprecation. Also perhaps the default should be to not serialize the credentials as well?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think we could remove it in the future, but it's not that critical. I would perhaps avoid making it too magical, but I'd guess skipping pw&user in favor of credentials_hash makes sense to avoid storing clear text creds (but would break on protocol updates).


If credentials_hash is provided credentials will be None.
If credentials_hash is '' credentials_hash and credentials will be None.
exclude credentials controls whether to include credentials.
The defaults are the same as calling to_dict().
"""
if credentials_hash is None:
if not exclude_credentials:
return self.to_dict()
else:
return replace(self, credentials=None).to_dict()

@staticmethod
def from_dict(config_dict: dict[str, dict[str, str]]) -> "DeviceConfig":
"""Return device config from dict."""
if isinstance(config_dict, dict):
return _dataclass_from_dict(DeviceConfig, config_dict)
raise KasaException(f"Invalid device config data: {config_dict}")
return replace(
self,
credentials_hash=credentials_hash if credentials_hash else None,
credentials=None,
).to_dict()
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import sys
import warnings
from pathlib import Path
from unittest.mock import MagicMock, patch

import pytest
Expand All @@ -21,6 +22,13 @@
turn_on = pytest.mark.parametrize("turn_on", [True, False])


def load_fixture(foldername, filename):
"""Load a fixture."""
path = Path(Path(__file__).parent / "fixtures" / foldername / filename)
with path.open() as fdp:
return fdp.read()


async def handle_turn_on(dev, turn_on):
if turn_on:
await dev.turn_on()
Expand Down
10 changes: 10 additions & 0 deletions tests/fixtures/serialization/deviceconfig_camera-aes-https.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"host": "127.0.0.1",
"timeout": 5,
"connection_type": {
"device_family": "SMART.IPCAMERA",
"encryption_type": "AES",
"https": true
},
"uses_http": false
}
11 changes: 11 additions & 0 deletions tests/fixtures/serialization/deviceconfig_plug-klap.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"host": "127.0.0.1",
"timeout": 5,
"connection_type": {
"device_family": "SMART.TAPOPLUG",
"encryption_type": "KLAP",
"https": false,
"login_version": 2
},
"uses_http": false
}
10 changes: 10 additions & 0 deletions tests/fixtures/serialization/deviceconfig_plug-xor.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"host": "127.0.0.1",
"timeout": 5,
"connection_type": {
"device_family": "IOT.SMARTPLUGSWITCH",
"encryption_type": "XOR",
"https": false
},
"uses_http": false
}
80 changes: 71 additions & 9 deletions tests/test_deviceconfig.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,97 @@
import json
from dataclasses import replace
from json import dumps as json_dumps
from json import loads as json_loads

import aiohttp
import pytest
from mashumaro import MissingField

from kasa.credentials import Credentials
from kasa.deviceconfig import (
DeviceConfig,
DeviceConnectionParameters,
DeviceEncryptionType,
DeviceFamily,
)

from .conftest import load_fixture

PLUG_XOR_CONFIG = DeviceConfig(host="127.0.0.1")
PLUG_KLAP_CONFIG = DeviceConfig(
host="127.0.0.1",
connection_type=DeviceConnectionParameters(
DeviceFamily.SmartTapoPlug, DeviceEncryptionType.Klap, login_version=2
),
)
CAMERA_AES_CONFIG = DeviceConfig(
host="127.0.0.1",
connection_type=DeviceConnectionParameters(
DeviceFamily.SmartIpCamera, DeviceEncryptionType.Aes, https=True
),
)
from kasa.exceptions import KasaException


async def test_serialization():
"""Test device config serialization."""
config = DeviceConfig(host="Foo", http_client=aiohttp.ClientSession())
config_dict = config.to_dict()
config_json = json_dumps(config_dict)
config2_dict = json_loads(config_json)
config2 = DeviceConfig.from_dict(config2_dict)
assert config == config2
assert config.to_dict_control_credentials() == config.to_dict()


@pytest.mark.parametrize(
("fixture_name", "expected_value"),
[
("deviceconfig_plug-xor.json", PLUG_XOR_CONFIG),
("deviceconfig_plug-klap.json", PLUG_KLAP_CONFIG),
("deviceconfig_camera-aes-https.json", CAMERA_AES_CONFIG),
],
ids=lambda arg: arg.split("_")[-1] if isinstance(arg, str) else "",
)
async def test_deserialization(fixture_name: str, expected_value: DeviceConfig):
"""Test device config deserialization."""
dict_val = json.loads(load_fixture("serialization", fixture_name))
config = DeviceConfig.from_dict(dict_val)
assert config == expected_value
assert expected_value.to_dict() == dict_val


async def test_serialization_http_client():
"""Test that the http client does not try to serialize."""
dict_val = json.loads(load_fixture("serialization", "deviceconfig_plug-klap.json"))

config = replace(PLUG_KLAP_CONFIG, http_client=object())
assert config.http_client

assert config.to_dict() == dict_val


async def test_conn_param_no_https():
"""Test no https in connection param defaults to False."""
dict_val = {
"device_family": "SMART.TAPOPLUG",
"encryption_type": "KLAP",
"login_version": 2,
}
param = DeviceConnectionParameters.from_dict(dict_val)
assert param.https is False
assert param.to_dict() == {**dict_val, "https": False}


@pytest.mark.parametrize(
("input_value", "expected_msg"),
("input_value", "expected_error"),
[
({"Foo": "Bar"}, "Cannot create dataclass from dict, unknown key: Foo"),
("foobar", "Invalid device config data: foobar"),
({"Foo": "Bar"}, MissingField),
("foobar", ValueError),
],
ids=["invalid-dict", "not-dict"],
)
def test_deserialization_errors(input_value, expected_msg):
with pytest.raises(KasaException, match=expected_msg):
def test_deserialization_errors(input_value, expected_error):
with pytest.raises(expected_error):
DeviceConfig.from_dict(input_value)


Expand All @@ -39,7 +101,7 @@ async def test_credentials_hash():
http_client=aiohttp.ClientSession(),
credentials=Credentials("foo", "bar"),
)
config_dict = config.to_dict(credentials_hash="credhash")
config_dict = config.to_dict_control_credentials(credentials_hash="credhash")
config_json = json_dumps(config_dict)
config2_dict = json_loads(config_json)
config2 = DeviceConfig.from_dict(config2_dict)
Expand All @@ -53,7 +115,7 @@ async def test_blank_credentials_hash():
http_client=aiohttp.ClientSession(),
credentials=Credentials("foo", "bar"),
)
config_dict = config.to_dict(credentials_hash="")
config_dict = config.to_dict_control_credentials(credentials_hash="")
config_json = json_dumps(config_dict)
config2_dict = json_loads(config_json)
config2 = DeviceConfig.from_dict(config2_dict)
Expand All @@ -67,7 +129,7 @@ async def test_exclude_credentials():
http_client=aiohttp.ClientSession(),
credentials=Credentials("foo", "bar"),
)
config_dict = config.to_dict(exclude_credentials=True)
config_dict = config.to_dict_control_credentials(exclude_credentials=True)
config_json = json_dumps(config_dict)
config2_dict = json_loads(config_json)
config2 = DeviceConfig.from_dict(config2_dict)
Expand Down