Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 35 additions & 11 deletions telegram/ext/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from future.utils import string_types

from telegram import Chat, Update
from telegram import Chat, Update, MessageEntity

__all__ = ['Filters', 'BaseFilter', 'InvertedFilter', 'MergedFilter']

Expand Down Expand Up @@ -249,10 +249,7 @@ def filter(self, message):

def __call__(self, update):
if isinstance(update, Update):
if self.update_filter:
return self.filter(update)
else:
return self.filter(update.effective_message)
return self.filter(update.effective_message)
else:
return self._TextIterable(update)

Expand Down Expand Up @@ -296,10 +293,7 @@ def filter(self, message):

def __call__(self, update):
if isinstance(update, Update):
if self.update_filter:
return self.filter(update)
else:
return self.filter(update.effective_message)
return self.filter(update.effective_message)
else:
return self._CaptionIterable(update)

Expand All @@ -321,11 +315,41 @@ def filter(self, message):
class _Command(BaseFilter):
name = 'Filters.command'

class _CommandOnlyStart(BaseFilter):

def __init__(self, only_start):
self.only_start = only_start
self.name = 'Filters.command({})'.format(only_start)

def filter(self, message):
return (message.entities
and any([e.type == MessageEntity.BOT_COMMAND for e in message.entities]))

def __call__(self, update):
if isinstance(update, Update):
return self.filter(update.effective_message)
else:
return self._CommandOnlyStart(update)

def filter(self, message):
return bool(message.text and message.text.startswith('/'))
return (message.entities and message.entities[0].type == MessageEntity.BOT_COMMAND
and message.entities[0].offset == 0)

command = _Command()
"""Messages starting with ``/``."""
"""
Messages with a :attr:`telegram.MessageEntity.BOT_COMMAND`. By default only allows
messages `starting` with a bot command. Pass ``False`` to also allow messages that contain a
bot command `anywhere` in the text.

Examples::

MessageHandler(Filters.command, command_at_start_callback)
MessageHandler(Filters.command(False), command_anywhere_callback)

Args:
update (:obj:`bool`, optional): Whether to only allow messages that `start` with a bot
command. Defaults to ``True``.
"""

class regex(BaseFilter):
"""
Expand Down
29 changes: 28 additions & 1 deletion tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,23 @@ def test_filters_caption_iterable(self, update):
assert Filters.caption({'test', 'test1'})(update)
assert not Filters.caption(['test1', 'test2'])(update)

def test_filters_command(self, update):
def test_filters_command_default(self, update):
update.message.text = 'test'
assert not Filters.command(update)
update.message.text = '/test'
assert not Filters.command(update)
# Only accept commands at the beginning
update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 3, 5)]
assert not Filters.command(update)
update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)]
assert Filters.command(update)

def test_filters_command_anywhere(self, update):
update.message.text = 'test /cmd'
assert not (Filters.command(False))(update)
update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 5, 4)]
assert (Filters.command(False))(update)

def test_filters_regex(self, update):
SRE_TYPE = type(re.match("", ""))
update.message.text = '/start deep-linked param'
Expand Down Expand Up @@ -120,6 +131,7 @@ def test_filters_regex_multiple(self, update):
def test_filters_merged_with_regex(self, update):
SRE_TYPE = type(re.match("", ""))
update.message.text = '/start deep-linked param'
update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 6)]
result = (Filters.command & Filters.regex(r'linked param'))(update)
assert result
assert isinstance(result, dict)
Expand Down Expand Up @@ -216,6 +228,7 @@ def test_regex_complex_merges(self, update):
result = filter(update)
assert not result
update.message.text = '/start'
update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 6)]
result = filter(update)
assert result
assert isinstance(result, bool)
Expand All @@ -230,6 +243,7 @@ def test_regex_complex_merges(self, update):

def test_regex_inverted(self, update):
update.message.text = '/start deep-linked param'
update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)]
filter = ~Filters.regex(r'deep-linked param')
result = filter(update)
assert not result
Expand All @@ -243,6 +257,7 @@ def test_regex_inverted(self, update):
result = filter(update)
assert not result
update.message.text = '/start'
update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 6)]
result = filter(update)
assert result
update.message.text = '/linked'
Expand All @@ -251,15 +266,18 @@ def test_regex_inverted(self, update):

filter = (~Filters.regex('linked') | Filters.command)
update.message.text = "it's linked"
update.message.entities = []
result = filter(update)
assert not result
update.message.text = '/start linked'
update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 6)]
result = filter(update)
assert result
update.message.text = '/start'
result = filter(update)
assert result
update.message.text = 'nothig'
update.message.entities = []
result = filter(update)
assert result

Expand Down Expand Up @@ -664,14 +682,17 @@ def test_and_or_filters(self, update):

def test_inverted_filters(self, update):
update.message.text = '/test'
update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)]
assert Filters.command(update)
assert not (~Filters.command)(update)
update.message.text = 'test'
update.message.entities = []
assert not Filters.command(update)
assert (~Filters.command)(update)

def test_inverted_and_filters(self, update):
update.message.text = '/test'
update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)]
update.message.forward_date = 1
assert (Filters.forwarded & Filters.command)(update)
assert not (~Filters.forwarded & Filters.command)(update)
Expand All @@ -683,6 +704,7 @@ def test_inverted_and_filters(self, update):
assert not (Filters.forwarded & ~Filters.command)(update)
assert (~(Filters.forwarded & Filters.command))(update)
update.message.text = 'test'
update.message.entities = []
assert not (Filters.forwarded & Filters.command)(update)
assert not (~Filters.forwarded & Filters.command)(update)
assert not (Filters.forwarded & ~Filters.command)(update)
Expand Down Expand Up @@ -746,6 +768,7 @@ def test_update_type_edited_channel_post(self, update):

def test_merged_short_circuit_and(self, update):
update.message.text = '/test'
update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)]

class TestException(Exception):
pass
Expand All @@ -760,6 +783,7 @@ def filter(self, _):
(Filters.command & raising_filter)(update)

update.message.text = 'test'
update.message.entities = []
(Filters.command & raising_filter)(update)

def test_merged_short_circuit_or(self, update):
Expand All @@ -778,10 +802,12 @@ def filter(self, _):
(Filters.command | raising_filter)(update)

update.message.text = '/test'
update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)]
(Filters.command | raising_filter)(update)

def test_merged_data_merging_and(self, update):
update.message.text = '/test'
update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)]

class DataFilter(BaseFilter):
data_filter = True
Expand All @@ -799,6 +825,7 @@ def filter(self, _):
assert result['test'] == ['blah1', 'blah2']

update.message.text = 'test'
update.message.entities = []
result = (Filters.command & DataFilter('blah'))(update)
assert not result

Expand Down