Skip to content

Conversation

@misrasaurabh1
Copy link
Contributor

@misrasaurabh1 misrasaurabh1 commented Jul 3, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

Incremental decoding can be slow, according to an original comment on this code section. I tried optimizing it with Codeflash (the automated performance optimization tool I'm building) and found this optimization. I verified it with the below benchmark, and generated many tests to ensure correctness.

⏱️ Runtime : 0.28 seconds 0.019 seconds (per 1000 loops)

The main optimization is to not call the function tokenizer.get_added_vocab() in a loop as from line profiling it took a large amount of time. tokenizer.all_special_tokens was also conditionally made into a set, only when required. tokenizer.convert_tokens_to_string was also localized to improve the function calling time in a hot loop.

Test Plan

Benchmarking script I wrote manually to verify performance gains. The original implementation is the _slow function -

from vllm.transformers_utils.detokenizer_utils import _convert_tokens_to_string_with_added_encoders, _convert_tokens_to_string_with_added_encoders_slow
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Embedding-0.6B")

tokens = tokenizer.tokenize("Hello world! Isn't this a beautiful day? I'm so glad I'm alive! I like to go for a walk in the park.")

import timeit

time_taken = timeit.timeit(lambda: _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, skip_special_tokens=False, spaces_between_special_tokens=False), number=1000)
print(f"Fast skip_special_tokens=False Average time taken: {time_taken:.6f}")

time_taken = timeit.timeit(lambda: _convert_tokens_to_string_with_added_encoders_slow(tokenizer, tokens, skip_special_tokens=False, spaces_between_special_tokens=False), number=1000)
print(f"Slow skip_special_tokens=False Average time taken: {time_taken:.6f}")

time_taken = timeit.timeit(lambda: _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, skip_special_tokens=True, spaces_between_special_tokens=False), number=1000)
print(f"Fast skip_special_tokens=True Average time taken: {time_taken:.6f}")

time_taken = timeit.timeit(lambda: _convert_tokens_to_string_with_added_encoders_slow(tokenizer, tokens, skip_special_tokens=True, spaces_between_special_tokens=False), number=1000)
print(f"Slow skip_special_tokens=True Average time taken: {time_taken:.6f}")

I also verified correctness by generating and running the below regression tests. The behavior across the 2 functions were exactly the same.

🌀 Generated Regression Tests Summary
import pytest
from vllm.transformers_utils.detokenizer_utils import \
    _convert_tokens_to_string_with_added_encoders

# --- Minimal mock tokenizer for testing ---

class DummyTokenizer:
    """
    A minimal tokenizer mock to support the required interface for _convert_tokens_to_string_with_added_encoders.
    Allows configuration of added vocab and special tokens.
    """
    def __init__(self, added_vocab=None, special_tokens=None, join_with=" "):
        # added_vocab: dict mapping token string to index (e.g. {"<foo>": 1000})
        # special_tokens: iterable of special token strings (e.g. ["<s>", "</s>"])
        self._added_vocab = added_vocab or {}
        self.all_special_tokens = set(special_tokens or [])
        self._join_with = join_with

    def get_added_vocab(self):
        return self._added_vocab

    def convert_tokens_to_string(self, tokens):
        # Joins tokens with the configured separator
        return self._join_with.join(tokens)
from vllm.transformers_utils.detokenizer_utils import \
    _convert_tokens_to_string_with_added_encoders

# --- Unit tests ---

# -------------------- Basic Test Cases --------------------

def test_basic_no_added_no_special_tokens():
    # No added vocab, no special tokens, basic join
    tokenizer = DummyTokenizer()
    tokens = ["hello", "world"]
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, False, False); result = codeflash_output # 2.75μs -> 2.91μs (5.47% slower)

def test_basic_with_added_token():
    # Added vocab in the middle, should appear as-is, with correct join
    tokenizer = DummyTokenizer(added_vocab={"<foo>": 100})
    tokens = ["hello", "<foo>", "world"]
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, False, False); result = codeflash_output # 3.37μs -> 3.79μs (11.1% slower)

def test_basic_with_special_tokens_skip_false():
    # Special tokens present, skip_special_tokens=False
    tokenizer = DummyTokenizer(special_tokens=["<s>", "</s>"])
    tokens = ["<s>", "hello", "world", "</s>"]
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, False, False); result = codeflash_output # 3.07μs -> 2.97μs (3.40% faster)

