Skip to content

Commit f43a11c

Browse files
reedwmmihaimaruseac
authored andcommitted
Fix segfault on OOM in Conv2D.
PiperOrigin-RevId: 404655317 Change-Id: I33588dbd3f5d0fef980e3c908bf5515a9ee09ce7
1 parent 4413827 commit f43a11c

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

tensorflow/core/kernels/conv_ops.cc

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,20 +183,29 @@ struct LaunchGrouped {
183183
auto on_shuffled = [&]() { shuffles_completed.DecrementCount(); };
184184

185185
// Shuffle input into temporary tensor.
186-
Tensor input_shuffled(input.dtype(), TensorShape(post_shuffle(input)));
186+
Tensor input_shuffled;
187+
OP_REQUIRES_OK(
188+
ctx, ctx->allocate_temp(input.dtype(), TensorShape(post_shuffle(input)),
189+
&input_shuffled));
187190
input_shuffled.tensor<T, 5>().device(device, on_shuffled) =
188191
input.shaped<T, 5>(pre_shuffle(input)).shuffle(shuffle);
189192

190193
// Shuffle filter into temporary tensor.
191-
Tensor filter_shuffled(filter.dtype(), TensorShape(post_shuffle(filter)));
194+
Tensor filter_shuffled;
195+
OP_REQUIRES_OK(ctx, ctx->allocate_temp(filter.dtype(),
196+
TensorShape(post_shuffle(filter)),
197+
&filter_shuffled));
192198
filter_shuffled.tensor<T, 5>().device(device, on_shuffled) =
193199
filter.shaped<T, 5>(pre_shuffle(filter)).shuffle(shuffle);
194200

195201
// Wait for the completion of input/filter shuffles.
196202
shuffles_completed.Wait();
197203

198204
// Write group convolution results into temporary output tensor.
199-
Tensor output_shuffled(output->dtype(), TensorShape(post_shuffle(*output)));
205+
Tensor output_shuffled;
206+
OP_REQUIRES_OK(ctx, ctx->allocate_temp(output->dtype(),
207+
TensorShape(post_shuffle(*output)),
208+
&output_shuffled));
200209

201210
for (int64_t i = 0; i < num_groups; ++i) {
202211
// TODO(ezhulenev): Run this loop using `parallelFor` (regular parallelFor

0 commit comments

Comments
 (0)