@@ -181,20 +181,29 @@ struct LaunchGrouped {
181181 auto on_shuffled = [&]() { shuffles_completed.DecrementCount (); };
182182
183183 // Shuffle input into temporary tensor.
184- Tensor input_shuffled (input.dtype (), TensorShape (post_shuffle (input)));
184+ Tensor input_shuffled;
185+ OP_REQUIRES_OK (
186+ ctx, ctx->allocate_temp (input.dtype (), TensorShape (post_shuffle (input)),
187+ &input_shuffled));
185188 input_shuffled.tensor <T, 5 >().device (device, on_shuffled) =
186189 input.shaped <T, 5 >(pre_shuffle (input)).shuffle (shuffle);
187190
188191 // Shuffle filter into temporary tensor.
189- Tensor filter_shuffled (filter.dtype (), TensorShape (post_shuffle (filter)));
192+ Tensor filter_shuffled;
193+ OP_REQUIRES_OK (ctx, ctx->allocate_temp (filter.dtype (),
194+ TensorShape (post_shuffle (filter)),
195+ &filter_shuffled));
190196 filter_shuffled.tensor <T, 5 >().device (device, on_shuffled) =
191197 filter.shaped <T, 5 >(pre_shuffle (filter)).shuffle (shuffle);
192198
193199 // Wait for the completion of input/filter shuffles.
194200 shuffles_completed.Wait ();
195201
196202 // Write group convolution results into temporary output tensor.
197- Tensor output_shuffled (output->dtype (), TensorShape (post_shuffle (*output)));
203+ Tensor output_shuffled;
204+ OP_REQUIRES_OK (ctx, ctx->allocate_temp (output->dtype (),
205+ TensorShape (post_shuffle (*output)),
206+ &output_shuffled));
198207
199208 for (int64 i = 0 ; i < num_groups; ++i) {
200209 // TODO(ezhulenev): Run this loop using `parallelFor` (regular parallelFor
0 commit comments