Skip to content

Commit c47bdd7

Browse files
bdhirshpytorchmergebot
authored andcommitted
*_scatter ops should preserve input stride/storage_offset (#91029)
It turns out that we *do* need to update *_scatter ops to return the exact same strides as their inputs. I added a test to `test/test_functionalization.py`, which now trips thanks to Ed's functionalization stride debugging check. It only actually ends up tripping silent correctness if you try to .backward() on that function. Pull Request resolved: #91029 Approved by: https://github.com/ezyang
1 parent a329161 commit c47bdd7

22 files changed

+353
-184
lines changed

aten/src/ATen/FunctionalTensorWrapper.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,14 @@ functionalization::FunctionalStorageImpl* FunctionalTensorWrapper::functional_st
146146
void FunctionalTensorWrapper::commit_update() {
147147
auto storage_impl = functional_storage_impl();
148148
storage_impl->add_update(value_, view_metas_);
149-
// Invariant: commit_update() is called during an inplace operation.
150-
// Tensor inputs to the operation are synced before runnig the op,
151-
// so the current tensor must be up-to-date with its alias at this point.
152-
generation_ = storage_impl->generation();
149+
// As an optimization, we used to mark the tensor here as "up-to-date",
150+
// That way, code like:
151+
// x = torch.ones(1'000'000)
152+
// x[0].add_(1)
153+
// doesn't result in an unnecessary materialization of the base.
154+
// This optimization results in the slice temporarily haven't incorrect
155+
// stride/storage_offset though, and DCE should handle that optimization anyway.
156+
// generation_ = storage_impl->generation();
153157
}
154158

155159
bool FunctionalTensorWrapper::is_up_to_date() const {

aten/src/ATen/MemoryOverlap.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ MemOverlap has_internal_overlap(TensorImpl* t) {
1616
return MemOverlap::No;
1717
}
1818

19-
auto strides = t->strides();
20-
auto sizes = t->sizes();
19+
auto strides = t->sym_strides();
20+
auto sizes = t->sym_sizes();
2121
for (const auto i : c10::irange(strides.size())) {
2222
if (strides[i] == 0 && sizes[i] > 1) {
2323
return MemOverlap::Yes;

aten/src/ATen/functorch/BatchRulesScatterOps.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
// LICENSE file in the root directory of this source tree.
66

77
#include <ATen/functorch/BatchRulesHelper.h>
8-
#include <iostream>
98
#include <ATen/Operators.h>
109
#include <ATen/functorch/PlumbingHelper.h>
1110
#include <ATen/functorch/BatchedFallback.h>
1211
#include <ATen/native/TensorAdvancedIndexing.h>
1312
#include <ATen/native/IndexKernel.h>
1413
#include <ATen/native/IndexingUtils.h>
14+
#include <iostream>
15+
#include <torch/library.h>
1516

1617

1718
namespace at { namespace functorch {
@@ -1074,6 +1075,12 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
10741075
VMAP_SUPPORT(scatter_add, scatter_add_batch_rule);
10751076
VMAP_SUPPORT2(scatter, reduce, scatter_reduce_batch_rule);
10761077
VMAP_SUPPORT2(scatter, value_reduce, scatter_value_reduce_batch_rule);
1078+
// as_strided_scatter does not work with the for-loop fallback today,
1079+
// because as_strided_scatter will return an output that matches
1080+
// the strides/storage_offset of its input.
1081+
// With the for loop fallback, each input tensor is a slice into
1082+
// the larger batched tensor.
1083+
m.impl("as_strided_scatter", torch::CppFunction::makeFromBoxedFunction<&vmapErrorFallback>());
10771084
}
10781085

10791086
}}

aten/src/ATen/functorch/BatchedFallback.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,5 +396,9 @@ void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Sta
396396
}
397397
}
398398

399+
void vmapErrorFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
400+
TORCH_CHECK(false, "Error: ", op.operator_name(), " requires special handling, and does not yet have a batching rule. Feel free to file a github issue!");
401+
}
402+
399403
}
400404
} // namespace at

