Skip to content

Commit 507b8c3

Browse files
DenisVieriu97pytorchmergebot
authored andcommitted
[MPS] Native implementation for addr (#94538)
``` addr_out_mps to perform res = betainput + alpha(vec1Xvec2) move addr f16 to low precision list move addr none float to unsupported list add test_addr tests ``` Pull Request resolved: #94538 Approved by: https://github.com/razarmehr
1 parent d51ca38 commit 507b8c3

File tree

3 files changed

+180
-1
lines changed

3 files changed

+180
-1
lines changed

aten/src/ATen/native/mps/operations/LinearAlgebra.mm

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,152 @@ void prepare_matrices_for_broadcasting(
185185
return output;
186186
}
187187

188+
189+
Tensor addr_mps(const Tensor& self,
190+
const Tensor& vec1, const Tensor& vec2,
191+
const Scalar& beta, const Scalar& alpha) {
192+
Tensor result = at::empty({0}, self.options());
193+
addr_out_mps(self, vec1,vec2,beta,alpha,result);
194+
return result;
195+
}
196+
197+
198+
Tensor& addr_out_mps(const Tensor& self,
199+
const Tensor& vec1, const Tensor& vec2,
200+
const Scalar& beta, const Scalar& alpha, Tensor &result) {
201+
using namespace mps;
202+
203+
TORCH_CHECK(result.is_mps());
204+
TORCH_CHECK(vec1.dim() == 1 && vec2.dim() == 1, "tensors must be 1-D");
205+
TORCH_CHECK(vec1.scalar_type() == ScalarType::Double
206+
|| vec1.scalar_type() == ScalarType::Float
207+
|| vec1.scalar_type() == ScalarType::Half, "MPS device does not support addr for non-float input");
208+
209+
TensorArg args[]{{result, "out", 0}, {self, "self", 1}, {vec1, "vec1", 2}, {vec2, "vec2", 3}};
210+
checkAllSameGPU(__func__, args);
211+
212+
IntArrayRef vec1_sizes = vec1.sizes();
213+
IntArrayRef vec2_sizes = vec2.sizes();
214+
IntArrayRef self_sizes;
215+
216+
c10::MaybeOwned<Tensor> self_;
217+
if (&result != &self) {
218+
self_ = expand_size(self, {vec1_sizes[0], vec2_sizes[0]}, "addr");
219+
self_sizes = self_->sizes();
220+
} else {
221+
self_ = c10::MaybeOwned<Tensor>::borrowed(self);
222+
self_sizes = self_->sizes();
223+
TORCH_CHECK(result.dim() == 2, "tensors must be 2-D");
224+
TORCH_CHECK(self_sizes[0] == vec1_sizes[0], "vec1_ dim 0 must match vec1 dim 0");
225+
TORCH_CHECK(self_sizes[1] == vec2_sizes[0], "vec1_ dim 1 must match vec2 dim 0");
226+
}
227+
228+
if (&result != &vec1) {
229+
result.resize_(self_sizes);
230+
if (beta.toComplexDouble() != 0.0) {
231+
at::native::copy_(result, *self_);
232+
}
233+
}
234+
235+
IntArrayRef result_sizes = result.sizes();
236+
if ((result_sizes[0] == 0) || (result_sizes[1] == 0)) {
237+
return result;
238+
}
239+
240+
MPSStream* stream = getCurrentMPSStream();
241+
bool is_beta_non_zero = beta.toDouble() != 0.0;
242+
MPSShape* inputShape = @[@(vec1.numel()), @(1)];
243+
MPSShape* otherShape = @[@(1), @(vec2.numel())];
244+
245+
struct CachedGraph : public mps::MPSCachedGraph
246+
{
247+
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
248+
MPSGraphTensor *vec1Tensor_ = nil;
249+
MPSGraphTensor *vec2Tensor_ = nil;
250+
MPSGraphTensor *selfTensor_ = nil;
251+
MPSGraphTensor *resultTensor_ = nil;
252+
};
253+
254+
mps::MPSGraphCache *cache_ = mps::MPSGraphCache::getInstance();
255+
256+
@autoreleasepool {
257+
string key = "addr_out_mps_impl" + getTensorsStringKey({vec1, vec2, *self_})
258+
+ ":" + to_string(beta.toDouble())
259+
+ ":" + to_string(alpha.toDouble());
260+
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
261+
if(!cachedGraph) {
262+
263+
mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ mps::MPSCachedGraph * () {
264+
CachedGraph *newCachedGraph = nil;
265+
266+
@autoreleasepool{
267+
MPSGraph *mpsGraph = mps::make_mps_graph();
268+
newCachedGraph = new CachedGraph(mpsGraph);
269+
270+
MPSGraphTensor *t1 = mps::mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(vec1.scalar_type()), inputShape);
271+
MPSGraphTensor *t2 = mps::mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(vec2.scalar_type()), otherShape);
272+
MPSGraphTensor *selfTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, *self_);
273+
274+
// Intermediate as placeholder
275+
MPSGraphTensor* productTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:t1
276+
secondaryTensor:t2
277+
name:@"MM/(vec1Xvec2)"];
278+
279+
// Intermediates for beta and alpha
280+
MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta.toDouble()
281+
dataType:getMPSScalarType((*self_).scalar_type())];
282+
MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar:alpha.toDouble()
283+
dataType:getMPSScalarType(vec1.scalar_type())];
284+
285+
// Intermediates for multiplying by beta and alpha
286+
MPSGraphTensor* productTimesAlphaTensor = [mpsGraph multiplicationWithPrimaryTensor:productTensor
287+
secondaryTensor:alphaTensor
288+
name:@"MM/alpha*(vec1Xvec2)"];
289+
MPSGraphTensor* selfTimesBetaTensor = selfTensor;
290+
if (is_beta_non_zero) {
291+
selfTimesBetaTensor = [mpsGraph multiplicationWithPrimaryTensor:selfTensor
292+
secondaryTensor:betaTensor
293+
name:@"MM/beta*input"];
294+
}
295+
296+
MPSGraphTensor* resultTensor = productTimesAlphaTensor;
297+
if (is_beta_non_zero) {
298+
resultTensor = [mpsGraph additionWithPrimaryTensor:productTimesAlphaTensor
299+
secondaryTensor:selfTimesBetaTensor
300+
name:@"MM/beta*input+alpha*(vec1@vec2)"];
301+
}
302+
303+
newCachedGraph->vec1Tensor_ = t1;
304+
newCachedGraph->vec2Tensor_ = t2;
305+
newCachedGraph->selfTensor_ = selfTensor;
306+
newCachedGraph->resultTensor_ = resultTensor;
307+
}
308+
return newCachedGraph;
309+
});
310+
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
311+
}
312+
313+
Placeholder vec1Placeholder = Placeholder(cachedGraph->vec1Tensor_, vec1, inputShape);
314+
Placeholder vec2Placeholder = Placeholder(cachedGraph->vec2Tensor_, vec2, otherShape);
315+
Placeholder selfPlaceholder = Placeholder(cachedGraph->selfTensor_, *self_);
316+
Placeholder resultPlaceholder = Placeholder(cachedGraph->resultTensor_, result);
317+
318+
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
319+
vec1Placeholder.getMPSGraphTensor() : vec1Placeholder.getMPSGraphTensorData(),
320+
vec2Placeholder.getMPSGraphTensor() : vec2Placeholder.getMPSGraphTensorData(),
321+
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()
322+
};
323+
324+
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
325+
resultPlaceholder.getMPSGraphTensor() : resultPlaceholder.getMPSGraphTensorData()
326+
};
327+
328+
mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
329+
}
330+
331+
return result;
332+
}
333+
188334
Tensor& addmm_out_mps_impl(
189335
const Tensor& bias,
190336
const Tensor& self, // input

aten/src/ATen/native/native_functions.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,7 @@
596596
variants: function, method
597597
dispatch:
598598
CPU, CUDA: addr
599+
MPS: addr_mps
599600
CompositeExplicitAutograd: math_addr
600601

601602
- func: addr_(Tensor(a!) self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)
@@ -606,6 +607,7 @@
606607
- func: addr.out(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
607608
dispatch:
608609
CPU, CUDA: addr_out
610+
MPS: addr_out_mps
609611
CompositeExplicitAutograd: math_addr_out
610612

611613
- func: affine_grid_generator(Tensor theta, int[] size, bool align_corners) -> Tensor

test/test_mps.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,13 @@ def test_bmm(self):
522522
self.assertEqual(output_cpu, output_mps)
523523
self.assertEqual(output_cpu.size(), output_mps.size())
524524

525+
def test_addr(self):
526+
A = torch.ones(5, 10).to("mps")
527+
B = torch.ones(5).to("mps")
528+
C = torch.ones(10).to("mps")
529+
D = torch.addr(A, B, C).to("cpu")
530+
torch.testing.assert_close(D, torch.full((5, 10), 2.0))
531+
525532
def test_trace(self):
526533
M_cpu = torch.randn(3, 3)
527534
M_mps = M_cpu.detach().clone().to("mps")
@@ -6422,6 +6429,30 @@ def maybe_transpose(cond, m):
64226429
m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype))
64236430
self._test_addmm_addmv(torch.addmm, M, m1, m2, transpose_out=t4)
64246431