def test_basic_with_special_tokens_skip_true():
    # Special tokens present, skip_special_tokens=True
    tokenizer = DummyTokenizer(special_tokens=["<s>", "</s>"])
    tokens = ["<s>", "hello", "world", "</s>"]
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, True, False); result = codeflash_output # 2.80μs -> 3.24μs (13.6% slower)

def test_basic_added_and_special_tokens():
    # Both added vocab and special tokens
    tokenizer = DummyTokenizer(added_vocab={"<foo>": 1}, special_tokens=["<s>", "</s>"])
    tokens = ["<s>", "hello", "<foo>", "world", "</s>"]
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, False, False); result = codeflash_output # 4.05μs -> 4.24μs (4.48% slower)

def test_basic_added_and_special_tokens_skip_special():
    # Both added vocab and special tokens, skip_special_tokens=True
    tokenizer = DummyTokenizer(added_vocab={"<foo>": 1}, special_tokens=["<s>", "</s>"])
    tokens = ["<s>", "hello", "<foo>", "world", "</s>"]
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, True, False); result = codeflash_output # 3.26μs -> 3.96μs (17.7% slower)

def test_basic_spaces_between_special_tokens():
    # spaces_between_special_tokens=True should join sub_texts with spaces
    tokenizer = DummyTokenizer(added_vocab={"<foo>": 1}, special_tokens=["<s>", "</s>"])
    tokens = ["<s>", "hello", "<foo>", "world", "</s>"]
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, False, True); result = codeflash_output # 3.04μs -> 4.39μs (30.8% slower)

def test_basic_spaces_between_special_tokens_skip_special():
    # spaces_between_special_tokens=True and skip_special_tokens=True
    tokenizer = DummyTokenizer(added_vocab={"<foo>": 1}, special_tokens=["<s>", "</s>"])
    tokens = ["<s>", "hello", "<foo>", "world", "</s>"]
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, True, True); result = codeflash_output # 3.27μs -> 4.32μs (24.3% slower)

def test_basic_multiple_added_tokens():
    # Multiple added tokens, should split correctly
    tokenizer = DummyTokenizer(added_vocab={"<foo>": 1, "<bar>": 2})
    tokens = ["a", "<foo>", "b", "<bar>", "c"]
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, False, False); result = codeflash_output # 3.14μs -> 3.79μs (17.2% slower)

def test_basic_multiple_added_tokens_with_spaces():
    # Multiple added tokens, spaces_between_special_tokens=True
    tokenizer = DummyTokenizer(added_vocab={"<foo>": 1, "<bar>": 2})
    tokens = ["a", "<foo>", "b", "<bar>", "c"]
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, False, True); result = codeflash_output # 3.47μs -> 4.36μs (20.4% slower)

# -------------------- Edge Test Cases --------------------

def test_edge_empty_tokens_list():
    # Empty tokens list should return empty string
    tokenizer = DummyTokenizer()
    tokens = []
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, False, False); result = codeflash_output # 1.42μs -> 2.19μs (35.2% slower)

def test_edge_all_special_tokens_skipped():
    # All tokens are special and skip_special_tokens=True
    tokenizer = DummyTokenizer(special_tokens=["<s>", "</s>"])
    tokens = ["<s>", "</s>", "<s>", "</s>"]
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, True, False); result = codeflash_output # 1.62μs -> 2.62μs (38.2% slower)

def test_edge_all_added_tokens():
    # All tokens are added tokens
    tokenizer = DummyTokenizer(added_vocab={"<foo>": 1, "<bar>": 2})
    tokens = ["<foo>", "<bar>", "<foo>"]
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, False, False); result = codeflash_output # 2.20μs -> 3.00μs (26.6% slower)

def test_edge_added_and_special_overlap():
    # Token is both in added_vocab and special_tokens
    tokenizer = DummyTokenizer(added_vocab={"<foo>": 1}, special_tokens=["<foo>"])
    tokens = ["<foo>", "a", "<foo>"]
    # skip_special_tokens=True, should skip <foo> even if it's in added_vocab
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, True, False); result = codeflash_output # 2.35μs -> 3.09μs (23.9% slower)
    # skip_special_tokens=False, should treat <foo> as added token
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, False, False); result2 = codeflash_output # 1.44μs -> 1.77μs (18.5% slower)

def test_edge_custom_joiner():
    # Tokenizer joins with a custom character
    tokenizer = DummyTokenizer(join_with="-")
    tokens = ["a", "b", "c"]
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, False, False); result = codeflash_output # 2.45μs -> 3.07μs (20.2% slower)

def test_edge_consecutive_added_tokens():
    # Consecutive added tokens should not produce empty sub_texts
    tokenizer = DummyTokenizer(added_vocab={"<foo>": 1, "<bar>": 2})
    tokens = ["a", "<foo>", "<bar>", "b"]
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, False, False); result = codeflash_output # 2.97μs -> 3.96μs (25.0% slower)

def test_edge_added_token_at_start_and_end():
    tokenizer = DummyTokenizer(added_vocab={"<foo>": 1})
    tokens = ["<foo>", "a", "b", "<foo>"]
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, False, False); result = codeflash_output # 2.88μs -> 3.57μs (19.3% slower)

def test_edge_tokenizer_with_no_added_vocab_method():
    # Tokenizer without get_added_vocab method should raise AttributeError
    class IncompleteTokenizer:
        def __init__(self):
            self.all_special_tokens = set()
        def convert_tokens_to_string(self, tokens):
            return " ".join(tokens)
    tokenizer = IncompleteTokenizer()
    tokens = ["a", "b"]
    with pytest.raises(AttributeError):
        _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, False, False)


def test_edge_added_token_is_empty_string():
    # Added token is empty string, should be skipped in output
    tokenizer = DummyTokenizer(added_vocab={"": 1})
    tokens = ["a", "", "b"]
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, False, False); result = codeflash_output # 2.97μs -> 3.48μs (14.7% slower)

# -------------------- Large Scale Test Cases --------------------

def test_large_many_tokens_no_added():
    # 1000 tokens, no added vocab
    tokenizer = DummyTokenizer()
    tokens = [f"tok{i}" for i in range(1000)]
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, False, False); result = codeflash_output # 72.2μs -> 45.7μs (58.0% faster)

def test_large_many_tokens_with_added_every_100():
    # 1000 tokens, every 100th is an added token
    added = {f"<added{i}>": i for i in range(10)}
    tokenizer = DummyTokenizer(added_vocab=added)
    tokens = []
    for i in range(1000):
        if i % 100 == 0:
            tokens.append(f"<added{i//100}>")
        else:
            tokens.append(f"tok{i}")
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, False, False); result = codeflash_output
    # Should split at each added token
    # For i in 0..9: the segment after <added{i}> is 99 tokens, joined with space
    expected = []
    idx = 0
    while idx < 1000:
        # added token
        expected.append(tokens[idx])
        # next 99 tokens
        if idx + 1 < 1000:
            segment = tokens[idx+1:idx+100]
            if segment:
                expected.append(" ".join(segment))
        idx += 100

def test_large_all_added_tokens():
    # 1000 tokens, all added
    added = {f"<added{i}>": i for i in range(1000)}
    tokenizer = DummyTokenizer(added_vocab=added)
    tokens = [f"<added{i}>" for i in range(1000)]
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, False, False); result = codeflash_output # 91.0μs -> 85.0μs (7.07% faster)

def test_large_skip_special_tokens():
    # 1000 tokens, every 10th is special, skip_special_tokens=True
    special = {f"<spec{i}>": i for i in range(100)}
    tokenizer = DummyTokenizer(special_tokens=list(special.keys()))
    tokens = []
    for i in range(1000):
        if i % 10 == 0:
            tokens.append(f"<spec{i//10}>")
        else:
            tokens.append(f"tok{i}")
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, True, False); result = codeflash_output
    # Should skip all special tokens
    filtered = [t for t in tokens if t not in special]

def test_large_spaces_between_special_tokens():
    # 1000 tokens, every 100th is an added token, spaces_between_special_tokens=True
    added = {f"<added{i}>": i for i in range(10)}
    tokenizer = DummyTokenizer(added_vocab=added)
    tokens = []
    for i in range(1000):
        if i % 100 == 0:
            tokens.append(f"<added{i//100}>")
        else:
            tokens.append(f"tok{i}")
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, False, True); result = codeflash_output
    # Should split at each added token, join sub_texts with spaces
    expected = []
    idx = 0
    while idx < 1000:
        expected.append(tokens[idx])
        if idx + 1 < 1000:
            segment = tokens[idx+1:idx+100]
            if segment:
                expected.append(" ".join(segment))
        idx += 100

def test_large_alternating_added_and_special():
    # 500 tokens, alternating added and special tokens
    added = {f"<added{i}>": i for i in range(250)}
    special = [f"<spec{i}>" for i in range(250)]
    tokenizer = DummyTokenizer(added_vocab=added, special_tokens=special)
    tokens = []
    for i in range(250):
        tokens.append(f"<added{i}>")
        tokens.append(f"<spec{i}>")
    # skip_special_tokens=True should remove all special tokens
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, True, False); result = codeflash_output
    # Only added tokens remain
    expected = "".join([f"<added{i}>" for i in range(250)])

def test_large_performance():
    # This test is to check that the function completes in reasonable time for 1000 tokens
    import time
    tokenizer = DummyTokenizer(added_vocab={f"<added{i}>": i for i in range(10)})
    tokens = []
    for i in range(1000):
        if i % 100 == 0:
            tokens.append(f"<added{i//100}>")
        else:
            tokens.append(f"tok{i}")
    start = time.time()
    _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, False, False)
    duration = time.time() - start
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

import pytest
from vllm.transformers_utils.detokenizer_utils import \
    _convert_tokens_to_string_with_added_encoders

# function to test (already provided above, not repeated here for clarity)

# --- Minimal mock tokenizer for testing ---

class DummyTokenizer:
    """
    Minimal mock tokenizer for testing _convert_tokens_to_string_with_added_encoders.
    Allows configuring:
      - all_special_tokens: set of special tokens
      - added_vocab: dict mapping added token string to index
      - convert_tokens_to_string: method to join tokens (can be overridden)
    """
    def __init__(self, all_special_tokens=None, added_vocab=None, join_with=" "):
        self.all_special_tokens = set(all_special_tokens) if all_special_tokens else set()
        self._added_vocab = dict(added_vocab) if added_vocab else {}
        self._join_with = join_with

    def get_added_vocab(self):
        return dict(self._added_vocab)

    def convert_tokens_to_string(self, tokens):
        # Join tokens with the configured separator
        return self._join_with.join(tokens)

# --- Unit Tests ---

# 1. BASIC TEST CASES

def test_basic_no_added_tokens_no_special_tokens():
    # No added tokens, no special tokens, just join tokens normally
    tokenizer = DummyTokenizer()
    tokens = ["hello", "world"]
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, skip_special_tokens=False, spaces_between_special_tokens=True); result = codeflash_output # 3.79μs -> 4.31μs (12.1% slower)

def test_basic_with_added_token():
    # Added token in the middle, should be split out
    tokenizer = DummyTokenizer(added_vocab={"<NEW>": 100})
    tokens = ["hello", "<NEW>", "world"]
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, skip_special_tokens=False, spaces_between_special_tokens=True); result = codeflash_output # 4.21μs -> 4.51μs (6.63% slower)

def test_basic_with_multiple_added_tokens():
    # Multiple added tokens, should split each out
    tokenizer = DummyTokenizer(added_vocab={"<A>": 1, "<B>": 2})
    tokens = ["<A>", "foo", "<B>", "bar", "<A>"]
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, skip_special_tokens=False, spaces_between_special_tokens=True); result = codeflash_output # 5.06μs -> 4.06μs (24.6% faster)

def test_basic_skip_special_tokens():
    # Should skip special tokens if skip_special_tokens=True
    tokenizer = DummyTokenizer(all_special_tokens={"<SEP>", "<PAD>"})
    tokens = ["hello", "<SEP>", "world", "<PAD>"]
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, skip_special_tokens=True, spaces_between_special_tokens=True); result = codeflash_output # 3.64μs -> 4.48μs (18.8% slower)

def test_basic_spaces_between_special_tokens_false():
    # Should join with no spaces if spaces_between_special_tokens=False
    tokenizer = DummyTokenizer(added_vocab={"<A>": 1, "<B>": 2})
    tokens = ["foo", "<A>", "bar", "<B>", "baz"]
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, skip_special_tokens=False, spaces_between_special_tokens=False); result = codeflash_output # 4.87μs -> 4.84μs (0.620% faster)

