Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions aten/src/ATen/ScalarType.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@ _(at::Half,Half,d) \
_(float,Float,d) \
_(double,Double,d)

#define AT_FORALL_SCALAR_TYPES_EXCEPT_HALF(_) \
_(uint8_t,Byte,i) \
_(int8_t,Char,i) \
_(int16_t,Short,i) \
_(int,Int,i) \
_(int64_t,Long,i) \
_(float,Float,d) \
_(double,Double,d)

enum class ScalarType {
#define DEFINE_ENUM(_1,n,_2) \
n,
Expand Down
17 changes: 17 additions & 0 deletions aten/src/ATen/native/TensorFactories.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "ATen/NativeFunctions.h"
#include "ATen/ScalarType.h"
#include "ATen/Deprecated.h"
#include "ATen/DeviceGuard.h"
#include "TH/THRandom.h"

#include <algorithm>
Expand Down Expand Up @@ -593,5 +594,21 @@ Tensor hann_window(
return native::hamming_window(
window_length, periodic, /*alpha=*/0.5, /*beta=*/0.5, options);
}

template <typename T>
Tensor tensor(ArrayRef<T> values, const TensorOptions& options) {
auto result = at::empty(values.size(), options);
for (size_t i = 0; i < values.size(); ++i) {
result[i] = values[i];

This comment was marked as off-topic.

This comment was marked as off-topic.

}
return result;
}

#define TENSOR(T, _1, _2) \
Tensor tensor(ArrayRef<T> values, const TensorOptions& options) { \

This comment was marked as off-topic.

This comment was marked as off-topic.

return tensor<T>(values, options); \
}
AT_FORALL_SCALAR_TYPES_EXCEPT_HALF(TENSOR)
#undef TENSOR
} // namespace native
} // namespace at
8 changes: 4 additions & 4 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1573,10 +1573,10 @@
SparseCPU: new_with_size_sparse

- func: tensor(Type dtype) -> Tensor
variants: function
variants: []

- func: tensor(Type dtype, IntList size) -> Tensor
variants: function
variants: []


# NB: The function overloads are removed to avoid a nasty bug where
Expand All @@ -1598,10 +1598,10 @@
SparseCPU: new_with_tensor_and_size_sparse

- func: sparse_coo_tensor(IndexTensor indices, Tensor values) -> Tensor
variants: function
variants: []

- func: sparse_coo_tensor(IndexTensor indices, Tensor values, IntList size) -> Tensor
variants: function
variants: []


