44from typing import Any , Callable , Dict , List , NamedTuple , Optional , Sequence , Set
55
66import numpy as np
7+ import tensorrt as trt
78import torch
89import torch .fx
910from torch .fx .node import _get_qualified_name
2526from torch_tensorrt .fx .observer import Observer
2627from torch_tensorrt .logging import TRT_LOGGER
2728
28- import tensorrt as trt
2929from packaging import version
3030
3131_LOGGER : logging .Logger = logging .getLogger (__name__ )
@@ -316,8 +316,10 @@ def run(
316316 )
317317 timing_cache = self ._create_timing_cache (builder_config , existing_cache )
318318
319- engine = self .builder .build_serialized_network (self .ctx .net , builder_config )
320- assert engine
319+ serialized_engine = self .builder .build_serialized_network (
320+ self .ctx .net , builder_config
321+ )
322+ assert serialized_engine
321323
322324 serialized_cache = (
323325 bytearray (timing_cache .serialize ())
@@ -327,10 +329,10 @@ def run(
327329 _LOGGER .info (
328330 f"Build TRT engine elapsed time: { datetime .now () - build_engine_start_time } "
329331 )
330- _LOGGER .info (f"TRT Engine uses: { engine .nbytes } bytes of Memory" )
332+ _LOGGER .info (f"TRT Engine uses: { serialized_engine .nbytes } bytes of Memory" )
331333
332334 return TRTInterpreterResult (
333- engine , self ._input_names , self ._output_names , serialized_cache
335+ serialized_engine , self ._input_names , self ._output_names , serialized_cache
334336 )
335337
336338 def run_node (self , n : torch .fx .Node ) -> torch .fx .Node :
0 commit comments