def test_basic_added_token_at_edges():
    # Added token at start and end
    tokenizer = DummyTokenizer(added_vocab={"<X>": 1})
    tokens = ["<X>", "foo", "bar", "<X>"]
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, skip_special_tokens=False, spaces_between_special_tokens=True); result = codeflash_output # 4.45μs -> 4.26μs (4.46% faster)

def test_basic_added_token_repeated():
    # Repeated added tokens should always be split out
    tokenizer = DummyTokenizer(added_vocab={"<A>": 1})
    tokens = ["foo", "<A>", "<A>", "bar"]
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, skip_special_tokens=False, spaces_between_special_tokens=True); result = codeflash_output # 4.42μs -> 4.39μs (0.683% faster)

# 2. EDGE TEST CASES

def test_empty_token_list():
    # Empty input should return empty string
    tokenizer = DummyTokenizer()
    tokens = []
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, skip_special_tokens=False, spaces_between_special_tokens=True); result = codeflash_output # 1.73μs -> 2.84μs (39.1% slower)

def test_all_added_tokens():
    # All tokens are added tokens
    tokenizer = DummyTokenizer(added_vocab={"<A>": 1, "<B>": 2})
    tokens = ["<A>", "<B>", "<A>"]
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, skip_special_tokens=False, spaces_between_special_tokens=True); result = codeflash_output # 3.36μs -> 3.57μs (5.88% slower)

def test_all_special_tokens_skip_all():
    # All tokens are special, skip_special_tokens=True, should return empty string
    tokenizer = DummyTokenizer(all_special_tokens={"<SEP>", "<PAD>"})
    tokens = ["<SEP>", "<PAD>", "<SEP>"]
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, skip_special_tokens=True, spaces_between_special_tokens=True); result = codeflash_output # 1.78μs -> 3.34μs (46.7% slower)

def test_special_and_added_tokens_overlap():
    # Token is both special and added, should be treated as added token (not skipped)
    tokenizer = DummyTokenizer(all_special_tokens={"<A>"}, added_vocab={"<A>": 1})
    tokens = ["foo", "<A>", "bar"]
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, skip_special_tokens=True, spaces_between_special_tokens=True); result = codeflash_output # 3.82μs -> 4.23μs (9.69% slower)

def test_added_token_is_empty_string():
    # Added token is empty string (should be ignored as token)
    tokenizer = DummyTokenizer(added_vocab={"": 1})
    tokens = ["foo", "", "bar"]
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, skip_special_tokens=False, spaces_between_special_tokens=True); result = codeflash_output # 4.20μs -> 4.04μs (3.96% faster)

def test_tokenizer_with_custom_join():
    # Tokenizer joins with underscores
    tokenizer = DummyTokenizer(join_with="_")
    tokens = ["a", "b", "c"]
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, skip_special_tokens=False, spaces_between_special_tokens=True); result = codeflash_output # 3.70μs -> 3.95μs (6.30% slower)

def test_added_token_adjacent_to_normal():
    # Added token adjacent to normal tokens
    tokenizer = DummyTokenizer(added_vocab={"<X>": 1})
    tokens = ["foo", "<X>", "bar"]
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, skip_special_tokens=False, spaces_between_special_tokens=True); result = codeflash_output # 4.37μs -> 4.22μs (3.58% faster)

def test_multiple_consecutive_added_tokens():
    tokenizer = DummyTokenizer(added_vocab={"<A>": 1, "<B>": 2})
    tokens = ["foo", "<A>", "<B>", "bar"]
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, skip_special_tokens=False, spaces_between_special_tokens=True); result = codeflash_output # 4.50μs -> 4.46μs (0.897% faster)

def test_skip_special_tokens_and_added_tokens():
    # Added tokens are not skipped, only special tokens are
    tokenizer = DummyTokenizer(all_special_tokens={"<S>"}, added_vocab={"<A>": 1})
    tokens = ["foo", "<A>", "<S>", "bar"]
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, skip_special_tokens=True, spaces_between_special_tokens=True); result = codeflash_output # 4.48μs -> 4.37μs (2.52% faster)

def test_no_space_between_special_tokens_false_and_true():
    # Test both values for spaces_between_special_tokens
    tokenizer = DummyTokenizer(added_vocab={"<A>": 1})
    tokens = ["foo", "<A>", "bar"]
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, skip_special_tokens=False, spaces_between_special_tokens=True); result1 = codeflash_output # 4.24μs -> 4.05μs (4.69% faster)
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, skip_special_tokens=False, spaces_between_special_tokens=False); result2 = codeflash_output # 1.88μs -> 1.64μs (14.6% faster)

