@@ -42,6 +42,18 @@ static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask,
4242 throw std::runtime_error (ss.str ());
4343}
4444
45+ static void checkIndexTensorTypes (TensorList indices) {
46+ for (auto & tensor : indices) {
47+ if (tensor.defined ()) {
48+ auto & type = tensor.type ();
49+ auto scalarType = type.scalarType ();
50+ if (scalarType != kLong && scalarType != kByte ) {
51+ throw std::runtime_error (" tensors used as indices must be long or byte tensors" );
52+ }
53+ }
54+ }
55+ }
56+
4557static std::vector<Tensor> expandByteTensors (const Tensor & self, TensorList indices) {
4658 // Expands byte tensors (masks) into the equivalent indexing by LongTensors
4759 std::vector<Tensor> result;
@@ -57,8 +69,15 @@ static std::vector<Tensor> expandByteTensors(const Tensor & self, TensorList ind
5769 }
5870 // Replace with nonzeros
5971 auto nonzero = index.nonzero ();
60- for (int64_t j = 0 ; j < nonzero.size (1 ); j++) {
61- result.emplace_back (nonzero.select (1 , j));
72+ auto is_empty = nonzero.numel () == 0 ;
73+ for (int64_t j = 0 ; j < index.dim (); j++) {
74+ if (is_empty) {
75+ // We can't call select on an empty tensor so we just create an empty
76+ // tensor.
77+ result.emplace_back (nonzero.type ().tensor ());
78+ } else {
79+ result.emplace_back (nonzero.select (1 , j));
80+ }
6281 }
6382 } else {
6483 result.emplace_back (index);
@@ -100,7 +119,7 @@ transposeToFront(Tensor self, TensorList indices) {
100119 transposedIndices.emplace_back ();
101120 }
102121 }
103- return std::make_tuple<> (self.permute (dims), std::move (transposedIndices));
122+ return std::make_tuple (self.permute (dims), std::move (transposedIndices));
104123}
105124
106125static std::vector<int64_t > computeLinearStride (const Tensor & tensor) {
@@ -176,9 +195,22 @@ static Tensor computeLinearIndex(const Tensor & src, TensorList indices) {
176195 return linearIndex;
177196}
178197
198+ static bool hasEmptyTensor (TensorList tensors) {
199+ for (auto & tensor : tensors) {
200+ if (tensor.defined () && tensor.numel () == 0 ) {
201+ return true ;
202+ }
203+ }
204+ return false ;
205+ }
206+
179207static std::tuple<Tensor, Tensor> makeLinearIndex (Tensor self, TensorList orig) {
208+ checkIndexTensorTypes (orig);
180209 // first expand ByteTensor (boolean masks) into 1 or more LongTensors
181210 auto indices = expandByteTensors (self, orig);
211+ if (hasEmptyTensor (indices)) {
212+ return std::make_tuple (self, self.type ().toScalarType (kLong ).tensor ());
213+ }
182214 // next broadcast all index tensors together
183215 indices = expand_outplace (indices);
184216 // add missing null Tensors so that it matches self.dim()
@@ -191,7 +223,7 @@ static std::tuple<Tensor, Tensor> makeLinearIndex(Tensor self, TensorList orig)
191223 std::tie (self, indices) = transposeToFront (self, indices);
192224 }
193225 auto linearIndex = computeLinearIndex (self, indices);
194- return std::make_tuple<> (self, linearIndex);
226+ return std::make_tuple (self, linearIndex);
195227}
196228
197229Tensor index (const Tensor & self, TensorList indices) {
0 commit comments