@@ -183,20 +183,29 @@ struct LaunchGrouped {
183183 auto on_shuffled = [&]() { shuffles_completed.DecrementCount (); };
184184
185185 // Shuffle input into temporary tensor.
186- Tensor input_shuffled (input.dtype (), TensorShape (post_shuffle (input)));
186+ Tensor input_shuffled;
187+ OP_REQUIRES_OK (
188+ ctx, ctx->allocate_temp (input.dtype (), TensorShape (post_shuffle (input)),
189+ &input_shuffled));
187190 input_shuffled.tensor <T, 5 >().device (device, on_shuffled) =
188191 input.shaped <T, 5 >(pre_shuffle (input)).shuffle (shuffle);
189192
190193 // Shuffle filter into temporary tensor.
191- Tensor filter_shuffled (filter.dtype (), TensorShape (post_shuffle (filter)));
194+ Tensor filter_shuffled;
195+ OP_REQUIRES_OK (ctx, ctx->allocate_temp (filter.dtype (),
196+ TensorShape (post_shuffle (filter)),
197+ &filter_shuffled));
192198 filter_shuffled.tensor <T, 5 >().device (device, on_shuffled) =
193199 filter.shaped <T, 5 >(pre_shuffle (filter)).shuffle (shuffle);
194200
195201 // Wait for the completion of input/filter shuffles.
196202 shuffles_completed.Wait ();
197203
198204 // Write group convolution results into temporary output tensor.
199- Tensor output_shuffled (output->dtype (), TensorShape (post_shuffle (*output)));
205+ Tensor output_shuffled;
206+ OP_REQUIRES_OK (ctx, ctx->allocate_temp (output->dtype (),
207+ TensorShape (post_shuffle (*output)),
208+ &output_shuffled));
200209
201210 for (int64_t i = 0 ; i < num_groups; ++i) {
202211 // TODO(ezhulenev): Run this loop using `parallelFor` (regular parallelFor
0 commit comments