Skip to content

Commit b9f90dc

Browse files
nithinraokeustlb
andauthored
add support for saving encoder only so any parakeet model can be loaded for inference (#41969)
* add support for saving encoder only so any decoder model can be loaded Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> * use convolution_bias * convert modular * convolution_bias in convertion script --------- Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> Co-authored-by: Eustache Le Bihan <eulebihan@gmail.com> Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>
1 parent 37a6296 commit b9f90dc

File tree

5 files changed

+126
-32
lines changed

5 files changed

+126
-32
lines changed

src/transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ class FastSpeech2ConformerConfig(PreTrainedConfig):
147147
Speaker embedding dimension. If set to > 0, assume that speaker_embedding will be provided as the input.
148148
is_encoder_decoder (`bool`, *optional*, defaults to `True`):
149149
Specifies whether the model is an encoder-decoder.
150+
convolution_bias (`bool`, *optional*, defaults to `True`):
151+
Specifies whether to use bias in convolutions of the conformer's convolution module.
150152
151153
Example:
152154
@@ -224,6 +226,7 @@ def __init__(
224226
num_languages=None,
225227
speaker_embed_dim=None,
226228
is_encoder_decoder=True,
229+
convolution_bias=True,
227230
**kwargs,
228231
):
229232
if positionwise_conv_kernel_size % 2 == 0:
@@ -318,6 +321,7 @@ def __init__(
318321
self.speaker_embed_dim = speaker_embed_dim
319322
self.duration_predictor_dropout_rate = duration_predictor_dropout_rate
320323
self.is_encoder_decoder = is_encoder_decoder
324+
self.convolution_bias = convolution_bias
321325

322326
super().__init__(
323327
is_encoder_decoder=is_encoder_decoder,

src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -490,12 +490,22 @@ def __init__(self, config: FastSpeech2ConformerConfig, module_config=None):
490490
kernel_size = module_config["kernel_size"]
491491
self.activation = ACT2FN[module_config.get("activation", "silu")]
492492
self.padding = (kernel_size - 1) // 2
493-
self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=True)
493+
self.pointwise_conv1 = nn.Conv1d(
494+
channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=config.convolution_bias
495+
)
494496
self.depthwise_conv = nn.Conv1d(
495-
channels, channels, kernel_size, stride=1, padding=self.padding, groups=channels, bias=True
497+
channels,
498+
channels,
499+
kernel_size,
500+
stride=1,
501+
padding=self.padding,
502+
groups=channels,
503+
bias=config.convolution_bias,
496504
)
497505
self.norm = nn.BatchNorm1d(channels)
498-
self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=True)
506+
self.pointwise_conv2 = nn.Conv1d(
507+
channels, channels, kernel_size=1, stride=1, padding=0, bias=config.convolution_bias
508+
)
499509

500510
def forward(self, hidden_states, attention_mask=None):
501511
"""

src/transformers/models/parakeet/configuration_parakeet.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ class ParakeetEncoderConfig(PreTrainedConfig):
4444
The non-linear activation function (function or string) in the encoder and pooler.
4545
attention_bias (`bool`, *optional*, defaults to `True`):
4646
Whether to use bias in the attention layers.
47+
convolution_bias (`bool`, *optional*, defaults to `True`):
48+
Whether to use bias in convolutions of the conformer's convolution module.
4749
conv_kernel_size (`int`, *optional*, defaults to 9):
4850
The kernel size of the convolution layers in the Conformer block.
4951
subsampling_factor (`int`, *optional*, defaults to 8):
@@ -102,6 +104,7 @@ def __init__(
102104
intermediate_size=4096,
103105
hidden_act="silu",
104106
attention_bias=True,
107+
convolution_bias=True,
105108
conv_kernel_size=9,
106109
subsampling_factor=8,
107110
subsampling_conv_channels=256,
@@ -128,6 +131,7 @@ def __init__(
128131
self.intermediate_size = intermediate_size
129132
self.hidden_act = hidden_act
130133
self.attention_bias = attention_bias
134+
self.convolution_bias = convolution_bias
131135

132136
if (conv_kernel_size - 1) % 2 != 0:
133137
raise ValueError(f"conv_kernel_size must be odd, got {conv_kernel_size}")

src/transformers/models/parakeet/convert_nemo_to_hf.py

Lines changed: 92 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
from transformers import (
2727
ParakeetCTCConfig,
28+
ParakeetEncoder,
29+
ParakeetEncoderConfig,
2830
ParakeetFeatureExtractor,
2931
ParakeetForCTC,
3032
ParakeetProcessor,
@@ -203,7 +205,8 @@ def write_processor(nemo_config: dict, model_files, output_dir, push_to_repo_id=
203205
processor.push_to_hub(push_to_repo_id)
204206

205207

206-
def write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_id=None):
208+
def convert_encoder_config(nemo_config):
209+
"""Convert NeMo encoder config to HF encoder config."""
207210
encoder_keys_to_ignore = [
208211
"att_context_size",
209212
"causal_downsampling",
@@ -220,8 +223,11 @@ def write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_i
220223
"stochastic_depth_mode",
221224
"conv_context_size",
222225
"dropout_pre_encoder",
226+
"reduction",
227+
"reduction_factor",
228+
"reduction_position",
223229
]
224-
enocder_config_keys_mapping = {
230+
encoder_config_keys_mapping = {
225231
"d_model": "hidden_size",
226232
"n_heads": "num_attention_heads",
227233
"n_layers": "num_hidden_layers",
@@ -234,17 +240,26 @@ def write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_i
234240
"dropout_emb": "dropout_positions",
235241
"dropout_att": "attention_dropout",
236242
"xscaling": "scale_input",
243+
"use_bias": "attention_bias",
237244
}
238245
converted_encoder_config = {}
239246

240247
for key, value in nemo_config["encoder"].items():
241248
if key in encoder_keys_to_ignore:
242249
continue
243-
if key in enocder_config_keys_mapping:
244-
converted_encoder_config[enocder_config_keys_mapping[key]] = value
250+
if key in encoder_config_keys_mapping:
251+
converted_encoder_config[encoder_config_keys_mapping[key]] = value
252+
# NeMo uses 'use_bias' for both attention and convolution bias, but HF separates them
253+
if key == "use_bias":
254+
converted_encoder_config["convolution_bias"] = value
245255
else:
246-
raise ValueError(f"Key {key} not found in enocder_config_keys_mapping")
256+
raise ValueError(f"Key {key} not found in encoder_config_keys_mapping")
257+
258+
return ParakeetEncoderConfig(**converted_encoder_config)
259+
247260

261+
def load_and_convert_state_dict(model_files):
262+
"""Load NeMo state dict and convert keys to HF format."""
248263
state_dict = torch.load(model_files["model_weights"], map_location="cpu", weights_only=True)
249264
converted_state_dict = {}
250265
for key, value in state_dict.items():
@@ -255,31 +270,80 @@ def write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_i
255270
converted_key = convert_key(key, NEMO_TO_HF_WEIGHT_MAPPING)
256271
converted_state_dict[converted_key] = value
257272

258-
if model_type == "ctc":
259-
model_config = ParakeetCTCConfig(
260-
encoder_config=converted_encoder_config,
261-
)
262-
print("Loading the checkpoint in a Parakeet CTC model.")
263-
with torch.device("meta"):
264-
model = ParakeetForCTC(model_config)
265-
model.load_state_dict(converted_state_dict, strict=True, assign=True)
266-
print("Checkpoint loaded successfully.")
267-
del model.config._name_or_path
273+
return converted_state_dict
274+
275+
276+
def write_ctc_model(encoder_config, converted_state_dict, output_dir, push_to_repo_id=None):
277+
"""Write CTC model using encoder config and converted state dict."""
278+
model_config = ParakeetCTCConfig.from_encoder_config(encoder_config)
279+
280+
print("Loading the checkpoint in a Parakeet CTC model.")
281+
with torch.device("meta"):
282+
model = ParakeetForCTC(model_config)
283+
model.load_state_dict(converted_state_dict, strict=True, assign=True)
284+
print("Checkpoint loaded successfully.")
285+
del model.config._name_or_path
286+
287+
print("Saving the model.")
288+
model.save_pretrained(output_dir)
289+
290+
if push_to_repo_id:
291+
model.push_to_hub(push_to_repo_id)
268292

269-
print("Saving the model.")
270-
model.save_pretrained(output_dir)
293+
del model
271294

272-
if push_to_repo_id:
273-
model.push_to_hub(push_to_repo_id)
295+
# Safety check: reload the converted model
296+
gc.collect()
297+
print("Reloading the model to check if it's saved correctly.")
298+
ParakeetForCTC.from_pretrained(output_dir, dtype=torch.bfloat16, device_map="auto")
299+
print("Model reloaded successfully.")
274300

275-
del converted_state_dict, model
276301

277-
# Safety check: reload the converted model
278-
gc.collect()
279-
print("Reloading the model to check if it's saved correctly.")
280-
ParakeetForCTC.from_pretrained(output_dir, dtype=torch.bfloat16, device_map="auto")
281-
print("Model reloaded successfully.")
302+
def write_encoder_model(encoder_config, converted_state_dict, output_dir, push_to_repo_id=None):
303+
"""Write encoder model using encoder config and converted state dict."""
304+
# Filter to only encoder weights (exclude CTC head if present)
305+
encoder_state_dict = {
306+
k.replace("encoder.", "", 1) if k.startswith("encoder.") else k: v
307+
for k, v in converted_state_dict.items()
308+
if k.startswith("encoder.")
309+
}
310+
311+
print("Loading the checkpoint in a Parakeet Encoder model (for TDT).")
312+
with torch.device("meta"):
313+
model = ParakeetEncoder(encoder_config)
314+
315+
model.load_state_dict(encoder_state_dict, strict=True, assign=True)
316+
print("Checkpoint loaded successfully.")
317+
del model.config._name_or_path
318+
319+
print("Saving the model.")
320+
model.save_pretrained(output_dir)
321+
322+
if push_to_repo_id:
323+
model.push_to_hub(push_to_repo_id)
324+
del model
325+
326+
# Safety check: reload the converted model
327+
gc.collect()
328+
print("Reloading the model to check if it's saved correctly.")
329+
ParakeetEncoder.from_pretrained(output_dir, dtype=torch.bfloat16, device_map="auto")
330+
print("Model reloaded successfully.")
331+
282332

333+
def write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_id=None):
334+
"""Main model conversion function."""
335+
# Step 1: Convert encoder config (shared across all model types)
336+
encoder_config = convert_encoder_config(nemo_config)
337+
print(f"Converted encoder config: {encoder_config}")
338+
339+
# Step 2: Load and convert state dict (shared across all model types)
340+
converted_state_dict = load_and_convert_state_dict(model_files)
341+
342+
# Step 3: Write model based on type
343+
if model_type == "encoder":
344+
write_encoder_model(encoder_config, converted_state_dict, output_dir, push_to_repo_id)
345+
elif model_type == "ctc":
346+
write_ctc_model(encoder_config, converted_state_dict, output_dir, push_to_repo_id)
283347
else:
284348
raise ValueError(f"Model type {model_type} not supported.")
285349

@@ -303,7 +367,9 @@ def main(
303367
if __name__ == "__main__":
304368
parser = argparse.ArgumentParser()
305369
parser.add_argument("--hf_repo_id", required=True, help="Model repo on huggingface.co")
306-
parser.add_argument("--model_type", required=True, choices=["ctc"], help="Model type (`ctc`, `tdt`)")
370+
parser.add_argument(
371+
"--model_type", required=True, choices=["encoder", "ctc"], help="Model type (`encoder`, `ctc`)"
372+
)
307373
parser.add_argument("--output_dir", required=True, help="Output directory for HuggingFace model")
308374
parser.add_argument("--push_to_repo_id", help="Repository ID to push the model to on the Hub")
309375
args = parser.parse_args()

src/transformers/models/parakeet/modeling_parakeet.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,22 @@ def __init__(self, config: ParakeetEncoderConfig, module_config=None):
130130
kernel_size = module_config["kernel_size"]
131131
self.activation = ACT2FN[module_config.get("activation", "silu")]
132132
self.padding = (kernel_size - 1) // 2
133-
self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=True)
133+
self.pointwise_conv1 = nn.Conv1d(
134+
channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=config.convolution_bias
135+
)
134136
self.depthwise_conv = nn.Conv1d(
135-
channels, channels, kernel_size, stride=1, padding=self.padding, groups=channels, bias=True
137+
channels,
138+
channels,
139+
kernel_size,
140+
stride=1,
141+
padding=self.padding,
142+
groups=channels,
143+
bias=config.convolution_bias,
136144
)
137145
self.norm = nn.BatchNorm1d(channels)
138-
self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=True)
146+
self.pointwise_conv2 = nn.Conv1d(
147+
channels, channels, kernel_size=1, stride=1, padding=0, bias=config.convolution_bias
148+
)
139149

140150
def forward(self, hidden_states, attention_mask=None):
141151
"""

0 commit comments

Comments
 (0)