Skip to content

fix: inputs_embeds ignored when input_ids is not None in _fast_prepare_inputs_for_generation#3814

Open
siddhudonda wants to merge 1 commit intounslothai:mainfrom
siddhudonda:main
Open

fix: inputs_embeds ignored when input_ids is not None in _fast_prepare_inputs_for_generation#3814
siddhudonda wants to merge 1 commit intounslothai:mainfrom
siddhudonda:main

Conversation

@siddhudonda
Copy link

Fixes #3798

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @siddhudonda, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical bug in the _fast_prepare_inputs_for_generation function within the unsloth/models/llama.py module. Previously, the function would inadvertently disregard the inputs_embeds parameter if input_ids was also supplied, leading to incorrect model input. The updated logic now correctly prioritizes and utilizes inputs_embeds when available, ensuring that models receive the intended embedded inputs for generation tasks, thereby fixing GitHub issue #3798.

Highlights

  • Bug Fix: inputs_embeds Handling: Resolved an issue where the inputs_embeds parameter was being ignored in the _fast_prepare_inputs_for_generation function when input_ids was also present, leading to incorrect model input.
  • Input Prioritization Logic: Modified the input preparation logic to correctly prioritize and utilize inputs_embeds over input_ids when inputs_embeds is provided, ensuring the model receives the intended embedded inputs.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly fixes an issue where inputs_embeds would be ignored if input_ids were also present in _fast_prepare_inputs_for_generation. The change ensures inputs_embeds is prioritized, which is the correct behavior. The implementation is sound, and I've provided a minor refactoring suggestion to reduce code duplication and improve clarity.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines 233 to 238
inputs_embeds = kwargs.pop("inputs_embeds", None)
if inputs_embeds is not None:
result = {
"inputs_embeds": inputs_embeds,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Gate inputs_embeds on first step only

When generation continues with cached keys/values, inputs_embeds remains in model_kwargs across steps, so this new branch will keep returning the full inputs_embeds even after past_key_values is set. That bypasses the earlier input_ids = input_ids[:, [-1]] slicing and re-feeds the entire sequence on every cached step, expanding the KV cache with duplicate tokens and producing incorrect logits/performance for any generate call that starts with inputs_embeds. The standard prepare_inputs_for_generation only uses inputs_embeds on the first step and switches back to sliced input_ids once a cache exists; the same guard is needed here.

Useful? React with 👍 / 👎.

@siddhudonda siddhudonda force-pushed the main branch 2 times, most recently from a06142a to 803d0fe Compare January 1, 2026 11:18
Copy link
Author

@siddhudonda siddhudonda left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi! Thanks for the review feedback. I've addressed both suggestions:

@chatgpt-codex-connector's concern: Added a guard to only use inputs_embeds on the first generation step (when past_key_values is None). This prevents re-feeding the entire sequence on cached steps.
@gemini-code-assist's refactoring suggestion: Refactored to reduce code duplication by building the base result dict first, then conditionally adding inputs_embeds or input_ids.
Please review the updated changes in commit 803d0fe. Thanks!

@chatgpt-codex-connector
Copy link

Codex Review: Didn't find any major issues. 🚀

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@Datta0
Copy link
Collaborator

Datta0 commented Jan 2, 2026

Hey @siddhudonda thanks a lot for the contribution.
If you can provide a script to verify your changes (a before and after would be even better), that would help ease out our review.

@siddhudonda
Copy link
Author

siddhudonda commented Jan 2, 2026

Hi! Thanks for the feedback. Here's a verification script that demonstrates the before/after behavior

"""
Verification script for PR #3814: inputs_embeds handling fix
https://github.com/unslothai/unsloth/pull/3814

This script demonstrates that inputs_embeds is now correctly used
when passed to model.generate() instead of being silently ignored.

Issue: https://github.com/unslothai/unsloth/issues/3798
"""

import torch

# Mock the _fast_prepare_inputs_for_generation function to test both behaviors

# ============= BEFORE THE FIX =============
def before_fix_prepare_inputs_for_generation(
    self,
    input_ids,
    attention_mask=None,
    **kwargs,
):
    """Original behavior: inputs_embeds was always ignored, input_ids was always used"""
    past_key_values = kwargs.get("past_key_values", None)
    
    if "cache_position" in kwargs:
        kwargs["position_ids"] = kwargs["cache_position"]
    
    # BUG: This always returns input_ids, ignoring inputs_embeds completely
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        **kwargs,
    }


