Skip to content

Commit 04f09d4

Browse files
Will Fengfacebook-github-bot
authored andcommitted
Move unwrap logic from c10 to caffe2 (#21620)
Summary: After #17072, we are allowed to pass Variables into ATen ops, thus there is no need to unwrap input variables in the c10 call path. Note that since Caffe2 still expects inputs to be pure Tensors, we moved the unwrapping logic to the Caffe2 wrapper. Pull Request resolved: #21620 Differential Revision: D15763560 Pulled By: yf225 fbshipit-source-id: 5375f0e51eb320f380ae599ebf98e6b259f0bff8
1 parent 794ee6d commit 04f09d4

File tree

2 files changed

+43
-39
lines changed

2 files changed

+43
-39
lines changed

caffe2/core/export_caffe2_op_to_c10.h

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,44 @@ inline c10::ListPtr<at::Tensor> _call_caffe2_op(
2727
return std::move(op).move_newstyle_outputs();
2828
}
2929

30+
inline at::Tensor unwrap_tensor(at::Tensor&& tensor) {
31+
if (tensor.is_variable()) {
32+
auto tensor_impl = tensor.unsafeGetTensorImpl();
33+
auto tensor_impl_copy = tensor_impl->shallow_copy_and_detach(
34+
/*version_counter=*/tensor_impl->version_counter(),
35+
/*allow_tensor_metadata_change=*/tensor_impl->allow_tensor_metadata_change());
36+
return at::Tensor(tensor_impl_copy);
37+
} else {
38+
return std::move(tensor);
39+
}
40+
}
41+
42+
inline IValue unwrap(IValue&& ivalue) {
43+
// TODO Remove the .defined() check once we don't have undefined tensors on the stack anymore (@wanchaol is working on this)
44+
if (ivalue.isTensor() && ivalue.toTensor().defined()) {
45+
return unwrap_tensor(std::move(ivalue).toTensor());
46+
} else if (ivalue.isTensorList()) {
47+
c10::ListPtr<at::Tensor> list = std::move(ivalue).toTensorList();
48+
for (size_t i = 0; i < list.size(); ++i) {
49+
list[i] = unwrap_tensor(list.extract(i));
50+
}
51+
return std::move(list);
52+
} else if (ivalue.isGenericList()) {
53+
c10::impl::GenericListPtr list = std::move(ivalue).toGenericList();
54+
for (size_t i = 0; i < list.size(); ++i) {
55+
list[i] = unwrap(list.extract(i));
56+
}
57+
return std::move(list);
58+
} else if (ivalue.isGenericDict()) {
59+
for (auto& item : ivalue.toGenericDict()) {
60+
item.setValue(unwrap(item.value()));
61+
}
62+
return std::move(ivalue);
63+
} else {
64+
return std::move(ivalue);
65+
}
66+
}
67+
3068
// This function is inline in the hope that compilers optimizing for speed will
3169
// inline it into call_caffe2_op_from_c10, allowing call_op to be inlined and
3270
// avoiding the function pointer indirection, while compilers optimizing for
@@ -65,6 +103,11 @@ inline void _call_caffe2_op_from_c10(
65103
outputs = std::move(preallocated_outputs).toTensorList();
66104
}
67105

106+
// unwrap tensor inputs from variable
107+
for (auto iter = stack->end() - num_inputs; iter != stack->end(); ++iter) {
108+
*iter = unwrap(std::move(*iter));
109+
}
110+
68111
// TODO Avoid vector allocation. One idea would be to keep the std::vector
69112
// instances in the cache.
70113
std::vector<IValue> inputs = torch::jit::pop(*stack, num_inputs);

torch/csrc/jit/register_c10_ops.cpp

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7,40 +7,6 @@ namespace torch {
77
namespace jit {
88
namespace {
99

10-
at::Tensor unwrap_tensor(at::Tensor&& tensor) {
11-
if (tensor.is_variable()) {
12-
return torch::autograd::Variable(std::move(tensor)).tensor_data();
13-
} else {
14-
return std::move(tensor);
15-
}
16-
}
17-
18-
IValue unwrap(IValue&& ivalue) {
19-
// TODO Remove the .defined() check once we don't have undefined tensors on the stack anymore (@wanchaol is working on this)
20-
if (ivalue.isTensor() && ivalue.toTensor().defined()) {
21-
return unwrap_tensor(std::move(ivalue).toTensor());
22-
} else if (ivalue.isTensorList()) {
23-
c10::ListPtr<at::Tensor> list = std::move(ivalue).toTensorList();
24-
for (size_t i = 0; i < list.size(); ++i) {
25-
list[i] = unwrap_tensor(list.extract(i));
26-
}
27-
return std::move(list);
28-
} else if (ivalue.isGenericList()) {
29-
c10::impl::GenericListPtr list = std::move(ivalue).toGenericList();
30-
for (size_t i = 0; i < list.size(); ++i) {
31-
list[i] = unwrap(list.extract(i));
32-
}
33-
return std::move(list);
34-
} else if (ivalue.isGenericDict()) {
35-
for (auto& item : ivalue.toGenericDict()) {
36-
item.setValue(unwrap(item.value()));
37-
}
38-
return std::move(ivalue);
39-
} else {
40-
return std::move(ivalue);
41-
}
42-
}
43-
4410
at::Tensor wrap_tensor(at::Tensor&& tensor) {
4511
if (tensor.is_variable()) {
4612
return std::move(tensor);
@@ -170,11 +136,6 @@ Operator createOperatorFromC10(const c10::OperatorHandle& op) {
170136
graph->insertNode(node);
171137
}
172138

173-
// unwrap tensor inputs from variable
174-
for (auto iter = stack.end() - input_size; iter != stack.end(); ++iter) {
175-
*iter = unwrap(std::move(*iter));
176-
}
177-
178139
c10::Dispatcher::singleton().lookup(op, &stack).call(&stack);
179140

180141
// wrap tensor outputs as variable

0 commit comments

Comments
 (0)