Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions aten/src/ATen/native/mps/OperationUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ MPSDataType getMPSDataType(ScalarType scalar_type);
MPSDataType getMPSScalarType(ScalarType scalar_type);
MPSScalar getMPSScalar(const Scalar& scalar, ScalarType type);
std::string getMPSTypeString(ScalarType scalar_type);
NSArray<NSNumber*>* getTensorAxes(const Tensor& t);
NSArray<NSNumber*>* getTensorAxes(const Tensor& t, at::OptionalIntArrayRef dim);
std::string getMPSShapeString(MPSShape* shape);
std::string getTensorsStringKey(const TensorList& tensors, bool use_scalar_value = false);
std::string getArrayRefString(const IntArrayRef s);
Expand Down Expand Up @@ -127,6 +129,13 @@ struct MPSUnaryCachedGraph : public MPSCachedGraph
MPSGraphTensor *outputTensor_ = nil;
};

struct MPSBinaryCachedGraph : public MPSCachedGraph
{
MPSBinaryCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *inputTensor_ = nil;
MPSGraphTensor *otherTensor_ = nil;
MPSGraphTensor *outputTensor_ = nil;
};

// TODO: Improve the overall design of MPSGraphCache.
// https://github.com/pytorch/pytorch/issues/77176
Expand Down
24 changes: 24 additions & 0 deletions aten/src/ATen/native/mps/OperationUtils.mm
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,30 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) {
}
}

NSArray<NSNumber*>* getTensorAxes(const Tensor& t) {
int64_t ndim = t.dim();
auto axes = [NSMutableArray<NSNumber*> arrayWithCapacity:ndim];
for (const auto i: c10::irange(ndim)) {
axes[i] = [NSNumber numberWithInteger:i];
}
return axes;
}

NSArray<NSNumber*>* getTensorAxes(const Tensor& t, at::OptionalIntArrayRef dim) {
if (dim.has_value() && dim.value().size() != 0) {
IntArrayRef dimValues = dim.value();
int ndim = dimValues.size();
auto axes = [NSMutableArray<NSNumber*> arrayWithCapacity:ndim];
for (const auto i: c10::irange(ndim)) {
axes[i] = [NSNumber numberWithInteger:dimValues[i]];
}

return axes;
}

return getTensorAxes(t);
}