# ============= AFTER THE FIX =============
def after_fix_prepare_inputs_for_generation(
    self,
    input_ids,
    attention_mask=None,
    **kwargs,
):
    """Fixed behavior: inputs_embeds is used on first step when provided"""
    past_key_values = kwargs.get("past_key_values", None)
    
    if "cache_position" in kwargs:
        kwargs["position_ids"] = kwargs["cache_position"]
    
    # FIX: Check for inputs_embeds and use it on first generation step (no cache)
    inputs_embeds = kwargs.pop("inputs_embeds", None)
    result = {
        "attention_mask": attention_mask,
        **kwargs,
    }
    if inputs_embeds is not None and past_key_values is None:
        # First step with inputs_embeds - use embeddings
        result["inputs_embeds"] = inputs_embeds
    else:
        # Subsequent steps with cache OR normal input_ids path
        result["input_ids"] = input_ids
    return result


# ============= TEST CASES =============
def test_inputs_embeds_handling():
    """Test that inputs_embeds is correctly handled"""
    
    # Simulate inputs
    batch_size, seq_len, hidden_dim = 2, 10, 768
    input_ids = torch.randint(0, 32000, (batch_size, seq_len))
    attention_mask = torch.ones(batch_size, seq_len)
    inputs_embeds = torch.randn(batch_size, seq_len, hidden_dim)
    
    print("=" * 60)
    print("TEST: inputs_embeds handling in _fast_prepare_inputs_for_generation")
    print("=" * 60)
    
    # Test Case 1: First generation step with inputs_embeds (no cache)
    print("\n[Test 1] First step with inputs_embeds (past_key_values=None)")
    print("-" * 60)
    
    # BEFORE fix
    result_before = before_fix_prepare_inputs_for_generation(
        None, input_ids, attention_mask,
        inputs_embeds=inputs_embeds,
        past_key_values=None
    )
    print(f"BEFORE fix:")
    print(f"  - 'inputs_embeds' in result: {'inputs_embeds' in result_before}")
    print(f"  - 'input_ids' in result: {'input_ids' in result_before}")
    print(f"  - Result uses: {'input_ids (WRONG!)' if 'input_ids' in result_before else 'inputs_embeds'}")
    
    # AFTER fix
    result_after = after_fix_prepare_inputs_for_generation(
        None, input_ids, attention_mask,
        inputs_embeds=inputs_embeds.clone(),
        past_key_values=None
    )
    print(f"\nAFTER fix:")
    print(f"  - 'inputs_embeds' in result: {'inputs_embeds' in result_after}")
    print(f"  - 'input_ids' in result: {'input_ids' in result_after}")
    print(f"  - Result uses: {'inputs_embeds (CORRECT!)' if 'inputs_embeds' in result_after else 'input_ids'}")
    
    # Verify
    assert 'input_ids' in result_before, "Before: should have input_ids (bug)"
    assert 'inputs_embeds' in result_after, "After: should have inputs_embeds (fixed)"
    assert 'input_ids' not in result_after, "After: should NOT have input_ids on first step"
    print("\n Test 1 PASSED: inputs_embeds is correctly used on first step")
    
    # Test Case 2: Subsequent step with cache (should use input_ids)
    print("\n[Test 2] Subsequent step with cache (past_key_values set)")
    print("-" * 60)
    
    # Simulate a non-empty cache
    fake_cache = [(torch.randn(2, 8, 5, 64), torch.randn(2, 8, 5, 64))]
    
    result_after_cached = after_fix_prepare_inputs_for_generation(
        None, input_ids[:, -1:], attention_mask,  # Sliced input_ids for cached step
        inputs_embeds=inputs_embeds.clone(),  # This should be ignored when cache exists
        past_key_values=fake_cache
    )
    print(f"AFTER fix (with cache):")
    print(f"  - 'inputs_embeds' in result: {'inputs_embeds' in result_after_cached}")
    print(f"  - 'input_ids' in result: {'input_ids' in result_after_cached}")
    print(f"  - Result uses: {'input_ids (CORRECT! - cached step)' if 'input_ids' in result_after_cached else 'inputs_embeds'}")
    
    assert 'input_ids' in result_after_cached, "After with cache: should use input_ids"
    assert 'inputs_embeds' not in result_after_cached, "After with cache: should NOT use inputs_embeds"
    print("\n Test 2 PASSED: input_ids is correctly used on cached steps")
    
    # Test Case 3: Normal case without inputs_embeds
    print("\n[Test 3] Normal case without inputs_embeds")
    print("-" * 60)
    
    result_normal = after_fix_prepare_inputs_for_generation(
        None, input_ids, attention_mask,
        past_key_values=None
        # No inputs_embeds provided
    )
    print(f"AFTER fix (no inputs_embeds):")
    print(f"  - 'inputs_embeds' in result: {'inputs_embeds' in result_normal}")
    print(f"  - 'input_ids' in result: {'input_ids' in result_normal}")
    
    assert 'input_ids' in result_normal, "Normal case: should use input_ids"
    print("\n Test 3 PASSED: input_ids is correctly used when no inputs_embeds provided")
    
    print("\n" + "=" * 60)
    print("ALL TESTS PASSED! ")
    print("=" * 60)
    print("\nSummary:")
    print("- BEFORE: inputs_embeds was always ignored (bug)")
    print("- AFTER:  inputs_embeds is used on first step, input_ids on cached steps (correct)")