def test_skip_special_tokens_with_added_token_that_is_special():
    # If a token is both added and special, added wins (not skipped)
    tokenizer = DummyTokenizer(all_special_tokens={"<A>"}, added_vocab={"<A>": 1})
    tokens = ["foo", "<A>", "bar"]
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, skip_special_tokens=True, spaces_between_special_tokens=True); result = codeflash_output # 3.33μs -> 4.14μs (19.6% slower)

# 3. LARGE SCALE TEST CASES

def test_large_input_no_added_tokens():
    # Large input, no added tokens, should just join all
    tokenizer = DummyTokenizer()
    tokens = [f"tok{i}" for i in range(1000)]
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, skip_special_tokens=False, spaces_between_special_tokens=True); result = codeflash_output # 102μs -> 45.8μs (123% faster)

def test_large_input_with_added_tokens_every_10():
    # Added token every 10th position
    added_token = "<ADDED>"
    tokenizer = DummyTokenizer(added_vocab={added_token: 1})
    tokens = []
    expected = []
    for i in range(1000):
        if i % 10 == 0:
            tokens.append(added_token)
            expected.append(added_token)
        else:
            tokens.append(f"tok{i}")
    # Build expected output
    # For every sequence of normal tokens between added tokens, join with space
    sub_texts = []
    curr = []
    for t in tokens:
        if t == added_token:
            if curr:
                sub_texts.append(" ".join(curr))
                curr = []
            sub_texts.append(added_token)
        else:
            curr.append(t)
    if curr:
        sub_texts.append(" ".join(curr))
    expected_str = " ".join(sub_texts)
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, skip_special_tokens=False, spaces_between_special_tokens=True); result = codeflash_output

def test_large_input_with_special_tokens_skipped():
    # Large input, every 5th token is special, skip them
    special_token = "<SPECIAL>"
    tokenizer = DummyTokenizer(all_special_tokens={special_token})
    tokens = []
    expected = []
    for i in range(1000):
        if i % 5 == 0:
            tokens.append(special_token)
        else:
            tokens.append(f"tok{i}")
            expected.append(f"tok{i}")
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, skip_special_tokens=True, spaces_between_special_tokens=True); result = codeflash_output

def test_large_input_with_added_and_special_tokens():
    # Large input, both added and special tokens
    added_token = "<ADDED>"
    special_token = "<SPECIAL>"
    tokenizer = DummyTokenizer(all_special_tokens={special_token}, added_vocab={added_token: 1})
    tokens = []
    expected_sub_texts = []
    curr = []
    for i in range(1000):
        if i % 13 == 0:
            tokens.append(added_token)
            if curr:
                expected_sub_texts.append(" ".join(curr))
                curr = []
            expected_sub_texts.append(added_token)
        elif i % 7 == 0:
            tokens.append(special_token)
            # skip_special_tokens=True, so skip in output
        else:
            tokens.append(f"tok{i}")
            curr.append(f"tok{i}")
    if curr:
        expected_sub_texts.append(" ".join(curr))
    expected_str = " ".join(expected_sub_texts)
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, skip_special_tokens=True, spaces_between_special_tokens=True); result = codeflash_output

def test_large_input_no_spaces_between_special_tokens():
    # Large input, spaces_between_special_tokens=False
    added_token = "<ADDED>"
    tokenizer = DummyTokenizer(added_vocab={added_token: 1})
    tokens = []
    for i in range(1000):
        if i % 100 == 0:
            tokens.append(added_token)
        else:
            tokens.append(f"tok{i}")
    # Build expected output
    sub_texts = []
    curr = []
    for t in tokens:
        if t == added_token:
            if curr:
                sub_texts.append(" ".join(curr))
                curr = []
            sub_texts.append(added_token)
        else:
            curr.append(t)
    if curr:
        sub_texts.append(" ".join(curr))
    expected_str = "".join(sub_texts)
    codeflash_output = _convert_tokens_to_string_with_added_encoders(tokenizer, tokens, skip_special_tokens=False, spaces_between_special_tokens=False); result = codeflash_output
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

Test Result

Benchmarking Results -

