2020namespace at ::native {
2121namespace {
2222// Check if tensor list has either a boolean tensor or a integer tensor
23- bool has_integral_tensor (TensorList tensors, const bool includeBool) {
23+ inline bool has_integral_tensor (TensorList tensors, const bool includeBool) {
2424 return std::any_of (
2525 tensors.begin (), tensors.end (), [&includeBool](const auto & t) {
2626 return at::isIntegralType (t.scalar_type (), includeBool);
2727 });
2828}
2929// check if tensor list has bool tensors
30- bool has_bool_tensor (TensorList tensors) {
30+ inline bool has_bool_tensor (TensorList tensors) {
3131 return std::any_of (tensors.begin (), tensors.end (), [](const auto & t) -> bool {
3232 return t.scalar_type () == ScalarType::Bool;
3333 });
@@ -37,11 +37,11 @@ bool has_bool_tensor(TensorList tensors) {
3737// - Tensor lists must be non-empty.
3838// - All TensorLists and ScalarLists must have the same number of elements.
3939// - Corresponding tensors must have the same size.
40- void check_foreach_api_restrictions (TensorList tensors) {
40+ inline void check_foreach_api_restrictions (TensorList tensors) {
4141 TORCH_CHECK (!tensors.empty (), " Tensor list must have at least one tensor." );
4242}
4343
44- void check_foreach_api_restrictions (
44+ inline void check_foreach_api_restrictions (
4545 TensorList tensors,
4646 ArrayRef<Scalar> scalars) {
4747 check_foreach_api_restrictions (tensors);
@@ -50,7 +50,9 @@ void check_foreach_api_restrictions(
5050 " Tensor list must have same number of elements as scalar list." );
5151}
5252
53- void check_foreach_api_restrictions (TensorList tensors1, TensorList tensors2) {
53+ inline void check_foreach_api_restrictions (
54+ TensorList tensors1,
55+ TensorList tensors2) {
5456 TORCH_CHECK (!tensors1.empty (), " Tensor list must have at least one tensor." );
5557 TORCH_CHECK (!tensors2.empty (), " Tensor list must have at least one tensor." );
5658 TORCH_CHECK (
@@ -61,7 +63,7 @@ void check_foreach_api_restrictions(TensorList tensors1, TensorList tensors2) {
6163 tensors2.size ());
6264}
6365
64- void check_foreach_api_restrictions (
66+ inline void check_foreach_api_restrictions (
6567 TensorList tensors1,
6668 TensorList tensors2,
6769 TensorList tensors3) {
@@ -82,7 +84,7 @@ void check_foreach_api_restrictions(
8284 tensors3.size ());
8385}
8486
85- void check_foreach_api_restrictions (
87+ inline void check_foreach_api_restrictions (
8688 TensorList tensors1,
8789 TensorList tensors2,
8890 TensorList tensors3,
@@ -99,7 +101,8 @@ void check_foreach_api_restrictions(
99101// Helper function called in check_fast_path_restrictions to check whether all
100102// corresponding tensors (aligning in index across the tensorLists) share the
101103// same device and dtype.
102- bool _check_tensors_share_device_and_dtype (ArrayRef<TensorList> tensorLists) {
104+ inline bool _check_tensors_share_device_and_dtype (
105+ ArrayRef<TensorList> tensorLists) {
103106 const auto expected_dtype = tensorLists[0 ][0 ].dtype ();
104107 const auto expected_device = tensorLists[0 ][0 ].device ();
105108
@@ -122,7 +125,8 @@ bool _check_tensors_share_device_and_dtype(ArrayRef<TensorList> tensorLists) {
122125
123126// Helper function called in check_fast_path_restrictions to check if
124127// corresponding tensors in tensor lists have the same sizes and strides.
125- bool _check_tensors_share_sizes_and_strides (ArrayRef<TensorList> tensorLists) {
128+ inline bool _check_tensors_share_sizes_and_strides (
129+ ArrayRef<TensorList> tensorLists) {
126130 for (const auto i : c10::irange (1 , tensorLists.size ())) {
127131 for (const auto j : c10::irange (tensorLists[0 ].size ())) {
128132 if (tensorLists[0 ][j].sizes () != tensorLists[i][j].sizes () ||
@@ -140,7 +144,7 @@ bool _check_tensors_share_sizes_and_strides(ArrayRef<TensorList> tensorLists) {
140144// function assumes that _check_tensors_share_device_and_dtype has already been
141145// called so that all corresponding tensors in tensorLists have the same dtype.
142146// Then, it is sufficient to check the type promotion with just one tensorList.
143- bool _check_tensors_do_type_promotion_with_scalars (
147+ inline bool _check_tensors_do_type_promotion_with_scalars (
144148 TensorList tensorList,
145149 ArrayRef<Scalar> scalarList = {},
146150 bool does_op_promote_integer_inputs_to_float = false ) {
@@ -176,7 +180,7 @@ bool _check_tensors_do_type_promotion_with_scalars(
176180
177181// Please, make sure to call check_foreach_api_restrictions before calling this
178182// method. There is a set of preconditions that have to be satisfied.
179- bool check_fast_path_restrictions (
183+ inline bool check_fast_path_restrictions (
180184 ArrayRef<TensorList> tensorLists,
181185 ArrayRef<Scalar> scalarList = {},
182186 bool does_op_promote_integer_inputs_to_float = false ) {
@@ -188,7 +192,7 @@ bool check_fast_path_restrictions(
188192 does_op_promote_integer_inputs_to_float);
189193}
190194
191- std::vector<c10::Scalar> convert_tensor_to_scalar_list (
195+ inline std::vector<c10::Scalar> convert_tensor_to_scalar_list (
192196 const Tensor& scalarList_,
193197 int64_t expect_length) {
194198 std::vector<c10::Scalar> scalarList;
@@ -221,21 +225,21 @@ std::vector<c10::Scalar> convert_tensor_to_scalar_list(
221225 scalarList_.size (0 ),
222226 " instead." );
223227 for (int64_t i = 0 ; i < scalarList_.size (0 ); i++) {
224- scalarList.push_back ( c10::Scalar ( scalar_data[i]) );
228+ scalarList.emplace_back ( scalar_data[i]);
225229 }
226230 });
227231 return scalarList;
228232}
229233
230- bool can_use_fast_route (
234+ inline bool can_use_fast_route (
231235 ArrayRef<TensorList> tensorLists,
232236 ArrayRef<Scalar> scalarList = {},
233237 bool does_op_promote_integer_inputs_to_float = false ) {
234238 return check_fast_path_restrictions (
235239 tensorLists, scalarList, does_op_promote_integer_inputs_to_float);
236240}
237241
238- bool can_use_fast_route (
242+ inline bool can_use_fast_route (
239243 TensorList tensors1,
240244 TensorList tensors2,
241245 bool does_op_promote_integer_inputs_to_float = false ) {
@@ -253,13 +257,13 @@ using FlatMap = std::unordered_map<
253257 TensorsAndIndicesT,
254258 ParamsHash<DeviceDtypeKey>>;
255259
256- FlatMap _group_tensors_by_first_tensors_device_and_dtype (
260+ inline FlatMap _group_tensors_by_first_tensors_device_and_dtype (
257261 const nested_optional_tensorvec_t & nested_tensorlist,
258262 const bool with_indices) {
259263 FlatMap grouped_tensors_with_indices;
260264
261- TORCH_CHECK (nested_tensorlist.size () > 0 );
262- TORCH_CHECK (nested_tensorlist[0 ].size () > 0 );
265+ TORCH_CHECK (! nested_tensorlist.empty () );
266+ TORCH_CHECK (! nested_tensorlist[0 ].empty () );
263267 const auto num_lists = nested_tensorlist.size ();
264268 const auto num_tensors = nested_tensorlist[0 ].size ();
265269
0 commit comments