@@ -185,6 +185,152 @@ void prepare_matrices_for_broadcasting(
185185 return output;
186186}
187187
188+
189+ Tensor addr_mps (const Tensor& self,
190+ const Tensor& vec1, const Tensor& vec2,
191+ const Scalar& beta, const Scalar& alpha) {
192+ Tensor result = at::empty ({0 }, self.options ());
193+ addr_out_mps (self, vec1,vec2,beta,alpha,result);
194+ return result;
195+ }
196+
197+
198+ Tensor& addr_out_mps (const Tensor& self,
199+ const Tensor& vec1, const Tensor& vec2,
200+ const Scalar& beta, const Scalar& alpha, Tensor &result) {
201+ using namespace mps ;
202+
203+ TORCH_CHECK (result.is_mps ());
204+ TORCH_CHECK (vec1.dim () == 1 && vec2.dim () == 1 , " tensors must be 1-D" );
205+ TORCH_CHECK (vec1.scalar_type () == ScalarType::Double
206+ || vec1.scalar_type () == ScalarType::Float
207+ || vec1.scalar_type () == ScalarType::Half, " MPS device does not support addr for non-float input" );
208+
209+ TensorArg args[]{{result, " out" , 0 }, {self, " self" , 1 }, {vec1, " vec1" , 2 }, {vec2, " vec2" , 3 }};
210+ checkAllSameGPU (__func__, args);
211+
212+ IntArrayRef vec1_sizes = vec1.sizes ();
213+ IntArrayRef vec2_sizes = vec2.sizes ();
214+ IntArrayRef self_sizes;
215+
216+ c10::MaybeOwned<Tensor> self_;
217+ if (&result != &self) {
218+ self_ = expand_size (self, {vec1_sizes[0 ], vec2_sizes[0 ]}, " addr" );
219+ self_sizes = self_->sizes ();
220+ } else {
221+ self_ = c10::MaybeOwned<Tensor>::borrowed (self);
222+ self_sizes = self_->sizes ();
223+ TORCH_CHECK (result.dim () == 2 , " tensors must be 2-D" );
224+ TORCH_CHECK (self_sizes[0 ] == vec1_sizes[0 ], " vec1_ dim 0 must match vec1 dim 0" );
225+ TORCH_CHECK (self_sizes[1 ] == vec2_sizes[0 ], " vec1_ dim 1 must match vec2 dim 0" );
226+ }
227+
228+ if (&result != &vec1) {
229+ result.resize_ (self_sizes);
230+ if (beta.toComplexDouble () != 0.0 ) {
231+ at::native::copy_ (result, *self_);
232+ }
233+ }
234+
235+ IntArrayRef result_sizes = result.sizes ();
236+ if ((result_sizes[0 ] == 0 ) || (result_sizes[1 ] == 0 )) {
237+ return result;
238+ }
239+
240+ MPSStream* stream = getCurrentMPSStream ();
241+ bool is_beta_non_zero = beta.toDouble () != 0.0 ;
242+ MPSShape* inputShape = @[@(vec1.numel ()), @(1 )];
243+ MPSShape* otherShape = @[@(1 ), @(vec2.numel ())];
244+
245+ struct CachedGraph : public mps ::MPSCachedGraph
246+ {
247+ CachedGraph (MPSGraph *graph) : MPSCachedGraph(graph) {}
248+ MPSGraphTensor *vec1Tensor_ = nil ;
249+ MPSGraphTensor *vec2Tensor_ = nil ;
250+ MPSGraphTensor *selfTensor_ = nil ;
251+ MPSGraphTensor *resultTensor_ = nil ;
252+ };
253+
254+ mps::MPSGraphCache *cache_ = mps::MPSGraphCache::getInstance ();
255+
256+ @autoreleasepool {
257+ string key = " addr_out_mps_impl" + getTensorsStringKey ({vec1, vec2, *self_})
258+ + " :" + to_string (beta.toDouble ())
259+ + " :" + to_string (alpha.toDouble ());
260+ CachedGraph* cachedGraph = static_cast <CachedGraph *>(cache_->LookUp (key));
261+ if (!cachedGraph) {
262+
263+ mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph (key, ^ mps::MPSCachedGraph * () {
264+ CachedGraph *newCachedGraph = nil ;
265+
266+ @autoreleasepool{
267+ MPSGraph *mpsGraph = mps::make_mps_graph ();
268+ newCachedGraph = new CachedGraph (mpsGraph);
269+
270+ MPSGraphTensor *t1 = mps::mpsGraphRankedPlaceHolder (mpsGraph, getMPSDataType (vec1.scalar_type ()), inputShape);
271+ MPSGraphTensor *t2 = mps::mpsGraphRankedPlaceHolder (mpsGraph, getMPSDataType (vec2.scalar_type ()), otherShape);
272+ MPSGraphTensor *selfTensor = mps::mpsGraphRankedPlaceHolder (mpsGraph, *self_);
273+
274+ // Intermediate as placeholder
275+ MPSGraphTensor* productTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor: t1
276+ secondaryTensor: t2
277+ name: @" MM/(vec1Xvec2)" ];
278+
279+ // Intermediates for beta and alpha
280+ MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar: beta.toDouble ()
281+ dataType: getMPSScalarType ((*self_).scalar_type ())];
282+ MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar: alpha.toDouble ()
283+ dataType: getMPSScalarType (vec1.scalar_type ())];
284+
285+ // Intermediates for multiplying by beta and alpha
286+ MPSGraphTensor* productTimesAlphaTensor = [mpsGraph multiplicationWithPrimaryTensor: productTensor
287+ secondaryTensor: alphaTensor
288+ name: @" MM/alpha*(vec1Xvec2)" ];
289+ MPSGraphTensor* selfTimesBetaTensor = selfTensor;
290+ if (is_beta_non_zero) {
291+ selfTimesBetaTensor = [mpsGraph multiplicationWithPrimaryTensor: selfTensor
292+ secondaryTensor: betaTensor
293+ name: @" MM/beta*input" ];
294+ }
295+
296+ MPSGraphTensor* resultTensor = productTimesAlphaTensor;
297+ if (is_beta_non_zero) {
298+ resultTensor = [mpsGraph additionWithPrimaryTensor: productTimesAlphaTensor
299+ secondaryTensor: selfTimesBetaTensor
300+ name: @" MM/beta*input+alpha*(vec1@vec2)" ];
301+ }
302+
303+ newCachedGraph->vec1Tensor_ = t1;
304+ newCachedGraph->vec2Tensor_ = t2;
305+ newCachedGraph->selfTensor_ = selfTensor;
306+ newCachedGraph->resultTensor_ = resultTensor;
307+ }
308+ return newCachedGraph;
309+ });
310+ cachedGraph = static_cast <CachedGraph *>(tmpCachedGraph);
311+ }
312+
313+ Placeholder vec1Placeholder = Placeholder (cachedGraph->vec1Tensor_ , vec1, inputShape);
314+ Placeholder vec2Placeholder = Placeholder (cachedGraph->vec2Tensor_ , vec2, otherShape);
315+ Placeholder selfPlaceholder = Placeholder (cachedGraph->selfTensor_ , *self_);
316+ Placeholder resultPlaceholder = Placeholder (cachedGraph->resultTensor_ , result);
317+
318+ NSDictionary <MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
319+ vec1Placeholder.getMPSGraphTensor () : vec1Placeholder.getMPSGraphTensorData (),
320+ vec2Placeholder.getMPSGraphTensor () : vec2Placeholder.getMPSGraphTensorData (),
321+ selfPlaceholder.getMPSGraphTensor () : selfPlaceholder.getMPSGraphTensorData ()
322+ };
323+
324+ NSDictionary <MPSGraphTensor*, MPSGraphTensorData*>* results = @{
325+ resultPlaceholder.getMPSGraphTensor () : resultPlaceholder.getMPSGraphTensorData ()
326+ };
327+
328+ mps::runMPSGraph (stream, cachedGraph->graph (), feeds, results);
329+ }
330+
331+ return result;
332+ }
333+
188334Tensor& addmm_out_mps_impl (
189335 const Tensor& bias,
190336 const Tensor& self, // input
0 commit comments