Skip to content

JIT fuser can't handle scalar ops #9940

@apaszke

Description

@apaszke

#8919 changes the semantics of the most basic binary ops, and we need to update the JIT code to handle that.

Previously, the fuser depended on shape propagation to insert numerous expands, making the sizes of inputs to pointwise ops match (this is a hard requirement of our codegen). However, this is no longer a sound transformation, because the type inference rules depend on input ranks and the "scalar-auto-wrapping status" of input tensors.

Since all this is fundamentally only a requirement of the fuser, it makes sense to handle it there. We will need to special case the pointwise operators between regular tensors and (possibly CPU, possibly differently typed) tensors representing scalars. There are two ways to approach this:

  1. Allow the codegen to take scalars as plain kernel arguments. This even gives us some nice perf benefits we didn't have before.
  2. If the TensorIterator actually allows to easily write CUDA kernels, and handles type promotion and expands for us, we might want to use that in our codegen and relax some constraints. I'm not very familiar with this functionality, so I don't know how feasible this is.
  3. (inferior to previous) insert type_as and expand nodes, making the inputs amenable to fusion

Metadata

Metadata

Assignees

No one assigned

    Labels

    oncall: jitAdd this issue/PR to JIT oncall triage queue

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions