Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion AUTHORS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ The following wonderful people contributed directly or indirectly to this projec
- `d-qoi <https://github.com/d-qoi>`_
- `daimajia <https://github.com/daimajia>`_
- `Daniel Reed <https://github.com/nmlorg>`_
- `Dmitry Grigoryev <https://github.com/icecom-dg>`_
- `Ehsan Online <https://github.com/ehsanonline>`_
- `Eli Gao <https://github.com/eligao>`_
- `Emilio Molinari <https://github.com/xates>`_
Expand Down
113 changes: 78 additions & 35 deletions telegram/ext/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from future.utils import string_types

from telegram import Chat
from telegram import Chat, Update

__all__ = ['Filters', 'BaseFilter', 'InvertedFilter', 'MergedFilter']

Expand Down Expand Up @@ -236,11 +236,87 @@ def filter(self, message):
class _Text(BaseFilter):
name = 'Filters.text'

class _TextIterable(BaseFilter):

def __init__(self, iterable):
self.iterable = iterable
self.name = 'Filters.text({})'.format(iterable)

def filter(self, message):
if message.text and not message.text.startswith('/'):
return message.text in self.iterable
return False

def __call__(self, update):
if isinstance(update, Update):
if self.update_filter:
return self.filter(update)
else:
return self.filter(update.effective_message)
else:
return self._TextIterable(update)

def filter(self, message):
return bool(message.text and not message.text.startswith('/'))

text = _Text()
"""Text Messages."""
"""Text Messages. If an iterable of strings is passed, it filters messages to only allow those
whose text is appearing in the given iterable.

Examples:
To allow any text message, simply use
``MessageHandler(Filters.text, callback_method)``.

A simple usecase for passing an iterable is to allow only messages that were send by a
custom :class:`telegram.ReplyKeyboardMarkup`::

buttons = ['Start', 'Settings', 'Back']
markup = ReplyKeyboardMarkup.from_column(buttons)
...
MessageHandler(Filters.text(buttons), callback_method)

Args:
update (Iterable[:obj:`str`], optional): Which messages to allow. Only exact matches
are allowed. If not specified, will allow any text message.
"""

class _Caption(BaseFilter):
name = 'Filters.caption'

class _CaptionIterable(BaseFilter):

def __init__(self, iterable):
self.iterable = iterable
self.name = 'Filters.caption({})'.format(iterable)

def filter(self, message):
if message.caption:
return message.caption in self.iterable
return False

def __call__(self, update):
if isinstance(update, Update):
if self.update_filter:
return self.filter(update)
else:
return self.filter(update.effective_message)
else:
return self._CaptionIterable(update)

def filter(self, message):
return bool(message.caption)

caption = _Caption()
"""Messages with a caption. If an iterable of strings is passed, it filters messages to only
allow those whose caption is appearing in the given iterable.

Examples:
``MessageHandler(Filters.caption, callback_method)``

Args:
update (Iterable[:obj:`str`], optional): Which captions to allow. Only exact matches
are allowed. If not specified, will allow any message with a caption.
"""

class _Command(BaseFilter):
name = 'Filters.command'
Expand Down Expand Up @@ -909,39 +985,6 @@ def filter(self, message):
return message.from_user.language_code and any(
[message.from_user.language_code.startswith(x) for x in self.lang])

class msg_in(BaseFilter):
"""Filters messages to only allow those whose text/caption appears in a given list.

Examples:
A simple usecase is to allow only messages that were send by a custom
:class:`telegram.ReplyKeyboardMarkup`::

buttons = ['Start', 'Settings', 'Back']
markup = ReplyKeyboardMarkup.from_column(buttons)
...
MessageHandler(Filters.msg_in(buttons), callback_method)

Args:
list_ (List[:obj:`str`]): Which messages to allow through. Only exact matches
are allowed.
caption (:obj:`bool`): Optional. Whether the caption should be used instead of text.
Default is ``False``.

"""

def __init__(self, list_, caption=False):
self.list_ = list_
self.caption = caption
self.name = 'Filters.msg_in({!r}, caption={!r})'.format(self.list_, self.caption)

def filter(self, message):
if self.caption:
txt = message.caption
else:
txt = message.text

return txt in self.list_

class _UpdateType(BaseFilter):
update_filter = True

Expand Down
30 changes: 18 additions & 12 deletions tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,25 @@ def test_filters_all(self, update):

def test_filters_text(self, update):
update.message.text = 'test'
assert Filters.text(update)
assert (Filters.text)(update)
update.message.text = '/test'
assert not Filters.text(update)
assert not (Filters.text)(update)

def test_filters_text_iterable(self, update):
update.message.text = 'test'
assert Filters.text({'test', 'test1'})(update)
assert not Filters.text(['test1', 'test2'])(update)

def test_filters_caption(self, update):
update.message.caption = 'test'
assert (Filters.caption)(update)
update.message.caption = None
assert not (Filters.caption)(update)

def test_filters_caption_iterable(self, update):
update.message.caption = 'test'
assert Filters.caption({'test', 'test1'})(update)
assert not Filters.caption(['test1', 'test2'])(update)

def test_filters_command(self, update):
update.message.text = 'test'
Expand Down Expand Up @@ -604,16 +620,6 @@ def test_language_filter_multiple(self, update):
update.message.from_user.language_code = 'da'
assert f(update)

def test_msg_in_filter(self, update):
update.message.text = 'test'
update.message.caption = 'caption'

assert Filters.msg_in(['test'])(update)
assert Filters.msg_in(['caption'], caption=True)(update)

assert not Filters.msg_in(['test'], caption=True)(update)
assert not Filters.msg_in(['caption'])(update)

def test_and_filters(self, update):
update.message.text = 'test'
update.message.forward_date = datetime.datetime.utcnow()
Expand Down