Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion kasa/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ async def state(ctx, dev: SmartDevice):
"""Print out device state and versions."""
await dev.update()
click.echo(click.style(f"== {dev.alias} - {dev.model} ==", bold=True))
click.echo(f"\tHost: {dev.host}")
click.echo(f"\tHost: {dev.protocol.host}")
click.echo(
click.style(
"\tDevice state: {}\n".format("ON" if dev.is_on else "OFF"),
Expand Down
8 changes: 4 additions & 4 deletions kasa/discover.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(
self.timeout = timeout
self.interface = interface
self.on_discovered = on_discovered
self.protocol = TPLinkSmartHomeProtocol()
self.protocol = TPLinkSmartHomeProtocol(target)
self.target = (target, Discover.DISCOVERY_PORT)
self.discovered_devices = {}
self.discovered_devices_raw = {}
Expand Down Expand Up @@ -201,13 +201,13 @@ async def discover(
async def discover_single(host: str) -> SmartDevice:
"""Discover a single device by the given IP address.

:param host: Hostname of device to query
:param host: fname of device to query
:rtype: SmartDevice
:return: Object for querying/controlling found device.
"""
protocol = TPLinkSmartHomeProtocol()
protocol = TPLinkSmartHomeProtocol(host)

info = await protocol.query(host, Discover.DISCOVERY_QUERY)
info = await protocol.query(Discover.DISCOVERY_QUERY)

device_class = Discover._get_device_class(info)
if device_class is not None:
Expand Down
23 changes: 14 additions & 9 deletions kasa/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,26 +28,31 @@ class TPLinkSmartHomeProtocol:
DEFAULT_PORT = 9999
DEFAULT_TIMEOUT = 5

@staticmethod
async def query(host: str, request: Union[str, Dict], retry_count: int = 3) -> Dict:
"""Request information from a TP-Link SmartHome Device.
def __init__(self, host: str, retry_count: int = 3):
"""Initialize a new instance of protocol.

:param str host: host name or ip address of the device
:param int retry_count: how many retries to do in case of failure
"""
self.host = host
self.port = TPLinkSmartHomeProtocol.DEFAULT_PORT
self.retry_count = retry_count

async def query(self, request: Union[str, Dict]) -> Dict:
"""Request information from a TP-Link SmartHome Device.

:param request: command to send to the device (can be either dict or
json string)
:param retry_count: how many retries to do in case of failure
:return: response dict
"""
if isinstance(request, dict):
request = json.dumps(request)

timeout = TPLinkSmartHomeProtocol.DEFAULT_TIMEOUT
writer = None
for retry in range(retry_count + 1):
for retry in range(self.retry_count + 1):
try:
task = asyncio.open_connection(
host, TPLinkSmartHomeProtocol.DEFAULT_PORT
)
task = asyncio.open_connection(self.host, self.port)
reader, writer = await asyncio.wait_for(task, timeout=timeout)
_LOGGER.debug("> (%i) %s", len(request), request)
writer.write(TPLinkSmartHomeProtocol.encrypt(request))
Expand All @@ -73,7 +78,7 @@ async def query(host: str, request: Union[str, Dict], retry_count: int = 3) -> D
return json_payload

except Exception as ex:
if retry >= retry_count:
if retry >= self.retry_count:
_LOGGER.debug("Giving up after %s retries", retry)
raise SmartDeviceException(
"Unable to query the device: %s" % ex
Expand Down
14 changes: 6 additions & 8 deletions kasa/smartdevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,11 +218,9 @@ def __init__(self, host: str) -> None:

:param str host: host name or ip address on which the device listens
"""
self.host = host

self.protocol = TPLinkSmartHomeProtocol()
self.protocol = TPLinkSmartHomeProtocol(host)
self.emeter_type = "emeter"
_LOGGER.debug("Initializing %s of type %s", self.host, type(self))
_LOGGER.debug("Initializing %s of type %s", self.protocol.host, type(self))
self._device_type = DeviceType.Unknown
# TODO: typing Any is just as using Optional[Dict] would require separate checks in
# accessors. the @updated_required decorator does not ensure mypy that these
Expand Down Expand Up @@ -255,7 +253,7 @@ async def _query_helper(
request = self._create_request(target, cmd, arg, child_ids)

try:
response = await self.protocol.query(host=self.host, request=request)
response = await self.protocol.query(request=request)
except Exception as ex:
raise SmartDeviceException(f"Communication error on {target}:{cmd}") from ex

Expand Down Expand Up @@ -300,7 +298,7 @@ async def update(self):
# Check for emeter if we were never updated, or if the device has emeter
if self._last_update is None or self.has_emeter:
req.update(self._create_emeter_request())
self._last_update = await self.protocol.query(self.host, req)
self._last_update = await self.protocol.query(req)
# TODO: keep accessible for tests
self._sys_info = self._last_update["system"]["get_sysinfo"]

Expand Down Expand Up @@ -741,5 +739,5 @@ def is_color(self) -> bool:

def __repr__(self):
if self._last_update is None:
return f"<{self._device_type} at {self.host} - update() needed>"
return f"<{self._device_type} model {self.model} at {self.host} ({self.alias}), is_on: {self.is_on} - dev specific: {self.state_information}>"
return f"<{self._device_type} at {self.protocol.host} - update() needed>"
return f"<{self._device_type} model {self.model} at {self.protocol.host} ({self.alias}), is_on: {self.is_on} - dev specific: {self.state_information}>"
4 changes: 3 additions & 1 deletion kasa/smartstrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ async def update(self):
_LOGGER.debug("Initializing %s child sockets", len(children))
for child in children:
self.children.append(
SmartStripPlug(self.host, parent=self, child_id=child["id"])
SmartStripPlug(
self.protocol.host, parent=self, child_id=child["id"]
)
)

async def turn_on(self, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion kasa/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def get_device_for_file(file):
sysinfo = json.load(f)
model = basename(file)
p = device_for_file(model)(host="123.123.123.123")
p.protocol = FakeTransportProtocol(sysinfo)
p.protocol = FakeTransportProtocol("123.123.123.123", sysinfo)
asyncio.run(p.update())
return p

Expand Down
5 changes: 3 additions & 2 deletions kasa/tests/newfakes.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,9 @@ def success(res):


class FakeTransportProtocol(TPLinkSmartHomeProtocol):
def __init__(self, info):
def __init__(self, host, info):
self.discovery_data = info
self.host = host
proto = FakeTransportProtocol.baseproto

for target in info:
Expand Down Expand Up @@ -415,7 +416,7 @@ def light_state(self, x, *args):
},
}

async def query(self, host, request, port=9999):
async def query(self, request):
proto = self.proto

# collect child ids from context
Expand Down
3 changes: 2 additions & 1 deletion kasa/tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ def aio_mock_writer(_, __):

conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
with pytest.raises(SmartDeviceException):
await TPLinkSmartHomeProtocol.query("127.0.0.1", {}, retry_count=retry_count)
protocol = TPLinkSmartHomeProtocol("127.0.0.1", retry_count=retry_count)
await protocol.query({})

assert conn.call_count == retry_count + 1

Expand Down