fix: inputs_embeds ignored when input_ids is not None in _fast_prepare_inputs_for_generation#3814
fix: inputs_embeds ignored when input_ids is not None in _fast_prepare_inputs_for_generation#3814siddhudonda wants to merge 1 commit intounslothai:mainfrom
Conversation
Summary of ChangesHello @siddhudonda, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a critical bug in the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request correctly fixes an issue where inputs_embeds would be ignored if input_ids were also present in _fast_prepare_inputs_for_generation. The change ensures inputs_embeds is prioritized, which is the correct behavior. The implementation is sound, and I've provided a minor refactoring suggestion to reduce code duplication and improve clarity.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
unsloth/models/llama.py
Outdated
| inputs_embeds = kwargs.pop("inputs_embeds", None) | ||
| if inputs_embeds is not None: | ||
| result = { | ||
| "inputs_embeds": inputs_embeds, |
There was a problem hiding this comment.
Gate inputs_embeds on first step only
When generation continues with cached keys/values, inputs_embeds remains in model_kwargs across steps, so this new branch will keep returning the full inputs_embeds even after past_key_values is set. That bypasses the earlier input_ids = input_ids[:, [-1]] slicing and re-feeds the entire sequence on every cached step, expanding the KV cache with duplicate tokens and producing incorrect logits/performance for any generate call that starts with inputs_embeds. The standard prepare_inputs_for_generation only uses inputs_embeds on the first step and switches back to sliced input_ids once a cache exists; the same guard is needed here.
Useful? React with 👍 / 👎.
a06142a to
803d0fe
Compare
siddhudonda
left a comment
There was a problem hiding this comment.
Hi! Thanks for the review feedback. I've addressed both suggestions:
@chatgpt-codex-connector's concern: Added a guard to only use inputs_embeds on the first generation step (when past_key_values is None). This prevents re-feeding the entire sequence on cached steps.
@gemini-code-assist's refactoring suggestion: Refactored to reduce code duplication by building the base result dict first, then conditionally adding inputs_embeds or input_ids.
Please review the updated changes in commit 803d0fe. Thanks!
|
Codex Review: Didn't find any major issues. 🚀 ℹ️ About Codex in GitHubCodex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
If Codex has suggestions, it will comment; otherwise it will react with 👍. When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback". |
|
Hey @siddhudonda thanks a lot for the contribution. |
|
Hi! Thanks for the feedback. Here's a verification script that demonstrates the before/after behavior """
Verification script for PR #3814: inputs_embeds handling fix
https://github.com/unslothai/unsloth/pull/3814
This script demonstrates that inputs_embeds is now correctly used
when passed to model.generate() instead of being silently ignored.
Issue: https://github.com/unslothai/unsloth/issues/3798
"""
import torch
# Mock the _fast_prepare_inputs_for_generation function to test both behaviors
# ============= BEFORE THE FIX =============
def before_fix_prepare_inputs_for_generation(
self,
input_ids,
attention_mask=None,
**kwargs,
):
"""Original behavior: inputs_embeds was always ignored, input_ids was always used"""
past_key_values = kwargs.get("past_key_values", None)
if "cache_position" in kwargs:
kwargs["position_ids"] = kwargs["cache_position"]
# BUG: This always returns input_ids, ignoring inputs_embeds completely
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
**kwargs,
}
# ============= AFTER THE FIX =============
def after_fix_prepare_inputs_for_generation(
self,
input_ids,
attention_mask=None,
**kwargs,
):
"""Fixed behavior: inputs_embeds is used on first step when provided"""
past_key_values = kwargs.get("past_key_values", None)
if "cache_position" in kwargs:
kwargs["position_ids"] = kwargs["cache_position"]
# FIX: Check for inputs_embeds and use it on first generation step (no cache)
inputs_embeds = kwargs.pop("inputs_embeds", None)
result = {
"attention_mask": attention_mask,
**kwargs,
}
if inputs_embeds is not None and past_key_values is None:
# First step with inputs_embeds - use embeddings
result["inputs_embeds"] = inputs_embeds
else:
# Subsequent steps with cache OR normal input_ids path
result["input_ids"] = input_ids
return result
# ============= TEST CASES =============
def test_inputs_embeds_handling():
"""Test that inputs_embeds is correctly handled"""
# Simulate inputs
batch_size, seq_len, hidden_dim = 2, 10, 768
input_ids = torch.randint(0, 32000, (batch_size, seq_len))
attention_mask = torch.ones(batch_size, seq_len)
inputs_embeds = torch.randn(batch_size, seq_len, hidden_dim)
print("=" * 60)
print("TEST: inputs_embeds handling in _fast_prepare_inputs_for_generation")
print("=" * 60)
# Test Case 1: First generation step with inputs_embeds (no cache)
print("\n[Test 1] First step with inputs_embeds (past_key_values=None)")
print("-" * 60)
# BEFORE fix
result_before = before_fix_prepare_inputs_for_generation(
None, input_ids, attention_mask,
inputs_embeds=inputs_embeds,
past_key_values=None
)
print(f"BEFORE fix:")
print(f" - 'inputs_embeds' in result: {'inputs_embeds' in result_before}")
print(f" - 'input_ids' in result: {'input_ids' in result_before}")
print(f" - Result uses: {'input_ids (WRONG!)' if 'input_ids' in result_before else 'inputs_embeds'}")
# AFTER fix
result_after = after_fix_prepare_inputs_for_generation(
None, input_ids, attention_mask,
inputs_embeds=inputs_embeds.clone(),
past_key_values=None
)
print(f"\nAFTER fix:")
print(f" - 'inputs_embeds' in result: {'inputs_embeds' in result_after}")
print(f" - 'input_ids' in result: {'input_ids' in result_after}")
print(f" - Result uses: {'inputs_embeds (CORRECT!)' if 'inputs_embeds' in result_after else 'input_ids'}")
# Verify
assert 'input_ids' in result_before, "Before: should have input_ids (bug)"
assert 'inputs_embeds' in result_after, "After: should have inputs_embeds (fixed)"
assert 'input_ids' not in result_after, "After: should NOT have input_ids on first step"
print("\n Test 1 PASSED: inputs_embeds is correctly used on first step")
# Test Case 2: Subsequent step with cache (should use input_ids)
print("\n[Test 2] Subsequent step with cache (past_key_values set)")
print("-" * 60)
# Simulate a non-empty cache
fake_cache = [(torch.randn(2, 8, 5, 64), torch.randn(2, 8, 5, 64))]
result_after_cached = after_fix_prepare_inputs_for_generation(
None, input_ids[:, -1:], attention_mask, # Sliced input_ids for cached step
inputs_embeds=inputs_embeds.clone(), # This should be ignored when cache exists
past_key_values=fake_cache
)
print(f"AFTER fix (with cache):")
print(f" - 'inputs_embeds' in result: {'inputs_embeds' in result_after_cached}")
print(f" - 'input_ids' in result: {'input_ids' in result_after_cached}")
print(f" - Result uses: {'input_ids (CORRECT! - cached step)' if 'input_ids' in result_after_cached else 'inputs_embeds'}")
assert 'input_ids' in result_after_cached, "After with cache: should use input_ids"
assert 'inputs_embeds' not in result_after_cached, "After with cache: should NOT use inputs_embeds"
print("\n Test 2 PASSED: input_ids is correctly used on cached steps")
# Test Case 3: Normal case without inputs_embeds
print("\n[Test 3] Normal case without inputs_embeds")
print("-" * 60)
result_normal = after_fix_prepare_inputs_for_generation(
None, input_ids, attention_mask,
past_key_values=None
# No inputs_embeds provided
)
print(f"AFTER fix (no inputs_embeds):")
print(f" - 'inputs_embeds' in result: {'inputs_embeds' in result_normal}")
print(f" - 'input_ids' in result: {'input_ids' in result_normal}")
assert 'input_ids' in result_normal, "Normal case: should use input_ids"
print("\n Test 3 PASSED: input_ids is correctly used when no inputs_embeds provided")
print("\n" + "=" * 60)
print("ALL TESTS PASSED! ")
print("=" * 60)
print("\nSummary:")
print("- BEFORE: inputs_embeds was always ignored (bug)")
print("- AFTER: inputs_embeds is used on first step, input_ids on cached steps (correct)")
if __name__ == "__main__":
test_inputs_embeds_handling()
''' |
|
I downloaded the pr locally, but this modified code didn't solve the problem; it's still throwing errors. # Fix new HF's inference code
def _fast_prepare_inputs_for_generation(
self,
input_ids,
attention_mask = None,
inputs_embeds = None,
**kwargs,
):
past_key_values = kwargs.get("past_key_values", None)
if past_key_values is not None:
# Check for uninitialized DynamicCache
if len(past_key_values) == 0:
past_key_values = None
kwargs["past_key_values"] = None
# New since 4.56
elif (
hasattr(past_key_values, "get_seq_length")
and past_key_values.get_seq_length() == 0
):
past_key_values = None
kwargs["past_key_values"] = None
else:
bs, cache_length = input_ids.shape
input_ids = input_ids[:, [-1]]
# Get to the base model
base_model = self
if hasattr(base_model, "base_model_prefix"):
base_model = getattr(base_model, base_model.base_model_prefix)
if hasattr(
base_model, "_prepare_4d_causal_attention_mask_with_cache_position"
):
def needs_device_kw(fn) -> bool:
try:
sig = inspect.signature(inspect.unwrap(fn))
return "device" in sig.parameters
except:
# transformers <= 4.51.3 includes device arg but > 4.51.3 does not
return transformers_version < Version("4.52.0")
kwargs = {
"sequence_length": 1,
"target_length": cache_length,
"dtype": self.dtype,
"cache_position": torch.arange(
cache_length, cache_length + 1, device = input_ids.device
),
"batch_size": bs,
"config": self.config,
"past_key_values": past_key_values,
}
try:
if needs_device_kw(
base_model._prepare_4d_causal_attention_mask_with_cache_position
):
kwargs["device"] = input_ids.device
except:
print(
f"Unsloth: Could not inspect signature of {base_model._prepare_4d_causal_attention_mask_with_cache_position}"
)
attention_mask = (
base_model._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
**kwargs,
)
)
else:
attention_mask = attention_mask[:, [-1]]
if transformers_version <= Version("4.52.4"):
logger.warning_once(
f"{self.__class__.__name__} has no `_prepare_4d_causal_attention_mask_with_cache_position` method "
"defined in its base modeling class. Compiled forward passes will be sub-optimal. If you're "
"writing code, see Llama for an example implementation. If you're a user, please report this "
"issue on GitHub."
)
if "cache_position" in kwargs:
kwargs["position_ids"] = kwargs["cache_position"]
# Handle inputs_embeds - only use it on the FIRST generation step (no cache).
# Once past_key_values is set, we must use sliced input_ids for subsequent steps.
# This fixes GitHub issue #3798: inputs_embeds was ignored when input_ids was not None
inputs_embeds = kwargs.pop("inputs_embeds", None)
result = {
"attention_mask": attention_mask,
**kwargs,
}
if inputs_embeds is not None and past_key_values is None:
# First step with inputs_embeds - use embeddings
result["inputs_embeds"] = inputs_embeds
else:
# Subsequent steps with cache OR normal input_ids path
result["input_ids"] = input_ids
return result"""
Minimal test script to verify UNSLOTH_DISABLE_FAST_GENERATION works with inputs_embeds.
Run this BEFORE the main training script to isolate the issue.
Usage:
python scripts/test_unsloth_inputs_embeds.py
"""
import os
# MUST be set before ANY unsloth import
os.environ["UNSLOTH_DISABLE_FAST_GENERATION"] = "1"
os.environ["UNSLOTH_STUDIO_DISABLED"] = "1"
print(f"[DEBUG] UNSLOTH_DISABLE_FAST_GENERATION = {os.environ.get('UNSLOTH_DISABLE_FAST_GENERATION')}")
# import torch
from unsloth import FastLanguageModel
def main():
print("\n=== Loading Model via Unsloth ===")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/tinyllama-chat-bnb-4bit", # Use smallest model
max_seq_length=512,
dtype=None,
load_in_4bit=True,
)
# Prepare for inference
FastLanguageModel.for_inference(model)
model.eval()
print(f"\n[DEBUG] Model class: {model.__class__.__name__}")
print(f"[DEBUG] Model has prepare_inputs_for_generation: {hasattr(model, 'prepare_inputs_for_generation')}")
# Check if inputs_embeds is in prepare_inputs_for_generation signature
import inspect
sig = inspect.signature(model.prepare_inputs_for_generation)
params = list(sig.parameters.keys())
print(f"[DEBUG] prepare_inputs_for_generation params: {params}")
print(f"[DEBUG] 'inputs_embeds' in signature: {'inputs_embeds' in params}")
# Test 1: Standard generation with input_ids (should work)
print("\n=== Test 1: Standard generation with input_ids ===")
inputs = tokenizer("Who are you?", return_tensors="pt").to(model.device)
try:
outputs = model.generate(**inputs, max_new_tokens=10, do_sample=False)
print(f"[PASS] Generated: {tokenizer.decode(outputs[0], skip_special_tokens=True)}")
except Exception as e:
print(f"[FAIL] Error: {e}")
# Test 2: Generation with inputs_embeds (this is what we need to work)
print("\n=== Test 2: Generation with inputs_embeds ===")
try:
input_ids = inputs["input_ids"]
inputs_embeds = model.get_input_embeddings()(input_ids)
attention_mask = inputs["attention_mask"]
print(f"[DEBUG] inputs_embeds shape: {inputs_embeds.shape}")
outputs = model.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
max_new_tokens=10,
do_sample=False,
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id
)
print(f"[PASS] Generated with inputs_embeds: {tokenizer.decode(outputs[0], skip_special_tokens=True)}")
except ValueError as e:
print(f"[FAIL] ValueError (expected if env var not working): {e}")
except Exception as e:
print(f"[FAIL] Other error: {type(e).__name__}: {e}")
print("\n=== Test Complete ===")
if __name__ == "__main__":
main()
=== Test 1: Standard generation with input_ids === === Test 2: Generation with inputs_embeds === === Test Complete === |
Fixes #3798