Use torch in get_2d_rotary_pos_embed#10155
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Downstream usage - should be ok, this was already returning |
|
did we run hunyuan test? |
|
Checkpoint used in the slow test is 404 |
|
just run its docstring example manually would be fine for now |
|
There a slight change to the image. import torch
from diffusers import HunyuanDiTPipeline
pipe = HunyuanDiTPipeline.from_pretrained(
"Tencent-Hunyuan/HunyuanDiT-Diffusers", torch_dtype=torch.float16
)
pipe.to("cuda")
prompt = "An astronaut riding a horse"
image = pipe(prompt, generator=torch.Generator("cuda").manual_seed(0)).images[0] |
|
Unclear why though, I'll run the test again. Edit: I haven't ran the reproduction on CUDA, might account for the difference. >>> torch.abs(image_rotary_emb[0].flatten() - image_rotary_emb_np[0].flatten()).max()
tensor(0.)
>>> torch.abs(image_rotary_emb[1].flatten() - image_rotary_emb_np[1].flatten()).max()
tensor(0.) |
|
Yes there's a very minor difference when we create the tensors on CUDA. It's below PyTorch's tolerance for float32 though https://pytorch.org/docs/stable/testing.html cc @yiyixuxu |
af5ecd9 to
f2e7731
Compare
|
I've added |
|
It's a bit of a surprise that something numerically < 1e-7 would cause a visual difference like this, but it is not worse, is it? my eyes are not very good with it |
|
It is surprising, we can run more tests before merge if you want, the visual difference is acceptable imo. |
|
thanks @hlky great work as always:) |
* Use `torch` in `get_2d_rotary_pos_embed` * Add deprecation


What does this PR do?
Refactors
get_2d_rotary_pos_embedto usetorchinstead ofnumpy, and addsdeviceargument so that tensors can be created on e.g.cuda.Usage of
get_2d_rotary_pos_embedin HunyuanDiT pipelines is updated to passdevice.torchandnumpyversions match numerically.Reproduction
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@sayakpaul @yiyixuxu