Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions aten/src/ATen/native/mps/operations/LinearAlgebra.mm
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,152 @@ void prepare_matrices_for_broadcasting(
return output;
}


Tensor addr_mps(const Tensor& self,
const Tensor& vec1, const Tensor& vec2,
const Scalar& beta, const Scalar& alpha) {
Tensor result = at::empty({0}, self.options());
addr_out_mps(self, vec1,vec2,beta,alpha,result);
return result;
}


Tensor& addr_out_mps(const Tensor& self,
const Tensor& vec1, const Tensor& vec2,
const Scalar& beta, const Scalar& alpha, Tensor &result) {
using namespace mps;

TORCH_CHECK(result.is_mps());
TORCH_CHECK(vec1.dim() == 1 && vec2.dim() == 1, "tensors must be 1-D");
TORCH_CHECK(vec1.scalar_type() == ScalarType::Double
|| vec1.scalar_type() == ScalarType::Float
|| vec1.scalar_type() == ScalarType::Half, "MPS device does not support addr for non-float input");

TensorArg args[]{{result, "out", 0}, {self, "self", 1}, {vec1, "vec1", 2}, {vec2, "vec2", 3}};
checkAllSameGPU(__func__, args);

IntArrayRef vec1_sizes = vec1.sizes();
IntArrayRef vec2_sizes = vec2.sizes();
IntArrayRef self_sizes;

c10::MaybeOwned<Tensor> self_;
if (&result != &self) {
self_ = expand_size(self, {vec1_sizes[0], vec2_sizes[0]}, "addr");
self_sizes = self_->sizes();
} else {
self_ = c10::MaybeOwned<Tensor>::borrowed(self);
self_sizes = self_->sizes();
TORCH_CHECK(result.dim() == 2, "tensors must be 2-D");
TORCH_CHECK(self_sizes[0] == vec1_sizes[0], "vec1_ dim 0 must match vec1 dim 0");
TORCH_CHECK(self_sizes[1] == vec2_sizes[0], "vec1_ dim 1 must match vec2 dim 0");
}

if (&result != &vec1) {
result.resize_(self_sizes);
if (beta.toComplexDouble() != 0.0) {
at::native::copy_(result, *self_);
}
}

IntArrayRef result_sizes = result.sizes();
if ((result_sizes[0] == 0) || (result_sizes[1] == 0)) {
return result;
}

MPSStream* stream = getCurrentMPSStream();
bool is_beta_non_zero = beta.toDouble() != 0.0;
MPSShape* inputShape = @[@(vec1.numel()), @(1)];
MPSShape* otherShape = @[@(1), @(vec2.numel())];

struct CachedGraph : public mps::MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *vec1Tensor_ = nil;
MPSGraphTensor *vec2Tensor_ = nil;
MPSGraphTensor *selfTensor_ = nil;
MPSGraphTensor *resultTensor_ = nil;
};

mps::MPSGraphCache *cache_ = mps::MPSGraphCache::getInstance();

@autoreleasepool {
string key = "addr_out_mps_impl" + getTensorsStringKey({vec1, vec2, *self_})
+ ":" + to_string(beta.toDouble())
+ ":" + to_string(alpha.toDouble());
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
if(!cachedGraph) {

mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ mps::MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;

@autoreleasepool{
MPSGraph *mpsGraph = mps::make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);

MPSGraphTensor *t1 = mps::mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(vec1.scalar_type()), inputShape);
MPSGraphTensor *t2 = mps::mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(vec2.scalar_type()), otherShape);
MPSGraphTensor *selfTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, *self_);

// Intermediate as placeholder
MPSGraphTensor* productTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:t1
secondaryTensor:t2
name:@"MM/(vec1Xvec2)"];

// Intermediates for beta and alpha
MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta.toDouble()
dataType:getMPSScalarType((*self_).scalar_type())];
MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar:alpha.toDouble()
dataType:getMPSScalarType(vec1.scalar_type())];

// Intermediates for multiplying by beta and alpha
MPSGraphTensor* productTimesAlphaTensor = [mpsGraph multiplicationWithPrimaryTensor:productTensor
secondaryTensor:alphaTensor
name:@"MM/alpha*(vec1Xvec2)"];
MPSGraphTensor* selfTimesBetaTensor = selfTensor;
if (is_beta_non_zero) {
selfTimesBetaTensor = [mpsGraph multiplicationWithPrimaryTensor:selfTensor
secondaryTensor:betaTensor
name:@"MM/beta*input"];
}