std::string getMPSShapeString(MPSShape* shape) {
std::string str;
for(NSNumber *elem in shape) {
Expand Down
212 changes: 166 additions & 46 deletions aten/src/ATen/native/mps/operations/ReduceOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
namespace at {
namespace native {

typedef MPSGraphTensor* (^NormOpBlock)(mps::MPSBinaryCachedGraph*, MPSGraphTensor*, MPSGraphTensor*);
#define NormOpFn(graph, primary, secondary) MPSGraphTensor* (mps::MPSBinaryCachedGraph* graph, MPSGraphTensor* primary, MPSGraphTensor* secondary)

enum StdVarType {
STANDARD_VARIANCE,
STANDARD_DEVIATION
Expand All @@ -34,15 +37,6 @@

using namespace mps;

NSArray<NSNumber*>* getTensorAxes(const Tensor& t) {
int64_t ndim = t.dim();
auto axes = [NSMutableArray<NSNumber*> arrayWithCapacity:ndim];
for (const auto i: c10::irange(ndim)) {
axes[i] = [NSNumber numberWithInteger:i];
}
return axes;
}

void set_apparent_shapes(NSMutableArray<NSNumber*> * &apparent_out_shape,
NSMutableArray<NSNumber*> * &apparent_in_shape,
int64_t num_reduce_dims,
Expand Down Expand Up @@ -410,19 +404,28 @@ Tensor count_nonzero_mps(const Tensor& self, IntArrayRef dims) {
reduction_out_mps(input_t, opt_dim, keepdim, dtype, output_t, MPSReductionType::MEAN, "mean_out_mps");
}

TORCH_IMPL_FUNC(norm_out_mps)
(const Tensor& input_tensor,
const OptionalScalarRef opt_p,
IntArrayRef dim,
bool keepdim,
const Tensor& output_t) {
void impl_func_norm_mps(
const Tensor& input_tensor,
const Tensor& other_tensor,
const OptionalScalarRef& opt_p,
IntArrayRef dim,
bool keepdim,
c10::optional<ScalarType> opt_dtype,
const Tensor& output_t,
bool cdist = false,
c10::optional<IntArrayRef> input_broadcasted_shape = c10::nullopt,
NormOpBlock normOpBlock = nullptr
) {

if (input_tensor.numel() == 0) {
return;
}

auto input_t = (input_tensor.sizes().size() == 0) ? input_tensor.view({1}) : input_tensor;
auto in_dtype = opt_dtype.value_or(input_tensor.scalar_type());
auto mps_input_dtype = getMPSDataType(in_dtype);

IntArrayRef input_shape = input_t.sizes();
IntArrayRef input_shape = cdist ? input_broadcasted_shape.value() : input_t.sizes();

for (const auto dim_val: dim) {
auto wrap_dim = maybe_wrap_dim(dim_val, input_shape.size());
Expand Down Expand Up @@ -456,6 +459,13 @@ Tensor count_nonzero_mps(const Tensor& self, IntArrayRef dims) {
num_output_dims,
input_shape,
axes);

NSArray<NSNumber*>* wrappedAxes = mps::getTensorAxes(input_t, dim);
if (cdist) {
apparent_input_shape = [mps::getMPSShape(input_tensor.sizes()) mutableCopy];
apparent_output_shape = [mps::getMPSShape(output_t.sizes()) mutableCopy];
}

if (output_t.numel() == 0) {
return;
}
Expand All @@ -465,100 +475,210 @@ Tensor count_nonzero_mps(const Tensor& self, IntArrayRef dims) {
@autoreleasepool {
NSString* ns_key = [[axes valueForKey:@"description"] componentsJoinedByString:@","];
string keepdim_info = (keepdim) ? "keepdim=1" : "keepdim=0";
string key = string("norm_out_mps:") + [ns_key UTF8String] + ":" + getMPSTypeString(input_t.scalar_type()) + ":p" + to_string(p) + ":" + keepdim_info;
string tensor_key = cdist ? getTensorsStringKey({input_tensor, other_tensor}) : mps::getTensorsStringKey({input_t});
string key = string("norm_out_mps:") + [ns_key UTF8String] + ":" + tensor_key + ":p" + to_string(p) + ":" + keepdim_info;

auto cachedGraph = cache_->LookUpAs<MPSUnaryCachedGraph>(key);
auto cachedGraph = cache_->LookUpAs<MPSBinaryCachedGraph>(key);

if (!cachedGraph) {
cachedGraph = cache_->CreateCachedGraphAs<MPSUnaryCachedGraph>(key, ^ MPSCachedGraph * () {
if(!cachedGraph) {
cachedGraph = cache_->CreateCachedGraphAs<MPSBinaryCachedGraph>(key, ^ MPSCachedGraph * () {

MPSUnaryCachedGraph *newCachedGraph = nil;
MPSBinaryCachedGraph *newCachedGraph = nil;

@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new MPSUnaryCachedGraph(mpsGraph);
newCachedGraph = new MPSBinaryCachedGraph(mpsGraph);
newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, input_tensor);

MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(input_t.scalar_type()));
if (cdist) {
newCachedGraph->otherTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, other_tensor);
}

MPSGraphTensor *outputTensor = nil;
MPSGraphTensor* inputTensor = cdist ? normOpBlock(newCachedGraph, newCachedGraph->inputTensor_, newCachedGraph->otherTensor_) :
newCachedGraph->inputTensor_;
if (opt_dtype.has_value()) {
inputTensor = [mpsGraph castTensor:inputTensor
toType:mps_input_dtype
name:@"castInputTensor"];
}

MPSGraphTensor *outputTensor;

if (pIsZero) {
MPSGraphTensor *absoluteTensor = [mpsGraph absoluteWithTensor:inputTensor
name:nil];
MPSGraphTensor *powerValTensor = [mpsGraph constantWithScalar:p
dataType:getMPSDataType(input_t.scalar_type())];
dataType:mps_input_dtype];
MPSGraphTensor *powerTensor = [mpsGraph powerWithPrimaryTensor:absoluteTensor
secondaryTensor:powerValTensor
name:nil];
outputTensor = [mpsGraph reductionSumWithTensor:powerTensor
axes:axes
axes:wrappedAxes
name:nil];
} else if (pIsPosInf) {
}
else if (pIsPosInf) {
MPSGraphTensor *absoluteTensor = [mpsGraph absoluteWithTensor:inputTensor
name:nil];
outputTensor = [mpsGraph reductionMaximumWithTensor:absoluteTensor
axes:axes
axes:wrappedAxes
name:nil];
} else if (pIsNegInf) {
}
else if (pIsNegInf) {
MPSGraphTensor *absoluteTensor = [mpsGraph absoluteWithTensor:inputTensor
name:nil];
outputTensor = [mpsGraph reductionMinimumWithTensor:absoluteTensor
axes:axes
axes:wrappedAxes
name:nil];
} else {
MPSGraphTensor *absoluteTensor = [mpsGraph absoluteWithTensor:inputTensor
name:nil];

MPSGraphTensor *powerValTensor = [mpsGraph constantWithScalar:p
dataType:getMPSDataType(input_t.scalar_type())];
dataType:mps_input_dtype];

MPSGraphTensor *reciprocalPowerValTensor = [mpsGraph constantWithScalar:reciprocal_p
dataType:getMPSDataType(input_t.scalar_type())];
dataType:mps_input_dtype];

MPSGraphTensor *powerTensor = [mpsGraph powerWithPrimaryTensor:absoluteTensor
secondaryTensor:powerValTensor
name:nil];

MPSGraphTensor *reductionSumTensor = [mpsGraph reductionSumWithTensor:powerTensor
axes:axes
axes:wrappedAxes
name:nil];

outputTensor = [mpsGraph powerWithPrimaryTensor:reductionSumTensor
secondaryTensor:reciprocalPowerValTensor
name:nil];
}

newCachedGraph->inputTensor_ = inputTensor;
if (cdist) {
outputTensor= [mpsGraph reshapeTensor:outputTensor withShape:mps::getMPSShape(output_t) name: nil];
}

newCachedGraph->outputTensor_ = outputTensor;
}
return newCachedGraph;
});
}

auto inputPlaceholder = Placeholder();

if (apparent_input_shape) {
inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t, apparent_input_shape);
} else {
inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t);
}

auto otherPlaceholder = Placeholder();
auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t);
auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output_t, apparent_output_shape);

NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds =[NSMutableDictionary dictionary];
feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();

NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
};
if (cdist) {
otherPlaceholder = Placeholder(cachedGraph->otherTensor_, other_tensor);
feeds[otherPlaceholder.getMPSGraphTensor()] = otherPlaceholder.getMPSGraphTensorData();
}

NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};

runMPSGraph(stream, cachedGraph->graph(), feeds, results);

}
}

