Skip to content

Add KVzipPress#93

Merged
maxjeblick merged 13 commits intoNVIDIA:mainfrom
Janghyun1230:main
Jul 25, 2025
Merged

Add KVzipPress#93
maxjeblick merged 13 commits intoNVIDIA:mainfrom
Janghyun1230:main

Conversation

@Janghyun1230
Copy link
Contributor

@Janghyun1230 Janghyun1230 commented Jul 8, 2025

PR description

Hi! I've tried to add KVzip, a recent work on query-agnostic KV cache eviction.

KVzip achieves near-lossless compression at eviction ratios of up to 80% on RULER-4k with LLaMA3.1-8B (evaluated using the evaluation.py script from this repository). I've uploaded the result json files on Drive.

Screenshot 2025-07-08 at 10 50 42 AM
Compression ratio 0 0.1 0.25 0.5 0.6 0.7 0.8 0.9
Average Performance 95.7 95.5 95.5 95.5 95.5 95.3 94.9 90.5

KVzip introduces compression overhead (2× prefilling time, with negligible memory overhead). The original KVzip repository also provides a version without compression overhead at the cost of performance, using DuoAttention-style head-level eviction.

I tried to make minimal changes to this repository, but I had to make some additions in pipeline.py. I follow the fake compression strategy from AdaKV in this repository, whereas the original KVzip repository provides optimized code that improves decoding speed by 2×.

Please review and let me know if there are any issues or if everything looks fine. Truly appreciate your great repository!

Checklist

  • Tests are working (make test)
  • Code is formatted correctly (make style, on errors try fix with make format)
  • Copyright header is included
  • All commits are signed-off using git commit -s
  • (new press) mypress_press.py is in the presses directory
  • (new press) MyPress is in __init__.py
  • (new press) README.md is updated with a 1 liner about the new press in the Available presses section
  • (new press) New press is in the default_presses list in tests/default_presses.py
  • (new press) A docstring is provided that follows the same structure as the existing ones

Signed-off-by: Janghyun1230 <kimjanghyun1230@gmail.com>
Signed-off-by: Janghyun1230 <kimjanghyun1230@gmail.com>
Signed-off-by: Janghyun1230 <kimjanghyun1230@gmail.com>
Signed-off-by: Janghyun1230 <kimjanghyun1230@gmail.com>
Signed-off-by: Janghyun1230 <kimjanghyun1230@gmail.com>
@maxjeblick maxjeblick self-requested a review July 10, 2025 12:20
@maxjeblick
Copy link
Collaborator

maxjeblick commented Jul 13, 2025