MPSGraphTensor* resultTensor = productTimesAlphaTensor;
if (is_beta_non_zero) {
resultTensor = [mpsGraph additionWithPrimaryTensor:productTimesAlphaTensor
secondaryTensor:selfTimesBetaTensor
name:@"MM/beta*input+alpha*(vec1@vec2)"];
}

newCachedGraph->vec1Tensor_ = t1;
newCachedGraph->vec2Tensor_ = t2;
newCachedGraph->selfTensor_ = selfTensor;
newCachedGraph->resultTensor_ = resultTensor;
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
}

Placeholder vec1Placeholder = Placeholder(cachedGraph->vec1Tensor_, vec1, inputShape);
Placeholder vec2Placeholder = Placeholder(cachedGraph->vec2Tensor_, vec2, otherShape);
Placeholder selfPlaceholder = Placeholder(cachedGraph->selfTensor_, *self_);
Placeholder resultPlaceholder = Placeholder(cachedGraph->resultTensor_, result);

NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
vec1Placeholder.getMPSGraphTensor() : vec1Placeholder.getMPSGraphTensorData(),
vec2Placeholder.getMPSGraphTensor() : vec2Placeholder.getMPSGraphTensorData(),
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()
};

NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
resultPlaceholder.getMPSGraphTensor() : resultPlaceholder.getMPSGraphTensorData()
};

mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}

return result;
}

Tensor& addmm_out_mps_impl(
const Tensor& bias,
const Tensor& self, // input
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,7 @@
variants: function, method
dispatch:
CPU, CUDA: addr
MPS: addr_mps
CompositeExplicitAutograd: math_addr

- func: addr_(Tensor(a!) self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)
Expand All @@ -606,6 +607,7 @@
- func: addr.out(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA: addr_out
MPS: addr_out_mps
CompositeExplicitAutograd: math_addr_out

- func: affine_grid_generator(Tensor theta, int[] size, bool align_corners) -> Tensor
Expand Down
33 changes: 32 additions & 1 deletion test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,13 @@ def test_bmm(self):
self.assertEqual(output_cpu, output_mps)
self.assertEqual(output_cpu.size(), output_mps.size())

def test_addr(self):
A = torch.ones(5, 10).to("mps")
B = torch.ones(5).to("mps")
C = torch.ones(10).to("mps")
D = torch.addr(A, B, C).to("cpu")
torch.testing.assert_close(D, torch.full((5, 10), 2.0))

def test_trace(self):
M_cpu = torch.randn(3, 3)
M_mps = M_cpu.detach().clone().to("mps")
Expand Down Expand Up @@ -6341,6 +6348,30 @@ def maybe_transpose(cond, m):
m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype))
self._test_addmm_addmv(torch.addmm, M, m1, m2, transpose_out=t4)

def _test_addr(self, f, t, m, v, alpha=None, beta=None):
dtype = t.dtype
numpy_dtype = dtype
alpha = 1.2 if alpha is None else alpha
beta = 0.8 if beta is None else beta
res1 = f(t, m, v, alpha=alpha, beta=beta)
res2 = alpha * np.outer(m.to(numpy_dtype).cpu().numpy(), v.to(numpy_dtype).cpu().numpy())
if beta != 0:
res2 += (torch.mul(t, beta)).to(numpy_dtype).cpu().numpy()
res2 = torch.from_numpy(res2).to(dtype)
self.assertEqual(res1, res2)

def test_addr(self, device="mps", dtype=torch.float32):
M = torch.randn(10, 25, device=device).to(dtype)
m1 = torch.randn(10, device=device).to(dtype)
m2 = torch.randn(25, device=device).to(dtype)
self._test_addr(torch.addr, M, m1, m2)

# Test beta=0, M=nan
M = torch.full((10, 25), math.nan, device=device).to(dtype)
m1 = torch.randn(10, device=device).to(dtype)
m2 = torch.randn(25, device=device).to(dtype)
self._test_addr(torch.addr, M, m1, m2, beta=0)

class TestGatherScatter(TestCase):
def test_slicing_with_step(self):
# Slicing with step
Expand Down Expand Up @@ -8626,7 +8657,7 @@ class TestConsistency(TestCase):
'addcmul': ['f32', 'i16', 'i32', 'i64', 'u8'],
'addmm': ['f32'],
'addmv': ['f32'],
'addr': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'],
'addr': ['f32'],
'all': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'allclose': ['f16', 'f32'],
'any': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
Expand Down