fixed a dtype bfloat16 bug in torch_utils.py#10125
fixed a dtype bfloat16 bug in torch_utils.py#10125yiyixuxu merged 6 commits intohuggingface:mainfrom zhangp365:main
Conversation
when generating 1024*1024 image with bfloat16 dtype, there is an exception:
File "/opt/conda/lib/python3.10/site-packages/diffusers/utils/torch_utils.py", line 107, in fourier_filter
x_freq = fftn(x, dim=(-2, -1))
RuntimeError: Unsupported dtype BFloat16
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
hlky
left a comment
There was a problem hiding this comment.
@zhangp365 Thanks! this is when using freeu? Can we keep the check for non-power of 2 images and add another for bfloat16? I think that makes it clearer why we're casting to float32.
|
@hlky Yes, this uses FreeU and sets the pipeline dtype to bfloat16. In this case, when the image is not a non-power-of-2 size, the pipeline runs successfully. However, the standard size image fails to run. Therefore, I believe casting the type to float32 is a safe operation, making the code more robust. |
# Non-power of 2 images must be float32
if (W & (W - 1)) != 0 or (H & (H - 1)) != 0:
x = x.to(dtype=torch.float32)
# fftn does not support bfloat16
elif x.dtype == torch.bfloat16:
x = x.to(dtype=torch.float32)If we always cast someone looking at the function in the future may wonder why. cc @sayakpaul @DN6 WDYT? |
|
Makes sense to me! |
|
@zhangp365 Can you run |
I tried but the errors are not from this pr. I think this pr will not affect the make process. |
|
Here's the error from the last run, link. The extra errors you're seeing are because of |
Yes, after running |
|
Thanks @zhangp365! |
* fixed a dtype bfloat16 bug in torch_utils.py
when generating 1024*1024 image with bfloat16 dtype, there is an exception:
File "/opt/conda/lib/python3.10/site-packages/diffusers/utils/torch_utils.py", line 107, in fourier_filter
x_freq = fftn(x, dim=(-2, -1))
RuntimeError: Unsupported dtype BFloat16
* remove whitespace in torch_utils.py
* Update src/diffusers/utils/torch_utils.py
* Update torch_utils.py
---------
Co-authored-by: hlky <hlky@hlky.ac>
when generating 1024*1024 image with bfloat16 dtype, there is an exception:
File "/opt/conda/lib/python3.10/site-packages/diffusers/utils/torch_utils.py", line 107, in fourier_filter
x_freq = fftn(x, dim=(-2, -1))
RuntimeError: Unsupported dtype BFloat16
What does this PR do?
fix a bug.
@hlky