Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion kasa/transports/sslaestransport.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,19 @@
error_code = SmartErrorCode.INTERNAL_UNKNOWN_ERROR
return error_code

def _get_response_inner_error(self, resp_dict: Any) -> SmartErrorCode | None:
error_code_raw = resp_dict.get("data", {}).get("code")
if error_code_raw is None:
return None
try:
error_code = SmartErrorCode.from_int(error_code_raw)
except ValueError:
_LOGGER.warning(

Check warning on line 170 in kasa/transports/sslaestransport.py

View check run for this annotation

Codecov / codecov/patch

kasa/transports/sslaestransport.py#L169-L170

Added lines #L169 - L170 were not covered by tests
"Device %s received unknown error code: %s", self._host, error_code_raw
)
error_code = SmartErrorCode.INTERNAL_UNKNOWN_ERROR

Check warning on line 173 in kasa/transports/sslaestransport.py

View check run for this annotation

Codecov / codecov/patch

kasa/transports/sslaestransport.py#L173

Added line #L173 was not covered by tests
return error_code

def _handle_response_error_code(self, resp_dict: Any, msg: str) -> None:
error_code = self._get_response_error(resp_dict)
if error_code is SmartErrorCode.SUCCESS:
Expand Down Expand Up @@ -383,13 +396,29 @@
error_code = default_error_code
resp_dict = default_resp_dict

# If the default login worked it's ok not to provide credentials but if
# it didn't raise auth error here.
if not self._username:
raise AuthenticationError(
f"Credentials must be supplied to connect to {self._host}"
)

# Device responds with INVALID_NONCE and a "nonce" to indicate ready
# for secure login. Otherwise error.
if error_code is not SmartErrorCode.INVALID_NONCE or (
resp_dict and "nonce" not in resp_dict["result"].get("data", {})
resp_dict and "nonce" not in resp_dict.get("result", {}).get("data", {})
):
if (
resp_dict
and self._get_response_inner_error(resp_dict)
is SmartErrorCode.DEVICE_BLOCKED
):
sec_left = resp_dict.get("data", {}).get("sec_left")
msg = "Device blocked" + (
f" for {sec_left} seconds" if sec_left else ""
)
raise DeviceError(msg, error_code=SmartErrorCode.DEVICE_BLOCKED)

raise AuthenticationError(f"Error trying handshake1: {resp_dict}")

if TYPE_CHECKING:
Expand Down
27 changes: 27 additions & 0 deletions tests/transports/test_sslaestransport.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from kasa.deviceconfig import DeviceConfig
from kasa.exceptions import (
AuthenticationError,
DeviceError,
KasaException,
SmartErrorCode,
)
Expand Down Expand Up @@ -200,6 +201,22 @@ async def test_unencrypted_response(mocker, caplog):
)


async def test_device_blocked_response(mocker):
host = "127.0.0.1"
mock_ssl_aes_device = MockSslAesDevice(host, device_blocked=True)
mocker.patch.object(
aiohttp.ClientSession, "post", side_effect=mock_ssl_aes_device.post
)

transport = SslAesTransport(
config=DeviceConfig(host, credentials=Credentials(MOCK_USER, MOCK_PWD))
)
msg = "Device blocked for 1685 seconds"

with pytest.raises(DeviceError, match=msg):
await transport.perform_handshake()


async def test_port_override():
"""Test that port override sets the app_url."""
host = "127.0.0.1"
Expand Down Expand Up @@ -235,6 +252,11 @@ class MockSslAesDevice:
},
}

DEVICE_BLOCKED_RESP = {
"data": {"code": SmartErrorCode.DEVICE_BLOCKED.value, "sec_left": 1685},
"error_code": SmartErrorCode.SESSION_EXPIRED.value,
}

class _mock_response:
def __init__(self, status, request: dict):
self.status = status
Expand Down Expand Up @@ -263,6 +285,7 @@ def __init__(
send_error_code=0,
secure_passthrough_error_code=0,
digest_password_fail=False,
device_blocked=False,
):
self.host = host
self.http_client = HttpClient(DeviceConfig(self.host))
Expand All @@ -277,6 +300,7 @@ def __init__(
self.do_not_encrypt_response = do_not_encrypt_response
self.want_default_username = want_default_username
self.digest_password_fail = digest_password_fail
self.device_blocked = device_blocked

async def post(self, url: URL, params=None, json=None, data=None, *_, **__):
if data:
Expand All @@ -303,6 +327,9 @@ async def _return_handshake1_response(self, url: URL, request: dict[str, Any]):
request_nonce = request["params"].get("cnonce")
request_username = request["params"].get("username")

if self.device_blocked:
return self._mock_response(self.status_code, self.DEVICE_BLOCKED_RESP)

if (self.want_default_username and request_username != MOCK_ADMIN_USER) or (
not self.want_default_username and request_username != MOCK_USER
):
Expand Down
Loading