Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion kasa/smart/modules/led.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class Led(SmartModule, LedInterface):

def query(self) -> dict:
"""Query to execute during the update cycle."""
return {self.QUERY_GETTER_NAME: {"led_rule": None}}
return {self.QUERY_GETTER_NAME: None}

@property
def mode(self):
Expand Down
3 changes: 3 additions & 0 deletions kasa/smart/modules/lightpreset.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ def query(self) -> dict:
"""Query to execute during the update cycle."""
if self._state_in_sysinfo: # Child lights can have states in the child info
return {}
if self.supported_version < 3:
return {self.QUERY_GETTER_NAME: None}

return {self.QUERY_GETTER_NAME: {"start_index": 0}}

async def _check_supported(self):
Expand Down
2 changes: 1 addition & 1 deletion kasa/smart/modules/lighttransition.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def query(self) -> dict:
if self._state_in_sysinfo:
return {}
else:
return {self.QUERY_GETTER_NAME: {}}
return {self.QUERY_GETTER_NAME: None}

async def _check_supported(self):
"""Additional check to see if the module is supported by the device."""
Expand Down
93 changes: 8 additions & 85 deletions kasa/smartprotocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def __init__(
"""Create a protocol object."""
super().__init__(transport=transport)
self._terminal_uuid: str = base64.b64encode(md5(uuid.uuid4().bytes)).decode()
self._request_id_generator = SnowflakeId(1, 1)
self._query_lock = asyncio.Lock()
self._multi_request_batch_size = (
self._transport._config.batch_size or self.DEFAULT_MULTI_REQUEST_BATCH_SIZE
Expand All @@ -77,11 +76,11 @@ def get_smart_request(self, method, params=None) -> str:
"""Get a request message as a string."""
request = {
"method": method,
"params": params,
"requestID": self._request_id_generator.generate_id(),
"request_time_milis": round(time.time() * 1000),
"terminal_uuid": self._terminal_uuid,
}
if params:
request["params"] = params
return json_dumps(request)

async def query(self, request: str | dict, retry_count: int = 3) -> dict:
Expand Down Expand Up @@ -157,8 +156,10 @@ async def _execute_multiple_query(self, requests: dict, retry_count: int) -> dic
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
multi_result: dict[str, Any] = {}
smart_method = "multipleRequest"

multi_requests = [
{"method": method, "params": params} for method, params in requests.items()
{"method": method, "params": params} if params else {"method": method}
for method, params in requests.items()
]

end = len(multi_requests)
Expand All @@ -168,7 +169,7 @@ async def _execute_multiple_query(self, requests: dict, retry_count: int) -> dic
# If step is 1 do not send request batches
for request in multi_requests:
method = request["method"]
req = self.get_smart_request(method, request["params"])
req = self.get_smart_request(method, request.get("params"))
resp = await self._transport.send(req)
self._handle_response_error_code(resp, method, raise_on_error=False)
multi_result[method] = resp["result"]
Expand Down Expand Up @@ -347,86 +348,6 @@ async def close(self) -> None:
await self._transport.close()


class SnowflakeId:
"""Class for generating snowflake ids."""

EPOCH = 1420041600000 # Custom epoch (in milliseconds)
WORKER_ID_BITS = 5
DATA_CENTER_ID_BITS = 5
SEQUENCE_BITS = 12

MAX_WORKER_ID = (1 << WORKER_ID_BITS) - 1
MAX_DATA_CENTER_ID = (1 << DATA_CENTER_ID_BITS) - 1

SEQUENCE_MASK = (1 << SEQUENCE_BITS) - 1

def __init__(self, worker_id, data_center_id):
if worker_id > SnowflakeId.MAX_WORKER_ID or worker_id < 0:
raise ValueError(
"Worker ID can't be greater than "
+ str(SnowflakeId.MAX_WORKER_ID)
+ " or less than 0"
)
if data_center_id > SnowflakeId.MAX_DATA_CENTER_ID or data_center_id < 0:
raise ValueError(
"Data center ID can't be greater than "
+ str(SnowflakeId.MAX_DATA_CENTER_ID)
+ " or less than 0"
)

self.worker_id = worker_id
self.data_center_id = data_center_id
self.sequence = 0
self.last_timestamp = -1

def generate_id(self):
"""Generate a snowflake id."""
timestamp = self._current_millis()

if timestamp < self.last_timestamp:
raise ValueError("Clock moved backwards. Refusing to generate ID.")

if timestamp == self.last_timestamp:
# Within the same millisecond, increment the sequence number
self.sequence = (self.sequence + 1) & SnowflakeId.SEQUENCE_MASK
if self.sequence == 0:
# Sequence exceeds its bit range, wait until the next millisecond
timestamp = self._wait_next_millis(self.last_timestamp)
else:
# New millisecond, reset the sequence number
self.sequence = 0

# Update the last timestamp
self.last_timestamp = timestamp

# Generate and return the final ID
return (
(
(timestamp - SnowflakeId.EPOCH)
<< (
SnowflakeId.WORKER_ID_BITS
+ SnowflakeId.SEQUENCE_BITS
+ SnowflakeId.DATA_CENTER_ID_BITS
)
)
| (
self.data_center_id
<< (SnowflakeId.SEQUENCE_BITS + SnowflakeId.WORKER_ID_BITS)
)
| (self.worker_id << SnowflakeId.SEQUENCE_BITS)
| self.sequence
)

def _current_millis(self):
return round(time.monotonic() * 1000)

def _wait_next_millis(self, last_timestamp):
timestamp = self._current_millis()
while timestamp <= last_timestamp:
timestamp = self._current_millis()
return timestamp


class _ChildProtocolWrapper(SmartProtocol):
"""Protocol wrapper for controlling child devices.

Expand Down Expand Up @@ -456,6 +377,8 @@ def _get_method_and_params_for_request(self, request):
smart_method = "multipleRequest"
requests = [
{"method": method, "params": params}
if params
else {"method": method}
for method, params in request.items()
]
smart_params = {"requests": requests}
Expand Down
10 changes: 6 additions & 4 deletions kasa/tests/fakeprotocol_smart.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,9 @@ def credentials_hash(self):
async def send(self, request: str):
request_dict = json_loads(request)
method = request_dict["method"]
params = request_dict["params"]

if method == "multipleRequest":
params = request_dict["params"]
responses = []
for request in params["requests"]:
response = self._send_request(request) # type: ignore[arg-type]
Expand Down Expand Up @@ -308,12 +309,13 @@ def _edit_preset_rules(self, info, params):

def _send_request(self, request_dict: dict):
method = request_dict["method"]
params = request_dict["params"]

info = self.info
if method == "control_child":
return self._handle_control_child(params)
elif method == "component_nego" or method[:4] == "get_":
return self._handle_control_child(request_dict["params"])

params = request_dict.get("params")
if method == "component_nego" or method[:4] == "get_":
if method in info:
result = copy.deepcopy(info[method])
if "start_index" in result and "sum" in result:
Expand Down