We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 8dff966 commit 73d9dccCopy full SHA for 73d9dcc
aten/src/ATen/native/mps/operations/RangeFactories.mm
@@ -132,7 +132,15 @@
132
auto xend = end.to<accscalar_t>();
133
auto xstep = step.to<accscalar_t>();
134
135
- double size_d = ((xend - xstart) / xstep) + 1;
+ // double size_d = ((xend - xstart) / xstep) + 1;
136
+ double size_d;
137
+ if (std::is_same<scalar_t, int64_t>::value) {
138
+ size_d = static_cast<double>(end.to<accscalar_t>() - start.to<accscalar_t>())
139
+ / step.to<accscalar_t>() + 1;
140
+ } else {
141
+ size_d = static_cast<double>(end.to<double>() - start.to<double>())
142
+ / step.to<double>() + 1;
143
+ }
144
145
TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero");
146
TORCH_CHECK(std::isfinite(static_cast<double>(xstart)) &&
0 commit comments