Skip to content

Commit eed8bbd

Browse files
author
wxie
committed
Treat Scalar parameter as if it is constant
Pull Request resolved: #53582 We will pass `Scalar` by reference in the following commit, i.e. `const Scalar&`. ghstack-source-id: 123755068 Differential Revision: [D26904444](https://our.internmc.facebook.com/intern/diff/D26904444/)
1 parent aeb3e93 commit eed8bbd

File tree

3 files changed

+19
-23
lines changed

3 files changed

+19
-23
lines changed

aten/src/ATen/native/LinearAlgebra.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -601,9 +601,10 @@ static void addbmm_impl_(
601601
return;
602602
}
603603

604+
auto adjusted_beta(beta);
604605
for (int64_t batch = 0; batch < num_batches; ++batch) {
605-
result.addmm_(batch1[batch], batch2[batch], beta, alpha);
606-
beta = 1; // accumulate output once
606+
result.addmm_(batch1[batch], batch2[batch], adjusted_beta, alpha);
607+
adjusted_beta = 1; // accumulate output once
607608
}
608609
}
609610

aten/src/ATen/native/Pow.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ Tensor& float_power_out(Tensor& result, const Tensor& base, Scalar exp) {
105105

106106
// Note: need the casts inside the ternary because conversion functions return e.g. c10::complex,
107107
// which causes a complex scalar to always be returned.
108-
exp = (dtype == at::kComplexDouble) ? Scalar(exp.toComplexDouble()) : Scalar(exp.toDouble());
109-
return at::pow_out(result, base.to(dtype), exp);
108+
auto casted_exp = (dtype == at::kComplexDouble) ? Scalar(exp.toComplexDouble()) : Scalar(exp.toDouble());
109+
return at::pow_out(result, base.to(dtype), casted_exp);
110110
}
111111

112112
Tensor& float_power_out(Tensor& result, Scalar base, const Tensor& exp) {
@@ -115,20 +115,20 @@ Tensor& float_power_out(Tensor& result, Scalar base, const Tensor& exp) {
115115
"the output given to float_power has dtype ", result.scalar_type(),
116116
" but the operation's result requires dtype ", dtype);
117117

118-
base = (dtype == at::kComplexDouble) ? Scalar(base.toComplexDouble()) : Scalar(base.toDouble());
119-
return at::pow_out(result, base, exp.to(dtype));
118+
auto casted_base = (dtype == at::kComplexDouble) ? Scalar(base.toComplexDouble()) : Scalar(base.toDouble());
119+
return at::pow_out(result, casted_base, exp.to(dtype));
120120
}
121121

122122
Tensor float_power(const Tensor& base, Scalar exp) {
123123
auto dtype = (at::isComplexType(base.scalar_type()) || exp.isComplex()) ? at::kComplexDouble : at::kDouble;
124-
exp = (dtype == at::kComplexDouble) ? Scalar(exp.toComplexDouble()) : Scalar(exp.toDouble());
125-
return at::pow(base.to(dtype), exp);
124+
auto casted_exp = (dtype == at::kComplexDouble) ? Scalar(exp.toComplexDouble()) : Scalar(exp.toDouble());
125+
return at::pow(base.to(dtype), casted_exp);
126126
}
127127

128128
Tensor float_power(Scalar base, const Tensor& exp) {
129129
auto dtype = (at::isComplexType(exp.scalar_type()) || base.isComplex()) ? at::kComplexDouble : at::kDouble;
130-
base = (dtype == at::kComplexDouble) ? Scalar(base.toComplexDouble()) : Scalar(base.toDouble());
131-
return at::pow(base, exp.to(dtype));
130+
auto casted_base = (dtype == at::kComplexDouble) ? Scalar(base.toComplexDouble()) : Scalar(base.toDouble());
131+
return at::pow(casted_base, exp.to(dtype));
132132
}
133133

134134
Tensor float_power(const Tensor& base, const Tensor& exp) {
@@ -151,8 +151,8 @@ Tensor& float_power_(Tensor& base, Scalar exp) {
151151
"the base given to float_power_ has dtype ", base.scalar_type(),
152152
" but the operation's result requires dtype ", dtype);
153153

154-
exp = (dtype == at::kComplexDouble) ? Scalar(exp.toComplexDouble()) : Scalar(exp.toDouble());
155-
return base.pow_(exp);
154+
auto casted_exp = (dtype == at::kComplexDouble) ? Scalar(exp.toComplexDouble()) : Scalar(exp.toDouble());
155+
return base.pow_(casted_exp);
156156
}
157157

158158
} // namespace native

aten/src/ATen/native/TensorFactories.cpp

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,6 @@ void window_function_checks(
4545
window_length);
4646
}
4747

48-
// bool inputs are considered integral
49-
static inline bool allIntegral(std::initializer_list<std::reference_wrapper<Scalar>> l) {
50-
for (Scalar& s : l) {
51-
if (!s.isIntegral(true)) {
52-
return false;
53-
}
54-
}
55-
return true;
56-
}
57-
5848
} // namespace
5949

6050
DEFINE_DISPATCH(complex_stub);
@@ -75,7 +65,12 @@ Tensor arange(
7565
Scalar end,
7666
Scalar step,
7767
const TensorOptions& options) {
78-
bool set_to_integral_dtype = !options.has_dtype() && allIntegral({start, end, step});
68+
bool set_to_integral_dtype = !options.has_dtype() &&
69+
// bool inputs are considered integral
70+
start.isIntegral(true) &&
71+
end.isIntegral(true) &&
72+
step.isIntegral(true);
73+
7974
Tensor result = set_to_integral_dtype
8075
? at::empty({0}, options.dtype(at::ScalarType::Long))
8176
: at::empty({0}, options);

0 commit comments

Comments
 (0)