Skip to content

Commit 1243680

Browse files
committed
Stable autowiring functionality
1 parent bf15f54 commit 1243680

13 files changed

+169
-43
lines changed

telegram/ext/callbackqueryhandler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,9 @@ def __init__(self,
118118
self.pass_groups = pass_groups
119119
self.pass_groupdict = pass_groupdict
120120

121+
if self.autowire:
122+
self.set_autowired_flags(passable={'groups', 'groupdict', 'user_data', 'chat_data'})
123+
121124
def check_update(self, update):
122125
"""Determines whether an update should be passed to this handlers :attr:`callback`.
123126
@@ -145,6 +148,7 @@ def handle_update(self, update, dispatcher):
145148
146149
"""
147150
optional_args = self.collect_optional_args(dispatcher, update)
151+
148152
if self.pattern:
149153
match = re.match(self.pattern, update.callback_query.data)
150154

telegram/ext/choseninlineresulthandler.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ class ChosenInlineResultHandler(Handler):
2828
2929
Attributes:
3030
callback (:obj:`callable`): The callback function for this handler.
31+
autowire (:obj:`bool`): Optional. Determines whether objects will be passed to the
32+
callback function automatically.
3133
pass_update_queue (:obj:`bool`): Optional. Determines whether ``update_queue`` will be
3234
passed to the callback function.
3335
pass_job_queue (:obj:`bool`): Optional. Determines whether ``job_queue`` will be passed to
@@ -47,6 +49,10 @@ class ChosenInlineResultHandler(Handler):
4749
callback (:obj:`callable`): A function that takes ``bot, update`` as positional arguments.
4850
It will be called when the :attr:`check_update` has determined that an update should be
4951
processed by this handler.
52+
autowire (:obj:`bool`, optional): If set to ``True``, your callback handler will be
53+
inspected for positional arguments and pass objects whose names match any of the
54+
``pass_*`` flags of this Handler. Using any ``pass_*`` argument in conjunction with
55+
``autowire`` will yield a warning.
5056
pass_update_queue (:obj:`bool`, optional): If set to ``True``, a keyword argument called
5157
``update_queue`` will be passed to the callback function. It will be the ``Queue``
5258
instance used by the :class:`telegram.ext.Updater` and :class:`telegram.ext.Dispatcher`
@@ -64,16 +70,20 @@ class ChosenInlineResultHandler(Handler):
6470

6571
def __init__(self,
6672
callback,
73+
autowire=False,
6774
pass_update_queue=False,
6875
pass_job_queue=False,
6976
pass_user_data=False,
7077
pass_chat_data=False):
7178
super(ChosenInlineResultHandler, self).__init__(
7279
callback,
80+
autowire=autowire,
7381
pass_update_queue=pass_update_queue,
7482
pass_job_queue=pass_job_queue,
7583
pass_user_data=pass_user_data,
7684
pass_chat_data=pass_chat_data)
85+
if self.autowire:
86+
self.set_autowired_flags()
7787

7888
def check_update(self, update):
7989
"""Determines whether an update should be passed to this handlers :attr:`callback`.

telegram/ext/commandhandler.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121

2222
from future.utils import string_types
2323

24-
from .handler import Handler
2524
from telegram import Update
25+
from .handler import Handler
2626

2727

2828
class CommandHandler(Handler):
@@ -73,8 +73,7 @@ class CommandHandler(Handler):
7373
autowire (:obj:`bool`, optional): If set to ``True``, your callback handler will be
7474
inspected for positional arguments and pass objects whose names match any of the
7575
``pass_*`` flags of this Handler. Using any ``pass_*`` argument in conjunction with
76-
``autowire`` will yield
77-
a warning.
76+
``autowire`` will yield a warning.
7877
pass_args (:obj:`bool`, optional): Determines whether the handler should be passed the
7978
arguments passed to the command as a keyword argument called ``args``. It will contain
8079
a list of strings, which is the text following the command split on single or
@@ -113,13 +112,16 @@ def __init__(self,
113112
pass_user_data=pass_user_data,
114113
pass_chat_data=pass_chat_data)
115114

115+
self.pass_args = pass_args
116+
if self.autowire:
117+
self.set_autowired_flags({'update_queue', 'job_queue', 'user_data', 'chat_data', 'args'})
118+
116119
if isinstance(command, string_types):
117120
self.command = [command.lower()]
118121
else:
119122
self.command = [x.lower() for x in command]
120123
self.filters = filters
121124
self.allow_edited = allow_edited
122-
self.pass_args = pass_args
123125

124126
# We put this up here instead of with the rest of checking code
125127
# in check_update since we don't wanna spam a ton
@@ -139,7 +141,7 @@ def check_update(self, update):
139141
140142
"""
141143
if (isinstance(update, Update)
142-
and (update.message or update.edited_message and self.allow_edited)):
144+
and (update.message or update.edited_message and self.allow_edited)):
143145
message = update.message or update.edited_message
144146

