Skip to content

Commit 3f8f727

Browse files
OwenElliottDevDenisVieriu97
authored andcommitted
Add range MPS support
1 parent 701412a commit 3f8f727

File tree

3 files changed

+79
-0
lines changed

3 files changed

+79
-0
lines changed

aten/src/ATen/native/mps/operations/RangeFactories.mm

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,77 @@
129129
return result;
130130
}
131131

132+
Tensor& range_mps_out(const Scalar& start, const Scalar& end, const Scalar& step, Tensor& result) {
133+
AT_DISPATCH_MPS_TYPES(result.scalar_type(), "arange_mps", [&]() {
134+
using accscalar_t = at::acc_type<scalar_t, true>;
135+
auto xstart = start.to<accscalar_t>();
136+
auto xend = end.to<accscalar_t>();
137+
auto xstep = step.to<accscalar_t>();
138+
139+
// double size_d = ((xend - xstart) / xstep) + 1;
140+
double size_d;
141+
if (std::is_same<scalar_t, int64_t>::value) {
142+
size_d = static_cast<double>(end.to<accscalar_t>() - start.to<accscalar_t>())
143+
/ step.to<accscalar_t>() + 1;
144+
} else {
145+
size_d = static_cast<double>(end.to<double>() - start.to<double>())
146+
/ step.to<double>() + 1;
147+
}
148+
149+
TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero");
150+
TORCH_CHECK(std::isfinite(static_cast<double>(xstart)) &&
151+
std::isfinite(static_cast<double>(xend)),
152+
"unsupported range: ", xstart, " -> ", xend);
153+
TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)),
154+
"upper bound and larger bound inconsistent with step sign");
155+
156+
TORCH_CHECK(size_d >= 0 && size_d <= static_cast<double>(std::numeric_limits<int64_t>::max()),
157+
"invalid size, possible overflow?");
158+
159+
int64_t size = static_cast<int64_t>(size_d);
160+
161+
int64_t numel = result.numel();
162+
163+
if (numel != size) {
164+
result.resize_({size});
165+
}
166+
bool is_contiguous = result.is_contiguous();
167+
Tensor r = !is_contiguous ? at::empty_like(result, LEGACY_CONTIGUOUS_MEMORY_FORMAT) : result;
168+
using namespace mps;
169+
auto cache_ = MPSGraphCache::getInstance();
170+
auto stream = getCurrentMPSStream();
171+
auto mpsDataType = getMPSDataType(result.scalar_type());
172+
@autoreleasepool {
173+
string key = "arange_mps_out" + getTensorsStringKey({result}) + ":" + to_string(size);
174+
auto cachedGraph = static_cast<RangeCachedGraph *>(cache_->LookUp(key));
175+
if (!cachedGraph) {
176+
auto *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph *() {
177+
auto mpsGraph = make_mps_graph();
178+
return new RangeCachedGraph(mpsGraph, mpsDataType, size);
179+
});
180+
cachedGraph = static_cast<RangeCachedGraph *>(tmpCachedGraph);
181+
}
182+
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, r);
183+
NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease];
184+
MPSScalar startScalar = getMPSScalar(start, result.scalar_type());
185+
feeds[cachedGraph->startTensor] = getMPSGraphTensorFromScalar(stream, startScalar);
186+
MPSScalar stepScalar = getMPSScalar(step, result.scalar_type());
187+
feeds[cachedGraph->multiplyTensor] = getMPSGraphTensorFromScalar(stream, stepScalar);
188+
189+
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
190+
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
191+
};
192+
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
193+
}
194+
195+
if(!is_contiguous) {
196+
result.copy_(r);
197+
}
198+
});
199+
200+
return result;
201+
}
202+
132203
Tensor& linspace_out_mps(const Scalar& start, const Scalar& end, int64_t steps, Tensor& result) {
133204
using namespace mps;
134205

aten/src/ATen/native/native_functions.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4458,6 +4458,7 @@
44584458
dispatch:
44594459
CPU, Meta: range_out
44604460
CUDA: range_cuda_out
4461+
MPS: range_mps_out
44614462
cpp_no_default_args: ['step']
44624463

44634464
- func: ravel(Tensor(a) self) -> Tensor(a)

test/test_mps.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5794,6 +5794,13 @@ def test_arange_empty(self):
57945794
y_cpu = torch.arange(0, 0, 1, out=out_cpu)
57955795
self.assertEqual(y_mps, y_cpu)
57965796

5797+
# Test rgange
5798+
def test_range(self):
5799+
self.assertEqual(np.arange(11, dtype=np.float32), torch.range(0, 10, device='mps'))
5800+
self.assertEqual(np.arange(7, 0, -1, dtype=np.float32), torch.range(7, 1, -1, device='mps'))
5801+
self.assertEqual(np.array([1.0000, 1.3000, 1.6000, 1.9000], dtype=np.float32), torch.range(1, 2, .3, device='mps'))
5802+
self.assertEqual(np.arange(6.3, dtype=np.float32), torch.arange(0, 6.3, device='mps'))
5803+
57975804
# Test softmax
57985805
def test_softmax(self):
57995806
def helper(shape, dim, channels_last=False):

0 commit comments

Comments
 (0)