Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions aten/src/ATen/Declarations.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -2386,6 +2386,8 @@
cname: addcmul
variants:
- function
backends:
- CUDA
return: argument 0
arguments:
- arg: THTensor* result
Expand All @@ -2403,6 +2405,8 @@
options:
- cname: addcmul
variants: function
backends:
- CUDA
return: argument 0
arguments:
- THTensor* self
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/core/Tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,6 @@ class CAFFE2_API Tensor {
Tensor & remainder_(const Tensor & other);
Tensor & addbmm_(const Tensor & batch1, const Tensor & batch2, Scalar beta=1, Scalar alpha=1);
Tensor addbmm(const Tensor & batch1, const Tensor & batch2, Scalar beta=1, Scalar alpha=1) const;
Tensor & addcmul_(const Tensor & tensor1, const Tensor & tensor2, Scalar value=1);
Tensor & addcdiv_(const Tensor & tensor1, const Tensor & tensor2, Scalar value=1);
Tensor & random_(int64_t from, int64_t to, Generator * generator=nullptr);
Tensor & random_(int64_t to, Generator * generator=nullptr);
Expand Down Expand Up @@ -731,6 +730,7 @@ class CAFFE2_API Tensor {
std::vector<Tensor> nonzero_numpy() const;
Tensor gather(int64_t dim, const Tensor & index, bool sparse_grad=false) const;
Tensor addcmul(const Tensor & tensor1, const Tensor & tensor2, Scalar value=1) const;
Tensor & addcmul_(const Tensor & tensor1, const Tensor & tensor2, Scalar value=1);
Tensor addcdiv(const Tensor & tensor1, const Tensor & tensor2, Scalar value=1) const;
std::tuple<Tensor,Tensor> lstsq(const Tensor & A) const;
std::tuple<Tensor,Tensor> triangular_solve(const Tensor & A, bool upper=true, bool transpose=false, bool unitriangular=false) const;
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/core/TensorMethods.h
Original file line number Diff line number Diff line change
Expand Up @@ -1405,10 +1405,6 @@ inline Tensor Tensor::addbmm(const Tensor & batch1, const Tensor & batch2, Scala
static auto table = globalATenDispatch().getOpTable("aten::addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor");
return table->getOp<Tensor (const Tensor &, const Tensor &, const Tensor &, Scalar, Scalar)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, batch1, batch2, beta, alpha);
}
inline Tensor & Tensor::addcmul_(const Tensor & tensor1, const Tensor & tensor2, Scalar value) {
static auto table = globalATenDispatch().getOpTable("aten::addcmul_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!)");
return table->getOp<Tensor & (Tensor &, const Tensor &, const Tensor &, Scalar)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, tensor1, tensor2, value);
}
inline Tensor & Tensor::addcdiv_(const Tensor & tensor1, const Tensor & tensor2, Scalar value) {
static auto table = globalATenDispatch().getOpTable("aten::addcdiv_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!)");
return table->getOp<Tensor & (Tensor &, const Tensor &, const Tensor &, Scalar)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, tensor1, tensor2, value);
Expand Down Expand Up @@ -1545,6 +1541,10 @@ inline Tensor Tensor::addcmul(const Tensor & tensor1, const Tensor & tensor2, Sc
static auto table = globalATenDispatch().getOpTable("aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor");
return table->getOp<Tensor (const Tensor &, const Tensor &, const Tensor &, Scalar)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, tensor1, tensor2, value);
}
inline Tensor & Tensor::addcmul_(const Tensor & tensor1, const Tensor & tensor2, Scalar value) {
static auto table = globalATenDispatch().getOpTable("aten::addcmul_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!)");
return table->getOp<Tensor & (Tensor &, const Tensor &, const Tensor &, Scalar)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, tensor1, tensor2, value);
}
inline Tensor Tensor::addcdiv(const Tensor & tensor1, const Tensor & tensor2, Scalar value) const {
static auto table = globalATenDispatch().getOpTable("aten::addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor");
return table->getOp<Tensor (const Tensor &, const Tensor &, const Tensor &, Scalar)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, tensor1, tensor2, value);
Expand Down
56 changes: 56 additions & 0 deletions aten/src/ATen/native/PointwiseOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Ternary and higher-order pointwise operations
#include <ATen/native/PointwiseOps.h>

#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/MemoryOverlap.h>
#include <ATen/native/TensorIterator.h>

#ifdef BUILD_NAMEDTENSOR
#include <ATen/NamedTensorUtils.h>
#endif

namespace at {
namespace native {

Tensor addcmul_cpu(
const Tensor& self,
const Tensor& tensor1,
const Tensor& tensor2,
Scalar value) {
Tensor result = at::empty({0}, self.options());
return at::addcmul_out(result, self, tensor1, tensor2, value);
}

Tensor& addcmul_cpu_(
Tensor& self,
const Tensor& tensor1,
const Tensor& tensor2,
Scalar value) {
return at::addcmul_out(self, self, tensor1, tensor2, value);
}

Tensor& addcmul_cpu_out(
Tensor& result,
const Tensor& self,
const Tensor& tensor1,
const Tensor& tensor2,
Scalar value) {
checkBackend("addcmul_cpu", result, self.type().backend());
auto iter = at::TensorIterator();
iter.check_and_add_output(result);
iter.add_input(self);
iter.add_input(tensor1);
iter.add_input(tensor2);
iter.build();
addcmul_stub(kCPU, iter, value);
#ifdef BUILD_NAMEDTENSOR
at::namedinference::propagate_names(result, self);
#endif
return result;
}

DEFINE_DISPATCH(addcmul_stub);

} // namespace native
} // namespace at
17 changes: 17 additions & 0 deletions aten/src/ATen/native/PointwiseOps.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Ternary and higher-order pointwise operations
#pragma once

