-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Move backward and set_data off of Type #21963
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
df221af
0456097
b73a9c6
30cf761
ce10994
86e96a3
07b3e8e
c9236e3
db7b539
f979cf2
38504e0
ae29f78
d625561
87eb996
0b2580c
ff87151
81913bb
f3aa351
21adc45
297f311
1df0016
22692e8
d27e622
0abf2b3
0128bf5
0bdbe22
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -56,18 +56,15 @@ inline TensorOptions Tensor::options() const { | |
| .is_variable(is_variable()); | ||
| } | ||
|
|
||
| inline void Tensor::backward( | ||
| c10::optional<Tensor> gradient, | ||
| bool keep_graph, | ||
| bool create_graph) { | ||
| dispatch_type().backward(*this, std::move(gradient), keep_graph, create_graph); | ||
| // all static inline to allow for inlining of the non-dynamic part of dispatch | ||
| inline void Tensor::backward(const Tensor & gradient, bool keep_graph, bool create_graph) const { | ||
| static auto table = globalATenDispatch().getOpTable("aten::backward(Tensor self, Tensor? gradient=None, bool keep_graph=False, bool create_graph=False) -> void"); | ||
| return table->getOp<void (const Tensor &, const Tensor &, bool, bool)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, gradient, keep_graph, create_graph); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cool. |
||
| } | ||
|
|
||
| inline void Tensor::set_data(Tensor new_data) { | ||
| dispatch_type().set_data(*this, new_data); | ||
| inline void Tensor::set_data(const Tensor & new_data) const { | ||
| static auto table = globalATenDispatch().getOpTable("aten::set_data(Tensor(a!) self, Tensor new_data) -> void"); | ||
| return table->getOp<void (const Tensor &, const Tensor &)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, new_data); | ||
| } | ||
|
|
||
| // all static inline to allow for inlining of the non-dynamic part of dispatch | ||
| inline Tensor Tensor::abs() const { | ||
| static auto table = globalATenDispatch().getOpTable("aten::abs(Tensor self) -> Tensor"); | ||
| return table->getOp<Tensor (const Tensor &)>(tensorTypeIdToBackend(type_id()), is_variable())(*this); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| #include <ATen/ATen.h> | ||
| #include <ATen/NativeFunctions.h> | ||
|
|
||
| namespace at { | ||
| namespace native { | ||
|
|
||
| void backward(const Tensor& self, const Tensor& gradient, bool keep_graph, bool create_graph) { | ||
| AT_ERROR("backward is not implemented for Tensor"); | ||
| } | ||
|
|
||
| void set_data(const Tensor& self, const Tensor& new_data) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the codegen for this looks wrong, because this is really an inplace method, so you'd think |
||
| AT_ERROR("set_data is not implemented for Tensor"); | ||
| } | ||
|
|
||
| } // namespace native | ||
| } // namespace at | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the original signature was intentionally
optional<Tensor>, to make it clear that passingnulloptis valid (it's not obviously valid for other cases.) Does the codegen choke on this?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we've been planning to make
Tensor?translate tooptional<Tensor>for awhile now -- maybe now is a good time to do it?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this doesn't work on codegen yet.