145147
if message.text:
@@ -170,6 +172,7 @@ def handle_update(self, update, dispatcher):
170172
dispatcher (:class:`telegram.ext.Dispatcher`): Dispatcher that originated the Update.
171173
172174
"""
175+
self.set_autowired_flags({'args', 'update_queue', 'job_queue', 'user_data', 'chat_data'})
173176
optional_args = self.collect_optional_args(dispatcher, update)
174177

175178
message = update.message or update.edited_message

telegram/ext/handler.py

Lines changed: 100 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
"""This module contains the base class for handlers as used by the Dispatcher."""
2020
import warnings
2121

22-
from telegram.utils.inspection import get_positional_arguments
22+
from telegram.utils.inspection import inspect_arguments
2323

2424

2525
class Handler(object):
@@ -77,12 +77,12 @@ def __init__(self,
7777
pass_chat_data=False):
7878
self.callback = callback
7979
self.autowire = autowire
80-
if self.autowire and any((pass_update_queue, pass_job_queue, pass_user_data, pass_chat_data)):
81-
warnings.warn('If `autowire` is set to `True`, it is unnecessary to provide any `pass_*` flags.')
8280
self.pass_update_queue = pass_update_queue
8381
self.pass_job_queue = pass_job_queue
8482
self.pass_user_data = pass_user_data
8583
self.pass_chat_data = pass_chat_data
84+
self._autowire_initialized = False
85+
self._callback_args = None
8686

8787
def check_update(self, update):
8888
"""
@@ -113,8 +113,77 @@ def handle_update(self, update, dispatcher):
113113
"""
114114
raise NotImplementedError
115115

116+
def __get_available_pass_flags(self):
117+
"""
118+
Used to provide warnings if the user decides to use `autowire` in conjunction with
119+
`pass_*` flags, and to recalculate all flags.
120+
121+
Getting objects dynamically is better than hard-coding all passable objects and setting
122+
them to False in here, because the base class should not know about the existence of
123+
passable objects that are only relevant to subclasses (e.g. args, groups, groupdict).
124+
"""
125+
return [f for f in dir(self) if f.startswith('pass_')]
126+
127+
def set_autowired_flags(self, passable={'update_queue', 'job_queue', 'user_data', 'chat_data'}):
128+
"""
129+
130+
Make the passable arguments explicit as opposed to dynamically generated to be absolutely
131+
safe that no arguments will be passed that are not allowed.
132+
"""
133+
134+
if not self.autowire:
135+
raise ValueError("This handler is not autowired.")
136+
137+
if self._autowire_initialized:
138+
# In case that users decide to change their callback signatures at runtime, give the
139+
# possibility to recalculate all flags.
140+
for flag in self.__get_available_pass_flags():
141+
setattr(self, flag, False)
142+
143+
all_passable_objects = {'update_queue', 'job_queue', 'user_data', 'chat_data', 'args', 'groups', 'groupdict'}
144+
145+
self._callback_args = inspect_arguments(self.callback)
146+
147+
def should_pass_obj(name):
148+
"""
149+
Utility to determine whether a passable object is part of
150+
the user handler's signature, makes sense in this context,
151+
and is not explicitly set to `False`.
152+
"""
153+
is_requested = name in all_passable_objects and name in self._callback_args
154+
if is_requested and name not in passable:
155+
warnings.warn("The argument `{}` cannot be autowired since it is not available "
156+
"on `{}s`.".format(name, type(self).__name__))
157+
return False
158+
return is_requested
159+
160+
# Check whether the user has set any `pass_*` flag to True in addition to `autowire`
161+
for flag in self.__get_available_pass_flags():
162+
to_pass = bool(getattr(self, flag))
163+
if to_pass is True:
164+
warnings.warn('If `autowire` is set to `True`, it is unnecessary '
165+
'to provide the `{}` flag.'.format(flag))
166+
167+
if should_pass_obj('update_queue'):
168+
self.pass_update_queue = True
169+
if should_pass_obj('job_queue'):
170+
self.pass_job_queue = True
171+
if should_pass_obj('user_data'):
172+
self.pass_user_data = True
173+
if should_pass_obj('chat_data'):
174+
self.pass_chat_data = True
175+
if should_pass_obj('args'):
176+
self.pass_args = True
177+
if should_pass_obj('groups'):
178+
self.pass_groups = True
179+
if should_pass_obj('groupdict'):
180+
self.pass_groupdict = True
181+
182+
self._autowire_initialized = True
183+
116184
def collect_optional_args(self, dispatcher, update=None):
117-
"""Prepares the optional arguments that are the same for all types of handlers.
185+
"""
186+
Prepares the optional arguments that are the same for all types of handlers.
118187
119188
Args:
120189
dispatcher (:class:`telegram.ext.Dispatcher`): The dispatcher.
@@ -123,27 +192,32 @@ def collect_optional_args(self, dispatcher, update=None):
123192
optional_args = dict()
124193

125194
if self.autowire:
126-
callback_args = get_positional_arguments(self.callback)
127-
if 'update_queue' in callback_args:
128-
optional_args['update_queue'] = dispatcher.update_queue
129-
if 'job_queue' in callback_args:
130-
optional_args['job_queue'] = dispatcher.job_queue
131-
if 'user_data' in callback_args:
132-
user = update.effective_user
133-
optional_args['user_data'] = dispatcher.user_data[user.id if user else None]
134-
if 'chat_data' in callback_args:
135-
chat = update.effective_chat
136-
optional_args['chat_data'] = dispatcher.chat_data[chat.id if chat else None]
137-
else:
138-
if self.pass_update_queue:
139-
optional_args['update_queue'] = dispatcher.update_queue
140-
if self.pass_job_queue:
141-
optional_args['job_queue'] = dispatcher.job_queue
142-
if self.pass_user_data:
143-
user = update.effective_user
144-
optional_args['user_data'] = dispatcher.user_data[user.id if user else None]
145-
if self.pass_chat_data:
146-
chat = update.effective_chat
147-
optional_args['chat_data'] = dispatcher.chat_data[chat.id if chat else None]
195+
# Subclasses are responsible for calling `set_autowired_flags` in their __init__
196+
assert self._autowire_initialized
197+
198+
if self.pass_update_queue:
199+
optional_args['update_queue'] = dispatcher.update_queue
200+
if self.pass_job_queue:
201+
optional_args['job_queue'] = dispatcher.job_queue
202+
if self.pass_user_data:
203+
user = update.effective_user
204+
optional_args['user_data'] = dispatcher.user_data[user.id if user else None]
205+
if self.pass_chat_data:
206+
chat = update.effective_chat
207+
optional_args['chat_data'] = dispatcher.chat_data[chat.id if chat else None]
148208

149209
return optional_args
210+
211+
def collect_bot_update_args(self, dispatcher, update):
212+
if self.autowire:
213+
# Subclasses are responsible for calling `set_autowired_flags` in their __init__
214+
assert self._autowire_initialized
215+
216+
positional_args = []
217+
if 'bot' in self._callback_args:
218+
positional_args.append(dispatcher.bot)
219+
if 'update' in self._callback_args:
220+
positional_args.append(update)
221+
return positional_args
222+
else:
223+
return (dispatcher.bot, update)

telegram/ext/inlinequeryhandler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ def __init__(self,
112112
self.pattern = pattern
113113
self.pass_groups = pass_groups
114114
self.pass_groupdict = pass_groupdict
115+
if self.autowire:
116+
self.set_autowired_flags(passable={'groups', 'groupdict', 'user_data', 'chat_data'})
115117

116118
def check_update(self, update):
117119
"""
@@ -142,6 +144,7 @@ def handle_update(self, update, dispatcher):
142144
"""
143145

144146
optional_args = self.collect_optional_args(dispatcher, update)
147+
145148
if self.pattern:
146149
match = re.match(self.pattern, update.inline_query.query)
147150

telegram/ext/messagehandler.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@ class MessageHandler(Handler):
6464
callback (:obj:`callable`): A function that takes ``bot, update`` as positional arguments.
6565
It will be called when the :attr:`check_update` has determined that an update should be
6666
processed by this handler.
67+
autowire (:obj:`bool`, optional): If set to ``True``, your callback handler will be
68+
inspected for positional arguments and pass objects whose names match any of the
69+
``pass_*`` flags of this Handler. Using any ``pass_*`` argument in conjunction with
70+
``autowire`` will yield
71+
a warning.
6772
pass_update_queue (:obj:`bool`, optional): If set to ``True``, a keyword argument called
6873
``update_queue`` will be passed to the callback function. It will be the ``Queue``
6974
instance used by the :class:`telegram.ext.Updater` and :class:`telegram.ext.Dispatcher`
@@ -121,6 +126,9 @@ def __init__(self,
121126
self.channel_post_updates = channel_post_updates
122127
self.edited_updates = edited_updates
123128

129+
if self.autowire:
130+
self.set_autowired_flags()
131+
124132
# We put this up here instead of with the rest of checking code
125133
# in check_update since we don't wanna spam a ton
126134
if isinstance(self.filters, list):
@@ -168,6 +176,7 @@ def handle_update(self, update, dispatcher):
168176
dispatcher (:class:`telegram.ext.Dispatcher`): Dispatcher that originated the Update.
169177
170178
"""
179+
positional_args = self.collect_bot_update_args(dispatcher, update)
171180
optional_args = self.collect_optional_args(dispatcher, update)
172181