#include <ATen/ATen.h>
#include <ATen/native/DispatchStub.h>

namespace at {

struct TensorIterator;

namespace native {

using addcmul_fn = void (*)(TensorIterator&, Scalar scalar);

DECLARE_DISPATCH(addcmul_fn, addcmul_stub);
} // namespace native
} // namespace at
36 changes: 36 additions & 0 deletions aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Ternary and higher-order pointwise operations
#include <ATen/ATen.h>

#include <ATen/Dispatch.h>
#include <ATen/native/PointwiseOps.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cpu/Loops.h>

namespace at {
namespace native {
namespace {

static void addcmul_cpu_kernel(TensorIterator& iter, Scalar value) {
ScalarType dtype = iter.dtype(0);
AT_DISPATCH_ALL_TYPES(dtype, "addcmul_cpu_out", [&] {
scalar_t scalar_val = value.to<scalar_t>();
auto scalar_vec = Vec256<scalar_t>(scalar_val);
cpu_kernel_vec(
iter,
[=](scalar_t self_val, scalar_t t1_val, scalar_t t2_val) -> scalar_t {
return self_val + scalar_val * t1_val * t2_val;
},
[=](Vec256<scalar_t> self_vec,
Vec256<scalar_t> t1_vec,
Vec256<scalar_t> t2_vec) {
return self_vec + scalar_vec * t1_vec * t2_vec;
});
});
}

} // anonymous namespace

REGISTER_DISPATCH(addcmul_stub, &addcmul_cpu_kernel);

} // namespace native
} // namespace at
16 changes: 8 additions & 8 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3436,12 +3436,6 @@
CPU: legacy::cpu::_th_addbmm
CUDA: legacy::cuda::_th_addbmm

- func: addcmul_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!)
variants: method
dispatch:
CPU: legacy::cpu::_th_addcmul_
CUDA: legacy::cuda::_th_addcmul_

- func: addcdiv_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!)
variants: method
dispatch:
Expand Down Expand Up @@ -3755,15 +3749,21 @@

- func: addcmul.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU: legacy::cpu::_th_addcmul_out
CPU: addcmul_cpu_out
CUDA: legacy::cuda::_th_addcmul_out

- func: addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor
variants: method, function
dispatch:
CPU: legacy::cpu::_th_addcmul
CPU: addcmul_cpu
CUDA: legacy::cuda::_th_addcmul

- func: addcmul_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!)
variants: method
dispatch:
CPU: addcmul_cpu_
CUDA: legacy::cuda::_th_addcmul_

- func: addcdiv.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU: legacy::cpu::_th_addcdiv_out
Expand Down
22 changes: 0 additions & 22 deletions aten/src/TH/generic/THTensorMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -575,28 +575,6 @@ void THTensor_(tpow)(THTensor *r_, scalar_t value, THTensor *t)
}
}

void THTensor_(addcmul)(THTensor *r_, THTensor *t, scalar_t value, THTensor *src1, THTensor *src2)
{
if(r_ != t)
{
THTensor_(resizeAs)(r_, t);
at::Tensor r__wrap = THTensor_wrap(r_);
at::Tensor t_wrap = THTensor_wrap(t);
at::native::copy_(r__wrap, t_wrap);
}
int64_t r_Size = THTensor_(nElement)(r_);
int64_t src1Size = THTensor_(nElement)(src1);
int64_t src2Size = THTensor_(nElement)(src2);
int r_Contig = THTensor_(isContiguous)(r_);
int src1Contig = THTensor_(isContiguous)(src1);
int src2Contig = THTensor_(isContiguous)(src2);
if( (src1Size == src2Size) && (src1Size == r_Size) ){
TH_TENSOR_APPLY3_PARALLEL(r_Size, r_Contig, src1Contig, src2Contig, scalar_t, r_, scalar_t, src1, scalar_t, src2, *r__data += value * *src1_data * *src2_data;, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD);
} else {
TH_TENSOR_APPLY3(scalar_t, r_, scalar_t, src1, scalar_t, src2, *r__data += value * *src1_data * *src2_data;);
}
}

void THTensor_(addcdiv)(THTensor *r_, THTensor *t, scalar_t value, THTensor *src1, THTensor *src2)
{
if(r_ != t)
Expand Down
1 change: 0 additions & 1 deletion aten/src/TH/generic/THTensorMath.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ TH_API void THTensor_(crshift)(THTensor *r_, THTensor *t, THTensor *src);
TH_API void THTensor_(cfmod)(THTensor *r_, THTensor *t, THTensor *src);
TH_API void THTensor_(cremainder)(THTensor *r_, THTensor *t, THTensor *src);

TH_API void THTensor_(addcmul)(THTensor *r_, THTensor *t, scalar_t value, THTensor *src1, THTensor *src2);
TH_API void THTensor_(addcdiv)(THTensor *r_, THTensor *t, scalar_t value, THTensor *src1, THTensor *src2);

TH_API void THTensor_(addmv)(THTensor *r_, scalar_t beta, THTensor *t, scalar_t alpha, THTensor *mat, THTensor *vec);
Expand Down