Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aten/src/ATen/Declarations.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@
cname: indexSelect
variants:
- function
cpu_half: True
return: argument 0
arguments:
- arg: THTensor* result
Expand Down
3 changes: 0 additions & 3 deletions aten/src/ATen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,9 +390,6 @@ def legacy_iterate_types():
for scalar_type in (scalar_types + quantized_scalar_types):
if density == 'Mkldnn' and (backend != 'CPU' or scalar_type[0] != 'Float'):
continue
if density == 'Sparse' and scalar_type[0] == 'Half':
# THS does not do half type yet.
continue
else:
yield (backend, density, scalar_type)
for backend in quantized_backends:
Expand Down
16 changes: 11 additions & 5 deletions aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,19 @@ void sub_kernel(TensorIterator& iter, Scalar alpha_scalar) {
}

void mul_kernel(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES(iter.dtype(), "mul_cpu", [&]() {
binary_kernel_vec(iter,
[=](scalar_t a, scalar_t b) -> scalar_t { return a * b; },
[=](Vec256<scalar_t> a, Vec256<scalar_t> b) {
if( iter.dtype() == ScalarType::Half ) {
binary_kernel(iter, [](Half a, Half b) -> Half {
return a * b;
});
});
} else {
AT_DISPATCH_ALL_TYPES(iter.dtype(), "mul_cpu", [&]() {
binary_kernel_vec(iter,
[=](scalar_t a, scalar_t b) -> scalar_t { return a * b; },
[=](Vec256<scalar_t> a, Vec256<scalar_t> b) {
return a * b;
});
});
}
}

void div_kernel(TensorIterator& iter) {
Expand Down
18 changes: 12 additions & 6 deletions aten/src/ATen/native/cpu/ReduceOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,18 @@ namespace at { namespace native { namespace {
using namespace vec256;

static void sum_kernel_impl(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES(iter.dtype(), "sum_cpu", [&] {
binary_kernel_reduce_vec(
iter,
[=](scalar_t a, scalar_t b) -> scalar_t { return a + b; },
[=](Vec256<scalar_t> a, Vec256<scalar_t> b) { return a + b; });
});
if( iter.dtype() == ScalarType::Half ) {
binary_kernel(iter, [](Half a, Half b) -> Half {
return a + b;
});
} else {
AT_DISPATCH_ALL_TYPES(iter.dtype(), "sum_cpu", [&] {
binary_kernel_reduce_vec(
iter,
[=](scalar_t a, scalar_t b) -> scalar_t { return a + b; },
[=](Vec256<scalar_t> a, Vec256<scalar_t> b) { return a + b; });
});
}
}

static void mean_kernel_impl(TensorIterator& iter) {
Expand Down
7 changes: 4 additions & 3 deletions aten/src/ATen/native/sparse/SparseTensorMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ SparseTensor& add_out_sparse_cpu(SparseTensor& r, const SparseTensor& t, const S
Tensor s_values = src._values();
r.resize_as_(src);

if (s_values.is_contiguous() && t_values.is_contiguous()) {
if (s_values.is_contiguous() && t_values.is_contiguous() && t_values.scalar_type() != ScalarType::Half) {
LongTensor r_indices = at::empty({sparse_dim, max_nnz}, t_indices.options());
Tensor r_values = new_values_with_size_of(s_values, max_nnz).zero_();
get_sparse_impl(r)->set_indices_and_values_unsafe(r_indices, r_values);
Expand Down Expand Up @@ -287,8 +287,9 @@ SparseTensor& add_out_sparse_cpu(SparseTensor& r, const SparseTensor& t, const S
} else {
// If `t` or `src` contains non-contiguous `values`, `THBlas_axpy` doesn't work
// and we concat the indices and values tensors instead.
AT_DISPATCH_ALL_TYPES(
s_values.scalar_type(), "add_out_sparse_cuda", [&] {
// Also THBlas_axpy isn't implemnented for Half types.
AT_DISPATCH_ALL_TYPES_AND(ScalarType::Half,
s_values.scalar_type(), "add_out_sparse_cpu", [&] {
if (value.to<scalar_t>() != static_cast<scalar_t>(1)) {
s_values = s_values.mul(value);
}
Expand Down
11 changes: 6 additions & 5 deletions aten/src/TH/THGenerateHalfType.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,19 @@
#error "You must define TH_GENERIC_FILE before including THGenerateHalfType.h"
#endif

#include <TH/THHalf.h>
#define scalar_t THHalf
#define accreal float
#include "THHalf.h"

#define scalar_t at::Half
#define accreal double
#define TH_CONVERT_REAL_TO_ACCREAL(_val) (accreal)(_val)
#define TH_CONVERT_ACCREAL_TO_REAL(_val) (scalar_t)(_val)
#define Real Half
#define THInf TH_HALF_BITS_TO_LITERAL(TH_HALF_INF)
#define THInf std::mumeric_limits<at::Half>::max()
#define TH_REAL_IS_HALF
#line 1 TH_GENERIC_FILE
#include TH_GENERIC_FILE
#undef scalar_t
#undef accreal
#undef scalar_t
#undef Real
#undef THInf
#undef TH_REAL_IS_HALF
Expand Down
3 changes: 3 additions & 0 deletions aten/src/TH/THTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
#include <TH/generic/THTensorMath.h>
#include <TH/THGenerateBoolType.h>

#include <TH/generic/THTensorMath.h>
#include <TH/THGenerateHalfType.h>

/* fill and zero*/
#include <TH/generic/THTensorFill.h>
#include <TH/THGenerateAllTypes.h>
Expand Down
3 changes: 3 additions & 0 deletions aten/src/TH/THTensorEvenMoreMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,6 @@

#include <TH/generic/THTensorEvenMoreMath.cpp>
#include <TH/THGenerateBoolType.h>

#include <TH/generic/THTensorEvenMoreMath.cpp>
#include <TH/THGenerateHalfType.h>
3 changes: 3 additions & 0 deletions aten/src/TH/THTensorMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,6 @@

#include <TH/generic/THTensorMath.cpp>
#include <TH/THGenerateBoolType.h>

#include <TH/generic/THTensorMath.cpp>
#include <TH/THGenerateHalfType.h>
3 changes: 3 additions & 0 deletions aten/src/TH/THTensorMoreMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,6 @@

#include <TH/generic/THTensorMoreMath.cpp>
#include <TH/THGenerateBoolType.h>

#include <TH/generic/THTensorMoreMath.cpp>
#include <TH/THGenerateHalfType.h>
29 changes: 22 additions & 7 deletions aten/src/TH/generic/THTensorEvenMoreMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,19 @@

#include <TH/generic/THTensorApply.hpp>


#ifdef TH_REAL_IS_HALF
#include "c10/util/Half.h"
#endif

// Finds non-zero elements of a tensor and returns their subscripts
void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor)
{
ptrdiff_t numel = 0;
int64_t *subscript_data;
int64_t i = 0;
#ifdef TH_REAL_IS_HALF
#define IS_NONZERO(val) ((val.x & 0x7fff) != 0)
#define IS_NONZERO(val) (c10::Half(0)!=val)
#else
#define IS_NONZERO(val) ((val)!=0)
#endif
Expand Down Expand Up @@ -65,6 +70,9 @@ void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor)
);
delete [] sizes;
delete [] idx;

#undef IS_NONZERO

}

#if !defined(TH_REAL_IS_BOOL) /* non bool only part */
Expand Down Expand Up @@ -361,6 +369,8 @@ void THTensor_(put)(THTensor *tensor, THLongTensor *index, THTensor *src, int ac
THLongTensor_free(index);
}

#if !defined(TH_REAL_IS_HALF) // skipping because we don't have blas to define cadd

void THTensor_(indexAdd)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src)
{
ptrdiff_t i, numel;
Expand Down Expand Up @@ -401,6 +411,7 @@ void THTensor_(indexAdd)(THTensor *tensor, int dim, THLongTensor *index, THTenso
}
THLongTensor_free(index);
}
#endif

void THTensor_(indexFill)(THTensor *tensor, int dim, THLongTensor *index, scalar_t val)
{
Expand Down Expand Up @@ -523,6 +534,8 @@ void THTensor_(scatterAdd)(THTensor *tensor, int dim, THLongTensor *index, THTen
})
}

#if ! defined(TH_REAL_IS_HALF) /* blas not implemented for half */

void THTensor_(scatterFill)(THTensor *tensor, int dim, THLongTensor *index, scalar_t val)
{
int64_t elems_per_row, i, idx;
Expand Down Expand Up @@ -566,6 +579,8 @@ accreal THTensor_(dot)(THTensor *tensor, THTensor *src)
return sum;
}

#endif /* end ! half section */

scalar_t THTensor_(minall)(THTensor *tensor)
{
scalar_t theMin;
Expand Down Expand Up @@ -835,7 +850,7 @@ void THTensor_(fmod)(THTensor *r_, THTensor *t, scalar_t value)
int64_t i;
#pragma omp parallel for if(r_Size > TH_OMP_OVERHEAD_THRESHOLD) private(i)
for (i=0; i<r_Size; i++) {
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_HALF)
rp[i] = fmod(tp[i], value);
#else
rp[i] = tp[i] % value;
Expand All @@ -847,7 +862,7 @@ void THTensor_(fmod)(THTensor *r_, THTensor *t, scalar_t value)
if (inOMP) {
serial_path = 1;
} else {
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_HALF)
TH_TENSOR_APPLY2_OMP(r_Size, r_Contig, tContig, scalar_t, r_, scalar_t, t, *r__data = fmod(*t_data, value);, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD);
#else
TH_TENSOR_APPLY2_OMP(r_Size, r_Contig, tContig, scalar_t, r_, scalar_t, t, *r__data = (*t_data % value);, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD);
Expand All @@ -858,7 +873,7 @@ void THTensor_(fmod)(THTensor *r_, THTensor *t, scalar_t value)
#endif
}
if (serial_path) {
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_HALF)
TH_TENSOR_APPLY2(scalar_t, r_, scalar_t, t, *r__data = fmod(*t_data, value););
#else
TH_TENSOR_APPLY2(scalar_t, r_, scalar_t, t, *r__data = (*t_data % value););
Expand All @@ -884,7 +899,7 @@ void THTensor_(remainder)(THTensor *r_, THTensor *t, scalar_t value)
int64_t i;
#pragma omp parallel for if(r_Size > TH_OMP_OVERHEAD_THRESHOLD) private(i)
for (i=0; i<r_Size; i++) {
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_HALF)
rp[i] = (value == 0)? NAN : tp[i] - value * floor(tp[i] / value);
#else
// There is no NAN for integers
Expand All @@ -899,7 +914,7 @@ void THTensor_(remainder)(THTensor *r_, THTensor *t, scalar_t value)
if (inOMP) {
serial_path = 1;
} else {
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_HALF)
TH_TENSOR_APPLY2_OMP(r_Size, r_Contig, tContig, scalar_t, r_, scalar_t, t, *r__data = (value == 0)? NAN : *t_data - value * floor(*t_data / value);, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD);
#else
// There is no NAN for integers
Expand All @@ -912,7 +927,7 @@ void THTensor_(remainder)(THTensor *r_, THTensor *t, scalar_t value)
#endif
}
if (serial_path) {
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_HALF)
TH_TENSOR_APPLY2(scalar_t, r_, scalar_t, t, *r__data = (value == 0)? NAN : *t_data - value * floor(*t_data / value););
#else
// There is no NAN for integers
Expand Down
13 changes: 7 additions & 6 deletions aten/src/TH/generic/THTensorMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ void THTensor_(clamp)(THTensor *r_, THTensor *t, scalar_t min_value, scalar_t ma
}
}

#if ! defined(TH_REAL_IS_HALF) // we don't have half blas functions so skipping these
void THTensor_(cadd)(THTensor *r_, THTensor *t, scalar_t value, THTensor *src)
{
THTensor_(resizeAs)(r_, t);
Expand Down Expand Up @@ -179,6 +180,8 @@ void THTensor_(csub)(THTensor *r_, THTensor *t, scalar_t value, THTensor *src)
THTensor_(cadd)(r_, t, -value, src);
}

#endif // ! Half

void THTensor_(cmul)(THTensor *r_, THTensor *t, THTensor *src)
{
THTensor_(resizeAs)(r_, t);
Expand Down Expand Up @@ -344,11 +347,11 @@ void THTensor_(cdiv)(THTensor *r_, THTensor *t, THTensor *src)
}
}

#if !defined(TH_REAL_IS_HALF)
// return THError("clshift is not supported for torch.HalfTensor");

void THTensor_(clshift)(THTensor *r_, THTensor *t, THTensor *src)
{
#if defined(TH_REAL_IS_HALF)
return THError("clshift is not supported for torch.HalfTensor");
#endif
THTensor_(resizeAs)(r_, t);
int64_t r_Size = THTensor_(nElement)(r_);
int64_t srcSize = THTensor_(nElement)(src);
Expand Down Expand Up @@ -412,9 +415,6 @@ void THTensor_(clshift)(THTensor *r_, THTensor *t, THTensor *src)

void THTensor_(crshift)(THTensor *r_, THTensor *t, THTensor *src)
{
#if defined(TH_REAL_IS_HALF)
return THError("crshift is not supported for torch.HalfTensor");
#endif
THTensor_(resizeAs)(r_, t);
int64_t r_Size = THTensor_(nElement)(r_);
int64_t srcSize = THTensor_(nElement)(src);
Expand Down Expand Up @@ -1197,6 +1197,7 @@ void THTensor_(addbmm)(THTensor *result, scalar_t beta, THTensor *t, scalar_t al
c10::raw::intrusive_ptr::decref(matrix1);
c10::raw::intrusive_ptr::decref(matrix2);
}
#endif // ! Half

#endif /* !defined(TH_REAL_IS_BOOL) */

Expand Down
4 changes: 4 additions & 0 deletions aten/src/TH/generic/THTensorMoreMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ TENSOR_IMPLEMENT_LOGICAL(ne,!=)

#if !defined(TH_REAL_IS_BOOL) /* non bool only part */

#if !defined(TH_REAL_IS_HALF) // baddbmm not implemented for half.

void THTensor_(baddbmm)(THTensor *result, scalar_t beta, THTensor *t, scalar_t alpha, THTensor *batch1, THTensor *batch2)
{
int64_t batch;
Expand Down Expand Up @@ -87,6 +89,8 @@ void THTensor_(baddbmm)(THTensor *result, scalar_t beta, THTensor *t, scalar_t a
c10::raw::intrusive_ptr::decref(result_matrix);
}

#endif // !defined(TH_REAL_IS_HALF)

ptrdiff_t THTensor_(numel)(THTensor *t)
{
return THTensor_(nElement)(t);
Expand Down
Loading