-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
oncall: jitAdd this issue/PR to JIT oncall triage queueAdd this issue/PR to JIT oncall triage queue
Description
#8919 changes the semantics of the most basic binary ops, and we need to update the JIT code to handle that.
Previously, the fuser depended on shape propagation to insert numerous expands, making the sizes of inputs to pointwise ops match (this is a hard requirement of our codegen). However, this is no longer a sound transformation, because the type inference rules depend on input ranks and the "scalar-auto-wrapping status" of input tensors.
Since all this is fundamentally only a requirement of the fuser, it makes sense to handle it there. We will need to special case the pointwise operators between regular tensors and (possibly CPU, possibly differently typed) tensors representing scalars. There are two ways to approach this:
- Allow the codegen to take scalars as plain kernel arguments. This even gives us some nice perf benefits we didn't have before.
- If the
TensorIteratoractually allows to easily write CUDA kernels, and handles type promotion and expands for us, we might want to use that in our codegen and relax some constraints. I'm not very familiar with this functionality, so I don't know how feasible this is. - (inferior to previous) insert
type_asandexpandnodes, making the inputs amenable to fusion
Metadata
Metadata
Assignees
Labels
oncall: jitAdd this issue/PR to JIT oncall triage queueAdd this issue/PR to JIT oncall triage queue