1919"""This module contains the base class for handlers as used by the Dispatcher."""
2020import warnings
2121
22- from telegram .utils .inspection import get_positional_arguments
22+ from telegram .utils .inspection import inspect_arguments
2323
2424
2525class Handler (object ):
@@ -77,12 +77,12 @@ def __init__(self,
7777 pass_chat_data = False ):
7878 self .callback = callback
7979 self .autowire = autowire
80- if self .autowire and any ((pass_update_queue , pass_job_queue , pass_user_data , pass_chat_data )):
81- warnings .warn ('If `autowire` is set to `True`, it is unnecessary to provide any `pass_*` flags.' )
8280 self .pass_update_queue = pass_update_queue
8381 self .pass_job_queue = pass_job_queue
8482 self .pass_user_data = pass_user_data
8583 self .pass_chat_data = pass_chat_data
84+ self ._autowire_initialized = False
85+ self ._callback_args = None
8686
8787 def check_update (self , update ):
8888 """
@@ -113,8 +113,77 @@ def handle_update(self, update, dispatcher):
113113 """
114114 raise NotImplementedError
115115
116+ def __get_available_pass_flags (self ):
117+ """
118+ Used to provide warnings if the user decides to use `autowire` in conjunction with
119+ `pass_*` flags, and to recalculate all flags.
120+
121+ Getting objects dynamically is better than hard-coding all passable objects and setting
122+ them to False in here, because the base class should not know about the existence of
123+ passable objects that are only relevant to subclasses (e.g. args, groups, groupdict).
124+ """
125+ return [f for f in dir (self ) if f .startswith ('pass_' )]
126+
127+ def set_autowired_flags (self , passable = {'update_queue' , 'job_queue' , 'user_data' , 'chat_data' }):
128+ """
129+
130+ Make the passable arguments explicit as opposed to dynamically generated to be absolutely
131+ safe that no arguments will be passed that are not allowed.
132+ """
133+
134+ if not self .autowire :
135+ raise ValueError ("This handler is not autowired." )
136+
137+ if self ._autowire_initialized :
138+ # In case that users decide to change their callback signatures at runtime, give the
139+ # possibility to recalculate all flags.
140+ for flag in self .__get_available_pass_flags ():
141+ setattr (self , flag , False )
142+
143+ all_passable_objects = {'update_queue' , 'job_queue' , 'user_data' , 'chat_data' , 'args' , 'groups' , 'groupdict' }
144+
145+ self ._callback_args = inspect_arguments (self .callback )
146+
147+ def should_pass_obj (name ):
148+ """
149+ Utility to determine whether a passable object is part of
150+ the user handler's signature, makes sense in this context,
151+ and is not explicitly set to `False`.
152+ """
153+ is_requested = name in all_passable_objects and name in self ._callback_args
154+ if is_requested and name not in passable :
155+ warnings .warn ("The argument `{}` cannot be autowired since it is not available "
156+ "on `{}s`." .format (name , type (self ).__name__ ))
157+ return False
158+ return is_requested
159+
160+ # Check whether the user has set any `pass_*` flag to True in addition to `autowire`
161+ for flag in self .__get_available_pass_flags ():
162+ to_pass = bool (getattr (self , flag ))
163+ if to_pass is True :
164+ warnings .warn ('If `autowire` is set to `True`, it is unnecessary '
165+ 'to provide the `{}` flag.' .format (flag ))
166+
167+ if should_pass_obj ('update_queue' ):
168+ self .pass_update_queue = True
169+ if should_pass_obj ('job_queue' ):
170+ self .pass_job_queue = True
171+ if should_pass_obj ('user_data' ):
172+ self .pass_user_data = True
173+ if should_pass_obj ('chat_data' ):
174+ self .pass_chat_data = True
175+ if should_pass_obj ('args' ):
176+ self .pass_args = True
177+ if should_pass_obj ('groups' ):
178+ self .pass_groups = True
179+ if should_pass_obj ('groupdict' ):
180+ self .pass_groupdict = True
181+
182+ self ._autowire_initialized = True
183+
116184 def collect_optional_args (self , dispatcher , update = None ):
117- """Prepares the optional arguments that are the same for all types of handlers.
185+ """
186+ Prepares the optional arguments that are the same for all types of handlers.
118187
119188 Args:
120189 dispatcher (:class:`telegram.ext.Dispatcher`): The dispatcher.
@@ -123,27 +192,32 @@ def collect_optional_args(self, dispatcher, update=None):
123192 optional_args = dict ()
124193
125194 if self .autowire :
126- callback_args = get_positional_arguments (self .callback )
127- if 'update_queue' in callback_args :
128- optional_args ['update_queue' ] = dispatcher .update_queue
129- if 'job_queue' in callback_args :
130- optional_args ['job_queue' ] = dispatcher .job_queue
131- if 'user_data' in callback_args :
132- user = update .effective_user
133- optional_args ['user_data' ] = dispatcher .user_data [user .id if user else None ]
134- if 'chat_data' in callback_args :
135- chat = update .effective_chat
136- optional_args ['chat_data' ] = dispatcher .chat_data [chat .id if chat else None ]
137- else :
138- if self .pass_update_queue :
139- optional_args ['update_queue' ] = dispatcher .update_queue
140- if self .pass_job_queue :
141- optional_args ['job_queue' ] = dispatcher .job_queue
142- if self .pass_user_data :
143- user = update .effective_user
144- optional_args ['user_data' ] = dispatcher .user_data [user .id if user else None ]
145- if self .pass_chat_data :
146- chat = update .effective_chat
147- optional_args ['chat_data' ] = dispatcher .chat_data [chat .id if chat else None ]
195+ # Subclasses are responsible for calling `set_autowired_flags` in their __init__
196+ assert self ._autowire_initialized
197+
198+ if self .pass_update_queue :
199+ optional_args ['update_queue' ] = dispatcher .update_queue
200+ if self .pass_job_queue :
201+ optional_args ['job_queue' ] = dispatcher .job_queue
202+ if self .pass_user_data :
203+ user = update .effective_user
204+ optional_args ['user_data' ] = dispatcher .user_data [user .id if user else None ]
205+ if self .pass_chat_data :
206+ chat = update .effective_chat
207+ optional_args ['chat_data' ] = dispatcher .chat_data [chat .id if chat else None ]
148208
149209 return optional_args
210+
211+ def collect_bot_update_args (self , dispatcher , update ):
212+ if self .autowire :
213+ # Subclasses are responsible for calling `set_autowired_flags` in their __init__
214+ assert self ._autowire_initialized
215+
216+ positional_args = []
217+ if 'bot' in self ._callback_args :
218+ positional_args .append (dispatcher .bot )
219+ if 'update' in self ._callback_args :
220+ positional_args .append (update )
221+ return positional_args
222+ else :
223+ return (dispatcher .bot , update )
0 commit comments