Skip to content

Commit 1169e10

Browse files
authored
feat: improve engine caching and fix bugs (#3932)
1 parent 2e1fba6 commit 1169e10

File tree

7 files changed

+305
-148
lines changed

7 files changed

+305
-148
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1419,8 +1419,6 @@ def convert_exported_program_to_serialized_trt_engine(
14191419
interpreter_result = interpret_module_to_result(
14201420
gm,
14211421
inputs=flattened_input_list,
1422-
arg_inputs=list(trt_arg_inputs),
1423-
kwarg_inputs=trt_kwarg_inputs,
14241422
settings=settings,
14251423
engine_cache=engine_cache,
14261424
)

py/torch_tensorrt/dynamo/_refit.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
logger = logging.getLogger(__name__)
5454

5555

56-
@needs_refit
56+
@needs_refit # type: ignore[misc]
5757
def construct_refit_mapping(
5858
module: torch.fx.GraphModule,
5959
inputs: Sequence[Input],
@@ -86,7 +86,7 @@ def construct_refit_mapping(
8686
return weight_refit_map
8787

8888

89-
@needs_refit
89+
@needs_refit # type: ignore[misc]
9090
def construct_refit_mapping_from_weight_name_map(
9191
weight_name_map: dict[Any, Any],
9292
state_dict: dict[Any, Any],
@@ -131,7 +131,7 @@ def construct_refit_mapping_from_weight_name_map(
131131
return engine_weight_map
132132

133133

134-
@needs_refit
134+
@needs_refit # type: ignore[misc]
135135
def _refit_single_trt_engine_with_gm(
136136
new_gm: torch.fx.GraphModule,
137137
old_engine: trt.ICudaEngine,
@@ -214,7 +214,7 @@ def _refit_single_trt_engine_with_gm(
214214
raise AssertionError("Refitting failed.")
215215

216216

217-
@needs_refit
217+
@needs_refit # type: ignore[misc]
218218
def refit_module_weights(
219219
compiled_module: torch.fx.GraphModule | ExportedProgram,
220220
new_weight_module: ExportedProgram,
@@ -554,9 +554,10 @@ def refit_module_weights(
554554
weight_name_map=None,
555555
)
556556

557-
# clear EXCLUDE_WEIGHTS flag
557+
# clear EXCLUDE_WEIGHTS flag and set INCLUDE_REFIT flag to make the engine refittable
558558
serialization_config = engine.create_serialization_config()
559559
serialization_config.clear_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
560+
serialization_config.set_flag(trt.SerializationFlag.INCLUDE_REFIT)
560561
serialized_engine = engine.serialize_with_config(serialization_config)
561562

562563
if isinstance(compiled_submodule, PythonTorchTensorRTModule):

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,6 @@ def __setstate__(self, state: dict[str, Any]) -> None:
200200
"engine_capability",
201201
"hardware_compatible",
202202
"refit_identical_engine_weights",
203-
"strip_engine_weights", # TODO: @Evan to remove this after implementing caching weight-stripped engines as default?
204203
"immutable_weights",
205204
"enable_weight_streaming",
206205
"tiling_optimization_level",

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,11 @@ def _pretraced_backend(
158158
"require_full_compilation arg is not applicable for torch.compile with backend='torch_tensorrt"
159159
)
160160
if settings.strip_engine_weights:
161-
logger.error(
162-
"strip_engine_weights arg is not supported for torch.compile()"
161+
logger.warning(
162+
"strip_engine_weights=True is not supported for torch.compile(). It will be set to False automatically."
163163
)
164+
settings.strip_engine_weights = False
165+
164166
trt_compiled = compile_module(
165167
gm,
166168
torchtrt_inputs,

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 1 addition & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from torch_tensorrt._utils import is_tensorrt_version_supported
3232
from torch_tensorrt.dynamo import _defaults
3333
from torch_tensorrt.dynamo._engine_cache import BaseEngineCache
34-
from torch_tensorrt.dynamo._settings import CompilationSettings, settings_are_compatible
34+
from torch_tensorrt.dynamo._settings import CompilationSettings
3535
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
3636
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
3737
DYNAMO_CONVERTERS as CONVERTERS,
@@ -594,79 +594,6 @@ def _save_weight_mapping(self) -> None:
594594
gc.collect()
595595
torch.cuda.empty_cache()
596596

597-
@needs_refit # type: ignore[misc]
598-
def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]:
599-
# query the cached TRT engine
600-
cached_data = self.engine_cache.check(hash_val) # type: ignore[union-attr]
601-
if cached_data is not None: # hit the cache
602-
(
603-
serialized_engine,
604-
self._input_names,
605-
self._output_names,
606-
cached_engine_input_specs,
607-
engine_compilation_settings,
608-
self.weight_name_map,
609-
self.ctx.requires_output_allocator,
610-
) = cached_data
611-
612-
setting_compatiblity, incompattible_settings = settings_are_compatible(
613-
self.compilation_settings, engine_compilation_settings
614-
)
615-
assert (
616-
setting_compatiblity
617-
), f"Attempted to refit a cached engine with incompatible settings: {incompattible_settings}, (old_settings: {engine_compilation_settings}, new_settings: {self.compilation_settings})"
618-
619-
for i, e in enumerate(
620-
[
621-
Input.equivalent_spec(c, i)
622-
for c, i in zip(cached_engine_input_specs, self.input_specs)
623-
]
624-
):
625-
assert (
626-
e
627-
), f"Attempted to refit a cached engine built for a different input size (input: {i}, cached size: {cached_engine_input_specs[i]}, new size: {self.input_specs[i]}"
628-
629-
_LOGGER.info(
630-
"Found the cached engine that corresponds to this graph. It is directly loaded."
631-
)
632-
633-
# refit the cached engine with the new graph module
634-
if not self.compilation_settings.strip_engine_weights:
635-
runtime = trt.Runtime(TRT_LOGGER)
636-
engine = runtime.deserialize_cuda_engine(serialized_engine)
637-
638-
from torch_tensorrt.dynamo._refit import (
639-
_refit_single_trt_engine_with_gm,
640-
)
641-
642-
_refit_single_trt_engine_with_gm(
643-
new_gm=self.module,
644-
old_engine=engine,
645-
input_list=self.input_specs,
646-
settings=self.compilation_settings,
647-
weight_name_map=self.weight_name_map,
648-
)
649-
650-
# TODO: @Evan is waiting for TRT's feature to load the weight-stripped engine
651-
# # EXCLUDE_WEIGHTS flag must be cleared
652-
# serialization_config = engine.create_serialization_config()
653-
# serialization_config.clear_flag(
654-
# trt.SerializationFlag.EXCLUDE_WEIGHTS
655-
# )
656-
# serialized_engine = engine.serialize_with_config(
657-
# serialization_config
658-
# )
659-
# # As of now, the engine becomes non-refittable because when EXCLUDE_WEIGHTS flag is cleared, the REFIT flag is also cleared by TRT to make the plan file smaller
660-
661-
return TRTInterpreterResult(
662-
engine,
663-
self._input_names,
664-
self._output_names,
665-
self.weight_name_map,
666-
self.ctx.requires_output_allocator,
667-
)
668-
return None
669-
670597
def run(
671598
self,
672599
strict_type_constraints: bool = False,
@@ -682,26 +609,6 @@ def run(
682609
Return:
683610
TRTInterpreterResult
684611
"""
685-
# self.engine_cache could be None if:
686-
# 1) engine_cache is not passed in when calling this function like convert_exported_program_to_serialized_trt_engine etc., or
687-
# 2) both cache_built_engines and reuse_cached_engines are False
688-
if (
689-
self.engine_cache is not None
690-
and not self.compilation_settings.immutable_weights
691-
):
692-
if (
693-
self.compilation_settings.cache_built_engines
694-
or self.compilation_settings.reuse_cached_engines
695-
):
696-
hash_val = self.engine_cache.get_hash(
697-
self.module, self.input_specs, self.compilation_settings
698-
)
699-
700-
if self.compilation_settings.reuse_cached_engines:
701-
interpreter_result = self._pull_cached_engine(hash_val)
702-
if interpreter_result is not None: # hit the cache
703-
return interpreter_result # type: ignore[no-any-return]
704-
705612
self._construct_trt_network_def()
706613
_LOGGER.debug(
707614
f"CPU memory usage after network construction: {get_cpu_memory_usage()} MB"

0 commit comments

Comments
 (0)