173-
return self.callback(dispatcher.bot, update, **optional_args)
182+
return self.callback(*positional_args, **optional_args)

telegram/ext/precheckoutqueryhandler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ def __init__(self,
8282
pass_job_queue=pass_job_queue,
8383
pass_user_data=pass_user_data,
8484
pass_chat_data=pass_chat_data)
85+
if self.autowire:
86+
self.set_autowired_flags()
8587

8688
def check_update(self, update):
8789
"""Determines whether an update should be passed to this handlers :attr:`callback`.

telegram/ext/regexhandler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ def __init__(self,
137137
self.pattern = pattern
138138
self.pass_groups = pass_groups
139139
self.pass_groupdict = pass_groupdict
140+
if self.autowire:
141+
self.set_autowired_flags({'groups', 'groupdict', 'update_queue', 'job_queue', 'user_data', 'chat_data'})
140142
self.allow_edited = allow_edited
141143
self.message_updates = message_updates
142144
self.channel_post_updates = channel_post_updates

telegram/ext/shippingqueryhandler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ def __init__(self,
8282
pass_job_queue=pass_job_queue,
8383
pass_user_data=pass_user_data,
8484
pass_chat_data=pass_chat_data)
85+
if self.autowire:
86+
self.set_autowired_flags()
8587

8688
def check_update(self, update):
8789
"""Determines whether an update should be passed to this handlers :attr:`callback`.

telegram/ext/stringcommandhandler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,16 @@ class StringCommandHandler(Handler):
6464
def __init__(self,
6565
command,
6666
callback,
67+
autowire=False,
6768
pass_args=False,
6869
pass_update_queue=False,
6970
pass_job_queue=False):
7071
super(StringCommandHandler, self).__init__(
7172
callback, pass_update_queue=pass_update_queue, pass_job_queue=pass_job_queue)
7273
self.command = command
7374
self.pass_args = pass_args
75+
if self.autowire:
76+
self.set_autowired_flags(passable={'groups', 'groupdict', 'user_data', 'chat_data', 'args'})
7477

7578
def check_update(self, update):
7679
"""Determines whether an update should be passed to this handlers :attr:`callback`.

0 commit comments

Comments
 (0)