Skip to content

Commit b36a041

Browse files
Roy Lifacebook-github-bot
authored andcommitted
Move UnsafeTensorFromTH and UnsafeStorageFromTH off Type (#21923)
Summary: Pull Request resolved: #21923 ghimport-source-id: f015c85 Test Plan: Imported from OSS Differential Revision: D15883390 Pulled By: li-roy fbshipit-source-id: 6a7a7ffbe6000199d41cdca5efb97371f46dd8fe
1 parent 5d7cf66 commit b36a041

File tree

13 files changed

+29
-45
lines changed

13 files changed

+29
-45
lines changed

aten/src/ATen/ATen.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,4 @@
2525
#include <c10/core/TensorOptions.h>
2626
#include <c10/util/Exception.h>
2727
#include <ATen/core/ATenDispatch.h>
28+
#include <ATen/core/UnsafeFromTH.h>

aten/src/ATen/UndefinedType.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,6 @@ Device UndefinedType::getDeviceFromPtr(void*) const {
1717
AT_ERROR("getDeviceFromPtr not defined for UndefinedType");
1818
}
1919

20-
Storage UndefinedType::unsafeStorageFromTH(void * th_pointer, bool retain) const {
21-
AT_ERROR("unsafeStorageFromTH not defined for UndefinedType");
22-
}
23-
Tensor UndefinedType::unsafeTensorFromTH(void * th_pointer, bool retain) const {
24-
AT_ERROR("unsafeTensorFromTH not defined for UndefinedType");
25-
}
26-
2720
const char * UndefinedType::toString() const {
2821
return "UndefinedType";
2922
}

aten/src/ATen/UndefinedType.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@ struct UndefinedType final : public TypeDefault {
2020
virtual Type & toBackend(Backend b) const override;
2121
virtual Type & toScalarType(ScalarType s) const override;
2222
virtual TypeID ID() const override;
23-
virtual Storage unsafeStorageFromTH(void * th_pointer, bool retain) const override;
24-
virtual Tensor unsafeTensorFromTH(void * th_pointer, bool retain) const override;
2523
};
2624

2725
} // namespace at

aten/src/ATen/core/DeprecatedTypeProperties.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,16 @@
22

33
#include <ATen/core/LegacyTypeDispatch.h>
44
#include <ATen/core/Tensor.h>
5+
#include <ATen/core/UnsafeFromTH.h>
56

