Skip to content

Consolidate definition of operators/gradients where possible #22024

@nairbv

Description

@nairbv

In working on #21088, there were cases where code changes needed to be made that were repetitive and could be error-prone. We could probably simplify/merge some of this code.

To modify an operator we need to update:

  • Tensor.h, TensorMethods.h, and Type.h for C++ interfaces.
  • native_functions.yaml for the primary interface definition used by codegen.
  • aten/src/ATen/native/*Ops.cpp for the C++ operator implementation
  • external changes for xla/msnpu extensions
  • derivatives.yaml to define the c++ gradient
  • shape_analysis.cpp to define the dtype/shape returned for the operator in the jit graph.
  • symbolic_script.cpp to update the torch script definition of the operator and gradient
  • symbolic_variable.h to update the jit graph for the operator
  • symbolic_opset9.py to update the specification of the operator in onnx.

Motivation

This isn't a such concern for compile-checked interface definitions like TensorMethods.h vs Tensor.h. Larger concerns are, for example:

  • We have gradient implementations in both torchscript (symbolic_script.cpp) and c++ (in derivatives.yaml). It's possible for one to be inconsistent, and harder to test.
  • Shape analysis determines dtypes of returned tensors using its own separate logic in parallel to the operator implementation. Changes made necessary by my modification of an operator were not caught by tests, and dtypes retuned were already incorrect for some operators. A single path to determining the correct shape/dtype would make this more robust.
  • It could be faster/easier to contribute changes if developers working on one part of the code don't necessarily need to be familiar with all of the code. Updating an onnx opset is somewhat challenging for someone unfamiliar with onnx.

Pitch

  • Gradients: Investigate if it's possible to generate the C++ gradient from torchscript or vice versa.
  • Shapes analysis: Find some alternate way to determine dtype/shape returned. tracing?
  • SymbolicVariable: This mostly seems derived from the operator interface definition. I'm not sure if codegen would make sense here.
  • Interfaces: consider merging native_functions.yaml and derivatives.yaml to reduce repetition.
  • Onnx: ???

Alternatives

I don't know the best way to address these concerns, but I believe we could investigate and make some improvements to what we have. Some of this work (shape analysis?) may already be in progress.

cc @ezyang @bhosmer @smessmer @ljk53 @bdhirsh @ailzhang

Metadata

Metadata

Assignees

No one assigned

    Labels

    better-engineeringRelatively self-contained tasks for better engineering contributorsmodule: internalsRelated to internal abstractions in c10 and ATentriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions