Skip to content

Commit 3ea16cb

Browse files
authored
Merge pull request python-telegram-bot#675 from python-telegram-bot/name-filters
Allow filters to have a name.
2 parents eee0f78 + 04acbc4 commit 3ea16cb

File tree

2 files changed

+61
-17
lines changed

2 files changed

+61
-17
lines changed

telegram/ext/filters.py

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,14 @@ class BaseFilter(object):
5151
a `filter` method that returns a boolean: `True` if the message should be handled, `False`
5252
otherwise. Note that the filters work only as class instances, not actual class objects
5353
(so remember to initialize your filter classes).
54+
55+
By default the filters name (what will get printed when converted to a string for display)
56+
will be the class name. If you want to overwrite this assign a better name to the `name`
57+
class variable.
5458
"""
5559

60+
name = None
61+
5662
def __call__(self, message):
5763
return self.filter(message)
5864

@@ -65,6 +71,12 @@ def __or__(self, other):
6571
def __invert__(self):
6672
return InvertedFilter(self)
6773

74+
def __repr__(self):
75+
# We do this here instead of in a __init__ so filter don't have to call __init__ or super()
76+
if self.name is None:
77+
self.name = self.__class__.__name__
78+
return self.name
79+
6880
def filter(self, message):
6981
raise NotImplementedError
7082

@@ -82,10 +94,8 @@ def __init__(self, f):
8294
def filter(self, message):
8395
return not self.f(message)
8496

85-
def __str__(self):
86-
return "<telegram.ext.filters.InvertedFilter inverting {}>".format(self.f)
87-
88-
__repr__ = __str__
97+
def __repr__(self):
98+
return "<inverted {}>".format(self.f)
8999

90100

91101
class MergedFilter(BaseFilter):
@@ -108,12 +118,9 @@ def filter(self, message):
108118
elif self.or_filter:
109119
return self.base_filter(message) or self.or_filter(message)
110120

111-
def __str__(self):
112-
return ("<telegram.ext.filters.MergedFilter consisting of"
113-
" {} {} {}>").format(self.base_filter, "and" if self.and_filter else "or",
114-
self.and_filter or self.or_filter)
115-
116-
__repr__ = __str__
121+
def __repr__(self):
122+
return "<{} {} {}>".format(self.base_filter, "and" if self.and_filter else "or",
123+
self.and_filter or self.or_filter)
117124

118125

119126
class Filters(object):
@@ -122,90 +129,103 @@ class Filters(object):
122129
"""
123130

124131
class _All(BaseFilter):
132+
name = 'Filters.all'
125133

126134
def filter(self, message):
127135
return True
128136

129137
all = _All()
130138

131139
class _Text(BaseFilter):
140+
name = 'Filters.text'
132141

133142
def filter(self, message):
134143
return bool(message.text and not message.text.startswith('/'))
135144

136145
text = _Text()
137146

138147
class _Command(BaseFilter):
148+
name = 'Filters.command'
139149

140150
def filter(self, message):
141151
return bool(message.text and message.text.startswith('/'))
142152

153+
command = _Command()
154+
143155
class _Reply(BaseFilter):
156+
name = 'Filters.reply'
144157

145158
def filter(self, message):
146159
return bool(message.reply_to_message)
147160

148161
reply = _Reply()
149162

150-
command = _Command()
151-
152163
class _Audio(BaseFilter):
164+
name = 'Filters.audio'
153165

154166
def filter(self, message):
155167
return bool(message.audio)
156168

157169
audio = _Audio()
158170

159171
class _Document(BaseFilter):
172+
name = 'Filters.document'
160173

161174
def filter(self, message):
162175
return bool(message.document)
163176

164177
document = _Document()
165178

166179
class _Photo(BaseFilter):
180+
name = 'Filters.photo'
167181

168182
def filter(self, message):
169183
return bool(message.photo)
170184

171185
photo = _Photo()
172186

173187
class _Sticker(BaseFilter):
188+
name = 'Filters.sticker'
174189

175190
def filter(self, message):
176191
return bool(message.sticker)
177192

178193
sticker = _Sticker()
179194

180195
class _Video(BaseFilter):
196+
name = 'Filters.video'
181197

182198
def filter(self, message):
183199
return bool(message.video)
184200

185201
video = _Video()
186202

187203
class _Voice(BaseFilter):
204+
name = 'Filters.voice'
188205

189206
def filter(self, message):
190207
return bool(message.voice)
191208

192209
voice = _Voice()
193210

194211
class _Contact(BaseFilter):
212+
name = 'Filters.contact'
195213

196214
def filter(self, message):
197215
return bool(message.contact)
198216

199217
contact = _Contact()
200218

201219
class _Location(BaseFilter):
220+
name = 'Filters.location'
202221

203222
def filter(self, message):
204223
return bool(message.location)
205224

206225
location = _Location()
207226

208227
class _Venue(BaseFilter):
228+
name = 'Filters.venue'
209229

210230
def filter(self, message):
211231
return bool(message.venue)
@@ -215,41 +235,47 @@ def filter(self, message):
215235
class _StatusUpdate(BaseFilter):
216236

217237
class _NewChatMembers(BaseFilter):
238+
name = 'Filters.status_update.new_chat_members'
218239

219240
def filter(self, message):
220241
return bool(message.new_chat_members)
221242

222243
new_chat_members = _NewChatMembers()
223244

224245
class _LeftChatMember(BaseFilter):
246+
name = 'Filters.status_update.left_chat_member'
225247

