|
1 | | -#include <c10/util/irange.h> |
| 1 | +#include <torch/csrc/lazy/core/config.h> |
2 | 2 | #include <torch/csrc/lazy/core/tensor.h> |
3 | 3 |
|
4 | | -#include <torch/csrc/lazy/core/config.h> |
| 4 | +#include <c10/util/irange.h> |
5 | 5 | #include <torch/csrc/lazy/core/helpers.h> |
6 | 6 | #include <torch/csrc/lazy/core/ir_dump_util.h> |
7 | 7 | #include <torch/csrc/lazy/core/lazy_graph_executor.h> |
@@ -64,22 +64,24 @@ LazyTensorPtr LazyTensor::Create(std::shared_ptr<Data> data) { |
64 | 64 | } |
65 | 65 |
|
66 | 66 | LazyTensor::LazyTensor(const at::Tensor& tensor, const BackendDevice& device) |
67 | | - : data_(std::make_shared<Data>(tensor, device)) {} |
| 67 | + : LazyTensor(std::make_shared<Data>(tensor, device)) {} |
68 | 68 |
|
69 | 69 | LazyTensor::LazyTensor(BackendDataPtr handle) |
70 | | - : data_(std::make_shared<Data>(handle, handle->device())) {} |
| 70 | + : LazyTensor(std::make_shared<Data>(handle, handle->device())) {} |
71 | 71 |
|
72 | 72 | LazyTensor::LazyTensor(Value ir_value, const BackendDevice& device) |
73 | | - : data_(std::make_shared<Data>(std::move(ir_value), device)) { |
| 73 | + : LazyTensor(std::make_shared<Data>(std::move(ir_value), device)) { |
74 | 74 | TryLimitGraphSize(); |
75 | 75 | } |
76 | 76 |
|
77 | 77 | LazyTensor::LazyTensor( |
78 | 78 | std::shared_ptr<LazyView> view, |
79 | 79 | const BackendDevice& device) |
80 | | - : data_(std::make_shared<Data>(std::move(view), device)) {} |
| 80 | + : LazyTensor(std::make_shared<Data>(std::move(view), device)) {} |
81 | 81 |
|
82 | | -LazyTensor::LazyTensor(std::shared_ptr<Data> data) : data_(std::move(data)) {} |
| 82 | +LazyTensor::LazyTensor(std::shared_ptr<Data> data) |
| 83 | + : data_(std::move(data)) |
| 84 | + , storage_(c10::Storage({}, 0, c10::DataPtr(nullptr, backendDeviceToAtenDevice(data_->device)))) {} |
83 | 85 |
|
84 | 86 | LazyTensor::Data* LazyTensor::data() const { |
85 | 87 | TORCH_CHECK(data_ != nullptr, "Trying to access a null cursor"); |
@@ -346,7 +348,9 @@ std::shared_ptr<LazyView> LazyTensor::CreateView(ViewInfo view_info) const { |
346 | 348 | } |
347 | 349 |
|
348 | 350 | LazyTensorPtr LazyTensor::CreateViewTensor(ViewInfo view_info) const { |
349 | | - return Create(CreateView(std::move(view_info)), GetDevice()); |
| 351 | + auto new_tensor = Create(CreateView(std::move(view_info)), GetDevice()); |
| 352 | + new_tensor->storage_ = Storage(); |
| 353 | + return new_tensor; |
350 | 354 | } |
351 | 355 |
|
352 | 356 | at::Tensor LazyTensor::ToTensor(bool detached) { |
|
0 commit comments