Skip to content

Commit 3a63a93

Browse files
mrshenlifacebook-github-bot
authored andcommitted
Revert D22517785: [pytorch][PR] Enable TF32 support for cuBLAS
Test Plan: revert-hammer Differential Revision: D22517785 (288ece8) Original commit changeset: 87334c893561 fbshipit-source-id: 0a0674f49c1bcfc98f7f88af5a8c7de93b76e458
1 parent 8548a21 commit 3a63a93

File tree

12 files changed

+25
-248
lines changed

12 files changed

+25
-248
lines changed

aten/src/ATen/Context.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,6 @@ void Context::setBenchmarkCuDNN(bool b) {
8686
benchmark_cudnn = b;
8787
}
8888

89-
bool Context::allowTF32CuBLAS() const {
90-
return allow_tf32_cublas;
91-
}
92-
93-
void Context::setAllowTF32CuBLAS(bool b) {
94-
allow_tf32_cublas = b;
95-
}
96-
9789
bool Context::hasMKL() const {
9890
#if AT_MKL_ENABLED()
9991
return true;

aten/src/ATen/Context.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,6 @@ class CAFFE2_API Context {
109109
bool deterministic() const;
110110
void setDeterministic(bool);
111111
void alertNotDeterministic(c10::string_view const& caller);
112-
bool allowTF32CuBLAS() const;
113-
void setAllowTF32CuBLAS(bool);
114112
at::QEngine qEngine() const;
115113
void setQEngine(at::QEngine e);
116114
const std::vector<at::QEngine>& supportedQEngines() const;
@@ -138,7 +136,6 @@ class CAFFE2_API Context {
138136
bool deterministic_cudnn = false;
139137
bool _deterministic = false;
140138
bool benchmark_cudnn = false;
141-
bool allow_tf32_cublas = true;
142139
bool enabled_mkldnn = true;
143140
#ifdef C10_MOBILE
144141
bool release_original_weights = true;

aten/src/ATen/cuda/CUDABlas.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -233,11 +233,7 @@ void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
233233
#else
234234
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
235235
if (prop->major >= 5) {
236-
#if defined(CUDA_VERSION) && CUDA_VERSION < 11000
237-
// On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH
238-
// manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required.
239236
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
240-
#endif // CUDA_VERSION < 11000
241237
TORCH_CUDABLAS_CHECK(cublasGemmEx(
242238
handle,
243239
opa,
@@ -258,11 +254,7 @@ void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
258254
ldc,
259255
CUDA_R_32F,
260256
CUBLAS_GEMM_DFALT_TENSOR_OP));
261-
#if defined(CUDA_VERSION) && CUDA_VERSION < 11000
262-
// On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH
263-
// manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required.
264257
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
265-
#endif // CUDA_VERSION < 11000
266258
} else {
267259
TORCH_CUDABLAS_CHECK(cublasSgemmEx(
268260
handle,

aten/src/ATen/cuda/CublasHandlePool.cpp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,6 @@ cublasHandle_t getCurrentCUDABlasHandle() {
4141
auto handle = myPoolWindow->reserve(device);
4242
auto stream = c10::cuda::getCurrentCUDAStream();
4343
TORCH_CUDABLAS_CHECK(cublasSetStream(handle, stream));
44-
#if CUDA_VERSION >= 11000
45-
// On CUDA >= 11, and architecture >= Ampere, cuBLAS can use TF32 to speedup
46-
// FP32 data type calculations based on the value of the allow_tf32 flag.
47-
// To enable TF32, set the math mode of the handle to CUBLAS_TF32_TENSOR_OP_MATH.
48-
if (at::globalContext().allowTF32CuBLAS()) {
49-
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH));
50-
} else {
51-
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
52-
}
53-
#endif
5444
return handle;
5545
}
5646

aten/src/ATen/native/cuda/MiscUtils.h

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,10 @@ struct MAGMAQueue {
2525
// Constructor
2626
explicit MAGMAQueue(int64_t device_id) {
2727
auto& context = at::globalContext();
28-
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
29-
#if CUDA_VERSION >= 11000
30-
// Magma operations is numerically sensitive, so TF32 should be off
31-
// regardless of the global flag.
32-
TORCH_CUDABLAS_CHECK(cublasGetMathMode(handle, &original_math_mode));
33-
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
34-
#endif
3528
magma_queue_create_from_cuda(
3629
device_id,
3730
at::cuda::getCurrentCUDAStream(),
38-
handle,
31+
at::cuda::getCurrentCUDABlasHandle(),
3932
at::cuda::getCurrentCUDASparseHandle(),
4033
&magma_queue_);
4134
}
@@ -45,20 +38,11 @@ struct MAGMAQueue {
4538

4639
// Destructor
4740
~MAGMAQueue() {
48-
#if CUDA_VERSION >= 11000
49-
// We've manually set the math mode to CUBLAS_DEFAULT_MATH, now we
50-
// should restore the original math mode back
51-
cublasHandle_t handle = magma_queue_get_cublas_handle(magma_queue_);
52-
cublasSetMathMode(handle, original_math_mode);
53-
#endif
5441
magma_queue_destroy(magma_queue_);
5542
}
5643

5744
private:
5845
magma_queue_t magma_queue_;
59-
#if CUDA_VERSION >= 11000
60-
cublasMath_t original_math_mode;
61-
#endif
6246
};
6347

6448
static inline magma_int_t magma_int_cast(int64_t value, const char* varname) {

aten/src/THC/THCBlas.cu

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -185,22 +185,14 @@ void THCudaBlas_HgemmStridedBatched(THCState *state, char transa, char transb, i
185185
(int) batchCount, rocblas_datatype_f32_r, rocblas_gemm_algo_standard,
186186
0, 0));
187187
#else
188-
#if defined(CUDA_VERSION) && CUDA_VERSION < 11000
189-
// On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH
190-
// manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required.
191188
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
192-
#endif // CUDA_VERSION < 11000
193189
THCublasCheck(cublasGemmStridedBatchedEx(handle,
194190
opa, opb, (int)m, (int)n, (int)k,
195191
(void*)&fAlpha, a, CUDA_R_16F, (int)lda, strideA,
196192
b, CUDA_R_16F, (int)ldb, strideB,
197193
(void*)&fBeta, c, CUDA_R_16F, (int)ldc, strideC,
198194
(int)batchCount, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
199-
#if defined(CUDA_VERSION) && CUDA_VERSION < 11000
200-
// On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH
201-
// manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required.
202195
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
203-
#endif // CUDA_VERSION < 11000
204196
#endif // __HIP_PLATFORM_HCC__
205197
}
206198
#endif // CUDA_VERSION or __HIP_PLATFORM_HCC__

docs/source/notes/cuda.rst

Lines changed: 0 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -54,71 +54,6 @@ Below you can find a small example showcasing this::
5454
f = torch.randn(2).cuda(cuda2)
5555
# d.device, e.device, and f.device are all device(type='cuda', index=2)
5656

57-
.. _tf32_on_ampere:
58-
59-
TensorFloat-32(TF32) on Ampere devices
60-
--------------------------------------
61-
62-
Starting in PyTorch 1.7, there is a new flag called `allow_tf32` which defaults to true.
63-
This flag controls whether PyTorch is allowed to use the TensorFloat32 (TF32) tensor cores,
64-
available on new NVIDIA GPUs since Ampere, internally to compute matmul (matrix multiplies
65-
and batched matrix multiplies) and convolutions.
66-
67-
TF32 tensor cores are designed to achieve better performance on matmul and convolutions on
68-
`torch.float32` tensors by truncating input data to have 10 bits of mantissa, and accumulating
69-
results with FP32 precision, maintaining FP32 dynamic range.
70-
71-
matmul and convolutions are controlled separately, and their corresponding flag can be accessed at:
72-
73-
.. code:: python
74-
75-
# The flag below controls whether to allow TF32 on matmul. This flag defaults to True.
76-
torch.backends.cuda.matmul.allow_tf32 = True
77-
78-
# The allow_tf32 flag for convolutions is not implemented yet
79-
80-
To get an idea of the precision and speed, see the example code below:
81-
82-
.. code:: python
83-
84-
a_full = torch.randn(10240, 10240, dtype=torch.double, device='cuda')
85-
b_full = torch.randn(10240, 10240, dtype=torch.double, device='cuda')
86-
ab_full = a_full @ b_full
87-
mean = ab_full.abs().mean() # 80.7277
88-
89-
a = a_full.float()
90-
b = b_full.float()
91-
92-
# Do matmul at TF32 mode.
93-
ab_tf32 = a @ b # takes 0.016s on GA100
94-
error = (ab_tf32 - ab_full).abs().max() # 0.1747
95-
relative_error = error / mean # 0.0022
96-
97-
# Do matmul with TF32 disabled.
98-
torch.backends.cuda.matmul.allow_tf32 = False
99-
ab_fp32 = a @ b # takes 0.11s on GA100
100-
error = (ab_fp32 - ab_full).abs().max() # 0.0031
101-
relative_error = error / mean # 0.000039
102-
103-
From the above example, we can see that with TF32 enabled, the speed is ~7x faster, relative error
104-
compared to double precision is approximately 2 orders of magnitude larger. If the full FP32 precision
105-
is needed, users can disable TF32 by:
106-
107-
.. code:: python
108-
109-
torch.backends.cuda.matmul.allow_tf32 = False
110-
# disabling of TF32 for cuDNN is not implemented yet
111-
112-
For more information about TF32, see:
113-
114-
- `TensorFloat-32`_
115-
- `CUDA 11`_
116-
- `Ampere architecture`_
117-
118-
.. _TensorFloat-32: https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/
119-
.. _CUDA 11: https://devblogs.nvidia.com/cuda-11-features-revealed/
120-
.. _Ampere architecture: https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/
121-
12257
Asynchronous execution
12358
----------------------
12459

test/test_cuda.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -529,13 +529,6 @@ def test_serialization_array_with_storage(self):
529529
q_copy[1].fill_(10)
530530
self.assertTrue(q_copy[3], torch.cuda.IntStorage(10).fill_(10))
531531

532-
def test_allow_tf32_get_set(self):
533-
orig = torch.backends.cuda.matmul.allow_tf32
534-
self.assertEqual(torch._C._get_cublas_allow_tf32(), orig)
535-
torch.backends.cuda.matmul.allow_tf32 = not orig
536-
self.assertEqual(torch._C._get_cublas_allow_tf32(), not orig)
537-
torch.backends.cuda.matmul.allow_tf32 = orig
538-
539532
def test_type_conversions(self):
540533
x = torch.randn(5, 5)
541534
self.assertIsInstance(x.float(), torch.FloatTensor)

test/test_torch.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
from typing import Dict, List, Tuple, Union
3838
import torch.backends.quantized
3939
import torch.testing._internal.data
40-
from torch.testing._internal.common_cuda import tf32_on_and_off
4140

4241

4342
# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
@@ -8441,7 +8440,6 @@ def dims_full_for_fn():
84418440
r1 = fntorch(t0_full, t1, t2)
84428441
self.assertEqual(r0, r1)
84438442

8444-
@tf32_on_and_off(0.001)
84458443
def test_broadcast_batched_matmul(self, device):
84468444
n_dim = random.randint(1, 8)
84478445
m_dim = random.randint(1, 8)
@@ -10431,7 +10429,6 @@ def check_norm(a, b, expected_norm, gels_result):
1043110429

1043210430
@skipCUDAIfNoMagma
1043310431
@skipCPUIfNoLapack
10434-
@tf32_on_and_off(0.001)
1043510432
def test_qr(self, device):
1043610433
def run_test(tensor_dims, some):
1043710434
A = torch.randn(*tensor_dims, device=device)
@@ -11511,7 +11508,6 @@ def test_cdist_norm_batch(self, device):
1151111508
expected = self._brute_cdist(x, y, p=p)
1151211509
self.assertEqual(expected, actual)
1151311510

11514-
@tf32_on_and_off(0.005)
1151511511
def test_cdist_large(self, device):
1151611512
for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
1151711513
x = torch.randn(1000, 10, device=device)
@@ -11521,7 +11517,6 @@ def test_cdist_large(self, device):
1152111517
self.assertEqual(expected, actual)
1152211518

1152311519
@slowTest
11524-
@tf32_on_and_off(0.01)
1152511520
def test_cdist_large_batch(self, device):
1152611521
for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
1152711522
x = torch.randn(4, 3, 1000, 10, device=device)
@@ -11530,7 +11525,6 @@ def test_cdist_large_batch(self, device):
1153011525
expected = self._brute_cdist(x, y, p=2)
1153111526
self.assertEqual(expected, actual)
1153211527

11533-
@tf32_on_and_off(0.005)
1153411528
def test_cdist_non_contiguous(self, device):
1153511529
for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
1153611530
x = torch.randn(5, 7, device=device).transpose(-1, -2)
@@ -11557,7 +11551,6 @@ def test_cdist_non_contiguous(self, device):
1155711551
self.assertTrue(y.is_contiguous())
1155811552
self.assertEqual(expected, actual)
1155911553

11560-
@tf32_on_and_off()
1156111554
def test_cdist_non_contiguous_batch(self, device):
1156211555
for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
1156311556
x = torch.randn(4, 3, 2, 5, 7, device=device).transpose(-1, -2)
@@ -12394,7 +12387,6 @@ def test_empty_tensor_props(self, device):
1239412387
self.assertEqual(x.stride(), y.stride())
1239512388

1239612389
@unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
12397-
@tf32_on_and_off(0.005)
1239812390
def test_tensordot(self, device):
1239912391
a = torch.arange(60., device=device).reshape(3, 4, 5)
1240012392
b = torch.arange(24., device=device).reshape(4, 3, 2)
@@ -16478,7 +16470,6 @@ def test_addmm(self, device):
1647816470
@dtypes(torch.float, torch.double)
1647916471
@dtypesIfCUDA(*([torch.float, torch.double] +
1648016472
([] if TEST_WITH_ROCM else torch.testing.get_all_complex_dtypes())))
16481-
@tf32_on_and_off(0.005)
1648216473
def test_addmm_sizes(self, device, dtype):
1648316474
for m in [0, 1, 25]:
1648416475
for n in [0, 1, 10]:
@@ -16928,7 +16919,6 @@ def test_remainder_edge_cases(self, device, dtype):
1692816919
@onlyOnCPUAndCUDA
1692916920
@dtypes(torch.float32, torch.float64, torch.bfloat16, torch.int32, torch.int64, torch.cfloat, torch.cdouble)
1693016921
@dtypesIfCUDA(torch.float32, torch.float64)
16931-
@tf32_on_and_off(0.01)
1693216922
def test_mm(self, device, dtype):
1693316923
def _test_mm(n, m, p, dtype, genf):
1693416924
# helper function
@@ -17974,7 +17964,6 @@ def test_pickle_gradscaler(self, device):
1797417964
self.assertEqual(b.scale(torch.tensor([4.0], dtype=torch.float32, device=device)), 12.0)
1797517965

1797617966
@onlyCUDA
17977-
@tf32_on_and_off(0.005)
1797817967
def test_mv_stride_0(self, device):
1797917968
# Reference: https://github.com/pytorch/pytorch/issues/38315
1798017969
mat = torch.randn(2, 2, device=device)
@@ -18930,6 +18919,8 @@ def test_split_view(self, device):
1893018919

1893118920
_float_types_no_half = [torch.float, torch.double]
1893218921

18922+
_complex_types = [torch.cfloat, torch.cdouble]
18923+
1893318924
# _float_types2 adds bfloat16 type to _float_types only on ROCm. Should eventually be unified
1893418925
# with _float_types when bfloat16 bringup is complete on all platforms
1893518926
_float_types2 = _float_types + [torch.bfloat16] if TEST_WITH_ROCM else _float_types
@@ -19104,13 +19095,13 @@ def inner(self, device, dtype):
1910419095
('pow', 'tensor', _small_3d, lambda t, d: [_small_3d(t, d).abs()],
1910519096
1e-1, 1e-1, 1e-5, _float_types2),
1910619097
('addbmm', '', _small_2d, lambda t, d: [_small_3d(t, d), _small_3d(t, d)],
19107-
1e-1, 1e-1, 1e-4, _float_types2, _cpu_types, True, [tf32_on_and_off(0.005)]),
19098+
1e-1, 1e-1, 1e-4, _float_types2),
1910819099
('addbmm', 'scalar', _small_2d, lambda t, d: [_number(0.4, 2, t), _small_3d(t, d), _small_3d(t, d)],
1910919100
1e-1, 1e-1, 1e-4, _float_types2, _cpu_types, True,
19110-
[tf32_on_and_off(0.005), _wrap_maybe_warns("This overload of addbmm_? is deprecated")]),
19101+
[_wrap_maybe_warns("This overload of addbmm_? is deprecated")]),
1911119102
('addbmm', 'two_scalars', _small_2d, lambda t, d: [_number(0.5, 3, t), _number(0.4, 2, t), _small_3d(t, d), _small_3d(t, d)],
1911219103
1e-1, 1e-1, 1e-4, _float_types2, _cpu_types, True,
19113-
[tf32_on_and_off(0.005), _wrap_maybe_warns("This overload of addbmm_? is deprecated")]),
19104+
[_wrap_maybe_warns("This overload of addbmm_? is deprecated")]),
1911419105
('baddbmm', '', _small_3d, lambda t, d: [_small_3d(t, d), _small_3d(t, d)],
1911519106
1e-2, 1e-1, 1e-4, _float_types2),
1911619107
('baddbmm', 'scalar', _small_3d, lambda t, d: [_number(0.4, 2, t), _small_3d(t, d), _small_3d(t, d)],
@@ -19135,26 +19126,25 @@ def inner(self, device, dtype):
1913519126
1e-1, 1e-5, _types2, _cpu_types, True,
1913619127
[_wrap_maybe_warns("This overload of addcmul_? is deprecated")]),
1913719128
('addmm', '', _medium_2d, lambda t, d: [_medium_2d(t, d), _medium_2d(t, d)],
19138-
1e-1, 1e-1, 1e-4, _float_types2, _cpu_types, True, [tf32_on_and_off(0.005)]),
19129+
1e-1, 1e-1, 1e-4, _float_types2),
1913919130
('addmm', 'scalar', _medium_2d,
1914019131
lambda t, d: [_number(0.4, 2, t), _medium_2d(t, d), _medium_2d(t, d)],
1914119132
1e-1, 1e-1, 1e-4, _float_types2, _cpu_types, True,
19142-
[tf32_on_and_off(0.005), _wrap_maybe_warns("This overload of addmm_? is deprecated")]),
19133+
[_wrap_maybe_warns("This overload of addmm_? is deprecated")]),
1914319134
('addmm', 'two_scalars', _medium_2d,
1914419135
lambda t, d: [_number(0.5, 3, t), _number(0.4, 2, t), _medium_2d(t, d), _medium_2d(t, d)],
1914519136
1e-1, 1e-1, 1e-4, _float_types2, _cpu_types, True,
19146-
[tf32_on_and_off(0.005), _wrap_maybe_warns("This overload of addmm_? is deprecated")]),
19137+
[_wrap_maybe_warns("This overload of addmm_? is deprecated")]),
1914719138
('addmv', '', _medium_1d, lambda t, d: [_medium_2d(t, d), _medium_1d(t, d)],
19148-
1e-2, 1e-1, 1e-4, _float_types2 + _complex_types_skip_rocm, _cpu_types,
19149-
True, [tf32_on_and_off(0.005)]),
19139+
1e-2, 1e-1, 1e-4, _float_types2 + _complex_types_skip_rocm),
1915019140
('addmv', 'scalar', _medium_1d,
1915119141
lambda t, d: [_number(0.4, 2, t), _medium_2d(t, d), _medium_1d(t, d)],
1915219142
1e-2, 1e-1, 1e-4, _float_types2 + _complex_types_skip_rocm, _cpu_types, True,
19153-
[tf32_on_and_off(0.005), _wrap_maybe_warns("This overload of addmv_? is deprecated")]),
19143+
[_wrap_maybe_warns("This overload of addmv_? is deprecated")]),
1915419144
('addmv', 'two_scalars', _medium_1d,
1915519145
lambda t, d: [_number(0.5, 3, t), _number(0.4, 2, t), _medium_2d(t, d), _medium_1d(t, d)],
1915619146
1e-2, 1e-1, 1e-4, _float_types2 + _complex_types_skip_rocm, _cpu_types, True,
19157-
[tf32_on_and_off(0.005), _wrap_maybe_warns("This overload of addmv_? is deprecated")]),
19147+
[_wrap_maybe_warns("This overload of addmv_? is deprecated")]),
1915819148
('addr', '', _medium_2d, lambda t, d: [_medium_1d(t, d), _medium_1d(t, d)],
1915919149
1e-2, 1e-1, 1e-4, _float_types2),
1916019150
('addr', 'scalar', _medium_2d,

torch/backends/cuda/__init__.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -83,15 +83,16 @@ def __setattr__(self, name, value):
8383
return super(cuFFTPlanCacheManager, self).__setattr__(name, value)
8484

8585

86-
class cuBLASModule:
87-
def __getattr__(self, name):
88-
assert name == "allow_tf32", "Unknown attribute " + name
89-
return torch._C._get_cublas_allow_tf32()
90-
91-
def __setattr__(self, name, value):
92-
assert name == "allow_tf32", "Unknown attribute " + name
93-
return torch._C._set_cublas_allow_tf32(value)
94-
95-
96-
cufft_plan_cache = cuFFTPlanCacheManager()
97-
matmul = cuBLASModule()
86+
class CUDAModule(object):
87+
def __init__(self, m):
88+
self.__dict__ = m.__dict__
89+
# You have to retain the old module, otherwise it will
90+
# get GC'ed and a lot of things will break. See:
91+
# https://stackoverflow.com/questions/47540722/how-do-i-use-the-sys-modules-replacement-trick-in-init-py-on-python-2
92+
self.__old_mod = m
93+
94+
cufft_plan_cache = cuFFTPlanCacheManager()
95+
96+
# This is the sys.modules replacement trick, see
97+
# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
98+
sys.modules[__name__] = CUDAModule(sys.modules[__name__])

0 commit comments

Comments
 (0)