Skip to content

Commit 33221b1

Browse files
pbelevichfacebook-github-bot
authored andcommitted
C++ API parity: at::Tensor::data
Summary: Pull Request resolved: #26008 Test Plan: Imported from OSS Differential Revision: D17343488 Pulled By: pbelevich fbshipit-source-id: b9ba5e26cad621a428a14292446d7fb5a6e5535d
1 parent 5e2d25a commit 33221b1

File tree

10 files changed

+34
-5
lines changed

10 files changed

+34
-5
lines changed

aten/src/ATen/core/TensorBody.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,7 @@ class CAFFE2_API Tensor {
394394
//Tensor * add(Tensor & b);
395395
void backward(const Tensor & gradient={}, bool keep_graph=false, bool create_graph=false) const;
396396
void set_data(const Tensor & new_data) const;
397+
Tensor data() const;
397398
#ifdef BUILD_NAMEDTENSOR
398399
Tensor & names_(c10::optional<DimnameList> names) const;
399400
#endif

aten/src/ATen/core/TensorMethods.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,14 @@ inline void Tensor::set_data(const Tensor & new_data) const {
7373
return table->getOp<void (const Tensor &, const Tensor &)>(type_set())(const_cast<Tensor&>(*this), new_data);
7474
#endif
7575
}
76+
inline Tensor Tensor::data() const {
77+
#ifdef USE_STATIC_DISPATCH
78+
return TypeDefault::data(const_cast<Tensor&>(*this));
79+
#else
80+
static auto table = globalATenDispatch().getOpTable("aten::data(Tensor self) -> Tensor");
81+
return table->getOp<Tensor (const Tensor &)>(type_set())(const_cast<Tensor&>(*this));
82+
#endif
83+
}
7684
#ifdef BUILD_NAMEDTENSOR
7785
inline Tensor & Tensor::names_(c10::optional<DimnameList> names) const {
7886
#ifdef USE_STATIC_DISPATCH

aten/src/ATen/native/VariableMethodStubs.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,9 @@ void set_data(const Tensor& self, const Tensor& new_data) {
1212
AT_ERROR("set_data is not implemented for Tensor");
1313
}
1414

15+
Tensor data(const Tensor& self) {
16+
AT_ERROR("data is not implemented for Tensor");
17+
}
18+
1519
} // namespace native
1620
} // namespace at

aten/src/ATen/native/native_functions.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@
3535
- func: set_data(Tensor(a!) self, Tensor new_data) -> void
3636
variants: method
3737

38+
- func: data(Tensor self) -> Tensor
39+
variants: method
40+
3841
- func: names_(Tensor(a!) self, Dimname[]? names) -> Tensor(a!)
3942
variants: method
4043
named_guard: False

aten/src/ATen/native/quantized/cpu/qconv.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,8 @@ class QConv2dInt8 final : public c10::OperatorKernel {
235235
conv_p,
236236
act_ptr,
237237
*packB,
238-
reinterpret_cast<uint8_t*>(output.data<c10::quint8>()),
239-
buffer.data<int32_t>(),
238+
reinterpret_cast<uint8_t*>(output.data_ptr<c10::quint8>()),
239+
buffer.data_ptr<int32_t>(),
240240
outputProcObj,
241241
0 /* thread_id*/,
242242
1 /* num_threads */);

aten/src/ATen/test/quantized_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,14 +131,14 @@ TEST(TestQTensor, EmptyPerchannelQuantized) {
131131
{ch_axis},
132132
at::device(at::kCPU).dtype(kQUInt8));
133133
// Assigning to QTensor
134-
auto* q_data = q.data<quint8>();
134+
auto* q_data = q.data_ptr<quint8>();
135135
for (int i = 0; i < numel; ++i) {
136136
q_data[i].val_ = val;
137137
}
138138

139139
// dequantize
140140
auto r = q.dequantize();
141-
auto* r_data = r.data<float>();
141+
auto* r_data = r.data_ptr<float>();
142142
for (int i = 0; i < numel; ++i) {
143143
ASSERT_EQ(
144144
r_data[i],

test/cpp/api/tensor.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,3 +311,11 @@ TEST(TensorTest, DataPtr) {
311311
ASSERT_EQ(tensor_not_copy.data_ptr<float>(), tensor.data_ptr<float>());
312312
ASSERT_EQ(tensor_not_copy.data_ptr(), tensor.data_ptr());
313313
}
314+
315+
TEST(TensorTest, Data) {
316+
const auto tensor = torch::empty({3, 3});
317+
ASSERT_TRUE(torch::equal(tensor, tensor.data()));
318+
319+
const auto tensor2 = at::empty({3, 3});
320+
ASSERT_THROW(tensor2.data(), c10::Error);
321+
}

tools/autograd/gen_python_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
'set_quantizer_',
3737
'set_data',
3838
'.*_overrideable', # overrideable functions for backend extension
39+
'data'
3940
]
4041

4142
# These function signatures are not exposed to Python. Note that this signature

tools/autograd/gen_variable_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
# These functions are written manually in templates/VariableType.cpp
3131
MANUAL_IMPLEMENTATIONS = {
32-
'resize_', 'resize_as_', 'detach', 'detach_', 'copy_', 'backward', 'set_data'
32+
'resize_', 'resize_as_', 'detach', 'detach_', 'copy_', 'backward', 'set_data', 'data'
3333
}
3434

3535
# These functions we don't want to record for tracing, because we always want

torch/csrc/autograd/VariableTypeManual.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ void VariableType::set_data(const Tensor & self, const Tensor & new_data) {
9393
as_variable_ref(self).set_data(new_data);
9494
}
9595

96+
Tensor VariableType::data(const Tensor & self) {
97+
return as_variable_ref(self).variable_data();
98+
}
99+
96100
// We don't have an outplace copy, so this can't be generated automatically
97101
Tensor & VariableType::copy_(Tensor & self, const Tensor & src, bool non_blocking) {
98102
jit::Value* output = nullptr;

0 commit comments

Comments
 (0)