67
namespace at {
78

89
Tensor DeprecatedTypeProperties::unsafeTensorFromTH(void * th_pointer, bool retain) const {
9-
return getDispatchType().unsafeTensorFromTH(th_pointer, retain);
10+
return at::unsafeTensorFromTH(th_pointer, retain);
1011
}
1112

1213
Storage DeprecatedTypeProperties::unsafeStorageFromTH(void * th_pointer, bool retain) const {
13-
return getDispatchType().unsafeStorageFromTH(th_pointer, retain);
14+
return at::unsafeStorageFromTH(th_pointer, retain);
1415
}
1516

1617
Tensor DeprecatedTypeProperties::copy(const Tensor & src, bool non_blocking, c10::optional<Device> to_device) const {

aten/src/ATen/core/Type.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,6 @@ struct CAFFE2_API Type {
8181
bool is_undefined() const noexcept { return is_undefined_; }
8282
virtual Allocator * allocator() const = 0;
8383
virtual Device getDeviceFromPtr(void * data) const = 0;
84-
virtual Tensor unsafeTensorFromTH(void * th_pointer, bool retain) const = 0;
85-
virtual Storage unsafeStorageFromTH(void * th_pointer, bool retain) const = 0;
8684
virtual const char * toString() const = 0;
8785
virtual Type & toBackend(Backend b) const = 0;
8886
virtual Type & toScalarType(ScalarType s) const = 0;

aten/src/ATen/core/UnsafeFromTH.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#include <ATen/core/Tensor.h>
2+
3+
namespace at {
4+
5+
inline Tensor unsafeTensorFromTH(void * th_pointer, bool retain) {
6+
auto tensor_impl = c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>::reclaim(static_cast<TensorImpl*>(th_pointer));
7+
if (retain && tensor_impl.get() != UndefinedTensorImpl::singleton()) {
8+
c10::raw::intrusive_ptr::incref(tensor_impl.get());
9+
}
10+
return Tensor(std::move(tensor_impl));
11+
}
12+
13+
inline Storage unsafeStorageFromTH(void * th_pointer, bool retain) {
14+
if (retain && th_pointer) {
15+
c10::raw::intrusive_ptr::incref(static_cast<StorageImpl*>(th_pointer));
16+
}
17+
return Storage(c10::intrusive_ptr<StorageImpl>::reclaim(static_cast<StorageImpl*>(th_pointer)));
18+
}
19+
20+
}

aten/src/ATen/templates/Type.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,6 @@ struct CAFFE2_API Type {
7474
bool is_undefined() const noexcept { return is_undefined_; }
7575
virtual Allocator * allocator() const = 0;
7676
virtual Device getDeviceFromPtr(void * data) const = 0;
77-
virtual Tensor unsafeTensorFromTH(void * th_pointer, bool retain) const = 0;
78-
virtual Storage unsafeStorageFromTH(void * th_pointer, bool retain) const = 0;
7977
virtual const char * toString() const = 0;
8078
virtual Type & toBackend(Backend b) const = 0;
8179
virtual Type & toScalarType(ScalarType s) const = 0;

aten/src/ATen/templates/TypeDefault.cpp

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,6 @@ Type & TypeDefault::toScalarType(ScalarType s) const {
3838
return at::globalContext().getNonVariableType(backend(),s);
3939
}
4040

41-
Tensor TypeDefault::unsafeTensorFromTH(void * th_pointer, bool retain) const {
42-
auto tensor_impl = c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>::reclaim(static_cast<TensorImpl*>(th_pointer));
43-
if (retain && tensor_impl.get() != UndefinedTensorImpl::singleton()) {
44-
c10::raw::intrusive_ptr::incref(tensor_impl.get());
45-
}
46-
return Tensor(std::move(tensor_impl));
47-
}
48-
Storage TypeDefault::unsafeStorageFromTH(void * th_pointer, bool retain) const {
49-
if (retain && th_pointer) {
50-
c10::raw::intrusive_ptr::incref(static_cast<StorageImpl*>(th_pointer));
51-
}
52-
return Storage(c10::intrusive_ptr<StorageImpl>::reclaim(static_cast<StorageImpl*>(th_pointer)));
53-
}
54-
5541
${type_method_definitions}
5642

5743
static auto& registerer = globalATenDispatch()

aten/src/ATen/templates/TypeDefault.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,6 @@ struct CAFFE2_API TypeDefault : public TypeExtendedInterface {
3838
bool create_graph) const override;
3939
void set_data(Tensor & self, Tensor new_data) const override;
4040

41-
Storage unsafeStorageFromTH(void * th_pointer, bool retain) const override;
42-
Tensor unsafeTensorFromTH(void * th_pointer, bool retain) const override;
43-
4441
// example
4542
// virtual Tensor * add(Tensor & a, Tensor & b) = 0;
4643
${type_method_declarations}

aten/src/ATen/test/basic.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ void TestTensorFromTH() {
207207
int a = 4;
208208
THFloatTensor* t = THFloatTensor_newWithSize2d(a, a);
209209
THFloatTensor_fill(t, a);
210-
ASSERT_NO_THROW(CPU(kFloat).unsafeTensorFromTH(t, false));
210+
ASSERT_NO_THROW(at::unsafeTensorFromTH(t, false));
211211
}
212212

213213
void TestToCFloat() {

0 commit comments

Comments
 (0)