Skip to content

Commit fb23e62

Browse files
gchananezyang
authored andcommitted
Remove templatization of PyTypeObject in THP copy storage methods. (#7811)
* Remove templatization of PyTypeObject in THP copy storage methods. An in-progress refactoring of THStorage is collapsing the types of THStorages to not be ScalarType-specific. The revelant PyTypeObject to use for the THPStorageType is currently templatized based on the current THStorage; this doesn't work if the ScalarType is collapsed. Instead, just pass it explicitly. * Pass src type instead of dst type. * Line up columns.
1 parent 8b85b8a commit fb23e62

File tree

2 files changed

+25
-25
lines changed

2 files changed

+25
-25
lines changed

torch/csrc/copy_utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ inline PyObject * THPStorageCopyMethod(const THPCopyList& v, PyObject *self, PyO
5656

5757
template <typename StorageDst, typename StorageSrc>
5858
void THPInsertStorageCopyFunction(
59+
PyTypeObject *srcType,
5960
THPCopyList& copyList,
6061
void (*copyFunc)(LIBRARY_STATE_TYPE StorageDst* x, StorageSrc* z),
6162
bool non_blocking=false)
@@ -77,6 +78,5 @@ void THPInsertStorageCopyFunction(
7778
}
7879
};
7980

80-
PyTypeObject* srcType = THPTypeInfo<StorageSrc>::pyType();
8181
copyList.push_back({ srcType, wrapper, non_blocking, false });
8282
}

torch/csrc/generic/Storage.cpp

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -327,39 +327,39 @@ void THPStorage_(initCopyMethods)()
327327
#ifndef THD_GENERIC_FILE
328328
auto& h = THStorage_(copy_functions);
329329
// copy from CPU types
330-
THPInsertStorageCopyFunction(h, &THStorage_(copyByte));
331-
THPInsertStorageCopyFunction(h, &THStorage_(copyChar));
332-
THPInsertStorageCopyFunction(h, &THStorage_(copyShort));
333-
THPInsertStorageCopyFunction(h, &THStorage_(copyInt));
334-
THPInsertStorageCopyFunction(h, &THStorage_(copyLong));
335-
THPInsertStorageCopyFunction(h, &THStorage_(copyHalf));
336-
THPInsertStorageCopyFunction(h, &THStorage_(copyFloat));
337-
THPInsertStorageCopyFunction(h, &THStorage_(copyDouble));
330+
THPInsertStorageCopyFunction(&THPByteStorageType, h, &THStorage_(copyByte));
331+
THPInsertStorageCopyFunction(&THPCharStorageType, h, &THStorage_(copyChar));
332+
THPInsertStorageCopyFunction(&THPShortStorageType, h, &THStorage_(copyShort));
333+
THPInsertStorageCopyFunction(&THPIntStorageType, h, &THStorage_(copyInt));
334+
THPInsertStorageCopyFunction(&THPLongStorageType, h, &THStorage_(copyLong));
335+
THPInsertStorageCopyFunction(&THPHalfStorageType, h, &THStorage_(copyHalf));
336+
THPInsertStorageCopyFunction(&THPFloatStorageType, h, &THStorage_(copyFloat));
337+
THPInsertStorageCopyFunction(&THPDoubleStorageType, h, &THStorage_(copyDouble));
338338
#ifdef THC_GENERIC_FILE
339339
// copy from GPU types
340-
THPInsertStorageCopyFunction(h, &THStorage_(copyCudaByte));
341-
THPInsertStorageCopyFunction(h, &THStorage_(copyCudaChar));
342-
THPInsertStorageCopyFunction(h, &THStorage_(copyCudaShort));
343-
THPInsertStorageCopyFunction(h, &THStorage_(copyCudaInt));
344-
THPInsertStorageCopyFunction(h, &THStorage_(copyCudaLong));
345-
THPInsertStorageCopyFunction(h, &THStorage_(copyCudaFloat));
346-
THPInsertStorageCopyFunction(h, &THStorage_(copyCudaDouble));
340+
THPInsertStorageCopyFunction(&THCPByteStorageType, h, &THStorage_(copyCudaByte));
341+
THPInsertStorageCopyFunction(&THCPCharStorageType, h, &THStorage_(copyCudaChar));
342+
THPInsertStorageCopyFunction(&THCPShortStorageType, h, &THStorage_(copyCudaShort));
343+
THPInsertStorageCopyFunction(&THCPIntStorageType, h, &THStorage_(copyCudaInt));
344+
THPInsertStorageCopyFunction(&THCPLongStorageType, h, &THStorage_(copyCudaLong));
345+
THPInsertStorageCopyFunction(&THCPFloatStorageType, h, &THStorage_(copyCudaFloat));
346+
THPInsertStorageCopyFunction(&THCPDoubleStorageType, h, &THStorage_(copyCudaDouble));
347347
#ifdef CUDA_HALF_TENSOR
348-
THPInsertStorageCopyFunction(h, &THStorage_(copyCudaHalf));
348+
THPInsertStorageCopyFunction(&THCPHalfStorageType, h, &THStorage_(copyCudaHalf));
349349
#endif
350350
// add CPU <- GPU copies to base type
351351
#define THCpuStorage_(name) TH_CONCAT_4(TH, Real, Storage_, name)
352352
extern THPCopyList THCpuStorage_(copy_functions);
353353
auto& b = THCpuStorage_(copy_functions);
354-
THPInsertStorageCopyFunction(b, &THCpuStorage_(copyCudaByte));
355-
THPInsertStorageCopyFunction(b, &THCpuStorage_(copyCudaChar));
356-
THPInsertStorageCopyFunction(b, &THCpuStorage_(copyCudaShort));
357-
THPInsertStorageCopyFunction(b, &THCpuStorage_(copyCudaInt));
358-
THPInsertStorageCopyFunction(b, &THCpuStorage_(copyCudaLong));
359-
THPInsertStorageCopyFunction(b, &THCpuStorage_(copyCudaFloat));
360-
THPInsertStorageCopyFunction(b, &THCpuStorage_(copyCudaDouble));
354+
THPInsertStorageCopyFunction(&THCPByteStorageType, b, &THCpuStorage_(copyCudaByte));
355+
THPInsertStorageCopyFunction(&THCPCharStorageType, b, &THCpuStorage_(copyCudaChar));
356+
THPInsertStorageCopyFunction(&THCPShortStorageType, b, &THCpuStorage_(copyCudaShort));
357+
THPInsertStorageCopyFunction(&THCPIntStorageType, b, &THCpuStorage_(copyCudaInt));
358+
THPInsertStorageCopyFunction(&THCPLongStorageType, b, &THCpuStorage_(copyCudaLong));
359+
THPInsertStorageCopyFunction(&THCPFloatStorageType, b, &THCpuStorage_(copyCudaFloat));
360+
THPInsertStorageCopyFunction(&THCPDoubleStorageType, b, &THCpuStorage_(copyCudaDouble));
361361
#ifdef CUDA_HALF_TENSOR
362-
THPInsertStorageCopyFunction(b, &THCpuStorage_(copyCudaHalf));
362+
THPInsertStorageCopyFunction(&THCPHalfStorageType, b, &THCpuStorage_(copyCudaHalf));
363363
#endif
364364
#undef THCpuStorage_
365365
#endif

0 commit comments

Comments
 (0)