aten/src/ATen/functorch/BatchedFallback.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ namespace functorch {
3232
// write batching rules for operators whenever possible.
3333
void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
3434

35+
void vmapErrorFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
36+
3537
// The vmap fallback emits a warning by default, but it may be disabled if
3638
// the user finds it to be too annoying.
3739
TORCH_API bool isVmapFallbackWarningEnabled();

aten/src/ATen/native/Copy.cpp

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <ATen/native/quantized/Copy.h>
99
#include <ATen/native/mps/Copy.h>
1010
#include <ATen/native/vulkan/ops/Copy.h>
11+
#include <ATen/native/TensorShape.h>
1112
#include <ATen/quantized/Quantizer.h>
1213
#include <ATen/vulkan/Context.h>
1314
#include <ATen/metal/Context.h>
@@ -278,32 +279,6 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking)
278279
return self;
279280
}
280281

281-
// NB: cribbed from https://github.com/pytorch/pytorch/pull/88198
282-
at::Tensor clone_preserve_strides(const at::Tensor& self) {
283-
TORCH_INTERNAL_ASSERT(self.has_storage());
284-
// In cases where the input tensor has internal memory overlap, we cannot actually
285-
// preserve the strides/storage_offset of the input tensor, because
286-
// *_scatter ops will try to copy_() into the cloned tensor.
287-
// However, this should **never** show up in functionalized user code;
288-
// most aten ops that try to mutate a tensor with internal memory overlap would error anyway.
289-
//
290-
// The one place that this does come up is in autograd - if there's a select_scatter
291-
// in the forward, then autograd will generate one for the backward.
292-
// If the input to the select_scatter is grad_output, then this could be an expanded tensor
293-
// with internal overlap.
294-
//if (at::has_internal_overlap(self) == at::MemOverlap::Yes) {
295-
// return self.clone();
296-
//}
297-
auto dtype_size = self.dtype().itemsize();
298-
auto nbytes = self.storage().sym_nbytes();
299-
TORCH_INTERNAL_ASSERT(nbytes % dtype_size == 0);
300-
auto numel = nbytes / dtype_size;
301-
auto self_full_size = self.as_strided_symint({numel}, {1}, 0);
302-
auto clone = self_full_size.clone();
303-
auto out = clone.as_strided_symint(self.sym_sizes(), self.sym_strides(), self.sym_storage_offset());
304-
return out;
305-
}
306-
307282
Tensor copy(const Tensor& self, const Tensor& src, bool non_blocking) {
308283
// copy() is the "functional" form of copy_(). It exists so we can properly functionalize copy_(), but:
309284
// (1) It isn't exposed to the frontend (no python bindings)

aten/src/ATen/native/TensorShape.cpp

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3801,22 +3801,58 @@ std::vector<Tensor> unflatten_dense_tensors(const Tensor& flat, TensorList tenso
38013801
return outputs;
38023802
}
38033803

3804+
3805+
// Clones a tensor by cloning the underlying storage that it came from,
3806+
// which allows us to replicate the exact strides/storage_offset in the cloned tensor.
3807+
// Note [*_scatter ops preserve strides]
3808+
// In order for functionalization to preserve stride correctness, the *_scatter
3809+
// operators that it calls must preserve the striding behavior of their inputs.
3810+
// Specifically, the output of *_scatter(base, mutated_view, ...)
3811+
// should have identical size/stride/storage_offset to "base".
3812+
at::Tensor clone_preserve_strides(const at::Tensor& self) {
3813+
TORCH_INTERNAL_ASSERT(self.has_storage());
3814+
// In cases where the input tensor has internal memory overlap, we cannot actually
3815+
// preserve the strides/storage_offset of the input tensor, because
3816+
// *_scatter ops will try to copy_() into the cloned tensor.
3817+
// However, this should **never** show up in functionalized user code;
3818+
// most aten ops that try to mutate a tensor with internal memory overlap would error anyway.
3819+
//
3820+
// The one place that this does come up is in autograd - if there's a select_scatter
3821+
// in the forward, then autograd will generate one for the backward.
3822+
// If the input to the select_scatter is grad_output, then this could be an expanded tensor
3823+
// with internal overlap.
3824+
if (at::has_internal_overlap(self) == at::MemOverlap::Yes) {
3825+
return self.clone();
3826+
}
3827+
auto dtype_size = self.dtype().itemsize();
3828+
auto nbytes = self.storage().sym_nbytes();
3829+
TORCH_INTERNAL_ASSERT(nbytes % dtype_size == 0);
3830+
auto numel = nbytes / dtype_size;
3831+
auto self_full_size = self.as_strided_symint({numel}, {1}, 0);
3832+
auto clone = self_full_size.clone();
3833+
auto out = clone.as_strided_symint(self.sym_sizes(), self.sym_strides(), self.sym_storage_offset());
3834+
return out;
3835+
}
3836+
3837+
38043838
at::Tensor slice_scatter(const at::Tensor& self, const at::Tensor& src, int64_t dim, c10::optional<int64_t> start, c10::optional<int64_t> end, int64_t step) {
3805-
auto output = self.clone();
3839+
// See Note [*_scatter ops preserve strides]
3840+
auto output = clone_preserve_strides(self);
38063841
auto slice = output.slice(dim, start, end, step);
38073842
TORCH_CHECK(slice.sizes() == src.sizes(), "expected src to have a size equal to the slice of self. src size = ", src.sizes(), ", slice size = ", slice.sizes());
38083843
slice.copy_(src);
38093844
return output;
38103845
}
38113846
at::Tensor select_scatter_symint(const at::Tensor& self, const at::Tensor& src, int64_t dim, c10::SymInt index) {
3812-
auto output = self.clone();
3847+
auto output = clone_preserve_strides(self);
38133848
auto slice = output.select_symint(dim, index);
38143849
TORCH_CHECK(slice.sizes() == src.sizes(), "expected src to have a size equal to the slice of self. src size = ", src.sizes(), ", slice size = ", slice.sizes());
38153850
slice.copy_(src);
38163851
return output;
38173852
}
38183853
at::Tensor diagonal_scatter(const at::Tensor& self, const at::Tensor& src, int64_t offset, int64_t dim1, int64_t dim2) {
3819-
auto output = self.clone();
3854+
// See Note [*_scatter ops preserve strides]
3855+
auto output = clone_preserve_strides(self);
38203856
auto slice = output.diagonal(offset, dim1, dim2);
38213857
TORCH_CHECK(slice.sizes() == src.sizes(), "expected src to have a size equal to the slice of self. src size = ", src.sizes(), ", slice size = ", slice.sizes());
38223858
slice.copy_(src);
@@ -3825,7 +3861,8 @@ at::Tensor diagonal_scatter(const at::Tensor& self, const at::Tensor& src, int64
38253861
at::Tensor as_strided_scatter_symint(const at::Tensor& self, const at::Tensor& src, at::SymIntArrayRef size, at::SymIntArrayRef stride, c10::optional<c10::SymInt> storage_offset) {
38263862
// See Note [as_strided_scatter backward support]
38273863
TORCH_INTERNAL_ASSERT(!self.requires_grad() || self.is_contiguous(), "as_strided_scatter is currently only supported for contiguous inputs");
3828-
auto output = self.clone();
3864+
// See Note [*_scatter ops preserve strides]
3865+
auto output = clone_preserve_strides(self);
38293866
auto slice = output.as_strided_symint(size, stride, std::move(storage_offset));
38303867
TORCH_CHECK(slice.sym_sizes() == src.sym_sizes(), "expected src to have a size equal to the slice of self. src size = ", src.sym_sizes(), ", slice size = ", slice.sym_sizes());
38313868
slice.copy_(src);

aten/src/ATen/native/TensorShape.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55

66
namespace at {
77
namespace native {
8+
9+
TORCH_API at::Tensor clone_preserve_strides(const at::Tensor& self);
10+
811
inline bool cat_should_skip_tensor(const Tensor& t) {
912
return t.numel() == 0 && t.dim() == 1;
1013
}

test/functorch/test_aotdispatch.py

Lines changed: 40 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -613,8 +613,8 @@ def forward(self, primals_1):
613613
t_1 = torch.ops.aten.t.default(clone); clone = None
614614
select_scatter = torch.ops.aten.select_scatter.default(t_1, mul, 0, 0); t_1 = mul = None
615615
t_2 = torch.ops.aten.t.default(select_scatter); select_scatter = None
616-
t_3 = torch.ops.aten.t.default(t_2); t_2 = None
617-
return [t_3, 3, 3, 1, 3, 0]""")
616+
t_4 = torch.ops.aten.t.default(t_2); t_2 = None
617+
return [t_4, 3, 3, 1, 3, 0]""")
618618

619619
def test_view_and_inplace_view(self):
620620
def f(a, b):
@@ -683,11 +683,12 @@ def forward(self, primals_1):
683683
clone = torch.ops.aten.clone.default(primals_1); primals_1 = None
684684
as_strided_1 = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
685685
mul = torch.ops.aten.mul.Tensor(as_strided_1, 2); as_strided_1 = None
686-
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = None
687-
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None
688-
t_1 = torch.ops.aten.t.default(as_strided_5); as_strided_5 = None
686+
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = mul = None
687+
as_strided_3 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
688+
as_strided_6 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None
689+
t_1 = torch.ops.aten.t.default(as_strided_6); as_strided_6 = None
689690
mul_1 = torch.ops.aten.mul.Tensor(t_1, 3); t_1 = None
690-
return [mul, mul_1, 4, 1, 0]""")
691+
return [as_strided_3, mul_1, 4, 1, 0]""")
691692

692693
def test_input_mutation_aliases_other_input(self):
693694
def f(a, b):
@@ -712,10 +713,11 @@ def forward(self, primals_1):
712713
clone = torch.ops.aten.clone.default(primals_1); primals_1 = None
713714
as_strided = torch.ops.aten.as_strided.default(clone, [2], [1], 0)
714715
add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
715-
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [2], [1], 0); clone = None
716-
as_strided_4 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 2); as_strided_scatter = None
717-
add_1 = torch.ops.aten.add.Tensor(add, as_strided_4); as_strided_4 = None
718-
return [add, add_1]""")
716+
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [2], [1], 0); clone = add = None
717+
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0)
718+
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 2); as_strided_scatter = None
719+
add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_5); as_strided_5 = None
720+
return [as_strided_2, add_1]""")
719721

720722
def test_input_mutation_aliases_other_input2(self):
721723
def f(a, b):
@@ -736,10 +738,11 @@ def forward(self, primals_1):
736738
clone = torch.ops.aten.clone.default(primals_1); primals_1 = None
737739
as_strided = torch.ops.aten.as_strided.default(clone, [2], [1], 0)
738740
add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
739-
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [2], [1], 0); clone = None
740-
as_strided_4 = torch.ops.aten.as_strided.default(as_strided_scatter, [2, 2], [2, 1], 0); as_strided_scatter = None
741-
add_1 = torch.ops.aten.add.Tensor(add, as_strided_4); as_strided_4 = None
742-
return [add, add_1]""")
741+
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [2], [1], 0); clone = add = None
742+
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0)
743+
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [2, 2], [2, 1], 0); as_strided_scatter = None
744+
add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_5); as_strided_5 = None
745+
return [as_strided_2, add_1]""")
743746