TORCH_IMPL_FUNC(norm_out_mps)
(const Tensor& self,
const OptionalScalarRef opt_p,
IntArrayRef dim,
bool keepdim,
const Tensor& result) {
impl_func_norm_mps(self, self, opt_p, dim, keepdim, c10::nullopt, result, /*cdist=*/false);
}

TORCH_IMPL_FUNC(norm_dtype_out_mps)
(const Tensor& self,
const OptionalScalarRef opt_p,
IntArrayRef dim,
bool keepdim,
ScalarType dtype,
const Tensor& result) {
impl_func_norm_mps(self, self, opt_p, dim, keepdim, dtype, result, /*cdist=*/false);
}


Tensor _cdist_forward_mps(const Tensor& x1, const Tensor& x2, const double p, c10::optional<int64_t> compute_mode) {
using namespace mps;
TORCH_CHECK(x1.dim() >= 2, "cdist only supports at least 2D tensors, X1 got: ", x1.dim(), "D");
TORCH_CHECK(x2.dim() >= 2, "cdist only supports at least 2D tensors, X2 got: ", x2.dim(), "D");
TORCH_CHECK(x1.size(-1) == x2.size(-1), "X1 and X2 must have the same number of columns. X1: ", x1.size(-1), " X2: ", x2.size(-1));
TORCH_CHECK(at::isFloatingType(x1.scalar_type()), "cdist only supports floating-point dtypes, X1 got: ", x1.scalar_type());
auto device1 = x1.device().type();
TORCH_CHECK(at::isFloatingType(x2.scalar_type()), "cdist only supports floating-point dtypes, X2 got: ", x2.scalar_type());
auto device2 = x2.device().type();
TORCH_CHECK(p >= 0, "cdist only supports non-negative p values");
TORCH_CHECK(device1 == device2, "X1 and X2 must have the same device type. X1: ", device1, " X2: ", device2);
TORCH_CHECK(x1.is_mps() && (x1.get_device() == x2.get_device()), "device of X1 (", x1.get_device(), ") must match device of X2 (", x2.get_device(), ")");

int64_t c1 = x1.size(-1);
int64_t c2 = x2.size(-1);

auto dim1 = x1.dim();
auto dim2 = x2.dim();
int64_t mode = compute_mode.value_or(0);
TORCH_CHECK(mode >= 0 && mode <= 2, "possible modes: 0, 1, 2, but was: ", mode);

int64_t r1 = x1.size(-2);
int64_t r2 = x2.size(-2);

//For batch calculation we expand all dimensions(except the last two) to one, with size that equals to product of them.
//The last two dimensions will stay the same
IntArrayRef batch_tensor1(x1.sizes().data(), dim1 - 2);
IntArrayRef batch_tensor2(x2.sizes().data(), dim2 - 2);
std::vector<int64_t> expand_batch_portion = infer_size(batch_tensor1, batch_tensor2);
std::vector<int64_t> tensor1_expand_size(expand_batch_portion);
tensor1_expand_size.insert(tensor1_expand_size.end(), {r1, c1});
std::vector<int64_t> tensor2_expand_size(expand_batch_portion);
tensor2_expand_size.insert(tensor2_expand_size.end(), {r2, c2});

const int64_t expand_batch_product = c10::multiply_integers(expand_batch_portion);
std::vector<int64_t> tensor1_view{expand_batch_product, r1, c1};
std::vector<int64_t> tensor2_view{expand_batch_product, r2, c2};

std::vector<int64_t> output_shape(expand_batch_portion);
output_shape.insert(output_shape.end(), {r1, r2});
Tensor result = at::empty(output_shape, x1.options());

NormOpBlock norm_op_block = ^NormOpFn(cachedGraph, x1Tensor, x2Tensor) {
MPSGraph* mpsGraph = cachedGraph->graph();

MPSGraphTensor* inputBroadcast = [mpsGraph broadcastTensor:x1Tensor toShape:getMPSShape(tensor1_expand_size) name:nil];
MPSGraphTensor* inputBroadcastReshape = [mpsGraph reshapeTensor:inputBroadcast withShape:getMPSShape(tensor1_view) name:nil];

MPSGraphTensor* otherBroadcast = [mpsGraph broadcastTensor:x2Tensor toShape:getMPSShape(tensor2_expand_size) name:nil];
MPSGraphTensor* otherBroadcastReshape = [mpsGraph reshapeTensor:otherBroadcast withShape:getMPSShape(tensor2_view) name:nil];

NSMutableArray<MPSGraphTensor*> *inputArray = [NSMutableArray arrayWithCapacity:tensor1_view[1]];
NSMutableArray<MPSGraphTensor*> *otherArray = [NSMutableArray arrayWithCapacity:tensor2_view[1]];

for (const auto i : c10::irange(tensor2_view[1])) {
inputArray[i] = inputBroadcastReshape;
}

for (const auto i : c10::irange(tensor1_view[1])) {
otherArray[i] = otherBroadcastReshape;
}

MPSGraphTensor *inputTensorReshaped = [mpsGraph concatTensors:inputArray dimension:1 interleave:YES name:nil];
MPSGraphTensor *otherTensorReshaped = [mpsGraph concatTensors:otherArray dimension:1 interleave:NO name:nil];


MPSGraphTensor *inputTensorPNorm = [mpsGraph subtractionWithPrimaryTensor: inputTensorReshaped
secondaryTensor: otherTensorReshaped
name: nil];
return inputTensorPNorm;
};

c10::optional<IntArrayRef> inputBroadcastSize = c10::make_optional(makeArrayRef(tensor1_view.data(), tensor1_view.size()));
impl_func_norm_mps(x1, x2, OptionalScalarRef(p), makeArrayRef<int64_t>(2), false, c10::nullopt, result, /*cdist=*/true, inputBroadcastSize, norm_op_block);
return result;
}

Tensor std_var_common_impl_mps(const Tensor & input_t,
at::OptionalIntArrayRef dim,
c10::optional<int64_t> correction,
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4090,6 +4090,7 @@
- func: _cdist_forward(Tensor x1, Tensor x2, float p, int? compute_mode) -> Tensor
dispatch:
CPU, CUDA: _cdist_forward
MPS: _cdist_forward_mps
autogen: _cdist_forward.out

- func: _cdist_backward(Tensor grad, Tensor x1, Tensor x2, float p, Tensor cdist) -> Tensor
Expand Down Expand Up @@ -6174,6 +6175,7 @@
device_check: NoCheck # TensorIterator
dispatch:
CPU, CUDA: norm_dtype_out
MPS: norm_dtype_out_mps

- func: norm.out(Tensor self, Scalar? p, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
structured: True
Expand Down
Loading