Skip to content

Commit 93d16a6

Browse files
committed
fix: Add I/O values for nodes
1 parent 0b84610 commit 93d16a6

File tree

5 files changed

+96
-7
lines changed

5 files changed

+96
-7
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,6 @@ def compile_module(
294294
# Partition module into components that can be TRT-accelerated
295295
fast_partitioner_failed = False
296296

297-
logger.info("Beginning TensorRT operator Partitioning Phase")
298297
# If specified, try using the fast partitioner and fall back to the global one on failure
299298
if settings.use_fast_partitioner:
300299
try:
@@ -330,11 +329,6 @@ def compile_module(
330329
if not settings.use_fast_partitioner:
331330
dryrun_tracker.to_run_in_torch.extend(parse_non_trt_nodes(partitioned_module))
332331

333-
logger.info(
334-
"Successfully completed graph partitioning phase. "
335-
"Beginning the conversion phase."
336-
)
337-
338332
# Store TRT replicas of Torch subgraphs
339333
trt_modules = {}
340334
# Iterate over all components that can be accelerated

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import warnings
33
from datetime import datetime
4-
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set
4+
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple
55

66
import numpy as np
77
import tensorrt as trt
@@ -18,6 +18,7 @@
1818
)
1919
from torch_tensorrt.dynamo.conversion._ConverterRegistry import CallingConvention
2020
from torch_tensorrt.dynamo.conversion.converter_utils import (
21+
get_node_io,
2122
get_node_name,
2223
get_trt_tensor,
2324
)
@@ -100,6 +101,9 @@ def __init__(
100101
# Data types for TRT Module output Tensors
101102
self.output_dtypes = output_dtypes
102103

104+
# Mapping of constants to shapes and dtypes
105+
self.const_mapping: Dict[str, Tuple[Sequence[int], str]] = {}
106+
103107
def validate_conversion(self) -> Set[str]:
104108
missing_converters: Set[str] = set()
105109

@@ -284,6 +288,13 @@ def run_node(self, n: torch.fx.Node) -> torch.fx.Node:
284288
)
285289
trt_node: torch.fx.Node = super().run_node(n)
286290

291+
if n.op == "get_attr":
292+
self.const_mapping[str(n)] = (tuple(trt_node.shape), str(trt_node.dtype))
293+
294+
_LOGGER.debug(
295+
f"Ran node {self._cur_node_name} with properties: {get_node_io(n, self.const_mapping)}"
296+
)
297+
287298
# remove "_itensor_to_tensor_meta"
288299
kwargs = dict(n.kwargs)
289300
del kwargs["_itensor_to_tensor_meta"]

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99
from torch import SymBool, SymFloat, SymInt
1010
from torch.fx.node import Argument, Target
11+
from torch.fx.passes.shape_prop import TensorMetadata
1112
from torch_tensorrt.dynamo._SourceIR import SourceIR
1213
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
1314
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
@@ -47,6 +48,63 @@ def get_node_name(node: torch.fx.Node) -> str:
4748
return node_name
4849

4950

51+
def get_node_io(
52+
node: torch.fx.Node, constant_mapping: Dict[str, Tuple[Sequence[int], str]]
53+
) -> str:
54+
"""Gets a string representing the node inputs and outputs including tensor shapes and dtypes"""
55+
56+
def format_tensor_metadata(
57+
metadata: Union[TensorMetadata, Sequence[TensorMetadata]]
58+
) -> str:
59+
"""Formats the metadata for a single node"""
60+
# If the provided data is a simple TensorMetadata object, parse it
61+
if isinstance(metadata, TensorMetadata):
62+
return f"{tuple(metadata.shape)}@{metadata.dtype}"
63+
# If the provided data is a sequence, recursively parse it
64+
else:
65+
formatted_str = "("
66+
for meta in metadata:
67+
formatted_str += format_tensor_metadata(meta) + ", "
68+
69+
return formatted_str[:-2] + ")"
70+
71+
# Format input tensors
72+
metadata_string = "Inputs: ("
73+
74+
# For each input argument, format it accordingly
75+
for arg in node.args:
76+
if isinstance(arg, torch.fx.Node):
77+
if arg.op == "get_attr":
78+
shape, dtype = constant_mapping[str(arg)]
79+
arg_repr = f"{shape}@{dtype}"
80+
elif arg.meta.get("tensor_meta", False):
81+
arg_repr = format_tensor_metadata(arg.meta["tensor_meta"])
82+
else:
83+
arg_repr = ""
84+
85+
metadata_string += f"{arg}: {arg_repr}, "
86+
else:
87+
metadata_string += f"{arg}, "
88+
89+
metadata_string = (
90+
metadata_string[:-2] if metadata_string[-1] != "(" else metadata_string
91+
) + ")"
92+
93+
# Format output tensors and arguments
94+
metadata_string += " | Outputs: ("
95+
if node.op == "get_attr":
96+
shape, dtype = constant_mapping[str(node)]
97+
node_repr = f"{shape}@{dtype}"
98+
elif node.meta.get("tensor_meta", False):
99+
node_repr = format_tensor_metadata(node.meta["tensor_meta"])
100+
else:
101+
node_repr = ""
102+
metadata_string += f"{node}: {node_repr}, "
103+
metadata_string = metadata_string[:-2] + ")"
104+
105+
return metadata_string
106+
107+
50108
def is_only_operator_on_placeholder(node: torch.fx.Node) -> bool:
51109
"""Detects whether a call_function node is the only operator on a placeholder"""
52110
# Returns true if the node operates on a placeholder and is a direct output

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .lower_efficient_attention import lower_efficient_attention
99
from .lower_linear import lower_linear
1010
from .pass_manager import DynamoPassManager
11+
from .propagate_shapes import propagate_shapes
1112
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
1213
from .repair_input_as_output import repair_input_as_output
1314
from .replace_max_pool_with_indices import replace_max_pool_with_indices
@@ -23,6 +24,7 @@
2324
fuse_prims_broadcast,
2425
replace_max_pool_with_indices,
2526
view_to_reshape,
27+
propagate_shapes,
2628
]
2729
)
2830

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import logging
2+
from typing import Sequence
3+
4+
import torch
5+
from torch.fx.passes.shape_prop import ShapeProp
6+
7+
logger = logging.getLogger(__name__)
8+
9+
10+
def propagate_shapes(
11+
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
12+
) -> torch.fx.GraphModule:
13+
"""Attempts to propagate shapes through the graph"""
14+
15+
# Propagate shapes through the graph
16+
try:
17+
ShapeProp(gm).propagate(*sample_inputs)
18+
except (RuntimeError, AssertionError):
19+
logger.warning(
20+
"Shape Propagation Failed on Graph, skipping propagate_shapes lowering pass",
21+
exc_info=True,
22+
)
23+
24+
return gm

0 commit comments

Comments
 (0)