@@ -88,28 +88,23 @@ def _cast_to_type(g, input, to_type):
8888 return getattr (sym_opset9 , '_cast_{}' .format (to_type ))(g , input , False )
8989
9090
91- @wrap_logical_op_with_cast_to ('Byte' )
92- def gt (g , input , other ):
93- return gt_impl (g , input , other )
94-
95-
96- def gt_impl (g , input , other ):
91+ def _comparison_operator (g , input , other , op_name ):
9792 other = sym_help ._maybe_get_scalar (other )
9893 other = sym_help ._if_scalar_type_as (g , other , input )
9994 _ , input , other = _try_cast_integer_to_float (g , input , other )
100- return g .op ("Greater" , input , other )
95+ return g .op (op_name , input , other )
10196
10297
98+ # NOTE: For symbolics {gt, lt, bmm, matmul, prelu, mm, addmm, view, flatten},
99+ # integer input type not supported in opset8. Cast to float if possible.
103100@wrap_logical_op_with_cast_to ('Byte' )
104- def lt (g , input , other ):
105- return lt_impl (g , input , other )
101+ def gt (g , input , other ):
102+ return _comparison_operator (g , input , other , "Greater" )
106103
107104
108- def lt_impl (g , input , other ):
109- other = sym_help ._maybe_get_scalar (other )
110- other = sym_help ._if_scalar_type_as (g , other , input )
111- _ , input , other = _try_cast_integer_to_float (g , input , other )
112- return g .op ("Less" , input , other )
105+ @wrap_logical_op_with_cast_to ('Byte' )
106+ def lt (g , input , other ):
107+ return _comparison_operator (g , input , other , "Less" )
113108
114109
115110def bmm (g , self , other ):
@@ -121,11 +116,7 @@ def bmm(g, self, other):
121116
122117
123118def matmul (g , self , other ):
124- if _try_get_scalar_type (self ):
125- old_type , self , other = _try_cast_integer_to_float (g , self , other )
126- return _cast_to_type (g , g .op ("MatMul" , self , other ), old_type )
127- else :
128- return g .op ("MatMul" , self , other )
119+ return bmm (g , self , other )
129120
130121
131122def prelu (g , self , weight ):
0 commit comments