6432+
def _test_addr(self, f, t, m, v, alpha=None, beta=None):
6433+
dtype = t.dtype
6434+
numpy_dtype = dtype
6435+
alpha = 1.2 if alpha is None else alpha
6436+
beta = 0.8 if beta is None else beta
6437+
res1 = f(t, m, v, alpha=alpha, beta=beta)
6438+
res2 = alpha * np.outer(m.to(numpy_dtype).cpu().numpy(), v.to(numpy_dtype).cpu().numpy())
6439+
if beta != 0:
6440+
res2 += (torch.mul(t, beta)).to(numpy_dtype).cpu().numpy()
6441+
res2 = torch.from_numpy(res2).to(dtype)
6442+
self.assertEqual(res1, res2)
6443+
6444+
def test_addr(self, device="mps", dtype=torch.float32):
6445+
M = torch.randn(10, 25, device=device).to(dtype)
6446+
m1 = torch.randn(10, device=device).to(dtype)
6447+
m2 = torch.randn(25, device=device).to(dtype)
6448+
self._test_addr(torch.addr, M, m1, m2)
6449+
6450+
# Test beta=0, M=nan
6451+
M = torch.full((10, 25), math.nan, device=device).to(dtype)
6452+
m1 = torch.randn(10, device=device).to(dtype)
6453+
m2 = torch.randn(25, device=device).to(dtype)
6454+
self._test_addr(torch.addr, M, m1, m2, beta=0)
6455+
64256456
class TestGatherScatter(TestCase):
64266457
def test_slicing_with_step(self):
64276458
# Slicing with step
@@ -8707,7 +8738,7 @@ class TestConsistency(TestCase):
87078738
'addcmul': ['f32', 'i16', 'i32', 'i64', 'u8'],
87088739
'addmm': ['f32'],
87098740
'addmv': ['f32'],
8710-
'addr': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'],
8741+
'addr': ['f32'],
87118742
'all': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
87128743
'allclose': ['f16', 'f32'],
87138744
'any': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],

0 commit comments

Comments
 (0)