Skip to content

Commit 73d9dcc

Browse files
Owen ElliottOwen Elliott
authored andcommitted
fixed bug on specification of dtype parameter
1 parent 8dff966 commit 73d9dcc

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

aten/src/ATen/native/mps/operations/RangeFactories.mm

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,15 @@
132132
auto xend = end.to<accscalar_t>();
133133
auto xstep = step.to<accscalar_t>();
134134

135-
double size_d = ((xend - xstart) / xstep) + 1;
135+
// double size_d = ((xend - xstart) / xstep) + 1;
136+
double size_d;
137+
if (std::is_same<scalar_t, int64_t>::value) {
138+
size_d = static_cast<double>(end.to<accscalar_t>() - start.to<accscalar_t>())
139+
/ step.to<accscalar_t>() + 1;
140+
} else {
141+
size_d = static_cast<double>(end.to<double>() - start.to<double>())
142+
/ step.to<double>() + 1;
143+
}
136144

137145
TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero");
138146
TORCH_CHECK(std::isfinite(static_cast<double>(xstart)) &&

0 commit comments

Comments
 (0)