1+ import collections
12import functools
23import logging
34import re
@@ -50,20 +51,28 @@ def get_node_io(
5051) -> str :
5152 """Gets a string representing the node inputs and outputs including tensor shapes and dtypes"""
5253
53- def format_tensor_metadata (
54- metadata : Union [TensorMetadata , Sequence [TensorMetadata ]]
55- ) -> str :
54+ def format_tensor_metadata (metadata : Union [Any , Sequence [Any ]]) -> str :
5655 """Formats the metadata for a single node"""
5756 # If the provided data is a simple TensorMetadata object, parse it
58- if isinstance (metadata , TensorMetadata ):
59- return f"{ tuple (metadata .shape )} @{ metadata .dtype } "
57+ if isinstance (metadata , TensorMetadata ) or issubclass (
58+ type (metadata ), torch .Tensor
59+ ):
60+ return f"{ tuple (metadata .shape )} @{ metadata .dtype } " # type: ignore
61+ # If the provided data is a scalar, return it as is
62+ elif isinstance (metadata , (int , float , bool )):
63+ return f"{ metadata } @Python-{ type (metadata )} "
6064 # If the provided data is a sequence, recursively parse it
61- else :
65+ elif isinstance ( metadata , collections . abc . Sequence ) :
6266 formatted_str = "("
6367 for meta in metadata :
6468 formatted_str += format_tensor_metadata (meta ) + ", "
6569
6670 return formatted_str [:- 2 ] + ")"
71+ else :
72+ _LOGGER .warning (
73+ f"Detected unparseable type in node formatting: { type (metadata )} "
74+ )
75+ return ""
6776
6877 # Format input tensors
6978 metadata_string = "Inputs: ("
@@ -74,8 +83,10 @@ def format_tensor_metadata(
7483 if arg .op == "get_attr" :
7584 shape , dtype = constant_mapping [str (arg )]
7685 arg_repr = f"{ shape } @{ dtype } "
77- elif arg .meta .get ("tensor_meta" , False ) :
86+ elif arg .meta .get ("tensor_meta" ) is not None :
7887 arg_repr = format_tensor_metadata (arg .meta ["tensor_meta" ])
88+ elif arg .meta .get ("val" ) is not None :
89+ arg_repr = format_tensor_metadata (arg .meta ["val" ])
7990 else :
8091 arg_repr = ""
8192
@@ -92,8 +103,10 @@ def format_tensor_metadata(
92103 if node .op == "get_attr" :
93104 shape , dtype = constant_mapping [str (node )]
94105 node_repr = f"{ shape } @{ dtype } "
95- elif node .meta .get ("tensor_meta" , False ) :
106+ elif node .meta .get ("tensor_meta" ) is None :
96107 node_repr = format_tensor_metadata (node .meta ["tensor_meta" ])
108+ elif node .meta .get ("val" ) is None :
109+ node_repr = format_tensor_metadata (node .meta ["val" ])
97110 else :
98111 node_repr = ""
99112 metadata_string += f"{ node } : { node_repr } , "
0 commit comments