Thanks a lot for your PR, the results look very promising!
I went through the PR and have the following suggestion:
Instead of modifying pipeline.py, try to move the logic into the press' __call__ context manager method.
I'll leave some code stub below that exemplifies this. (The code stub hasn't been run by myself, and serves more as a template).
In particular:

  • context input ids and question suffix is fetched/computed in the press itself (answer_prefix isn't part of question suffix, IDK if this affects the performance)
  • Instead of .do_compress attribute, the forward hook is only registered after the initial forward pass of the model (upon exiting the context manager).
  • The model forward loop is part of the __call__ method.

Feel free to discuss these proposed changes here

    @contextmanager
    def __call__(self, model: PreTrainedModel) -> Generator:
        """
        Context manager that handles both initial prefilling and KVzip scoring/compression.
        
        This overrides the base class __call__ method to implement the full KVzip algorithm:
        1. First yield: allows initial prefilling with context
        2. After yield: performs KVzip scoring and compression using context reconstruction
        """
        if not isinstance(model, SUPPORTED_MODELS):
            logger.warning(f"Model {type(model)} not tested, supported models: {SUPPORTED_MODELS}")

        if isinstance(model, Gemma3ForCausalLM):
            logger.warning("Compression in Gemma3 is only applied to layer without sliding window attention")

        # Store model reference for later use
        tokenizer = AutoTokenizer.from_pretrained(model.config.name_or_path)

        # Get suffix_ids directly using tokenizer's chat template (do this once, not in hook)
        if tokenizer.chat_template is None:
            suffix_text = "\n"  # Default suffix for models without chat template
        else:
            # Use a dummy context to extract the question suffix from chat template
            dummy_context = "dummy context"
            separator = "\n" + "#" * len(dummy_context)
            temp_context = tokenizer.apply_chat_template(
                [{"role": "user", "content": dummy_context + separator}],
                add_generation_prompt=True,
                tokenize=False
            )
            _, suffix_text = temp_context.split(separator)
        
        # Tokenize suffix directly to ids
        self._suffix_ids = tokenizer.encode(suffix_text, return_tensors="pt", add_special_tokens=False)

        # Register embedding hook to capture context information
        hooks = []
        try:
            # First yield: Initial prefilling phase (no compression hooks yet)
            embedding_hook = model.model.embed_tokens.register_forward_hook(self._forward_hook_embedding,
                                                                            with_kwargs=True)
            yield
            # Remove embedding hook since we no longer need it
            embedding_hook.remove()

            # After yield: KVzip scoring and compression phase
            if self.compression_ratio > 0 and self._context_ids is not None:
                # Now register attention hooks for compression
                for layer in model.model.layers:
                    if isinstance(model, Gemma3ForCausalLM) and layer.is_sliding:
                        continue
                    layer.self_attn.rotary_emb = model.model.rotary_emb
                    hooks.append(layer.self_attn.register_forward_hook(self.forward_hook, with_kwargs=True))

                self._perform_kvzip_compression(model, tokenizer)
        finally:
            for hook in hooks:
                hook.remove()

    def _forward_hook_embedding(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list):
        """
        Hook for embedding layer to capture context information from the first forward pass.

        """
        self._context_ids = input[0]
        self._cache = ... # fetch from kwargs

        return output


    def _perform_kvzip_compression(self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer):
        """
        Perform the KVzip scoring and compression algorithm.
        """
        context_length = self._context_ids.shape[1]
        self.context_length = context_length

        # Prepare chunked inputs for context reconstruction
        input_ids = self.prepare(model, tokenizer, context_length)

        # Reset start_idx for scoring
        self.start_idx = 0

        # Perform scoring through context reconstruction
        # Use the stored cache from the initial forward pass
        for prefill_ids, repeat_ids in input_ids:
            self.end_idx = self.start_idx + prefill_ids.shape[1]
            # Pass the cache that was used in the initial forward pass
            model(
                input_ids=repeat_ids.to(model.device),
                past_key_values=self._cache,
                num_logits_to_keep=1,
            )
            self.start_idx = self.end_idx

        # Verify tokenization consistency
        assert self.end_idx == context_length, "Tokenization is not consistent"

        # Perform final compression
        self.compress_post(model)

Signed-off-by: Janghyun1230 <kimjanghyun1230@gmail.com>
@Janghyun1230 Janghyun1230 force-pushed the main branch 2 times, most recently from cd1e722 to 2799cab Compare July 14, 2025 08:37
Signed-off-by: Janghyun1230 <kimjanghyun1230@gmail.com>
Signed-off-by: Janghyun1230 <kimjanghyun1230@gmail.com>
@Janghyun1230
Copy link
Contributor Author

Janghyun1230 commented Jul 14, 2025

Following your guidelines, I've moved the modifications in pipeline.py into the __call__ context manager method of press. I also merged the latest upstream commits into my branch and updated codes accordingly.

I ran some tests and confirmed that the current version maintains performance. (The prefilling and compression time has slightly increased.) Please review and let me know if you have any further suggestions!

@maxjeblick maxjeblick requested a review from alessiodevoto July 15, 2025 09:19
@maxjeblick
Copy link
Collaborator

Thanks a lot for the quick updates!
We will review the PR, please expect this to take a few days.

Copy link
Collaborator

@maxjeblick maxjeblick left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the extensive refactoring of the press!

I've tested the press on ruler4k benchmark, and the results look very nice!
We will add your press to our benchmark, once it is merged.

For the press to be merged, I kindly ask to

  • Add a warning in the press post init method, informing the user that the press uses multiple forward passes.
  • The press implementation will benefit from being refactored in several places. I left some comments in the code; there are also other places where refactoring can help. Please also add some more comments/docstrings that help users.

Signed-off-by: Janghyun1230 <kimjanghyun1230@gmail.com>
Signed-off-by: Janghyun1230 <kimjanghyun1230@gmail.com>
Signed-off-by: Janghyun1230 <kimjanghyun1230@gmail.com>
Signed-off-by: Janghyun1230 <kimjanghyun1230@gmail.com>
@Janghyun1230
Copy link
Contributor Author

I appreciate your detailed feedback! Your comments improve the clarity and interpretability of the code. I've incorporated all of your suggestions, leaving comments on some specific points.

One thing I'd like to mention is that the current KVzipPress implementation is not compatible with ComposedPress. This is due to that KVzipPress follows a slightly different logic and adopts fake compression as AdaKV, which is also incompatible with ComposedPress.

This incompatibility raises an error in make test (tests/presses/test_presses.py, line 86), where the test invokes ComposedPress with KVzipPress. Aside from this issue, I found no other issues during testing.

Copy link
Collaborator

@maxjeblick maxjeblick left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the refactoring, code looks good!

One thing I'd like to mention is that the current KVzipPress implementation is not compatible with ComposedPress. This is due to that KVzipPress follows a slightly different logic and adopts fake compression as AdaKV, which is also incompatible with ComposedPress.

Ok, thanks for the notice. For now, you can add an if statement in the test, skipping that combination. I'll merge the PR once test pass.

Signed-off-by: Janghyun1230 <kimjanghyun1230@gmail.com>
@Janghyun1230
Copy link
Contributor Author

Thank you for the review! I've added a statement in the test_presses.py.

@maxjeblick maxjeblick merged commit fb93b31 into NVIDIA:main Jul 25, 2025
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants