Skip to content

Strange behavior of comparison operations #4417

@amrit110

Description

@amrit110

I get a strange behavior of using comparison operations in combination with indexing elements from a tensor, either directly or using torch.index_select. Ideally the below snippet should gather values below 0.8 but for length = 4100 and above it fails. Here I use torch.lt but it happens for torch.gt if I switch the order of comparison.

"""Test."""

import torch

def test():
    """Test."""
    print("Testing!")
    length = 4100
    threshold = 0.8
    error = False
    for i in range(100):
        indices = torch.linspace(0, length - 1, length).long()
        # tensor of unitary dimension with uniform random numbers in the range [0, 1)
        test_tensor = torch.rand(length)
        select = test_tensor.lt(threshold)
        scores = test_tensor.index_select(0, indices[select])
        if scores.max() > threshold:
            error = True
            print(i)
            print(scores.max())
            print("Something wrong!")
    if not error:
        print("Ran with expected output!")

if __name__ == '__main__':
    test()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions