Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 29 additions & 11 deletions telegram/ext/_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,9 +568,8 @@ async def start(self) -> None:
try:
if self.persistence:
self.__update_persistence_task = asyncio.create_task(
self._persistence_updater()
# TODO: Add this once we drop py3.7
# name=f'Application:{self.bot.id}:persistence_updater'
self._persistence_updater(),
name=f"Application:{self.bot.id}:persistence_updater",
)
_LOGGER.debug("Loop for updating persistence started")

Expand All @@ -579,9 +578,7 @@ async def start(self) -> None:
_LOGGER.debug("JobQueue started")

self.__update_fetcher_task = asyncio.create_task(
self._update_fetcher(),
# TODO: Add this once we drop py3.7
# name=f'Application:{self.bot.id}:update_fetcher'
self._update_fetcher(), name=f"Application:{self.bot.id}:update_fetcher"
)
_LOGGER.info("Application started")

Expand Down Expand Up @@ -955,6 +952,8 @@ def create_task(
self,
coroutine: Union[Generator[Optional["asyncio.Future[object]"], None, RT], Awaitable[RT]],
update: Optional[object] = None,
*,
name: Optional[str] = None,
) -> "asyncio.Task[RT]":
"""Thin wrapper around :func:`asyncio.create_task` that handles exceptions raised by
the :paramref:`coroutine` with :meth:`process_error`.
Expand All @@ -977,16 +976,22 @@ def create_task(
:attr:`chat_data` and :attr:`user_data` entries will be updated in the next run of
:meth:`update_persistence` after the :paramref:`coroutine` is finished.

Keyword Args:
name (:obj:`str`, optional): The name of the task.

.. versionadded:: NEXT.VERSION

Returns:
:class:`asyncio.Task`: The created task.
"""
return self.__create_task(coroutine=coroutine, update=update)
return self.__create_task(coroutine=coroutine, update=update, name=name)

def __create_task(
self,
coroutine: Union[Generator[Optional["asyncio.Future[object]"], None, RT], Awaitable[RT]],
update: Optional[object] = None,
is_error_handler: bool = False,
name: Optional[str] = None,
) -> "asyncio.Task[RT]":
# Unfortunately, we can't know if `coroutine` runs one of the error handler functions
# but by passing `is_error_handler=True` from `process_error`, we can make sure that we
Expand All @@ -995,7 +1000,8 @@ def __create_task(
task: "asyncio.Task[RT]" = asyncio.create_task(
self.__create_task_callback(
coroutine=coroutine, update=update, is_error_handler=is_error_handler
)
),
name=name,
)

if self.running:
Expand Down Expand Up @@ -1076,6 +1082,7 @@ async def _update_fetcher(self) -> None:
self.create_task(
self.__process_update_wrapper(update),
update=update,
name=f"Application:{self.bot.id}:process_concurrent_update",
)
else:
await self.__process_update_wrapper(update)
Expand Down Expand Up @@ -1132,7 +1139,12 @@ async def process_update(self, update: object) -> None:
and self.bot.defaults
and not self.bot.defaults.block
):
self.create_task(coroutine, update=update)
self.create_task(
coroutine,
update=update,
name=f"Application:{self.bot.id}:process_update_non_blocking"
f":{handler}",
)
else:
any_blocking = True
await coroutine
Expand Down Expand Up @@ -1203,7 +1215,10 @@ def add_handler(self, handler: BaseHandler[Any, CCT], group: int = DEFAULT_GROUP
f"can not be persistent if application has no persistence"
)
if self._initialized:
self.create_task(self._add_ch_to_persistence(handler))
self.create_task(
self._add_ch_to_persistence(handler),
name=f"Application:{self.bot.id}:add_handler:conversation_handler_after_init",
)
warn(
"A persistent `ConversationHandler` was passed to `add_handler`, "
"after `Application.initialize` was called. This is discouraged."
Expand Down Expand Up @@ -1683,7 +1698,10 @@ async def process_error(
and not self.bot.defaults.block
):
self.__create_task(
callback(update, context), update=update, is_error_handler=True
callback(update, context),
update=update,
is_error_handler=True,
name=f"Application:{self.bot.id}:process_error:non_blocking",
)
else:
try:
Expand Down
2 changes: 2 additions & 0 deletions telegram/ext/_conversationhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,7 @@ async def handle_update( # type: ignore[override]
update, application, handler_check_result, context
),
update=update,
name=f"ConversationHandler:{update.update_id}:handle_update:non_blocking_cb",
)
except ApplicationHandlerStop as exception:
new_state = exception.state
Expand All @@ -856,6 +857,7 @@ async def handle_update( # type: ignore[override]
new_state, application, update, context, conversation_key
),
update=update,
name=f"ConversationHandler:{update.update_id}:handle_update:timeout_job",
)
else:
self._schedule_job(new_state, application, update, context, conversation_key)
Expand Down
5 changes: 4 additions & 1 deletion telegram/ext/_jobqueue.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,10 @@ async def _run(
await context.refresh_data()
await self.callback(context)
except Exception as exc:
await application.create_task(application.process_error(None, exc, job=self))
await application.create_task(
application.process_error(None, exc, job=self),
name=f"Job:{self.id}:run:process_error",
)
finally:
# This is internal logic of application - let's keep it private for now
application._mark_for_persistence_update(job=self) # pylint: disable=protected-access
Expand Down
3 changes: 2 additions & 1 deletion telegram/ext/_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,8 @@ def default_error_callback(exc: TelegramError) -> None:
on_err_cb=error_callback or default_error_callback,
description="getting Updates",
interval=poll_interval,
)
),
name="Updater:start_polling:polling_task",
)

if ready is not None:
Expand Down
12 changes: 10 additions & 2 deletions tests/ext/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,8 @@ async def callback(u, c):
async with app:
await app.start()
assert app.running
tasks = asyncio.all_tasks()
assert any(":update_fetcher" in task.get_name() for task in tasks)
if job_queue:
assert app.job_queue.scheduler.running
else:
Expand Down Expand Up @@ -551,7 +553,6 @@ async def test_add_remove_handler_non_default_group(self, app):
app.remove_handler(handler)
app.remove_handler(handler, group=2)

#
async def test_handler_order_in_group(self, app):
app.add_handler(MessageHandler(filters.PHOTO, self.callback_set_count(1)))
app.add_handler(MessageHandler(filters.ALL, self.callback_set_count(2)))
Expand Down Expand Up @@ -964,6 +965,8 @@ async def callback(update, context):
await app.update_queue.put(1)
task = asyncio.create_task(app.stop())
await asyncio.sleep(0.05)
tasks = asyncio.all_tasks()
assert any(":process_update_non_blocking" in t.get_name() for t in tasks)
assert self.count == 1
# Make sure that app stops only once all non blocking callbacks are done
assert not task.done()
Expand Down Expand Up @@ -1029,6 +1032,8 @@ async def normal_error_handler(update, context):
await app.update_queue.put(self.message_update)
task = asyncio.create_task(app.stop())
await asyncio.sleep(0.05)
tasks = asyncio.all_tasks()
assert any(":process_error:non_blocking" in t.get_name() for t in tasks)
assert self.count == 42
assert self.received is None
event.set()
Expand Down Expand Up @@ -1196,7 +1201,8 @@ async def callback():
self.count = 42
return 43

task = app.create_task(callback())
task = app.create_task(callback(), name="test_task")
assert task.get_name() == "test_task"
await asyncio.sleep(0.01)
assert not task.done()
out = await task
Expand Down Expand Up @@ -1377,6 +1383,8 @@ async def callback(u, c):
assert not events[i].is_set()

await asyncio.sleep(0.9)
tasks = asyncio.all_tasks()
assert any(":process_concurrent_update" in task.get_name() for task in tasks)
for i in range(app.update_processor.max_concurrent_updates):
assert events[i].is_set()
for i in range(
Expand Down
24 changes: 24 additions & 0 deletions tests/ext/test_basepersistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,8 @@ async def test_add_conversation_handler_after_init(self, papp: Application, recw
papp.add_handler(conversation)

assert len(recwarn) >= 1
tasks = asyncio.all_tasks()
assert any("conversation_handler_after_init" in t.get_name() for t in tasks)
found = False
for warning in recwarn:
if "after `Application.initialize` was called" in str(warning.message):
Expand Down Expand Up @@ -584,6 +586,28 @@ async def test_add_conversation_handler_without_name(self, papp: Application):
with pytest.raises(ValueError, match="when handler is unnamed"):
papp.add_handler(build_conversation_handler(name=None, persistent=True))

@pytest.mark.parametrize(
"papp",
[
PappInput(update_interval=0.0),
],
indirect=True,
)
async def test_update_persistence_called(self, papp: Application, monkeypatch):
"""Tests if Application.update_persistence is called from app.start()"""
called = asyncio.Event()

async def update_persistence(*args, **kwargs):
called.set()

monkeypatch.setattr(papp, "update_persistence", update_persistence)
async with papp:
await papp.start()
tasks = asyncio.all_tasks()
assert any(":persistence_updater" in task.get_name() for task in tasks)
assert await called.wait()
await papp.stop()

@pytest.mark.flaky(3, 1)
@pytest.mark.parametrize(
"papp",
Expand Down
3 changes: 3 additions & 0 deletions tests/ext/test_conversationhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2068,6 +2068,9 @@ async def callback(_, __):
assert conv_handler.check_update(Update(0, message=message))
await app.process_update(Update(0, message=message))
await asyncio.sleep(0.7)
tasks = asyncio.all_tasks()
assert any(":handle_update:non_blocking_cb" in t.get_name() for t in tasks)
assert any(":handle_update:timeout_job" in t.get_name() for t in tasks)
assert not self.is_timeout
event.set()
await asyncio.sleep(0.7)
Expand Down
2 changes: 2 additions & 0 deletions tests/ext/test_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ async def delete_webhook(*args, **kwargs):
# We call the same logic twice to make sure that restarting the updater works as well
await updater.start_polling(drop_pending_updates=drop_pending_updates)
assert updater.running
tasks = asyncio.all_tasks()
assert any("Updater:start_polling:polling_task" in t.get_name() for t in tasks)
await updates.join()
await updater.stop()
assert not updater.running
Expand Down