Skip to content

Commit 524574a

Browse files
lhuang04facebook-github-bot
authored andcommitted
Define THPStorage struct only once (rather than N times) (#14802)
Summary: Pull Request resolved: #14802 The definetion of THPStorage does not depend on any Real, its macro defintion is unnecessary, refactor the code so that THPStorage is not macro defined. Reviewed By: ezyang Differential Revision: D13340445 fbshipit-source-id: 343393d0a36c868b9a06eea2ad9b80f5e395e947
1 parent ca6311d commit 524574a

File tree

10 files changed

+39
-43
lines changed

10 files changed

+39
-43
lines changed

torch/csrc/Storage.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#ifndef THP_STORAGE_INC
22
#define THP_STORAGE_INC
33

4-
#define THPStorage TH_CONCAT_3(THP,Real,Storage)
54
#define THPStorageStr TH_CONCAT_STRING_3(torch.,Real,Storage)
65
#define THPStorageClass TH_CONCAT_3(THP,Real,StorageClass)
76
#define THPStorage_(NAME) TH_CONCAT_4(THP,Real,Storage_,NAME)

torch/csrc/StorageDefs.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#pragma once
2+
struct THPStorage {
3+
PyObject_HEAD
4+
THWStorage *cdata;
5+
};

torch/csrc/cuda/Storage.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#ifndef THCP_STORAGE_INC
22
#define THCP_STORAGE_INC
33

4-
#define THCPStorage TH_CONCAT_3(THCP,Real,Storage)
54
#define THCPStorageStr TH_CONCAT_STRING_3(torch.cuda.,Real,Storage)
65
#define THCPStorageClass TH_CONCAT_3(THCP,Real,StorageClass)
76
#define THCPStorage_(NAME) TH_CONCAT_4(THCP,Real,Storage_,NAME)

torch/csrc/cuda/override_macros.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
#define THWTensor_(NAME) THCTensor_(NAME)
1212

1313
#define THPStorage_(NAME) TH_CONCAT_4(THCP,Real,Storage_,NAME)
14-
#define THPStorage THCPStorage
1514
#define THPStorageBaseStr THCPStorageBaseStr
1615
#define THPStorageStr THCPStorageStr
1716
#define THPStorageClass THCPStorageClass

torch/csrc/cuda/restore_macros.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
#define THPTensorClass TH_CONCAT_3(THP,Real,TensorClass)
88
#define THPTensor_(NAME) TH_CONCAT_4(THP,Real,Tensor_,NAME)
99

10-
#define THPStorage TH_CONCAT_3(THP,Real,Storage)
1110
#define THPStorageStr TH_CONCAT_STRING_3(torch.,Real,Storage)
1211
#define THPStorageClass TH_CONCAT_3(THP,Real,StorageClass)
1312
#define THPStorage_(NAME) TH_CONCAT_4(THP,Real,Storage_,NAME)

torch/csrc/cuda/undef_macros.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
#undef THPTensorType
1717

1818
#undef THPStorage_
19-
#undef THPStorage
2019
#undef THPStorageBaseStr
2120
#undef THPStorageStr
2221
#undef THPStorageClass

torch/csrc/generic/Storage.cpp

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -291,37 +291,37 @@ void THPStorage_(initCopyMethods)()
291291
#ifndef THD_GENERIC_FILE
292292
auto& h = THWStorage_(copy_functions);
293293
// copy from CPU types
294-
THPInsertStorageCopyFunction<THPStorage, THPByteStorage>(&THPByteStorageType, h, &THWStorage_(copyByte));
295-
THPInsertStorageCopyFunction<THPStorage, THPCharStorage>(&THPCharStorageType, h, &THWStorage_(copyChar));
296-
THPInsertStorageCopyFunction<THPStorage, THPShortStorage>(&THPShortStorageType, h, &THWStorage_(copyShort));
297-
THPInsertStorageCopyFunction<THPStorage, THPIntStorage>(&THPIntStorageType, h, &THWStorage_(copyInt));
298-
THPInsertStorageCopyFunction<THPStorage, THPLongStorage>(&THPLongStorageType, h, &THWStorage_(copyLong));
299-
THPInsertStorageCopyFunction<THPStorage, THPHalfStorage>(&THPHalfStorageType, h, &THWStorage_(copyHalf));
300-
THPInsertStorageCopyFunction<THPStorage, THPFloatStorage>(&THPFloatStorageType, h, &THWStorage_(copyFloat));
301-
THPInsertStorageCopyFunction<THPStorage, THPDoubleStorage>(&THPDoubleStorageType, h, &THWStorage_(copyDouble));
294+
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPByteStorageType, h, &THWStorage_(copyByte));
295+
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPCharStorageType, h, &THWStorage_(copyChar));
296+
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPShortStorageType, h, &THWStorage_(copyShort));
297+
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPIntStorageType, h, &THWStorage_(copyInt));
298+
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPLongStorageType, h, &THWStorage_(copyLong));
299+
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPHalfStorageType, h, &THWStorage_(copyHalf));
300+
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPFloatStorageType, h, &THWStorage_(copyFloat));
301+
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPDoubleStorageType, h, &THWStorage_(copyDouble));
302302
#ifdef THC_GENERIC_FILE
303303
// copy from GPU types
304-
THPInsertStorageCopyFunction<THPStorage, THCPByteStorage>(&THCPByteStorageType, h, &THWStorage_(copyCudaByte));
305-
THPInsertStorageCopyFunction<THPStorage, THCPCharStorage>(&THCPCharStorageType, h, &THWStorage_(copyCudaChar));
306-
THPInsertStorageCopyFunction<THPStorage, THCPShortStorage>(&THCPShortStorageType, h, &THWStorage_(copyCudaShort));
307-
THPInsertStorageCopyFunction<THPStorage, THCPIntStorage>(&THCPIntStorageType, h, &THWStorage_(copyCudaInt));
308-
THPInsertStorageCopyFunction<THPStorage, THCPLongStorage>(&THCPLongStorageType, h, &THWStorage_(copyCudaLong));
309-
THPInsertStorageCopyFunction<THPStorage, THCPFloatStorage>(&THCPFloatStorageType, h, &THWStorage_(copyCudaFloat));
310-
THPInsertStorageCopyFunction<THPStorage, THCPDoubleStorage>(&THCPDoubleStorageType, h, &THWStorage_(copyCudaDouble));
311-
THPInsertStorageCopyFunction<THPStorage, THCPHalfStorage>(&THCPHalfStorageType, h, &THWStorage_(copyCudaHalf));
304+
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPByteStorageType, h, &THWStorage_(copyCudaByte));
305+
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPCharStorageType, h, &THWStorage_(copyCudaChar));
306+
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPShortStorageType, h, &THWStorage_(copyCudaShort));
307+
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPIntStorageType, h, &THWStorage_(copyCudaInt));
308+
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPLongStorageType, h, &THWStorage_(copyCudaLong));
309+
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPFloatStorageType, h, &THWStorage_(copyCudaFloat));
310+
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPDoubleStorageType, h, &THWStorage_(copyCudaDouble));
311+
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPHalfStorageType, h, &THWStorage_(copyCudaHalf));
312312
// add CPU <- GPU copies to base type
313-
#define THPCpuStorage TH_CONCAT_3(THP, Real, Storage)
313+
/// #define THPCpuStorage TH_CONCAT_3(THP, Real, Storage)
314314
#define THCpuStorage_(name) TH_CONCAT_4(TH, Real, Storage_, name)
315315
extern THPCopyList THCpuStorage_(copy_functions);
316316
auto& b = THCpuStorage_(copy_functions);
317-
THPInsertStorageCopyFunction<THPCpuStorage, THCPByteStorage>(&THCPByteStorageType, b, &THCpuStorage_(copyCudaByte));
318-
THPInsertStorageCopyFunction<THPCpuStorage, THCPCharStorage>(&THCPCharStorageType, b, &THCpuStorage_(copyCudaChar));
319-
THPInsertStorageCopyFunction<THPCpuStorage, THCPShortStorage>(&THCPShortStorageType, b, &THCpuStorage_(copyCudaShort));
320-
THPInsertStorageCopyFunction<THPCpuStorage, THCPIntStorage>(&THCPIntStorageType, b, &THCpuStorage_(copyCudaInt));
321-
THPInsertStorageCopyFunction<THPCpuStorage, THCPLongStorage>(&THCPLongStorageType, b, &THCpuStorage_(copyCudaLong));
322-
THPInsertStorageCopyFunction<THPCpuStorage, THCPFloatStorage>(&THCPFloatStorageType, b, &THCpuStorage_(copyCudaFloat));
323-
THPInsertStorageCopyFunction<THPCpuStorage, THCPDoubleStorage>(&THCPDoubleStorageType, b, &THCpuStorage_(copyCudaDouble));
324-
THPInsertStorageCopyFunction<THPCpuStorage, THCPHalfStorage>(&THCPHalfStorageType, b, &THCpuStorage_(copyCudaHalf));
317+
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPByteStorageType, b, &THCpuStorage_(copyCudaByte));
318+
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPCharStorageType, b, &THCpuStorage_(copyCudaChar));
319+
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPShortStorageType, b, &THCpuStorage_(copyCudaShort));
320+
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPIntStorageType, b, &THCpuStorage_(copyCudaInt));
321+
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPLongStorageType, b, &THCpuStorage_(copyCudaLong));
322+
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPFloatStorageType, b, &THCpuStorage_(copyCudaFloat));
323+
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPDoubleStorageType, b, &THCpuStorage_(copyCudaDouble));
324+
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPHalfStorageType, b, &THCpuStorage_(copyCudaHalf));
325325
#undef THCpuStorage
326326
#undef THCpuStorage_
327327
#endif

torch/csrc/generic/Storage.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,7 @@
22
#define TH_GENERIC_FILE "generic/Storage.h"
33
#else
44

5-
struct THPStorage {
6-
PyObject_HEAD
7-
THWStorage *cdata;
8-
};
5+
#include "torch/csrc/StorageDefs.h"
96

107
THP_API PyObject * THPStorage_(New)(THWStorage *ptr);
118
extern PyObject *THPStorageClass;

torch/csrc/generic/utils.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,6 @@
77
#else
88
#define GENERATE_SPARSE 1
99
#endif
10-
11-
template<>
12-
void THPPointer<THPStorage>::free() {
13-
if (ptr)
14-
Py_DECREF(ptr);
15-
}
16-
17-
template class THPPointer<THPStorage>;
18-
1910
#undef GENERATE_SPARSE
2011

2112
#endif

torch/csrc/utils.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,3 +234,11 @@ void THPPointer<THTensor>::free() {
234234
THTensor_free(LIBRARY_STATE ptr);
235235
}
236236
}
237+
238+
template<>
239+
void THPPointer<THPStorage>::free() {
240+
if (ptr)
241+
Py_DECREF(ptr);
242+
}
243+
244+
template class THPPointer<THPStorage>;

0 commit comments

Comments
 (0)