- func: _native_sparse_coo_tensor_unsafe(IndexTensor indices, Tensor values, IntList size) -> Tensor
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/templates/Functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
namespace at {

using native::from_blob;
using native::tensor;

${function_declarations}

Expand Down
22 changes: 22 additions & 0 deletions aten/src/ATen/templates/NativeFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,28 @@ inline Tensor from_blob(
return native::from_blob(data, sizes, [](void*) {}, options);
}

// These functions are defined in native/TensorFactories.cpp.
#define TENSOR(T, S, _1) \
Tensor tensor(ArrayRef<T> values, const TensorOptions& options); \
inline Tensor tensor( \
std::initializer_list<T> values, const TensorOptions& options) { \
return native::tensor(ArrayRef<T>(values), options); \
} \
inline Tensor tensor(T value, const TensorOptions& options) { \
return native::tensor(ArrayRef<T>(value), options); \
} \
inline Tensor tensor(ArrayRef<T> values) { \
return native::tensor(std::move(values), at::dtype(k##S)); \
} \
inline Tensor tensor(std::initializer_list<T> values) { \
return native::tensor(ArrayRef<T>(values)); \
} \
inline Tensor tensor(T value) { \
return native::tensor(ArrayRef<T>(value)); \
}
AT_FORALL_SCALAR_TYPES_EXCEPT_HALF(TENSOR)
#undef TENSOR

${native_function_declarations}

} // namespace native
Expand Down
9 changes: 0 additions & 9 deletions aten/src/TH/THBlasUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,6 @@
// rather than by name directly. Someone should figure out a reasonable way to
// rewrite these in more idiomatic ATen and move it into ATen proper.

#define AT_FORALL_SCALAR_TYPES_EXCEPT_HALF(_) \
_(uint8_t,Byte,i) \
_(int8_t,Char,i) \
_(int16_t,Short,i) \
_(int,Int,i) \
_(int64_t,Long,i) \
_(float,Float,d) \
_(double,Double,d)

template<typename T>
inline void THBlas_axpy(int64_t n, T a, T *x, int64_t incx, T *y, int64_t incy);

Expand Down
78 changes: 78 additions & 0 deletions test/cpp/api/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,18 @@

#include <ATen/ATen.h>

#include <cmath>

template <typename T>
bool exactly_equal(at::Tensor left, T right) {
return at::Scalar(left).to<T>() == right;
}

template <typename T>
bool almost_equal(at::Tensor left, T right, T tolerance = 1e-4) {
return std::abs(at::Scalar(left).to<T>() - right) < tolerance;
}

#define REQUIRE_TENSOR_OPTIONS(device_, index_, type_, layout_) \
REQUIRE(tensor.device().type() == at::Device((device_), (index_)).type()); \
REQUIRE(tensor.device().index() == at::Device((device_), (index_)).index()); \
Expand Down Expand Up @@ -83,3 +95,69 @@ TEST_CASE("Tensor/ToDoesNotCopyWhenOptionsAreAllTheSame") {
auto hopefully_not_copy = tensor.to(at::kFloat);
REQUIRE(hopefully_not_copy.data<float>() == tensor.data<float>());
}

TEST_CASE("Tensor/ContainsCorrectValueForSingleValue") {
auto tensor = at::tensor(123);
REQUIRE(tensor.numel() == 1);
REQUIRE(tensor.dtype() == at::kInt);

This comment was marked as off-topic.

REQUIRE(tensor[0].toCInt() == 123);

tensor = at::tensor(123.456f);
REQUIRE(tensor.numel() == 1);
REQUIRE(tensor.dtype() == at::kFloat);
REQUIRE(almost_equal(tensor[0], 123.456f));

tensor = at::tensor(123.456);
REQUIRE(tensor.numel() == 1);
REQUIRE(tensor.dtype() == at::kDouble);
REQUIRE(almost_equal(tensor[0], 123.456));
}

TEST_CASE("Tensor/ContainsCorrectValuesForManyValues") {
auto tensor = at::tensor({1, 2, 3});
REQUIRE(tensor.numel() == 3);
REQUIRE(tensor.dtype() == at::kInt);
REQUIRE(exactly_equal(tensor[0], 1));
REQUIRE(exactly_equal(tensor[1], 2));
REQUIRE(exactly_equal(tensor[2], 3));

tensor = at::tensor({1.5, 2.25, 3.125});
REQUIRE(tensor.numel() == 3);
REQUIRE(tensor.dtype() == at::kDouble);
REQUIRE(almost_equal(tensor[0], 1.5));
REQUIRE(almost_equal(tensor[1], 2.25));
REQUIRE(almost_equal(tensor[2], 3.125));
}

TEST_CASE("Tensor/ContainsCorrectValuesWhenConstructedFromVector") {
std::vector<int> v = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
auto tensor = at::tensor(v);
REQUIRE(tensor.numel() == v.size());
REQUIRE(tensor.dtype() == at::kInt);
for (size_t i = 0; i < v.size(); ++i) {
REQUIRE(exactly_equal(tensor[i], v.at(i)));
}

std::vector<float> w = {1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.0};
tensor = at::tensor(w);
REQUIRE(tensor.numel() == w.size());
REQUIRE(tensor.dtype() == at::kFloat);
for (size_t i = 0; i < w.size(); ++i) {
REQUIRE(almost_equal(tensor[i], w.at(i)));
}
}

TEST_CASE("Tensor/UsesOptionsThatAreSupplied") {
auto tensor = at::tensor(123, dtype(at::kFloat)) + 0.5;
REQUIRE(tensor.numel() == 1);
REQUIRE(tensor.dtype() == at::kFloat);
REQUIRE(almost_equal(tensor[0], 123.5));

tensor = at::tensor({1.1, 2.2, 3.3}, dtype(at::kInt));
REQUIRE(tensor.numel() == 3);
REQUIRE(tensor.dtype() == at::kInt);
REQUIRE(tensor.layout() == at::kStrided);
REQUIRE(exactly_equal(tensor[0], 1));
REQUIRE(exactly_equal(tensor[1], 2));
REQUIRE(exactly_equal(tensor[2], 3));
}
10 changes: 10 additions & 0 deletions test/cpp/api/tensor_cuda.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#include <catch.hpp>

#include <ATen/ATen.h>

#include <cmath>

TEST_CASE("Tensor/AllocatesTensorOnTheCorrectDevice", "[cuda]") {
auto tensor = at::tensor({1, 2, 3}, at::device({at::kCUDA, 1}));
REQUIRE(tensor.device() == at::Device(at::kCUDA, 1));
}
1 change: 1 addition & 0 deletions torch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ if (TORCH_BUILD_TEST)
${TORCH_API_TEST_DIR}/serialization.cpp
${TORCH_API_TEST_DIR}/static.cpp
${TORCH_API_TEST_DIR}/tensor.cpp
${TORCH_API_TEST_DIR}/tensor_cuda.cpp
# Temporary until ATen tests are built with Caffe2
${TORCH_API_TEST_DIR}/tensor_options.cpp
${TORCH_API_TEST_DIR}/tensor_options_cuda.cpp
Expand Down