226248
def filter(self, message):
227249
return bool(message.left_chat_member)
228250

229251
left_chat_member = _LeftChatMember()
230252

231253
class _NewChatTitle(BaseFilter):
254+
name = 'Filters.status_update.new_chat_title'
232255

233256
def filter(self, message):
234257
return bool(message.new_chat_title)
235258

236259
new_chat_title = _NewChatTitle()
237260

238261
class _NewChatPhoto(BaseFilter):
262+
name = 'Filters.status_update.new_chat_photo'
239263

240264
def filter(self, message):
241265
return bool(message.new_chat_photo)
242266

243267
new_chat_photo = _NewChatPhoto()
244268

245269
class _DeleteChatPhoto(BaseFilter):
270+
name = 'Filters.status_update.delete_chat_photo'
246271

247272
def filter(self, message):
248273
return bool(message.delete_chat_photo)
249274

250275
delete_chat_photo = _DeleteChatPhoto()
251276

252277
class _ChatCreated(BaseFilter):
278+
name = 'Filters.status_update.chat_created'
253279

254280
def filter(self, message):
255281
return bool(message.group_chat_created or message.supergroup_chat_created or
@@ -258,19 +284,23 @@ def filter(self, message):
258284
chat_created = _ChatCreated()
259285

260286
class _Migrate(BaseFilter):
287+
name = 'Filters.status_update.migrate'
261288

262289
def filter(self, message):
263290
return bool(message.migrate_from_chat_id or message.migrate_to_chat_id)
264291

265292
migrate = _Migrate()
266293

267294
class _PinnedMessage(BaseFilter):
295+
name = 'Filters.status_update.pinned_message'
268296

269297
def filter(self, message):
270298
return bool(message.pinned_message)
271299

272300
pinned_message = _PinnedMessage()
273301

302+
name = 'Filters.status_update'
303+
274304
def filter(self, message):
275305
return bool(self.new_chat_members(message) or self.left_chat_member(message) or
276306
self.new_chat_title(message) or self.new_chat_photo(message) or
@@ -280,13 +310,15 @@ def filter(self, message):
280310
status_update = _StatusUpdate()
281311

282312
class _Forwarded(BaseFilter):
313+
name = 'Filters.forwarded'
283314

284315
def filter(self, message):
285316
return bool(message.forward_date)
286317

287318
forwarded = _Forwarded()
288319

289320
class _Game(BaseFilter):
321+
name = 'Filters.game'
290322

291323
def filter(self, message):
292324
return bool(message.game)
@@ -306,32 +338,37 @@ class entity(BaseFilter):
306338

307339
def __init__(self, entity_type):
308340
self.entity_type = entity_type
341+
self.name = 'Filters.entity({})'.format(self.entity_type)
309342

310343
def filter(self, message):
311344
return any([entity.type == self.entity_type for entity in message.entities])
312345

313346
class _Private(BaseFilter):
347+
name = 'Filters.private'
314348

315349
def filter(self, message):
316350
return message.chat.type == Chat.PRIVATE
317351

318352
private = _Private()
319353

320354
class _Group(BaseFilter):
355+
name = 'Filters.group'
321356

322357
def filter(self, message):
323358
return message.chat.type in [Chat.GROUP, Chat.SUPERGROUP]
324359

325360
group = _Group()
326361

327362
class _Invoice(BaseFilter):
363+
name = 'Filters.invoice'
328364

329365
def filter(self, message):
330366
return bool(message.invoice)
331367

332368
invoice = _Invoice()
333369

334370
class _SuccessfulPayment(BaseFilter):
371+
name = 'Filters.successful_payment'
335372

336373
def filter(self, message):
337374
return bool(message.successful_payment)
@@ -354,6 +391,7 @@ def __init__(self, lang):
354391
self.lang = [lang]
355392
else:
356393
self.lang = lang
394+
self.name = 'Filters.language({})'.format(self.lang)
357395

358396
def filter(self, message):
359397
return message.from_user.language_code and any(

tests/test_filters.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -254,12 +254,10 @@ def test_and_or_filters(self):
254254
self.assertTrue((Filters.text & (Filters.forwarded | Filters.entity(MessageEntity.MENTION))
255255
)(self.message))
256256

257-
self.assertRegexpMatches(
257+
self.assertEqual(
258258
str((Filters.text & (Filters.forwarded | Filters.entity(MessageEntity.MENTION)))),
259-
r"<telegram.ext.filters.MergedFilter consisting of <telegram.ext.filters.(Filters.)?_"
260-
r"Text object at .*?> and <telegram.ext.filters.MergedFilter consisting of "
261-
r"<telegram.ext.filters.(Filters.)?_Forwarded object at .*?> or "
262-
r"<telegram.ext.filters.(Filters.)?entity object at .*?>>>")
259+
'<Filters.text and <Filters.forwarded or Filters.entity(mention)>>'
260+
)
263261

264262
def test_inverted_filters(self):
265263
self.message.text = '/test'
@@ -323,6 +321,14 @@ def test_language_filter_multiple(self):
323321
self.message.from_user.language_code = 'da'
324322
self.assertTrue(f(self.message))
325323

324+
def test_custom_unnamed_filter(self):
325+
class Unnamed(BaseFilter):
326+
def filter(self, message):
327+
return True
328+
329+
unnamed = Unnamed()
330+
self.assertEqual(str(unnamed), Unnamed.__name__)
331+
326332

327333
if __name__ == '__main__':
328334
unittest.main()

0 commit comments

Comments
 (0)