Skip to content

Calling float() in modules converts integer buffers or parameters to floating point #3810

@lantiga

Description

@lantiga

I'm working on a module that takes a LongTensor as argument, that I register as an indices buffer (https://github.com/lantiga/pytorch/blob/idx2col/torch/nn/modules/conv.py#L706).

Calling float() on the module casts the indices buffer to FloatTensor, which I believe is not a desired behavior in general. In particular, it breaks subsequent calls to TH* functions, which expect either DoubleTensor or FloatTensor for certain arguments but always LongTensor for others.

IMO registered buffers/params with integer types should be immune to calls to float or double, which are really there to control the precision of floating point buffers/params.

I can take care of this in case the above makes sense.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions