@@ -89,6 +89,7 @@ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool
8989
9090 self .to_json_file (output_config_file )
9191 logger .info (f"ConfigMixinuration saved in { output_config_file } " )
92+
9293
9394 @classmethod
9495 def get_config_dict (
@@ -182,35 +183,43 @@ def get_config_dict(
182183 logger .info (f"loading configuration file { config_file } " )
183184 else :
184185 logger .info (f"loading configuration file { config_file } from cache at { resolved_config_file } " )
186+
187+ return config_dict
185188
189+ @classmethod
190+ def extract_init_dict (cls , config_dict , ** kwargs ):
186191 expected_keys = set (dict (inspect .signature (cls .__init__ ).parameters ).keys ())
187192 expected_keys .remove ("self" )
188-
193+ import ipdb ; ipdb .set_trace ()
194+ init_dict = {}
189195 for key in expected_keys :
190196 if key in kwargs :
191197 # overwrite key
192- config_dict [key ] = kwargs .pop (key )
198+ init_dict [key ] = kwargs .pop (key )
199+ elif key in config_dict :
200+ # use value from config dict
201+ init_dict [key ] = config_dict .pop (key )
193202
194- passed_keys = set (config_dict .keys ())
195-
196- unused_kwargs = kwargs
197- for key in passed_keys - expected_keys :
198- unused_kwargs [key ] = config_dict .pop (key )
199203
204+ unused_kwargs = config_dict .update (kwargs )
205+
206+ passed_keys = set (init_dict .keys ())
200207 if len (expected_keys - passed_keys ) > 0 :
201208 logger .warn (
202209 f"{ expected_keys - passed_keys } was not found in config. Values will be initialized to default values."
203210 )
204211
205- return config_dict , unused_kwargs
212+ return init_dict , unused_kwargs
206213
207214 @classmethod
208215 def from_config (cls , pretrained_model_name_or_path : Union [str , os .PathLike ], return_unused_kwargs = False , ** kwargs ):
209- config_dict , unused_kwargs = cls .get_config_dict (
216+ config_dict = cls .get_config_dict (
210217 pretrained_model_name_or_path = pretrained_model_name_or_path , ** kwargs
211218 )
212219
213- model = cls (** config_dict )
220+ init_dict , unused_kwargs = cls .extract_init_dict (config_dict , ** kwargs )
221+
222+ model = cls (** init_dict )
214223
215224 if return_unused_kwargs :
216225 return model , unused_kwargs
0 commit comments