Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion telegram/ext/messagehandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,23 @@ def status_update(message):
def forwarded(message):
return bool(message.forward_date)

@staticmethod
def entity(entity_type):
"""Filters messages to only allow those which have a :class:`telegram.MessageEntity`
where their `type` matches `entity_type`.

Args:
entity_type: Entity type to check for. All types can be found as constants
in :class:`telegram.MessageEntity`.

Returns: function to use as filter
"""

def entities_filter(message):
return any([entity.type == entity_type for entity in message.entities])

return entities_filter


class MessageHandler(Handler):
"""
Expand Down Expand Up @@ -141,7 +158,8 @@ def handle_update(self, update, dispatcher):

return self.callback(dispatcher.bot, update, **optional_args)

# old non-PEP8 Handler methods
# old non-PEP8 Handler methods

m = "telegram.MessageHandler."
checkUpdate = deprecate(check_update, m + "checkUpdate", m + "check_update")
handleUpdate = deprecate(handle_update, m + "handleUpdate", m + "handle_update")
18 changes: 17 additions & 1 deletion tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@
import sys
import unittest
from datetime import datetime
import functools

sys.path.append('.')

from telegram import Message, User, Chat
from telegram import Message, User, Chat, MessageEntity
from telegram.ext import Filters
from tests.base import BaseTest

Expand Down Expand Up @@ -150,6 +151,21 @@ def test_filters_status_update(self):
self.assertTrue(Filters.status_update(self.message))
self.message.pinned_message = None

def test_entities_filter(self):
e = functools.partial(MessageEntity, offset=0, length=0)

self.message.entities = [e(MessageEntity.MENTION)]
self.assertTrue(Filters.entity(MessageEntity.MENTION)(self.message))

self.message.entities = []
self.assertFalse(Filters.entity(MessageEntity.MENTION)(self.message))

self.message.entities = [e(MessageEntity.BOLD)]
self.assertFalse(Filters.entity(MessageEntity.MENTION)(self.message))

self.message.entities = [e(MessageEntity.BOLD), e(MessageEntity.MENTION)]
self.assertTrue(Filters.entity(MessageEntity.MENTION)(self.message))


if __name__ == '__main__':
unittest.main()