Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 50 additions & 12 deletions telegram/ext/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,14 @@ class BaseFilter(object):
a `filter` method that returns a boolean: `True` if the message should be handled, `False`
otherwise. Note that the filters work only as class instances, not actual class objects
(so remember to initialize your filter classes).

By default the filters name (what will get printed when converted to a string for display)
will be the class name. If you want to overwrite this assign a better name to the `name`
class variable.
"""

name = None

def __call__(self, message):
return self.filter(message)

Expand All @@ -65,6 +71,12 @@ def __or__(self, other):
def __invert__(self):
return InvertedFilter(self)

def __repr__(self):
# We do this here instead of in a __init__ so filter don't have to call __init__ or super()
if self.name is None:
self.name = self.__class__.__name__
return self.name

def filter(self, message):
raise NotImplementedError

Expand All @@ -82,10 +94,8 @@ def __init__(self, f):
def filter(self, message):
return not self.f(message)

def __str__(self):
return "<telegram.ext.filters.InvertedFilter inverting {}>".format(self.f)

__repr__ = __str__
def __repr__(self):
return "<inverted {}>".format(self.f)


class MergedFilter(BaseFilter):
Expand All @@ -108,12 +118,9 @@ def filter(self, message):
elif self.or_filter:
return self.base_filter(message) or self.or_filter(message)

def __str__(self):
return ("<telegram.ext.filters.MergedFilter consisting of"
" {} {} {}>").format(self.base_filter, "and" if self.and_filter else "or",
self.and_filter or self.or_filter)

__repr__ = __str__
def __repr__(self):
return "<{} {} {}>".format(self.base_filter, "and" if self.and_filter else "or",
self.and_filter or self.or_filter)


class Filters(object):
Expand All @@ -122,90 +129,103 @@ class Filters(object):
"""

class _All(BaseFilter):
name = 'Filters.all'

def filter(self, message):
return True

all = _All()

class _Text(BaseFilter):
name = 'Filters.text'

def filter(self, message):
return bool(message.text and not message.text.startswith('/'))

text = _Text()

class _Command(BaseFilter):
name = 'Filters.command'

def filter(self, message):
return bool(message.text and message.text.startswith('/'))

command = _Command()

class _Reply(BaseFilter):
name = 'Filters.reply'

def filter(self, message):
return bool(message.reply_to_message)

reply = _Reply()

command = _Command()

class _Audio(BaseFilter):
name = 'Filters.audio'

def filter(self, message):
return bool(message.audio)

audio = _Audio()

class _Document(BaseFilter):
name = 'Filters.document'

def filter(self, message):
return bool(message.document)

document = _Document()

class _Photo(BaseFilter):
name = 'Filters.photo'

def filter(self, message):
return bool(message.photo)

photo = _Photo()

class _Sticker(BaseFilter):
name = 'Filters.sticker'

def filter(self, message):
return bool(message.sticker)

sticker = _Sticker()

class _Video(BaseFilter):
name = 'Filters.video'

def filter(self, message):
return bool(message.video)

video = _Video()

class _Voice(BaseFilter):
name = 'Filters.voice'

def filter(self, message):
return bool(message.voice)

voice = _Voice()

class _Contact(BaseFilter):
name = 'Filters.contact'

def filter(self, message):
return bool(message.contact)

contact = _Contact()

class _Location(BaseFilter):
name = 'Filters.location'

def filter(self, message):
return bool(message.location)

location = _Location()

class _Venue(BaseFilter):
name = 'Filters.venue'

def filter(self, message):
return bool(message.venue)
Expand All @@ -215,41 +235,47 @@ def filter(self, message):
class _StatusUpdate(BaseFilter):

class _NewChatMembers(BaseFilter):
name = 'Filters.status_update.new_chat_members'

def filter(self, message):
return bool(message.new_chat_members)

new_chat_members = _NewChatMembers()

class _LeftChatMember(BaseFilter):
name = 'Filters.status_update.left_chat_member'

def filter(self, message):
return bool(message.left_chat_member)

left_chat_member = _LeftChatMember()

class _NewChatTitle(BaseFilter):
name = 'Filters.status_update.new_chat_title'

def filter(self, message):
return bool(message.new_chat_title)