744747
def test_input_mutation_aliases_and_output_alias(self):
745748
def f(a, b):
@@ -758,9 +761,11 @@ def inp_callable():
758761
self.assertExpectedInline(fw_graph.code.strip(), """\
759762
def forward(self, primals_1):
760763
clone = torch.ops.aten.clone.default(primals_1); primals_1 = None
761-
as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0); clone = None
764+
as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
762765
add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
763-
return [add, 4, 1, 0]""")
766+
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = add = None
767+
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None
768+
return [as_strided_2, 4, 1, 0]""")
764769

765770
def test_input_aliased_with_mutation_output_alias(self):
766771
def f(a, b, c):
@@ -783,10 +788,12 @@ def inp_callable():
783788
self.assertExpectedInline(fw_graph.code.strip(), """\
784789
def forward(self, primals_1, primals_2):
785790
clone = torch.ops.aten.clone.default(primals_1); primals_1 = None
786-
as_strided_1 = torch.ops.aten.as_strided.default(clone, [4], [1], 0); clone = None
791+
as_strided_1 = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
787792
mul = torch.ops.aten.mul.Tensor(as_strided_1, 2); as_strided_1 = None
793+
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = mul = None
794+
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None
788795
add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None
789-
return [mul, add, 4, 1, 0]""")
796+
return [as_strided_2, add, 4, 1, 0]""")
790797

