Skip to content

Commit cc72d68

Browse files
reedwmmihaimaruseac
authored andcommitted
Fix segfault on OOM in Conv2D.
PiperOrigin-RevId: 404655317 Change-Id: I33588dbd3f5d0fef980e3c908bf5515a9ee09ce7
1 parent a150be0 commit cc72d68

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

tensorflow/core/kernels/conv_ops.cc

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,20 +181,29 @@ struct LaunchGrouped {
181181
auto on_shuffled = [&]() { shuffles_completed.DecrementCount(); };
182182

183183
// Shuffle input into temporary tensor.
184-
Tensor input_shuffled(input.dtype(), TensorShape(post_shuffle(input)));
184+
Tensor input_shuffled;
185+
OP_REQUIRES_OK(
186+
ctx, ctx->allocate_temp(input.dtype(), TensorShape(post_shuffle(input)),
187+
&input_shuffled));
185188
input_shuffled.tensor<T, 5>().device(device, on_shuffled) =
186189
input.shaped<T, 5>(pre_shuffle(input)).shuffle(shuffle);
187190

188191
// Shuffle filter into temporary tensor.
189-
Tensor filter_shuffled(filter.dtype(), TensorShape(post_shuffle(filter)));
192+
Tensor filter_shuffled;
193+
OP_REQUIRES_OK(ctx, ctx->allocate_temp(filter.dtype(),
194+
TensorShape(post_shuffle(filter)),
195+
&filter_shuffled));
190196
filter_shuffled.tensor<T, 5>().device(device, on_shuffled) =
191197
filter.shaped<T, 5>(pre_shuffle(filter)).shuffle(shuffle);
192198

193199
// Wait for the completion of input/filter shuffles.
194200
shuffles_completed.Wait();
195201

196202
// Write group convolution results into temporary output tensor.
197-
Tensor output_shuffled(output->dtype(), TensorShape(post_shuffle(*output)));
203+
Tensor output_shuffled;
204+
OP_REQUIRES_OK(ctx, ctx->allocate_temp(output->dtype(),
205+
TensorShape(post_shuffle(*output)),
206+
&output_shuffled));
198207

199208
for (int64 i = 0; i < num_groups; ++i) {
200209
// TODO(ezhulenev): Run this loop using `parallelFor` (regular parallelFor

0 commit comments

Comments
 (0)