new_chat_title = _NewChatTitle()

class _NewChatPhoto(BaseFilter):
name = 'Filters.status_update.new_chat_photo'

def filter(self, message):
return bool(message.new_chat_photo)

new_chat_photo = _NewChatPhoto()

class _DeleteChatPhoto(BaseFilter):
name = 'Filters.status_update.delete_chat_photo'

def filter(self, message):
return bool(message.delete_chat_photo)

delete_chat_photo = _DeleteChatPhoto()

class _ChatCreated(BaseFilter):
name = 'Filters.status_update.chat_created'

def filter(self, message):
return bool(message.group_chat_created or message.supergroup_chat_created or
Expand All @@ -258,19 +284,23 @@ def filter(self, message):
chat_created = _ChatCreated()

class _Migrate(BaseFilter):
name = 'Filters.status_update.migrate'

def filter(self, message):
return bool(message.migrate_from_chat_id or message.migrate_to_chat_id)

migrate = _Migrate()

class _PinnedMessage(BaseFilter):
name = 'Filters.status_update.pinned_message'

def filter(self, message):
return bool(message.pinned_message)

pinned_message = _PinnedMessage()

name = 'Filters.status_update'

def filter(self, message):
return bool(self.new_chat_members(message) or self.left_chat_member(message) or
self.new_chat_title(message) or self.new_chat_photo(message) or
Expand All @@ -280,13 +310,15 @@ def filter(self, message):
status_update = _StatusUpdate()

class _Forwarded(BaseFilter):
name = 'Filters.forwarded'

def filter(self, message):
return bool(message.forward_date)

forwarded = _Forwarded()

class _Game(BaseFilter):
name = 'Filters.game'

def filter(self, message):
return bool(message.game)
Expand All @@ -306,32 +338,37 @@ class entity(BaseFilter):

def __init__(self, entity_type):
self.entity_type = entity_type
self.name = 'Filters.entity({})'.format(self.entity_type)

def filter(self, message):
return any([entity.type == self.entity_type for entity in message.entities])

class _Private(BaseFilter):
name = 'Filters.private'

def filter(self, message):
return message.chat.type == Chat.PRIVATE

private = _Private()

class _Group(BaseFilter):
name = 'Filters.group'

def filter(self, message):
return message.chat.type in [Chat.GROUP, Chat.SUPERGROUP]

group = _Group()

class _Invoice(BaseFilter):
name = 'Filters.invoice'

def filter(self, message):
return bool(message.invoice)

invoice = _Invoice()

class _SuccessfulPayment(BaseFilter):
name = 'Filters.successful_payment'

def filter(self, message):
return bool(message.successful_payment)
Expand All @@ -354,6 +391,7 @@ def __init__(self, lang):
self.lang = [lang]
else:
self.lang = lang
self.name = 'Filters.language({})'.format(self.lang)

def filter(self, message):
return message.from_user.language_code and any(
Expand Down
16 changes: 11 additions & 5 deletions tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,12 +254,10 @@ def test_and_or_filters(self):
self.assertTrue((Filters.text & (Filters.forwarded | Filters.entity(MessageEntity.MENTION))
)(self.message))

self.assertRegexpMatches(
self.assertEqual(
str((Filters.text & (Filters.forwarded | Filters.entity(MessageEntity.MENTION)))),
r"<telegram.ext.filters.MergedFilter consisting of <telegram.ext.filters.(Filters.)?_"
r"Text object at .*?> and <telegram.ext.filters.MergedFilter consisting of "
r"<telegram.ext.filters.(Filters.)?_Forwarded object at .*?> or "
r"<telegram.ext.filters.(Filters.)?entity object at .*?>>>")
'<Filters.text and <Filters.forwarded or Filters.entity(mention)>>'
)

def test_inverted_filters(self):
self.message.text = '/test'
Expand Down Expand Up @@ -323,6 +321,14 @@ def test_language_filter_multiple(self):
self.message.from_user.language_code = 'da'
self.assertTrue(f(self.message))

def test_custom_unnamed_filter(self):
class Unnamed(BaseFilter):
def filter(self, message):
return True

unnamed = Unnamed()
self.assertEqual(str(unnamed), Unnamed.__name__)


if __name__ == '__main__':
unittest.main()