@@ -17,7 +17,7 @@ void addcmul_cuda_scalar_tensor2_kernel(
1717 const Scalar& value
1818);
1919
20- #if AT_USE_JITERATOR()
20+ #if AT_USE_JITERATOR() && CUDA_VERSION >= 11050
2121constexpr char addcmul_name[] = " addcmul" ;
2222#endif
2323void addcmul_cuda_kernel (TensorIteratorBase& iter, const Scalar& value) {
@@ -37,7 +37,10 @@ void addcmul_cuda_kernel(TensorIteratorBase& iter, const Scalar& value) {
3737
3838 auto dtype = iter.common_dtype ();
3939 if (at::isComplexType (dtype)) {
40- #if AT_USE_JITERATOR()
40+ // When using Jiterator, addcmul and addcdiv kernels get stuck during a
41+ // promotion test on CUDA 11.3, so only enable that from CUDA 11.5:
42+ // https://github.com/pytorch/pytorch/pull/74234#issuecomment-1100932209
43+ #if AT_USE_JITERATOR() && CUDA_VERSION >= 11050
4144 AT_DISPATCH_COMPLEX_TYPES (dtype, " addcmul_cuda" , [&]() {
4245 auto alpha = value.to <scalar_t >();
4346 static const auto addcmul_string = jiterator_stringify (
@@ -90,14 +93,17 @@ void addcmul_cuda_kernel(TensorIteratorBase& iter, const Scalar& value) {
9093 }
9194}
9295
93- #if AT_USE_JITERATOR()
96+ #if AT_USE_JITERATOR() && CUDA_VERSION >= 11050
9497constexpr char addcmul_scalar_tensor2_name[] = " addcmul_scalar_tensor2" ;
9598#endif
9699void addcmul_cuda_scalar_tensor2_kernel (TensorIteratorBase& iter, const Scalar& scalar_tensor2, const Scalar& value) {
97100 auto dtype = iter.common_dtype ();
98101
99102 if (at::isComplexType (dtype)) {
100- #if AT_USE_JITERATOR()
103+ // When using Jiterator, addcmul and addcdiv kernels get stuck during a
104+ // promotion test on CUDA 11.3, so only enable that from CUDA 11.5:
105+ // https://github.com/pytorch/pytorch/pull/74234#issuecomment-1100932209
106+ #if AT_USE_JITERATOR() && CUDA_VERSION >= 11050
101107 AT_DISPATCH_COMPLEX_TYPES (dtype, " addcmul_cuda" , [&]() {
102108 auto c = scalar_tensor2.to <scalar_t >();
103109 auto alpha = value.to <scalar_t >();
@@ -139,14 +145,17 @@ void addcmul_cuda_scalar_tensor2_kernel(TensorIteratorBase& iter, const Scalar&
139145 }
140146}
141147
142- #if AT_USE_JITERATOR()
148+ #if AT_USE_JITERATOR() && CUDA_VERSION >= 11050
143149// return a + alpha * (b / static_cast<accscalar_t>(c));
144150constexpr char addcdiv_name[] = " addcdiv" ;
145151#endif
146152void addcdiv_cuda_kernel (TensorIteratorBase& iter, const Scalar& value) {
147153 auto dtype = iter.common_dtype ();
148154 if (at::isComplexType (dtype)) {
149- #if AT_USE_JITERATOR()
155+ // When using Jiterator, addcmul and addcdiv kernels get stuck during a
156+ // promotion test on CUDA 11.3, so only enable that from CUDA 11.5:
157+ // https://github.com/pytorch/pytorch/pull/74234#issuecomment-1100932209
158+ #if AT_USE_JITERATOR() && CUDA_VERSION >= 11050
150159 AT_DISPATCH_COMPLEX_TYPES (dtype, " addcdiv_cuda" , [&]() {
151160 auto alpha = value.to <scalar_t >();
152161 static const auto addcdiv_string =
0 commit comments