if __name__ == "__main__":
    test_inputs_embeds_handling()
'''

@Pioneer-Weirdo
Copy link

Pioneer-Weirdo commented Jan 2, 2026

I downloaded the pr locally, but this modified code didn't solve the problem; it's still throwing errors.

# Fix new HF's inference code
def _fast_prepare_inputs_for_generation(
    self,
    input_ids,
    attention_mask = None,
    inputs_embeds = None,
    **kwargs,
):
    past_key_values = kwargs.get("past_key_values", None)
    if past_key_values is not None:
        # Check for uninitialized DynamicCache
        if len(past_key_values) == 0:
            past_key_values = None
            kwargs["past_key_values"] = None
        # New since 4.56
        elif (
            hasattr(past_key_values, "get_seq_length")
            and past_key_values.get_seq_length() == 0
        ):
            past_key_values = None
            kwargs["past_key_values"] = None
        else:
            bs, cache_length = input_ids.shape
            input_ids = input_ids[:, [-1]]

            # Get to the base model
            base_model = self
            if hasattr(base_model, "base_model_prefix"):
                base_model = getattr(base_model, base_model.base_model_prefix)

            if hasattr(
                base_model, "_prepare_4d_causal_attention_mask_with_cache_position"
            ):

                def needs_device_kw(fn) -> bool:
                    try:
                        sig = inspect.signature(inspect.unwrap(fn))
                        return "device" in sig.parameters
                    except:
                        # transformers <= 4.51.3 includes device arg but > 4.51.3 does not
                        return transformers_version < Version("4.52.0")

                kwargs = {
                    "sequence_length": 1,
                    "target_length": cache_length,
                    "dtype": self.dtype,
                    "cache_position": torch.arange(
                        cache_length, cache_length + 1, device = input_ids.device
                    ),
                    "batch_size": bs,
                    "config": self.config,
                    "past_key_values": past_key_values,
                }
                try:
                    if needs_device_kw(
                        base_model._prepare_4d_causal_attention_mask_with_cache_position
                    ):
                        kwargs["device"] = input_ids.device
                except:
                    print(
                        f"Unsloth: Could not inspect signature of {base_model._prepare_4d_causal_attention_mask_with_cache_position}"
                    )

                attention_mask = (
                    base_model._prepare_4d_causal_attention_mask_with_cache_position(
                        attention_mask,
                        **kwargs,
                    )
                )
            else:
                attention_mask = attention_mask[:, [-1]]
                if transformers_version <= Version("4.52.4"):
                    logger.warning_once(
                        f"{self.__class__.__name__} has no `_prepare_4d_causal_attention_mask_with_cache_position` method "
                        "defined in its base modeling class. Compiled forward passes will be sub-optimal. If you're "
                        "writing code, see Llama for an example implementation. If you're a user, please report this "
                        "issue on GitHub."
                    )

    if "cache_position" in kwargs:
        kwargs["position_ids"] = kwargs["cache_position"]

    # Handle inputs_embeds - only use it on the FIRST generation step (no cache).
    # Once past_key_values is set, we must use sliced input_ids for subsequent steps.
    # This fixes GitHub issue #3798: inputs_embeds was ignored when input_ids was not None
    inputs_embeds = kwargs.pop("inputs_embeds", None)
    result = {
        "attention_mask": attention_mask,
        **kwargs,
    }
    if inputs_embeds is not None and past_key_values is None:
        # First step with inputs_embeds - use embeddings
        result["inputs_embeds"] = inputs_embeds
    else:
        # Subsequent steps with cache OR normal input_ids path
        result["input_ids"] = input_ids
    return result
"""
Minimal test script to verify UNSLOTH_DISABLE_FAST_GENERATION works with inputs_embeds.
Run this BEFORE the main training script to isolate the issue.

Usage:
    python scripts/test_unsloth_inputs_embeds.py
"""
import os
# MUST be set before ANY unsloth import
os.environ["UNSLOTH_DISABLE_FAST_GENERATION"] = "1"
os.environ["UNSLOTH_STUDIO_DISABLED"] = "1"

print(f"[DEBUG] UNSLOTH_DISABLE_FAST_GENERATION = {os.environ.get('UNSLOTH_DISABLE_FAST_GENERATION')}")

# import torch
from unsloth import FastLanguageModel

def main():
    print("\n=== Loading Model via Unsloth ===")
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name="unsloth/tinyllama-chat-bnb-4bit",  # Use smallest model
        max_seq_length=512,
        dtype=None,
        load_in_4bit=True,
    )
    
    # Prepare for inference
    FastLanguageModel.for_inference(model)
    model.eval()
    
    print(f"\n[DEBUG] Model class: {model.__class__.__name__}")
    print(f"[DEBUG] Model has prepare_inputs_for_generation: {hasattr(model, 'prepare_inputs_for_generation')}")
    
    # Check if inputs_embeds is in prepare_inputs_for_generation signature
    import inspect
    sig = inspect.signature(model.prepare_inputs_for_generation)
    params = list(sig.parameters.keys())
    print(f"[DEBUG] prepare_inputs_for_generation params: {params}")
    print(f"[DEBUG] 'inputs_embeds' in signature: {'inputs_embeds' in params}")
    
    # Test 1: Standard generation with input_ids (should work)
    print("\n=== Test 1: Standard generation with input_ids ===")
    inputs = tokenizer("Who are you?", return_tensors="pt").to(model.device)
    try:
        outputs = model.generate(**inputs, max_new_tokens=10, do_sample=False)
        print(f"[PASS] Generated: {tokenizer.decode(outputs[0], skip_special_tokens=True)}")
    except Exception as e:
        print(f"[FAIL] Error: {e}")
    
    # Test 2: Generation with inputs_embeds (this is what we need to work)
    print("\n=== Test 2: Generation with inputs_embeds ===")
    try:
        input_ids = inputs["input_ids"]
        inputs_embeds = model.get_input_embeddings()(input_ids)
        attention_mask = inputs["attention_mask"]
        
        print(f"[DEBUG] inputs_embeds shape: {inputs_embeds.shape}")
        
        outputs = model.generate(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            max_new_tokens=10,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id
        )
        print(f"[PASS] Generated with inputs_embeds: {tokenizer.decode(outputs[0], skip_special_tokens=True)}")
    except ValueError as e:
        print(f"[FAIL] ValueError (expected if env var not working): {e}")
    except Exception as e:
        print(f"[FAIL] Other error: {type(e).__name__}: {e}")

    print("\n=== Test Complete ===")

if __name__ == "__main__":
    main()

=== Test 1: Standard generation with input_ids ===
[PASS] Generated: Who are you?
You are the one who can make me whole

=== Test 2: Generation with inputs_embeds ===
[DEBUG] inputs_embeds shape: torch.Size([1, 5, 2048])
[FAIL] Other error: RuntimeError: shape '[-1, 0]' is invalid for input of size 5

=== Test Complete ===

@siddhudonda siddhudonda marked this pull request as draft January 2, 2026 10:37
@siddhudonda siddhudonda marked this pull request as ready for review January 2, 2026 10:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

fix: inputs_embeds ignored when input_ids is not None in _fast_prepare_inputs_for_generation

3 participants