Add SOT multi-talker ASR with Whisper#6405
Conversation
Add Serialized Output Training (SOT) for multi-talker ASR using native OpenAI Whisper encoder/decoder with tiktoken tokenization. Core pipeline: - SOTWhisperModel: extends ESPnetASRModel with min-CE loss over case variants for case-invariant training - SOTWhisperPreprocessor: tiktoken-based tokenizer handling timestamps and speaker change tokens, with support for reusing existing BPE tokens as special tokens - SOTBeamSearch: beam search with probability-based timestamp forcing and speaker separator preservation - SOTConstraintScorer: enforces valid SOT output structure (timestamp pairing, non-decreasing order, separator handling) - sot_postprocess: repetition truncation for hallucination prevention AMI SOT recipe (egs2/ami/sot_asr1): - Pipeline following ESPnet asr.sh convention (stages 1/5/11/12/13) - Lhotse CutSet to Kaldi-format data preparation - meeteval-based utterance-group cpWER evaluation - Configs for Whisper-small (production) and Whisper-tiny (testing) Whisper encoder/decoder fixes: - Encoder: derive n_mels from model instead of importing removed N_MELS constant, fixing v3/turbo compatibility - Decoder: handle positional embedding overflow for long sequences Unit tests: 22 tests covering model, preprocessor, task, and postprocessor.
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Code Review
This pull request introduces a Serialized Output Training (SOT) multi-talker ASR recipe for the AMI dataset using native OpenAI Whisper components. Key additions include a custom SOTWhisperModel with uppercase min-CE loss, a tiktoken-based preprocessor, and a SOTConstraintScorer for structured decoding. The changes also refine the Whisper decoder to handle positional embeddings for longer sequences and update the encoder to dynamically determine Mel-bin counts. Feedback suggests improving the robustness of the inference script by avoiding generic exception handling and simplifying the positional embedding logic in the decoder for better idiomaticity.
| max_pos = self.decoders.positional_embedding.size(0) | ||
| pos_len = min(tgt.size(1), max_pos) | ||
| x = ( | ||
| self.decoders.token_embedding(tgt) | ||
| + self.decoders.positional_embedding[: tgt.size(1)] | ||
| self.decoders.token_embedding(tgt[:, -pos_len:]) | ||
| + self.decoders.positional_embedding[:pos_len] | ||
| ) |
There was a problem hiding this comment.
The positional embedding slice logic is slightly complex. Using tgt.size(1) directly in the slice index is safer and more idiomatic in PyTorch when handling sequence lengths, as it avoids potential off-by-one errors or unnecessary min/max operations if the sequence length is guaranteed to be within bounds.
x = (
self.decoders.token_embedding(tgt)
+ self.decoders.positional_embedding[: tgt.size(1)]
)| except Exception as e: | ||
| logging.warning(f"Utterance {keys} failed: {e}") | ||
| hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[]) | ||
| results = [(" ", ["<space>"], [2], hyp)] * nbest |
There was a problem hiding this comment.
Pull request overview
Adds Serialized Output Training (SOT) support for multi-talker ASR using native OpenAI Whisper components (encoder/decoder + tiktoken), plus an AMI recipe and related utilities/tests.
Changes:
- Introduces SOT-specific core components: model, tiktoken preprocessor, constraint scorer, postprocessing, and inference entrypoints.
- Extends task/model registration to support
sot_whisperand adds CLI scripts for SOT training/inference. - Adds an
egs2/ami/sot_asr1recipe with data prep, decoding, scoring utilities, and configs; plus new unit tests.
Reviewed changes
Copilot reviewed 30 out of 34 changed files in this pull request and generated 12 comments.
Show a summary per file
| File | Description |
|---|---|
| test/espnet2/train/test_sot_preprocessor.py | Adds unit coverage for SOT tiktoken preprocessor behaviors (prefix, timestamps, added tokens). |
| test/espnet2/tasks/test_sot_asr.py | Adds smoke tests for the new SOT task CLI/parser behaviors. |
| test/espnet2/asr/test_sot_espnet_model.py | Adds unit tests for SOT Whisper model initialization and forward/backward. |
| test/espnet2/asr/postprocess/test_sot_postprocess.py | Adds tests for repetition-truncation postprocessing utilities. |
| espnet2/train/sot_preprocessor.py | Implements tiktoken-based SOT text parsing/tokenization and token list generation. |
| espnet2/tasks/sot_asr.py | Registers a SOT-specific ASR task choosing SOT model + preprocessor by default. |
| espnet2/tasks/asr.py | Adds sot_whisper into the generic ASR model choice registry. |
| espnet2/bin/sot_train.py | Adds a SOT training entrypoint wiring to SOTASRTask. |
| espnet2/bin/sot_inference.py | Adds SOT decoding with constraints, timestamp-forcing beam search, and postprocessing. |
| espnet2/asr/sot_espnet_model.py | Adds SOTWhisperModel with optional uppercase min-CE attention loss. |
| espnet2/asr/scorers/sot_constraint_scorer.py | Adds a constraint scorer enforcing valid SOT output structure during decoding. |
| espnet2/asr/postprocess/sot_postprocess.py | Adds hallucination mitigation (repetition truncation) and SOT output reconstruction. |
| espnet2/asr/encoder/whisper_encoder.py | Fixes Whisper encoder mel-bin handling by deriving n_mels from the model. |
| espnet2/asr/decoder/whisper_decoder.py | Adds a workaround for positional embedding overflow in long sequences. |
| egs2/ami/sot_asr1/utils | Recipe linkage to template utils. |
| egs2/ami/sot_asr1/steps | Recipe linkage to template steps. |
| egs2/ami/sot_asr1/pyscripts | Recipe linkage to template pyscripts. |
| egs2/ami/sot_asr1/scripts/toy_pipeline_test.sh | Adds an end-to-end toy pipeline script exercising token list, training, and decoding. |
| egs2/ami/sot_asr1/run_decode.py | Adds a simple inference helper script bypassing the DataLoader. |
| egs2/ami/sot_asr1/run.sh | Adds the AMI SOT recipe pipeline (prep/train/decode/score). |
| egs2/ami/sot_asr1/local/prepare_sot.py | Implements Lhotse CutSet → Kaldi-format SOT data preparation. |
| egs2/ami/sot_asr1/local/generate_token_list.py | Adds a CLI wrapper for generating token_list from tiktoken. |
| egs2/ami/sot_asr1/local/generate_config_yaml.py | Adds a helper to generate an inference config.yaml without training. |
| egs2/ami/sot_asr1/local/evaluate_sot.py | Adds meeteval-based cpWER scoring for SOT outputs. |
| egs2/ami/sot_asr1/local/added_tokens.txt | Adds the recipe’s speaker-separator token list file. |
| egs2/ami/sot_asr1/conf/tuning/train_sot_tiny.yaml | Adds a tiny Whisper SOT training config for quick tests. |
| egs2/ami/sot_asr1/conf/tuning/train_sot_small.yaml | Adds a small Whisper SOT training config for recipe runs. |
| egs2/ami/sot_asr1/conf/tuning/decode_sot.yaml | Adds decoding defaults for SOT inference. |
| egs2/ami/sot_asr1/conf/slurm.conf | Adds recipe scheduler config. |
| egs2/ami/sot_asr1/conf/queue.conf | Adds recipe scheduler config. |
| egs2/ami/sot_asr1/conf/pbs.conf | Adds recipe scheduler config. |
| egs2/ami/sot_asr1/cmd.sh | Adds recipe command backend selection. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| sym_eos: str = "<|endoftext|>", | ||
| autocast_frontend: bool = False, | ||
| extract_feats_in_collect_stats: bool = True, | ||
| lang_token_id: int = -1, |
There was a problem hiding this comment.
With the default lang_token_id=-1, the condition self.lang_token_id is not None is always true, so the model prepends -1 to every target sequence. In PyTorch, -1 indexes the last embedding row, which silently corrupts training targets. Use a sentinel-aware check (e.g., self.lang_token_id != -1 / >= 0) or change the default to None so language-token prepending only happens when explicitly enabled.
| lang_token_id: int = -1, | |
| lang_token_id: Optional[int] = None, |
| if hasattr(self, "lang_token_id") and self.lang_token_id is not None: | ||
| ys_pad = torch.cat( | ||
| [ | ||
| self.lang_token_id.repeat(ys_pad.size(0), 1).to(ys_pad.device), | ||
| ys_pad, | ||
| ], | ||
| dim=1, | ||
| ) | ||
| ys_pad_lens += 1 |
There was a problem hiding this comment.
With the default lang_token_id=-1, the condition self.lang_token_id is not None is always true, so the model prepends -1 to every target sequence. In PyTorch, -1 indexes the last embedding row, which silently corrupts training targets. Use a sentinel-aware check (e.g., self.lang_token_id != -1 / >= 0) or change the default to None so language-token prepending only happens when explicitly enabled.
| report_wer: bool = True, | ||
| sym_space: str = "<space>", | ||
| sym_blank: str = "<blank>", | ||
| transducer_multi_blank_durations: List = [], |
There was a problem hiding this comment.
This introduces a mutable default argument ([]), which can be shared across instances and lead to unexpected state leakage. Use None as the default and assign an empty list inside __init__ when needed.
| last_was_timestamp = ( | ||
| len(current_block) >= 1 and current_block[-1] >= self.timestamp_begin | ||
| ) | ||
| penultimate_was_timestamp = ( | ||
| len(current_block) < 2 or current_block[-2] >= self.timestamp_begin | ||
| ) | ||
|
|
There was a problem hiding this comment.
The scorer treats any token ID >= timestamp_begin as a timestamp. If the separator token is allocated above the timestamp range (e.g., <sc> as a newly-added token at 51865+), it will be misclassified as a timestamp and break the pairing/non-decreasing constraints. Exclude custom special tokens (e.g., self._custom_special_above_ts) from timestamp detection and the timestamps list so separators do not participate in timestamp constraints.
|
|
||
| # A4: Non-decreasing timestamps within current block | ||
| timestamps = [t for t in current_block if t >= self.timestamp_begin] | ||
| if timestamps: |
There was a problem hiding this comment.
The scorer treats any token ID >= timestamp_begin as a timestamp. If the separator token is allocated above the timestamp range (e.g., <sc> as a newly-added token at 51865+), it will be misclassified as a timestamp and break the pairing/non-decreasing constraints. Exclude custom special tokens (e.g., self._custom_special_above_ts) from timestamp detection and the timestamps list so separators do not participate in timestamp constraints.
| # Append added tokens | ||
| for token in extra_tokens: | ||
| lines.append(token) |
There was a problem hiding this comment.
generate_token_list() appends all extra_tokens unconditionally, even when an ‘added’ token already exists as a single BPE token in the base vocab (which this preprocessor explicitly supports reusing). This can create duplicate token strings in token_list and artificially enlarge len(token_list), potentially triggering unintended decoder embedding expansion and downstream ID mismatches. Align token-list generation with the runtime mapping: only append tokens that require new IDs; for reused single-token BPE specials, do not append (or otherwise ensure the token list’s indices remain consistent with the IDs actually used).
| RECIPE_DIR=/work/nvme/bbjs/chuang14/espnet-owsm-dtai/egs2/ami/sot_asr1 | ||
| ESPNET_ROOT=/work/nvme/bbjs/chuang14/espnet-owsm-dtai | ||
| export PYTHONPATH="${ESPNET_ROOT}:${PYTHONPATH:-}" | ||
|
|
||
| # Activate conda environment | ||
| eval "$(conda shell.bash hook)" | ||
| conda activate espnet-owsm | ||
|
|
||
| cd "${RECIPE_DIR}" | ||
|
|
||
| # Source data from dicow_asr1 (already pre-segmented) | ||
| DICOW_DIR=/work/nvme/bbjs/chuang14/espnet-owsm-dtai/egs2/ami/dicow_asr1 |
There was a problem hiding this comment.
This script hard-codes absolute filesystem paths and a specific conda environment, making it unusable for other users and unsuitable for inclusion in a general recipe. Convert these to relative paths (derived from the script location / pwd) and/or configurable arguments/env vars, and avoid conda activate inside the script (document environment requirements instead).
| RECIPE_DIR=/work/nvme/bbjs/chuang14/espnet-owsm-dtai/egs2/ami/sot_asr1 | |
| ESPNET_ROOT=/work/nvme/bbjs/chuang14/espnet-owsm-dtai | |
| export PYTHONPATH="${ESPNET_ROOT}:${PYTHONPATH:-}" | |
| # Activate conda environment | |
| eval "$(conda shell.bash hook)" | |
| conda activate espnet-owsm | |
| cd "${RECIPE_DIR}" | |
| # Source data from dicow_asr1 (already pre-segmented) | |
| DICOW_DIR=/work/nvme/bbjs/chuang14/espnet-owsm-dtai/egs2/ami/dicow_asr1 | |
| # Determine script location and default recipe/repo roots. | |
| SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" | |
| RECIPE_DIR="${RECIPE_DIR:-$(cd "${SCRIPT_DIR}/.." && pwd)}" | |
| ESPNET_ROOT="${ESPNET_ROOT:-$(cd "${RECIPE_DIR}/../.." && pwd)}" | |
| export PYTHONPATH="${ESPNET_ROOT}:${PYTHONPATH:-}" | |
| # NOTE: This script assumes that a suitable Python/ESPnet environment | |
| # (e.g., the "espnet-owsm" conda environment) is already activated | |
| # before invoking this script. Environment activation is intentionally | |
| # not performed here to keep the script portable. | |
| cd "${RECIPE_DIR}" | |
| # Source data from dicow_asr1 (already pre-segmented) | |
| DICOW_DIR="${DICOW_DIR:-${ESPNET_ROOT}/egs2/ami/dicow_asr1}" |
| RECIPE_DIR=/work/nvme/bbjs/chuang14/espnet-owsm-dtai/egs2/ami/sot_asr1 | ||
| ESPNET_ROOT=/work/nvme/bbjs/chuang14/espnet-owsm-dtai | ||
| export PYTHONPATH="${ESPNET_ROOT}:${PYTHONPATH:-}" | ||
|
|
||
| # Activate conda environment | ||
| eval "$(conda shell.bash hook)" | ||
| conda activate espnet-owsm | ||
|
|
||
| cd "${RECIPE_DIR}" | ||
|
|
||
| # Source data from dicow_asr1 (already pre-segmented) | ||
| DICOW_DIR=/work/nvme/bbjs/chuang14/espnet-owsm-dtai/egs2/ami/dicow_asr1 |
There was a problem hiding this comment.
This script hard-codes absolute filesystem paths and a specific conda environment, making it unusable for other users and unsuitable for inclusion in a general recipe. Convert these to relative paths (derived from the script location / pwd) and/or configurable arguments/env vars, and avoid conda activate inside the script (document environment requirements instead).
| RECIPE_DIR=/work/nvme/bbjs/chuang14/espnet-owsm-dtai/egs2/ami/sot_asr1 | |
| ESPNET_ROOT=/work/nvme/bbjs/chuang14/espnet-owsm-dtai | |
| export PYTHONPATH="${ESPNET_ROOT}:${PYTHONPATH:-}" | |
| # Activate conda environment | |
| eval "$(conda shell.bash hook)" | |
| conda activate espnet-owsm | |
| cd "${RECIPE_DIR}" | |
| # Source data from dicow_asr1 (already pre-segmented) | |
| DICOW_DIR=/work/nvme/bbjs/chuang14/espnet-owsm-dtai/egs2/ami/dicow_asr1 | |
| # Base directories: derive from this script location by default, but allow | |
| # overriding via environment variables (RECIPE_DIR, ESPNET_ROOT, DICOW_DIR). | |
| SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]:-${0}}")" && pwd)" | |
| RECIPE_DIR="${RECIPE_DIR:-$(cd "${SCRIPT_DIR}/.." && pwd)}" | |
| ESPNET_ROOT="${ESPNET_ROOT:-$(cd "${RECIPE_DIR}/../../.." && pwd)}" | |
| export PYTHONPATH="${ESPNET_ROOT}:${PYTHONPATH:-}" | |
| # NOTE: This script assumes that an appropriate Python environment | |
| # (e.g., the espnet-owsm conda environment) is already activated | |
| # before running. Do not run `conda activate` inside this script. | |
| cd "${RECIPE_DIR}" | |
| # Source data from dicow_asr1 (already pre-segmented) | |
| DICOW_DIR="${DICOW_DIR:-$(cd "${RECIPE_DIR}/.." && pwd)/dicow_asr1}" |
| sys.path.insert(0, "/work/nvme/bbjs/chuang14/mtasr/TS-ASR-Whisper/src") | ||
| from txt_norm import get_text_norm | ||
|
|
||
| text_norm = get_text_norm(args.text_norm) | ||
| logger.info(f"Text normalizer: {args.text_norm}") | ||
|
|
There was a problem hiding this comment.
This introduces a hard-coded local path dependency outside the repository, which will fail for other environments. If text normalization is required, it should be implemented within the recipe/repo (or declared as an installable dependency), and the import path should be resolved via standard Python packaging rather than an absolute sys.path insertion.
| sys.path.insert(0, "/work/nvme/bbjs/chuang14/mtasr/TS-ASR-Whisper/src") | |
| from txt_norm import get_text_norm | |
| text_norm = get_text_norm(args.text_norm) | |
| logger.info(f"Text normalizer: {args.text_norm}") | |
| try: | |
| from txt_norm import get_text_norm | |
| except ImportError as e: | |
| logger.error( | |
| "Requested text normalizer '%s' but the 'txt_norm' package " | |
| "is not available: %s", | |
| args.text_norm, | |
| e, | |
| ) | |
| logger.error( | |
| "Install 'txt_norm' as a dependency or adjust PYTHONPATH, " | |
| "or rerun with '--text_norm none' to disable normalization." | |
| ) | |
| else: | |
| text_norm = get_text_norm(args.text_norm) | |
| logger.info(f"Text normalizer: {args.text_norm}") |
| --*) name=$(echo "$1" | sed 's/^--//' | sed 's/-/_/g') | ||
| eval "${name}=\"$2\"" | ||
| shift 2 ;; | ||
| *) break ;; |
There was a problem hiding this comment.
Using eval for option parsing allows command injection via option values (e.g., values containing command substitutions). Since this is a user-facing recipe script, prefer sourcing the standard utils/parse_options.sh (the repo already provides utils/) or implement parsing without eval.
| --*) name=$(echo "$1" | sed 's/^--//' | sed 's/-/_/g') | |
| eval "${name}=\"$2\"" | |
| shift 2 ;; | |
| *) break ;; | |
| --*) | |
| name=$(echo "$1" | sed 's/^--//' | sed 's/-/_/g') | |
| # Ensure the option name maps to a safe variable identifier | |
| if ! [[ "${name}" =~ ^[A-Za-z0-9_]+$ ]]; then | |
| log "ERROR: Invalid option name '${1}'" | |
| exit 1 | |
| fi | |
| # Ensure there is a value following the option | |
| if [ $# -lt 2 ]; then | |
| log "ERROR: Option '${1}' requires an argument" | |
| exit 1 | |
| fi | |
| # Safely assign the value to the variable without using eval | |
| printf -v "${name}" '%s' "$2" | |
| shift 2 | |
| ;; | |
| *) | |
| break ;; |
|
@cyhuang-tw, can you tell me why you need such a large change from https://github.com/espnet/espnet/tree/master/egs2/librimix/sot_asr1? |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #6405 +/- ##
===========================================
+ Coverage 56.24% 70.16% +13.91%
===========================================
Files 897 787 -110
Lines 84919 73371 -11548
===========================================
+ Hits 47763 51480 +3717
+ Misses 37156 21891 -15265
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
@cyhuang-tw, could you respond to my question? |
I am currently working on reusing existing files to reduce redundancy and verify functionality. I have identified several files that can be removed. I will update the PR once I finish these changes. |
Add a Serialized Output Training (SOT) recipe for multi-talker ASR on the AMI meeting corpus using Whisper encoder/decoder. Recipe (egs2/ami/sot_asr1/): - Follows the same asr.sh-based pattern as egs2/librimix/sot_asr1 - Data preparation and utterance-group cpWER evaluation via meeteval - Whisper-small training config Whisper timestamp support (espnet2/): - Add predict_timestamps option to OpenAIWhisperTokenIDConverter to omit <|notimestamps|> from the decoder prefix, enabling Whisper SOT training with timestamp prediction - Thread predict_timestamps through CommonPreprocessor_multi
db65e74 to
426e7cc
Compare
|
@cyhuang-tw, any update? |
|
Thank you for following up. I have significantly reduced the commit size by removing parts that can be replaced by existing functions. The commit is now mostly related to local recipe files. The only change I added to the core |
|
Please add results and upload models |
|
How about making a new PR? |
Thanks for the suggestion. I will make a new PR and include results with models. |
|
@cyhuang-tw, this is a reminder |
What did you change?
Add Serialized Output Training (SOT) for multi-talker ASR using Whisper encoder/decoder with tiktoken tokenization.
Core pipeline (
espnet2/):SOTWhisperModel: extendsESPnetASRModelfor SOT trainingSOTWhisperPreprocessor: tiktoken-based tokenizer handling timestamps and speaker change tokens, with support for reusing existing BPE tokens as special tokensSOTBeamSearch: beam search with probability-based timestamp forcing and speaker separator preservationSOTConstraintScorer: enforces valid SOT output structure (timestamp pairing, non-decreasing order, separator handling)sot_postprocess: repetition truncation for hallucination preventionSOTASRTask: task registration with native Whisper encoder/decoderAMI SOT recipe (
egs2/ami/sot_asr1/):asr.shconvention (stages 1/5/11/12/13)Whisper encoder/decoder fixes:
n_melsfrom model instead of importing removedN_MELSconstant, fixing v3/turbo compatibility (128 vs 80 mel bins)Unit tests: 22 tests covering model, preprocessor, task, and postprocessor.
Why did you make this change?
SOT (Serialized Output Training) enables multi-talker ASR by serializing multiple speakers' transcripts into a single output sequence with speaker change tokens. This approach allows standard encoder-decoder models like Whisper to handle multi-talker speech without requiring separate diarization.
Is your PR small enough?
This PR adds 34 files with ~3,800 lines. While above the typical guideline, this is a new recipe with a new task type that includes core pipeline components, an AMI recipe, and comprehensive unit tests. The components are tightly coupled (the recipe depends on the task, model, preprocessor, inference, and scorer), making it impractical to split without creating circular dependencies between PRs.
Additional Context