@@ -210,10 +210,7 @@ void PackedConvWeightCudnn<kSpatialDim>::apply_impl_helper(const at::Tensor& qua
210210 auto requantize_multiplier = act_scale * weight_scale / output_scale;
211211 requantize_multiplier_tensor.fill_ (requantize_multiplier);
212212 c10::optional<at::Tensor> bias_multiplier_tensor;
213- c10::optional<at::Tensor> after_scales_bias;
214- c10::optional<at::Tensor> after_add;
215213 c10::optional<at::Tensor> broadcasted_bias;
216- c10::optional<at::Tensor> after_relu;
217214 if (bias.has_value ()) {
218215 // the input bias is a 1-D tensor whose size is the same as the size of the second dimension of quantized_output.
219216 // we need to add trailing dimensions in order to properly broadcast bias, otherwise broadcast_to will fail.
@@ -227,11 +224,6 @@ void PackedConvWeightCudnn<kSpatialDim>::apply_impl_helper(const at::Tensor& qua
227224 bias_multiplier_tensor = at::empty (quantized_output.sizes (), at::device (at::kCUDA ).dtype (at::kFloat ), at::MemoryFormat::ChannelsLast);
228225 auto bias_multiplier = 1.0 / (act_scale * weight_scale);
229226 bias_multiplier_tensor.value ().fill_ (bias_multiplier);
230- after_scales_bias = at::empty (quantized_output.sizes (), at::device (at::kCUDA ).dtype (at::kFloat ), at::MemoryFormat::ChannelsLast);
231- after_add = at::empty (quantized_output.sizes (), at::device (at::kCUDA ).dtype (at::kFloat ), at::MemoryFormat::ChannelsLast);
232- }
233- if (kReluFused ) {
234- after_relu = at::empty (quantized_output.sizes (), at::device (at::kCUDA ).dtype (at::kFloat ), at::MemoryFormat::ChannelsLast);
235227 }
236228
237229 cudnnHandle_t handle = at::native::getCudnnHandle ();
@@ -269,15 +261,15 @@ void PackedConvWeightCudnn<kSpatialDim>::apply_impl_helper(const at::Tensor& qua
269261 uids = {' x' , ' y' , ' w' , ' s' , ' r' };
270262 if (bias.has_value ()) {
271263 data_ptrs.insert (data_ptrs.end (), {broadcasted_bias.value ().data_ptr (), bias_multiplier_tensor.value ().data_ptr (),
272- after_scales_bias .value ().data_ptr (), after_add. value () .data_ptr ()});
264+ broadcasted_bias .value ().data_ptr (), conv_output .data_ptr ()});
273265 uids.insert (uids.end (), {' b' , ' c' , ' d' , ' e' });
274266 if (kReluFused ) {
275- data_ptrs.emplace_back (after_relu. value () .data_ptr ()),
267+ data_ptrs.emplace_back (conv_output .data_ptr ()),
276268 uids.emplace_back (' f' );
277269 }
278270 } else {
279271 if (kReluFused ) {
280- data_ptrs.emplace_back (after_relu. value () .data_ptr ());
272+ data_ptrs.emplace_back (conv_output .data_ptr ());
281273 uids.emplace_back (' f' );
282274 }
283275 }
@@ -313,28 +305,27 @@ void PackedConvWeightCudnn<kSpatialDim>::apply_impl_helper(const at::Tensor& qua
313305 // we can't directly assign bias_mult_op becauase operator= is deleted for cudnn_frontend::Operation;
314306 // alternatively, I think we can use std::unique_ptr and dynamically allocate these builder ops
315307 // but here, we chose to do it statically. c10::optional<T>::emplace() enables this approach
316- // TODO: can we assign the result back into bias and get rid of after_scales_bias? pending NVIDIA response
317308
318309 // bias_mult_op computes bias_fp32 / (act_scale * w_scale) or bias_fp32 * (1 / (act_scale * w_scale))
319310 // where bias_multiplier = (1 / (act_scale * w_scale))
320311 // output is a fp32 tensor
312+ // we use inplace operation here where the output is assigned to the input
321313 bias_mult_op.emplace (cudnn_frontend::OperationBuilder (CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
322314 .setxDesc (getTensorDescriptor (broadcasted_bias.value (), ' b' , getAlignment (broadcasted_bias.value ())))
323315 .setbDesc (getTensorDescriptor (bias_multiplier_tensor.value (), ' c' , getAlignment (bias_multiplier_tensor.value ())))
324- .setyDesc (getTensorDescriptor (after_scales_bias .value (), ' d' , getAlignment (after_scales_bias .value ())))
316+ .setyDesc (getTensorDescriptor (broadcasted_bias .value (), ' d' , getAlignment (broadcasted_bias .value ())))
325317 .setpwDesc (getPointWiseMulDescriptor (at::native::getCudnnDataType (bias_multiplier_tensor.value ())))
326318 .build ());
327319
328- // TODO: can we assign the result back into conv_output and get rid of after_add?
329-
330320 // computes (act_int8 * w_int8 + [bias_fp32/(act_scale * w_scale)])
331- // where the 1st and 2nd summands is conv_output and after_scales_bias , resp.
321+ // where the 1st and 2nd summands is conv_output and broadcasted_bias , resp.
332322 // output is a fp32 tensor
323+ // we use inplace operation here where the output is assigned to the input
333324 sum_conv_bias_op.emplace (cudnn_frontend::OperationBuilder (CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
334325 .setxDesc (conv_op.getOutputTensor ())
335- .setbDesc (getTensorDescriptor (after_scales_bias .value (), ' d' , getAlignment (after_scales_bias .value ())))
336- .setyDesc (getTensorDescriptor (after_add. value () , ' e' , getAlignment (after_add. value ()) ))
337- .setpwDesc (getPointWiseAddDescriptor (at::native::getCudnnDataType (after_scales_bias .value ())))
326+ .setbDesc (getTensorDescriptor (broadcasted_bias .value (), ' d' , getAlignment (broadcasted_bias .value ())))
327+ .setyDesc (getTensorDescriptor (conv_output , ' e' , key. output_alignment ))
328+ .setpwDesc (getPointWiseAddDescriptor (at::native::getCudnnDataType (broadcasted_bias .value ())))
338329 .build ());
339330 }
340331
@@ -344,11 +335,11 @@ void PackedConvWeightCudnn<kSpatialDim>::apply_impl_helper(const at::Tensor& qua
344335 c10::optional<cudnn_frontend::Operation> relu_op;
345336 std::shared_ptr<cudnn_frontend::OpaqueBackendPointer> tensor2requant_ptr = bias.has_value () ? sum_conv_bias_op.value ().getOutputTensor () : conv_op.getOutputTensor ();
346337 if (kReluFused ) {
347- // TODO: can we assign the result back into conv_output and get rid of after_relu?
338+ // we use inplace operation here where the output is assigned to the input
348339 relu_op.emplace (cudnn_frontend::OperationBuilder (CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
349340 .setxDesc (tensor2requant_ptr)
350- .setyDesc (getTensorDescriptor (after_relu. value () , ' f' , getAlignment (after_relu. value ()) ))
351- .setpwDesc (getPointWiseReluDescriptor (at::native::getCudnnDataType (after_relu. value () )))
341+ .setyDesc (getTensorDescriptor (conv_output , ' f' , key. output_alignment ))
342+ .setpwDesc (getPointWiseReluDescriptor (at::native::getCudnnDataType (conv_output )))
352343 .build ());
353344 }
354345
0 commit comments