Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions zeroconf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
import time
import warnings
from collections import OrderedDict
from typing import Dict, Iterable, List, Optional, Union, cast
from types import TracebackType # noqa # used in type hints
from typing import Dict, Iterable, List, Optional, Type, Union, cast
from typing import Any, Callable, Set, Tuple # noqa # used in type hints

import ifaddr
Expand Down Expand Up @@ -3061,8 +3062,19 @@ def close(self) -> None:
for s in self._respond_sockets:
self.engine.del_reader(s)
self.engine.join()

# shutdown the rest
self.notify_all()
for s in self._respond_sockets:
s.close()

def __enter__(self) -> 'Zeroconf':
return self

def __exit__( # pylint: disable=useless-return
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> Optional[bool]:
self.close()
return None
15 changes: 14 additions & 1 deletion zeroconf/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
import contextlib
import queue
import threading
from typing import Awaitable, Callable, Dict, List, Optional, Union
from types import TracebackType # noqa # used in type hints
from typing import Awaitable, Callable, Dict, List, Optional, Type, Union

from . import (
DNSOutgoing,
Expand Down Expand Up @@ -352,3 +353,15 @@ async def async_remove_all_service_listeners(self) -> None:
await asyncio.gather(
*[self.async_remove_service_listener(listener) for listener in list(self.async_browsers)]
)

async def __aenter__(self) -> 'AsyncZeroconf':
return self

async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> Optional[bool]:
await self.async_close()
return None
9 changes: 9 additions & 0 deletions zeroconf/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,15 @@ def test_launch_and_close(self):
rv = r.Zeroconf(interfaces=r.InterfaceChoice.Default)
rv.close()

def test_launch_and_close_context_manager(self):
with r.Zeroconf(interfaces=r.InterfaceChoice.All) as rv:
assert rv.done is False
assert rv.done is True

with r.Zeroconf(interfaces=r.InterfaceChoice.Default) as rv:
assert rv.done is False
assert rv.done is True

def test_launch_and_close_unicast(self):
rv = r.Zeroconf(interfaces=r.InterfaceChoice.All, unicast=True)
rv.close()
Expand Down
24 changes: 24 additions & 0 deletions zeroconf/test_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,3 +438,27 @@ def update_service(self, aiozc: AsyncZeroconf, type: str, name: str) -> None:
await aiozc.async_close()

assert calls[0] == ('add', type_, registration_name)


@pytest.mark.asyncio
async def test_async_context_manager() -> None:
"""Test using an async context manager."""
type_ = "_test10-sr-type._tcp.local."
name = "xxxyyy"
registration_name = "%s.%s" % (name, type_)

async with AsyncZeroconf(interfaces=['127.0.0.1']) as aiozc:
info = ServiceInfo(
type_,
registration_name,
80,
0,
0,
{'path': '/~paulsm/'},
"ash-2.local.",
addresses=[socket.inet_aton("10.0.1.2")],
)
task = await aiozc.async_register_service(info)
await task
aiosinfo = await aiozc.async_get_service_info(type_, registration_name)
assert aiosinfo is not None