Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
# If specified, try using the fast partitioner and fall back to the global one on failure
if settings.use_fast_partitioner:
try:
logger.info("Partitioning the graph via the fast partitioner")
partitioned_module, supported_ops = partitioning.fast_partition(
gm,
verbose=settings.debug,
Expand All @@ -322,14 +323,15 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
except torch.fx.passes.splitter_base.FxNetSplitterInternalError:
logger.error(
"Partitioning failed on the subgraph with fast partition. See trace above. "
+ "Retrying with global partition.",
"Retrying with global partition.",
exc_info=True,
)

fast_partitioner_failed = True
settings.use_fast_partitioner = False

if not settings.use_fast_partitioner:
logger.info("Partitioning the graph via the global partitioner")
partitioned_module, supported_ops = partitioning.global_partition(
gm,
verbose=settings.debug,
Expand Down Expand Up @@ -367,14 +369,15 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
# Get the submodule inputs for min, opt, max shapes of the graph inputs
submodule_inputs = partitioning.construct_submodule_inputs(submodule)

assert submodule_inputs is not None

logger.debug(
"Submodule name: %s\n Input shapes: %s\n %s",
"Converting submodule: %s\n Input shapes: %s\n %s",
str(name),
[input.shape for input in submodule_inputs],
str(submodule.graph),
)

assert submodule_inputs is not None
# Handle long/double inputs if requested by the user
if settings.truncate_double:
submodule_inputs = repair_double_inputs(
Expand Down
17 changes: 16 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import warnings
from datetime import datetime
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple

import numpy as np
import tensorrt as trt
Expand All @@ -21,6 +21,7 @@
)
from torch_tensorrt.dynamo.conversion._ConverterRegistry import CallingConvention
from torch_tensorrt.dynamo.conversion.converter_utils import (
get_node_io,
get_node_name,
get_trt_tensor,
)
Expand Down Expand Up @@ -106,6 +107,9 @@ def __init__(
[dtype._from(o) for o in output_dtypes] if output_dtypes else None
)

# Mapping of constants to shapes and dtypes
self.const_mapping: Dict[str, Tuple[Sequence[int], str]] = {}

def validate_conversion(self) -> Set[str]:
missing_converters: Set[str] = set()

Expand Down Expand Up @@ -361,8 +365,19 @@ def run_node(self, n: torch.fx.Node) -> torch.fx.Node:
n.kwargs = kwargs

# run the node
_LOGGER.debug(
f"Running node {self._cur_node_name}, a {self._cur_node.op} node "
f"with target {self._cur_node.target} in the TensorRT Interpreter"
)
trt_node: torch.fx.Node = super().run_node(n)

if n.op == "get_attr":
self.const_mapping[str(n)] = (tuple(trt_node.shape), str(trt_node.dtype))

_LOGGER.debug(
f"Ran node {self._cur_node_name} with properties: {get_node_io(n, self.const_mapping)}"
)

# remove "_itensor_to_tensor_meta"
kwargs = dict(n.kwargs)
del kwargs["_itensor_to_tensor_meta"]
Expand Down
71 changes: 71 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import collections
import functools
import logging
import re
Expand All @@ -8,6 +9,7 @@
import torch
import torch_tensorrt.dynamo.conversion.impl as impl
from torch.fx.node import Argument, Target
from torch.fx.passes.shape_prop import TensorMetadata
from torch_tensorrt import _enums
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
Expand Down Expand Up @@ -44,6 +46,75 @@ def get_node_name(node: torch.fx.Node) -> str:
return node_name


def get_node_io(
node: torch.fx.Node, constant_mapping: Dict[str, Tuple[Sequence[int], str]]
) -> str:
"""Gets a string representing the node inputs and outputs including tensor shapes and dtypes"""

def format_tensor_metadata(metadata: Union[Any, Sequence[Any]]) -> str:
"""Formats the metadata for a single node"""
# If the provided data is a simple TensorMetadata object, parse it
if isinstance(metadata, TensorMetadata) or issubclass(
type(metadata), torch.Tensor
):
return f"{tuple(metadata.shape)}@{metadata.dtype}" # type: ignore
# If the provided data is a scalar, return it as is
elif isinstance(metadata, (int, float, bool)):
return f"{metadata}@Python-{type(metadata)}"
# If the provided data is a sequence, recursively parse it
elif isinstance(metadata, collections.abc.Sequence):
formatted_str = "("
for meta in metadata:
formatted_str += format_tensor_metadata(meta) + ", "

return formatted_str[:-2] + ")"
else:
_LOGGER.warning(
f"Detected unparseable type in node formatting: {type(metadata)}"
)
return ""

# Format input tensors
metadata_string = "Inputs: ("

# For each input argument, format it accordingly
for arg in node.args:
if isinstance(arg, torch.fx.Node):
if arg.op == "get_attr":
shape, dtype = constant_mapping[str(arg)]
arg_repr = f"{shape}@{dtype}"
elif arg.meta.get("tensor_meta") is not None:
arg_repr = format_tensor_metadata(arg.meta["tensor_meta"])
elif arg.meta.get("val") is not None:
arg_repr = format_tensor_metadata(arg.meta["val"])
else:
arg_repr = ""

metadata_string += f"{arg}: {arg_repr}, "
else:
metadata_string += f"{arg}, "

metadata_string = (
metadata_string[:-2] if metadata_string[-1] != "(" else metadata_string
) + ")"

# Format output tensors and arguments
metadata_string += " | Outputs: ("
if node.op == "get_attr":
shape, dtype = constant_mapping[str(node)]
node_repr = f"{shape}@{dtype}"
elif node.meta.get("tensor_meta") is not None:
node_repr = format_tensor_metadata(node.meta["tensor_meta"])
elif node.meta.get("val") is not None:
node_repr = format_tensor_metadata(node.meta["val"])
else:
node_repr = ""
metadata_string += f"{node}: {node_repr}, "
metadata_string = metadata_string[:-2] + ")"

return metadata_string


def is_only_operator_on_placeholder(node: torch.fx.Node) -> bool:
"""Detects whether a call_function node is the only operator on a placeholder"""
# Returns true if the node operates on a placeholder and is a direct output
Expand Down
7 changes: 7 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/truncate_double.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import logging
from typing import Optional, Sequence, Set

import torch
Expand All @@ -8,6 +9,8 @@
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo.utils import get_torch_inputs

logger = logging.getLogger(__name__)


def _extract_downstream_get_nodes(
module_node: torch.fx.Node, output_indices: Set[int]
Expand Down Expand Up @@ -62,6 +65,10 @@ def _repair_64bit_input(
torch.float64,
), f"dtype argument must be torch.float64, got {dtype}"

logger.info(
f"Downcasting a 64-bit input at position {position} of submodule {submodule_name}"
)

# Determine target data type in 32 and 64 bit forms
dtype_64bit = dtype
dtype_32bit = torch.float32
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def replace_max_pool_with_indices(
args=node.args,
kwargs=node.kwargs,
)
maxpool_fused.meta = node.meta

logger.debug(
f"Replacing all uses of nodes {node}, {getitem_node} with fused maxpool node {maxpool_fused} "
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def parse_complex_tensor_structs(

else:
raise ValueError(
f"Invalid input type {type(inputs)} encountered in parse_complex_tensor_structs parsing. "
f"Invalid input type {type(inputs)} encountered during Dynamo input parsing. "
+ "Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}"
)

Expand Down