@@ -211,8 +211,6 @@ void PackedConvWeightCudnn<kSpatialDim>::apply_impl_helper(const at::Tensor& qua
211211 auto requantize_multiplier = act_scale * weight_scale / output_scale;
212212 requantize_multiplier_tensor.fill_ (requantize_multiplier);
213213 c10::optional<at::Tensor> bias_multiplier_tensor;
214- c10::optional<at::Tensor> after_scales_bias;
215- c10::optional<at::Tensor> after_add;
216214 c10::optional<at::Tensor> broadcasted_bias;
217215 c10::optional<at::Tensor> after_relu;
218216 auto weight = orig_weight_.int_repr ();
@@ -229,11 +227,6 @@ void PackedConvWeightCudnn<kSpatialDim>::apply_impl_helper(const at::Tensor& qua
229227 bias_multiplier_tensor = at::empty (quantized_output.sizes (), at::device (at::kCUDA ).dtype (at::kFloat ), at::MemoryFormat::ChannelsLast);
230228 auto bias_multiplier = 1.0 / (act_scale * weight_scale);
231229 bias_multiplier_tensor.value ().fill_ (bias_multiplier);
232- after_scales_bias = at::empty (quantized_output.sizes (), at::device (at::kCUDA ).dtype (at::kFloat ), at::MemoryFormat::ChannelsLast);
233- after_add = at::empty (quantized_output.sizes (), at::device (at::kCUDA ).dtype (at::kFloat ), at::MemoryFormat::ChannelsLast);
234- }
235- if (kReluFused ) {
236- after_relu = at::empty (quantized_output.sizes (), at::device (at::kCUDA ).dtype (at::kFloat ), at::MemoryFormat::ChannelsLast);
237230 }
238231
239232 cudnnHandle_t handle = at::native::getCudnnHandle ();
@@ -271,15 +264,15 @@ void PackedConvWeightCudnn<kSpatialDim>::apply_impl_helper(const at::Tensor& qua
271264 uids = {' x' , ' y' , ' w' , ' s' , ' r' };
272265 if (bias_.has_value ()) {
273266 data_ptrs.insert (data_ptrs.end (), {broadcasted_bias.value ().data_ptr (), bias_multiplier_tensor.value ().data_ptr (),
274- after_scales_bias .value ().data_ptr (), after_add. value () .data_ptr ()});
267+ broadcasted_bias .value ().data_ptr (), conv_output .data_ptr ()});
275268 uids.insert (uids.end (), {' b' , ' c' , ' d' , ' e' });
276269 if (kReluFused ) {
277- data_ptrs.emplace_back (after_relu. value () .data_ptr ()),
270+ data_ptrs.emplace_back (conv_output .data_ptr ()),
278271 uids.emplace_back (' f' );
279272 }
280273 } else {
281274 if (kReluFused ) {
282- data_ptrs.emplace_back (after_relu. value () .data_ptr ());
275+ data_ptrs.emplace_back (conv_output .data_ptr ());
283276 uids.emplace_back (' f' );
284277 }
285278 }
@@ -315,28 +308,27 @@ void PackedConvWeightCudnn<kSpatialDim>::apply_impl_helper(const at::Tensor& qua
315308 // we can't directly assign bias_mult_op becauase operator= is deleted for cudnn_frontend::Operation;
316309 // alternatively, I think we can use std::unique_ptr and dynamically allocate these builder ops
317310 // but here, we chose to do it statically. c10::optional<T>::emplace() enables this approach
318- // TODO: can we assign the result back into bias and get rid of after_scales_bias? pending NVIDIA response
319311
320312 // bias_mult_op computes bias_fp32 / (act_scale * w_scale) or bias_fp32 * (1 / (act_scale * w_scale))
321313 // where bias_multiplier = (1 / (act_scale * w_scale))
322314 // output is a fp32 tensor
315+ // we use inplace operation here where the output is assigned to the input
323316 bias_mult_op.emplace (cudnn_frontend::OperationBuilder (CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
324317 .setxDesc (getTensorDescriptor (broadcasted_bias.value (), ' b' , getAlignment (broadcasted_bias.value ())))
325318 .setbDesc (getTensorDescriptor (bias_multiplier_tensor.value (), ' c' , getAlignment (bias_multiplier_tensor.value ())))
326- .setyDesc (getTensorDescriptor (after_scales_bias .value (), ' d' , getAlignment (after_scales_bias .value ())))
319+ .setyDesc (getTensorDescriptor (broadcasted_bias .value (), ' d' , getAlignment (broadcasted_bias .value ())))
327320 .setpwDesc (getPointWiseMulDescriptor (at::native::getCudnnDataType (bias_multiplier_tensor.value ())))
328321 .build ());
329322
330- // TODO: can we assign the result back into conv_output and get rid of after_add?
331-
332323 // computes (act_int8 * w_int8 + [bias_fp32/(act_scale * w_scale)])
333- // where the 1st and 2nd summands is conv_output and after_scales_bias , resp.
324+ // where the 1st and 2nd summands is conv_output and broadcasted_bias , resp.
334325 // output is a fp32 tensor
326+ // we use inplace operation here where the output is assigned to the input
335327 sum_conv_bias_op.emplace (cudnn_frontend::OperationBuilder (CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
336328 .setxDesc (conv_op.getOutputTensor ())
337- .setbDesc (getTensorDescriptor (after_scales_bias .value (), ' d' , getAlignment (after_scales_bias .value ())))
338- .setyDesc (getTensorDescriptor (after_add. value () , ' e' , getAlignment (after_add. value ()) ))
339- .setpwDesc (getPointWiseAddDescriptor (at::native::getCudnnDataType (after_scales_bias .value ())))
329+ .setbDesc (getTensorDescriptor (broadcasted_bias .value (), ' d' , getAlignment (broadcasted_bias .value ())))
330+ .setyDesc (getTensorDescriptor (conv_output , ' e' , key. output_alignment ))
331+ .setpwDesc (getPointWiseAddDescriptor (at::native::getCudnnDataType (broadcasted_bias .value ())))
340332 .build ());
341333 }
342334
@@ -346,11 +338,11 @@ void PackedConvWeightCudnn<kSpatialDim>::apply_impl_helper(const at::Tensor& qua
346338 c10::optional<cudnn_frontend::Operation> relu_op;
347339 std::shared_ptr<cudnn_frontend::OpaqueBackendPointer> tensor2requant_ptr = bias_.has_value () ? sum_conv_bias_op.value ().getOutputTensor () : conv_op.getOutputTensor ();
348340 if (kReluFused ) {
349- // TODO: can we assign the result back into conv_output and get rid of after_relu?
341+ // we use inplace operation here where the output is assigned to the input
350342 relu_op.emplace (cudnn_frontend::OperationBuilder (CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
351343 .setxDesc (tensor2requant_ptr)
352- .setyDesc (getTensorDescriptor (after_relu. value () , ' f' , getAlignment (after_relu. value ()) ))
353- .setpwDesc (getPointWiseReluDescriptor (at::native::getCudnnDataType (after_relu. value () )))
344+ .setyDesc (getTensorDescriptor (conv_output , ' f' , key. output_alignment ))
345+ .setpwDesc (getPointWiseReluDescriptor (at::native::getCudnnDataType (conv_output )))
354346 .build ());
355347 }
356348
0 commit comments