2525
2626from transformers import (
2727 ParakeetCTCConfig ,
28+ ParakeetEncoder ,
29+ ParakeetEncoderConfig ,
2830 ParakeetFeatureExtractor ,
2931 ParakeetForCTC ,
3032 ParakeetProcessor ,
@@ -203,7 +205,8 @@ def write_processor(nemo_config: dict, model_files, output_dir, push_to_repo_id=
203205 processor .push_to_hub (push_to_repo_id )
204206
205207
206- def write_model (nemo_config , model_files , model_type , output_dir , push_to_repo_id = None ):
208+ def convert_encoder_config (nemo_config ):
209+ """Convert NeMo encoder config to HF encoder config."""
207210 encoder_keys_to_ignore = [
208211 "att_context_size" ,
209212 "causal_downsampling" ,
@@ -220,8 +223,11 @@ def write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_i
220223 "stochastic_depth_mode" ,
221224 "conv_context_size" ,
222225 "dropout_pre_encoder" ,
226+ "reduction" ,
227+ "reduction_factor" ,
228+ "reduction_position" ,
223229 ]
224- enocder_config_keys_mapping = {
230+ encoder_config_keys_mapping = {
225231 "d_model" : "hidden_size" ,
226232 "n_heads" : "num_attention_heads" ,
227233 "n_layers" : "num_hidden_layers" ,
@@ -234,17 +240,26 @@ def write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_i
234240 "dropout_emb" : "dropout_positions" ,
235241 "dropout_att" : "attention_dropout" ,
236242 "xscaling" : "scale_input" ,
243+ "use_bias" : "attention_bias" ,
237244 }
238245 converted_encoder_config = {}
239246
240247 for key , value in nemo_config ["encoder" ].items ():
241248 if key in encoder_keys_to_ignore :
242249 continue
243- if key in enocder_config_keys_mapping :
244- converted_encoder_config [enocder_config_keys_mapping [key ]] = value
250+ if key in encoder_config_keys_mapping :
251+ converted_encoder_config [encoder_config_keys_mapping [key ]] = value
252+ # NeMo uses 'use_bias' for both attention and convolution bias, but HF separates them
253+ if key == "use_bias" :
254+ converted_encoder_config ["convolution_bias" ] = value
245255 else :
246- raise ValueError (f"Key { key } not found in enocder_config_keys_mapping" )
256+ raise ValueError (f"Key { key } not found in encoder_config_keys_mapping" )
257+
258+ return ParakeetEncoderConfig (** converted_encoder_config )
259+
247260
261+ def load_and_convert_state_dict (model_files ):
262+ """Load NeMo state dict and convert keys to HF format."""
248263 state_dict = torch .load (model_files ["model_weights" ], map_location = "cpu" , weights_only = True )
249264 converted_state_dict = {}
250265 for key , value in state_dict .items ():
@@ -255,31 +270,80 @@ def write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_i
255270 converted_key = convert_key (key , NEMO_TO_HF_WEIGHT_MAPPING )
256271 converted_state_dict [converted_key ] = value
257272
258- if model_type == "ctc" :
259- model_config = ParakeetCTCConfig (
260- encoder_config = converted_encoder_config ,
261- )
262- print ("Loading the checkpoint in a Parakeet CTC model." )
263- with torch .device ("meta" ):
264- model = ParakeetForCTC (model_config )
265- model .load_state_dict (converted_state_dict , strict = True , assign = True )
266- print ("Checkpoint loaded successfully." )
267- del model .config ._name_or_path
273+ return converted_state_dict
274+
275+
276+ def write_ctc_model (encoder_config , converted_state_dict , output_dir , push_to_repo_id = None ):
277+ """Write CTC model using encoder config and converted state dict."""
278+ model_config = ParakeetCTCConfig .from_encoder_config (encoder_config )
279+
280+ print ("Loading the checkpoint in a Parakeet CTC model." )
281+ with torch .device ("meta" ):
282+ model = ParakeetForCTC (model_config )
283+ model .load_state_dict (converted_state_dict , strict = True , assign = True )
284+ print ("Checkpoint loaded successfully." )
285+ del model .config ._name_or_path
286+
287+ print ("Saving the model." )
288+ model .save_pretrained (output_dir )
289+
290+ if push_to_repo_id :
291+ model .push_to_hub (push_to_repo_id )
268292
269- print ("Saving the model." )
270- model .save_pretrained (output_dir )
293+ del model
271294
272- if push_to_repo_id :
273- model .push_to_hub (push_to_repo_id )
295+ # Safety check: reload the converted model
296+ gc .collect ()
297+ print ("Reloading the model to check if it's saved correctly." )
298+ ParakeetForCTC .from_pretrained (output_dir , dtype = torch .bfloat16 , device_map = "auto" )
299+ print ("Model reloaded successfully." )
274300
275- del converted_state_dict , model
276301
277- # Safety check: reload the converted model
278- gc .collect ()
279- print ("Reloading the model to check if it's saved correctly." )
280- ParakeetForCTC .from_pretrained (output_dir , dtype = torch .bfloat16 , device_map = "auto" )
281- print ("Model reloaded successfully." )
302+ def write_encoder_model (encoder_config , converted_state_dict , output_dir , push_to_repo_id = None ):
303+ """Write encoder model using encoder config and converted state dict."""
304+ # Filter to only encoder weights (exclude CTC head if present)
305+ encoder_state_dict = {
306+ k .replace ("encoder." , "" , 1 ) if k .startswith ("encoder." ) else k : v
307+ for k , v in converted_state_dict .items ()
308+ if k .startswith ("encoder." )
309+ }
310+
311+ print ("Loading the checkpoint in a Parakeet Encoder model (for TDT)." )
312+ with torch .device ("meta" ):
313+ model = ParakeetEncoder (encoder_config )
314+
315+ model .load_state_dict (encoder_state_dict , strict = True , assign = True )
316+ print ("Checkpoint loaded successfully." )
317+ del model .config ._name_or_path
318+
319+ print ("Saving the model." )
320+ model .save_pretrained (output_dir )
321+
322+ if push_to_repo_id :
323+ model .push_to_hub (push_to_repo_id )
324+ del model
325+
326+ # Safety check: reload the converted model
327+ gc .collect ()
328+ print ("Reloading the model to check if it's saved correctly." )
329+ ParakeetEncoder .from_pretrained (output_dir , dtype = torch .bfloat16 , device_map = "auto" )
330+ print ("Model reloaded successfully." )
331+
282332
333+ def write_model (nemo_config , model_files , model_type , output_dir , push_to_repo_id = None ):
334+ """Main model conversion function."""
335+ # Step 1: Convert encoder config (shared across all model types)
336+ encoder_config = convert_encoder_config (nemo_config )
337+ print (f"Converted encoder config: { encoder_config } " )
338+
339+ # Step 2: Load and convert state dict (shared across all model types)
340+ converted_state_dict = load_and_convert_state_dict (model_files )
341+
342+ # Step 3: Write model based on type
343+ if model_type == "encoder" :
344+ write_encoder_model (encoder_config , converted_state_dict , output_dir , push_to_repo_id )
345+ elif model_type == "ctc" :
346+ write_ctc_model (encoder_config , converted_state_dict , output_dir , push_to_repo_id )
283347 else :
284348 raise ValueError (f"Model type { model_type } not supported." )
285349
@@ -303,7 +367,9 @@ def main(
303367if __name__ == "__main__" :
304368 parser = argparse .ArgumentParser ()
305369 parser .add_argument ("--hf_repo_id" , required = True , help = "Model repo on huggingface.co" )
306- parser .add_argument ("--model_type" , required = True , choices = ["ctc" ], help = "Model type (`ctc`, `tdt`)" )
370+ parser .add_argument (
371+ "--model_type" , required = True , choices = ["encoder" , "ctc" ], help = "Model type (`encoder`, `ctc`)"
372+ )
307373 parser .add_argument ("--output_dir" , required = True , help = "Output directory for HuggingFace model" )
308374 parser .add_argument ("--push_to_repo_id" , help = "Repository ID to push the model to on the Hub" )
309375 args = parser .parse_args ()
0 commit comments