Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 26 additions & 9 deletions kasa/experimental/sslaestransport.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import asyncio
import base64
import hashlib
import logging
Expand All @@ -12,7 +13,6 @@
from functools import cache
from typing import TYPE_CHECKING, Any, Dict, cast

from urllib3.util import create_urllib3_context
from yarl import URL

from ..aestransport import AesEncyptionSession
Expand Down Expand Up @@ -108,11 +108,7 @@ def __init__(
self._host_port = f"{self._host}:{self._port}"
self._app_url = URL(f"https://{self._host_port}")
self._token_url: URL | None = None
self._ssl_context = create_urllib3_context(
ciphers=self.CIPHERS,
cert_reqs=ssl.CERT_NONE,
options=0,
)
self._ssl_context: ssl.SSLContext | None = None
ref = str(self._token_url) if self._token_url else str(self._app_url)
self._headers = {
**self.COMMON_HEADERS,
Expand Down Expand Up @@ -168,6 +164,21 @@ def _handle_response_error_code(self, resp_dict: Any, msg: str) -> None:
raise AuthenticationError(msg, error_code=error_code)
raise DeviceError(msg, error_code=error_code)

def _create_ssl_context(self) -> ssl.SSLContext:
context = ssl.SSLContext()
context.set_ciphers(self.CIPHERS)
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
return context

async def _get_ssl_context(self) -> ssl.SSLContext:
if not self._ssl_context:
loop = asyncio.get_running_loop()
self._ssl_context = await loop.run_in_executor(
None, self._create_ssl_context
)
return self._ssl_context

async def send_secure_passthrough(self, request: str) -> dict[str, Any]:
"""Send encrypted message as passthrough."""
if self._state is TransportState.ESTABLISHED and self._token_url:
Expand All @@ -194,7 +205,7 @@ async def send_secure_passthrough(self, request: str) -> dict[str, Any]:
url,
json=passthrough_request_str,
headers=headers,
ssl=self._ssl_context,
ssl=await self._get_ssl_context(),
)

if status_code != 200:
Expand Down Expand Up @@ -299,7 +310,10 @@ async def perform_handshake2(self, local_nonce, server_nonce, pwd_hash) -> None:
}
http_client = self._http_client
status_code, resp_dict = await http_client.post(
self._app_url, json=body, headers=self._headers, ssl=self._ssl_context
self._app_url,
json=body,
headers=self._headers,
ssl=await self._get_ssl_context(),
)
if status_code != 200:
raise KasaException(
Expand Down Expand Up @@ -337,7 +351,10 @@ async def perform_handshake1(self) -> tuple[str, str, str]:
http_client = self._http_client

status_code, resp_dict = await http_client.post(
self._app_url, json=body, headers=self._headers, ssl=self._ssl_context
self._app_url,
json=body,
headers=self._headers,
ssl=await self._get_ssl_context(),
)

_LOGGER.debug("Device responded with: %s", resp_dict)
Expand Down