791798
def test_input_metadata_mutation_aliases(self):
792799
def f(a, b):
@@ -829,11 +836,12 @@ def forward(self, primals_1, primals_2):
829836
clone = torch.ops.aten.clone.default(primals_1); primals_1 = None
830837
as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
831838
mul = torch.ops.aten.mul.Tensor(as_strided, 2); as_strided = None
832-
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = None
833-
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None
834-
add = torch.ops.aten.add.Tensor(as_strided_2, 1); as_strided_2 = None
839+
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = mul = None
840+
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
841+
as_strided_3 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None
842+
add = torch.ops.aten.add.Tensor(as_strided_3, 1); as_strided_3 = None
835843
add_1 = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None
836-
return [mul, add, add_1]""")
844+
return [as_strided_2, add, add_1]""")
837845

838846
def test_input_mutation_aliases_bases_out_of_order(self):
839847
# This tests our calling convention: if b and d are aliased, then the outer calling convention
@@ -864,12 +872,13 @@ def forward(self, primals_1, primals_2, primals_3):
864872
clone = torch.ops.aten.clone.default(primals_1); primals_1 = None
865873
as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
866874
add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
875+
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = add = None
876+
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
867877
add_1 = torch.ops.aten.add.Tensor(primals_2, primals_3); primals_2 = primals_3 = None
868-
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = None
869-
as_strided_4 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None
870-
t_1 = torch.ops.aten.t.default(as_strided_4); as_strided_4 = None
878+
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None
879+
t_1 = torch.ops.aten.t.default(as_strided_5); as_strided_5 = None
871880
add_2 = torch.ops.aten.add.Tensor(add_1, t_1); add_1 = t_1 = None
872-
return [add, add_2, 4, 1, 0, 4, 1, 0]""")
881+
return [as_strided_2, add_2, 4, 1, 0, 4, 1, 0]""")
873882

874883
# Mondo test that tests a combination of:
875884
# input is mutated, that aliases another input (so we make a synthetic base)
@@ -913,10 +922,11 @@ def forward(self, primals_1, primals_2):
913922
clone = torch.ops.aten.clone.default(primals_1); primals_1 = None
914923
as_strided_1 = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
915924
mul = torch.ops.aten.mul.Tensor(as_strided_1, 2); as_strided_1 = None
916-
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = None
917-
as_strided_4 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None
918-
add = torch.ops.aten.add.Tensor(as_strided_4, mul); as_strided_4 = None
919-
return [mul, add, 2, 2, 1, 2, 0, 2, 2, 2, 1, 0]""")
925+
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = mul = None
926+
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
927+
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None
928+
add = torch.ops.aten.add.Tensor(as_strided_5, as_strided_2); as_strided_5 = None
929+
return [as_strided_2, add, 2, 2, 1, 2, 0, 2, 2, 2, 1, 0]""")
920930

921931
def test_no_grad_input_output(self):
922932
def f(a, b):

0 commit comments

Comments
 (0)