Fast skip_special_tokens=False Average time taken: 0.019143
Slow skip_special_tokens=False Average time taken: 0.281349
Fast skip_special_tokens=True Average time taken: 0.019412
Slow skip_special_tokens=True Average time taken: 0.280456

Correctness verification report:

Test Status
🌀 Generated Regression Tests 51 Passed
📊 Tests Coverage 100.0%

(Optional) Documentation Update

…by 60%

Here is an optimized rewrite of your program focusing mainly on _convert_tokens_to_string_with_added_encoders, which (according to your profiling) is the bottleneck.
- Cache repeated lookups (such as tokenizer.get_added_vocab() and tokenizer.all_special_tokens).
- Avoid unnecessary list[str] allocations inside hot loops.
- Use local variables for repeated attribute lookups.
- Precompute set of added tokens (O(1) lookups), and convert all_special_tokens to a set only once.
- Convert current_sub_text to string only when strictly necessary.
- Append/clear efficiently.

Other parts (like class PreTrainedTokenizer and TokenizerBase) are left unchanged, as they're not relevant for performance.  
All docstrings and copyright notices preserved.



**Summary of optimizations:**
- Compute `set(tokenizer.all_special_tokens)` and `set(tokenizer.get_added_vocab())` once before the loop, **not** inside.
- Use `current_sub_text.clear()` instead of allocating a new list.
- Store `convert_tokens_to_string` to a local variable to avoid repeated attribute lookups.
- Avoid unnecessary creation of empty sets or dicts.
- Use tuple `()` instead of set for skip_special_tokens=False for speed (since membership test will always fail).

This will provide a meaningful speedup for the profiled bottleneck.
@github-actions
Copy link

github-actions bot commented Jul 3, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @misrasaurabh1, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a substantial performance optimization to the _convert_tokens_to_string_with_added_encoders function, which is a critical component in the detokenization process. The changes primarily focus on reducing redundant attribute lookups and pre-computing frequently accessed tokenizer properties, leading to a significant speedup in incremental decoding.

Highlights

  • Performance Optimization: The core _convert_tokens_to_string_with_added_encoders function has been significantly optimized, achieving a reported 13.7x speedup in detokenization, crucial for incremental decoding.
  • Caching Tokenizer Properties and Methods: To reduce overhead in the hot loop, the tokenizer.convert_tokens_to_string method is now localized to a variable, and the result of tokenizer.get_added_vocab() is pre-computed into a set (added_vocab_set) before iteration. Additionally, tokenizer.all_special_tokens is conditionally converted to a set only when skip_special_tokens is true, otherwise an empty tuple is used.
  • List Manipulation Efficiency: The current_sub_text list is now cleared using current_sub_text.clear() instead of reassigning an empty list, which can offer minor efficiency gains.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request optimizes the _convert_tokens_to_string_with_added_encoders function, significantly improving performance. The changes involve caching vocabulary lookups and localizing method calls, supported by benchmarks and regression tests. A suggestion is provided to enhance code conciseness.

Comment on lines 48 to +50
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider using str.join with a conditional expression for conciseness. This avoids the if/else block and directly returns the joined string based on the condition.

    return " ".join(sub_texts) if spaces_between_special_tokens else "".join(sub_texts)

@mgoin mgoin requested a review from njhill July 3, 2025 21:59
Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @misrasaurabh1, looks good to me.

I'll note that this path should no longer be very commonly used anyhow since it applies only to the "slow" incremental detokenizer case, the majority of tokenizers will support fast incremental detokenization (see https://github.com/vllm-project/vllm/blob/main/vllm/v1/engine/detokenizer.py).

@njhill
Copy link
Member

njhill commented Jul 7, 2025

@misrasaurabh1 please also sign-off your commits per https://github.com/vllm-project/vllm/pull/20413/checks?check_run_id=45267543355, thanks!

Signed-off-by: Saurabh Misra <misra.saurabh1@gmail.com>
Signed-off-by: Saurabh Misra <misra.saurabh1@gmail.com>
@misrasaurabh1 misrasaurabh1 force-pushed the codeflash/optimize-_convert_tokens_to_string_with_added_encoders-mcmk4l63 branch from 29705e9 to 70c5640 Compare July 16, 2025 01:14
Co-authored-by: Nick Hill <nhill@redhat.com>
Signed-off-by: Saurabh Misra <misra.saurabh1@gmail.com>
@misrasaurabh1 misrasaurabh1 force-pushed the codeflash/optimize-_convert_tokens_to_string_with_added_encoders-mcmk4l63 branch from 04d54d1 to 9a72549 Compare July 16, 2025 01:17
@misrasaurabh1 misrasaurabh1 requested a review from njhill July 16, 2025 01:17
@misrasaurabh1
Copy link
Contributor Author

Hi @njhill I fixed the issue above, this should now be ready to be merged

@njhill
Copy link
Member

njhill commented Jul 16, 2025

Thanks @misrasaurabh1. Please also fix the pre-commit linter errors.

Signed-off-by: Aseem Saxena <aseem.bits@gmail.com>
@aseembits93
Copy link
Contributor

@njhill we have fixed the pre-commit linter errors, it's ready to merge!

@njhill
Copy link
Member

njhill commented Aug 19, 2025

Thanks @misrasaurabh1 @aseembits93

@njhill njhill added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 19, 2025
@njhill njhill merged commit bf7c99d into vllm-project:main Aug 20, 2025
36 checks passed
djmmoss pushed a commit to djmmoss/vllm that referenced this pull request Aug 21, 2025
…rs` by 13.7x (vllm-project#20413)

Signed-off-by: Saurabh Misra <misra.saurabh1@gmail.com>
Signed-off-by: Aseem Saxena <aseem.bits@gmail.com>
Co-authored-by: codeflash-ai[bot] <148906541+codeflash-ai[bot]@users.noreply.github.com>
Co-authored-by: Aseem Saxena <aseem.bits@gmail.com>
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
…rs` by 13.7x (vllm-project#20413)

Signed-off-by: Saurabh Misra <misra.saurabh1@gmail.com>
Signed-off-by: Aseem Saxena <aseem.bits@gmail.com>
Co-authored-by: codeflash-ai[bot] <148906541+codeflash-ai[bot]@users.noreply.github.com>
Co-authored-by: Aseem Saxena <aseem.bits@gmail.com>
xiao-llm pushed a commit to xiao-llm/vllm that referenced this pull request Aug 28, 2025
…rs` by 13.7x (vllm-project#20413)

Signed-off-by: Saurabh Misra <misra.saurabh1@gmail.com>
Signed-off-by: Aseem Saxena <aseem.bits@gmail.com>
Co-authored-by: codeflash-ai[bot] <148906541+codeflash-ai[bot]@users.noreply.github.com>
Co-authored-by: Aseem Saxena <aseem.bits@gmail.com>
Signed-off-by: Xiao Yu <xiao.yu@amd.com>
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
…rs` by 13.7x (vllm-project#20413)

Signed-off-by: Saurabh Misra <misra.saurabh1@gmail.com>
Signed-off-by: Aseem Saxena <aseem.bits@gmail.com>
Co-authored-by: codeflash-ai[bot] <148906541+codeflash-ai[bot]@users.noreply.github.com>
Co-authored-by: Aseem Saxena <aseem.bits@gmail.com>
mengxingkongzhouhan pushed a commit to mengxingkongzhouhan/vllm that referenced this pull request Aug 30, 2025
…rs` by 13.7x (vllm-project#20413)

Signed-off-by: Saurabh Misra <misra.saurabh1@gmail.com>
Signed-off-by: Aseem Saxena <aseem.bits@gmail.com>
Co-authored-by: codeflash-ai[bot] <148906541+codeflash-ai[bot]@users.noreply.github.com>
Co-authored-by: Aseem Saxena <aseem.bits@gmail.com>
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Sep 3, 2025
…rs` by 13.7x (vllm-project#20413)

Signed-off-by: Saurabh Misra <misra.saurabh1@gmail.com>
Signed-off-by: Aseem Saxena <aseem.bits@gmail.com>
Co-authored-by: codeflash-ai[bot] <148906541+codeflash-ai[bot]@users.noreply.github.com>
Co-authored-by: Aseem Saxena <aseem.bits@gmail.com>
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
…rs` by 13.7x (vllm-project#20413)

Signed-off-by: Saurabh Misra <misra.saurabh1@gmail.com>
Signed-off-by: Aseem Saxena <aseem.bits@gmail.com>
Co-authored-by: codeflash-ai[bot] <148906541+codeflash-ai[bot]@users.noreply.github.com>
Co-authored-by: Aseem Saxena <aseem.bits@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants