Skip to content

cuda topk dim=0 crash on small tensor in two lines of code #4490

@davidbau

Description

@davidbau

Repro:

import torch
a = torch.randn(25, 1, 8).cuda()
b = torch.randn(25, 1, 9).cuda()
print('This one works')
print(a.topk(3, dim=0)[1])
print('This one crashes')
print(b.topk(3, dim=0)[1])

Expect two small tensors of small integers to be printed. Instead, only the first case works. In the second case, you get a series of failed assertions starting with this:

pytorch/aten/src/THC/THCTensorTopK.c
uh:251: DataType findPattern(DataType *, DataType *, IndexType, IndexType, BitDa
taType, BitDataType) [with DataType = float, BitDataType = unsigned int, IndexTy
pe = unsigned int]: block: [46,0,0], thread: [0,0,0] Assertion `false` failed.

I wonder if it is related to #3959

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions