Fix #12116: preserve boolean dtype for attention masks in ChromaPipeline #12263
Fix #12116: preserve boolean dtype for attention masks in ChromaPipeline #12263DN6 merged 8 commits intohuggingface:mainfrom
Conversation
- Convert attention masks to bool and prevent dtype corruption - Fix both positive and negative mask handling in _get_t5_prompt_embeds - Remove float conversion in _prepare_attention_mask method Fixes huggingface#12116
|
hello @DN6, just checking in to see if you’ve had a chance to look at the above PR. If you’re not the right person or are keeping busy, would you mind pointing me to someone who could review it? Thanks! |
|
thanks @akshay-babbar |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
DN6
left a comment
There was a problem hiding this comment.
Thanks @akshay-babbar. Minor changes requested.
| ) | ||
| text_input_ids = text_inputs.input_ids | ||
| attention_mask = text_inputs.attention_mask.clone() | ||
| attention_mask = attention_mask.bool() |
There was a problem hiding this comment.
Don't think we need the type conversion here. We can just convert the final mask to bool.
| assert (output_height, output_width) == (expected_height, expected_width) | ||
|
|
||
|
|
||
| class ChromaPipelineAttentionMaskTests(unittest.TestCase): |
There was a problem hiding this comment.
Dedicated tests aren't needed here. The existing tests should catch changes in numerical output if they are significant.
|
Hello @DN6 , thanks for the review! I have made the changes, let me know your feedback and next steps. Thanks! |
|
Sorry for not seeing this earlier, I was just notified by the commit. Are you sure that passing the attention mask as boolean... It seems to work I guess, but it's documented to require a FloatTensor: Textencoder: Float between 0 and 1 |






Problem
Fixes #12116
Short prompts generate corrupted images due to attention mask dtype conversion bug.
Root Cause
Attention masks converted from bool → float16/bfloat16, but PyTorch's scaled_dot_product_attention requires boolean masks.
Solution
_get_t5_prompt_embeds_prepare_attention_maskTesting
✅ Added @slow unit tests for dtype preservation
✅ Verified fix with prompts: "man", "cat"
✅ All tests pass locally
Please review when you have a chance. Thank you for your time and consideration!