Skip to content

Commit 5f6138b

Browse files
authored
Merge pull request python-telegram-bot#409 from python-telegram-bot/entities-filter
Add entities filter
2 parents e1242b3 + 1b99caa commit 5f6138b

File tree

2 files changed

+36
-2
lines changed

2 files changed

+36
-2
lines changed

telegram/ext/messagehandler.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,23 @@ def status_update(message):
8585
def forwarded(message):
8686
return bool(message.forward_date)
8787

88+
@staticmethod
89+
def entity(entity_type):
90+
"""Filters messages to only allow those which have a :class:`telegram.MessageEntity`
91+
where their `type` matches `entity_type`.
92+
93+
Args:
94+
entity_type: Entity type to check for. All types can be found as constants
95+
in :class:`telegram.MessageEntity`.
96+
97+
Returns: function to use as filter
98+
"""
99+
100+
def entities_filter(message):
101+
return any([entity.type == entity_type for entity in message.entities])
102+
103+
return entities_filter
104+
88105

89106
class MessageHandler(Handler):
90107
"""
@@ -141,7 +158,8 @@ def handle_update(self, update, dispatcher):
141158

142159
return self.callback(dispatcher.bot, update, **optional_args)
143160

144-
# old non-PEP8 Handler methods
161+
# old non-PEP8 Handler methods
162+
145163
m = "telegram.MessageHandler."
146164
checkUpdate = deprecate(check_update, m + "checkUpdate", m + "check_update")
147165
handleUpdate = deprecate(handle_update, m + "handleUpdate", m + "handle_update")

tests/test_filters.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@
2323
import sys
2424
import unittest
2525
from datetime import datetime
26+
import functools
2627

2728
sys.path.append('.')
2829

29-
from telegram import Message, User, Chat
30+
from telegram import Message, User, Chat, MessageEntity
3031
from telegram.ext import Filters
3132
from tests.base import BaseTest
3233

@@ -150,6 +151,21 @@ def test_filters_status_update(self):
150151
self.assertTrue(Filters.status_update(self.message))
151152
self.message.pinned_message = None
152153

154+
def test_entities_filter(self):
155+
e = functools.partial(MessageEntity, offset=0, length=0)
156+
157+
self.message.entities = [e(MessageEntity.MENTION)]
158+
self.assertTrue(Filters.entity(MessageEntity.MENTION)(self.message))
159+
160+
self.message.entities = []
161+
self.assertFalse(Filters.entity(MessageEntity.MENTION)(self.message))
162+
163+
self.message.entities = [e(MessageEntity.BOLD)]
164+
self.assertFalse(Filters.entity(MessageEntity.MENTION)(self.message))
165+
166+
self.message.entities = [e(MessageEntity.BOLD), e(MessageEntity.MENTION)]
167+
self.assertTrue(Filters.entity(MessageEntity.MENTION)(self.message))
168+
153169

154170
if __name__ == '__main__':
155171
unittest.main()

0 commit comments

Comments
 (0)