-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Enable test_torch.py tests for BFloat16 on cuda #22428
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
BFloat16 for cuda gh-metadata: pytorch pytorch 22428 gh/izdeby/13/head
BFloat16 for cuda gh-metadata: pytorch pytorch 22428 gh/izdeby/13/head
BFloat16 for cuda gh-metadata: pytorch pytorch 22428 gh/izdeby/13/head
| backend_types['CUDA'].discard('BFloat16') | ||
| if not option.get('cuda_bfloat16', False): | ||
| if 'CUDA' in backend_types: | ||
| backend_types['CUDA'].discard('BFloat16') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really removed, eh? :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
its coming soon :)
|
It's a pretty big patch but it all looks reasonable to me. Letting @gchanan take a look. |
@ezyang, yeah i know its big. The problem is that there are a lot of inner dependencies from one thing to another. For example, i had to enable couple methods just to make test utils work with BFloat16 type as it is and without hacks. |
BFloat16 for cuda gh-metadata: pytorch pytorch 22428 gh/izdeby/13/head
≈BFloat16 for cuda gh-metadata: pytorch pytorch 22428 gh/izdeby/13/head
| cpu_bool: True | ||
| cuda_bool: True | ||
| cpu_bfloat16: True | ||
| cuda_bfloat16: True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this is more related to the CPU review, but I would have expected corresponding bfloat16 definitions everywhere there are bool definitions, e.g. for masked_fill. Is there a reason those aren't enabled?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR is already big enough. More functionality will come in next PRs.
|
|
||
| void add_kernel_cuda(TensorIterator& iter, Scalar alpha_scalar) { | ||
| AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "add_cuda", [&]() { | ||
| AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "add_cuda", [&]() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are we doing math here? I thought this PR was for adding non-math support?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is needed for the test/common_utils.py. The way we compare two tensors requires some math ops.
aten/src/THC/CMakeLists.txt
Outdated
| LIST(APPEND extra_src "${CMAKE_CURRENT_SOURCE_DIR}/generated/THC${THC_FILE}Bool.cu") | ||
| endforeach() | ||
|
|
||
| foreach(THC_FILE TensorMathReduce TensorMathCompareT TensorMathCompare TensorMathPointwise) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a reason these files are in different orders in the various lists? It makes it hard to quickly figure out what files are supported for which types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorted
Enable test_torch BFloat16 tests for cuda gh-metadata: pytorch pytorch 22428 gh/izdeby/13/head
Enable test_torch BFloat16 tests for cuda gh-metadata: pytorch pytorch 22428 gh/izdeby/13/head
Stack from ghstack:
Differential Revision: D16083575