Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions aten/src/ATen/native/mps/operations/RangeFactories.mm
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,77 @@
return result;
}

Tensor& range_mps_out(const Scalar& start, const Scalar& end, const Scalar& step, Tensor& result) {
AT_DISPATCH_MPS_TYPES(result.scalar_type(), "arange_mps", [&]() {
using accscalar_t = at::acc_type<scalar_t, true>;
auto xstart = start.to<accscalar_t>();
auto xend = end.to<accscalar_t>();
auto xstep = step.to<accscalar_t>();

// double size_d = ((xend - xstart) / xstep) + 1;
double size_d;
if (std::is_same<scalar_t, int64_t>::value) {
size_d = static_cast<double>(end.to<accscalar_t>() - start.to<accscalar_t>())
/ step.to<accscalar_t>() + 1;
} else {
size_d = static_cast<double>(end.to<double>() - start.to<double>())
/ step.to<double>() + 1;
}

TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero");
TORCH_CHECK(std::isfinite(static_cast<double>(xstart)) &&
std::isfinite(static_cast<double>(xend)),
"unsupported range: ", xstart, " -> ", xend);
TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)),
"upper bound and larger bound inconsistent with step sign");

TORCH_CHECK(size_d >= 0 && size_d <= static_cast<double>(std::numeric_limits<int64_t>::max()),
"invalid size, possible overflow?");

int64_t size = static_cast<int64_t>(size_d);

int64_t numel = result.numel();

if (numel != size) {
result.resize_({size});
}
bool is_contiguous = result.is_contiguous();
Tensor r = !is_contiguous ? at::empty_like(result, LEGACY_CONTIGUOUS_MEMORY_FORMAT) : result;
using namespace mps;
auto cache_ = MPSGraphCache::getInstance();
auto stream = getCurrentMPSStream();
auto mpsDataType = getMPSDataType(result.scalar_type());
@autoreleasepool {
string key = "arange_mps_out" + getTensorsStringKey({result}) + ":" + to_string(size);
auto cachedGraph = static_cast<RangeCachedGraph *>(cache_->LookUp(key));
if (!cachedGraph) {
auto *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph *() {
auto mpsGraph = make_mps_graph();
return new RangeCachedGraph(mpsGraph, mpsDataType, size);
});
cachedGraph = static_cast<RangeCachedGraph *>(tmpCachedGraph);
}
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, r);
NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease];
MPSScalar startScalar = getMPSScalar(start, result.scalar_type());
feeds[cachedGraph->startTensor] = getMPSGraphTensorFromScalar(stream, startScalar);
MPSScalar stepScalar = getMPSScalar(step, result.scalar_type());
feeds[cachedGraph->multiplyTensor] = getMPSGraphTensorFromScalar(stream, stepScalar);

NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}

if(!is_contiguous) {
result.copy_(r);
}
});

return result;
}

Tensor& linspace_out_mps(const Scalar& start, const Scalar& end, int64_t steps, Tensor& result) {
using namespace mps;

Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4458,6 +4458,7 @@
dispatch:
CPU, Meta: range_out
CUDA: range_cuda_out
MPS: range_mps_out
cpp_no_default_args: ['step']

- func: ravel(Tensor(a) self) -> Tensor(a)
Expand Down
7 changes: 7 additions & 0 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -5794,6 +5794,13 @@ def test_arange_empty(self):
y_cpu = torch.arange(0, 0, 1, out=out_cpu)
self.assertEqual(y_mps, y_cpu)

# Test rgange
def test_range(self):
self.assertEqual(np.arange(11, dtype=np.float32), torch.range(0, 10, device='mps'))
self.assertEqual(np.arange(7, 0, -1, dtype=np.float32), torch.range(7, 1, -1, device='mps'))
self.assertEqual(np.array([1.0000, 1.3000, 1.6000, 1.9000], dtype=np.float32), torch.range(1, 2, .3, device='mps'))
self.assertEqual(np.arange(6.3, dtype=np.float32), torch.arange(0, 6.3, device='mps'))

# Test softmax
def test_softmax(self):
def helper(shape, dim, channels_last=False):
Expand Down