-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Description
I get a strange behavior of using comparison operations in combination with indexing elements from a tensor, either directly or using torch.index_select. Ideally the below snippet should gather values below 0.8 but for length = 4100 and above it fails. Here I use torch.lt but it happens for torch.gt if I switch the order of comparison.
"""Test."""
import torch
def test():
"""Test."""
print("Testing!")
length = 4100
threshold = 0.8
error = False
for i in range(100):
indices = torch.linspace(0, length - 1, length).long()
# tensor of unitary dimension with uniform random numbers in the range [0, 1)
test_tensor = torch.rand(length)
select = test_tensor.lt(threshold)
scores = test_tensor.index_select(0, indices[select])
if scores.max() > threshold:
error = True
print(i)
print(scores.max())
print("Something wrong!")
if not error:
print("Ran with expected output!")
if __name__ == '__main__':
test()Metadata
Metadata
Assignees
Labels
No labels