@@ -681,10 +681,22 @@ def upsample_bilinear2d(g, input, output_size, align_corners):
681681 mode_s = "linear" )
682682
683683
684- def wrap_logical_op_with_cast_to_uint8 (func ):
685- def wrap_with_cast (g , input , other ):
686- return g .op ("Cast" , func (g , input , other ), to_i = sym_help .cast_pytorch_to_onnx ['Byte' ])
687- return wrap_with_cast
684+ def wrap_logical_op_with_cast_to (to_type ):
685+ def decorator (fn ):
686+ def wrap_with_cast (g , input , other ):
687+ return g .op ("Cast" , fn (g , input , other ), to_i = sym_help .cast_pytorch_to_onnx [to_type ])
688+ return wrap_with_cast
689+ return decorator
690+
691+
692+ def wrap_logical_op_with_cast_to_and_from (to_type ):
693+ def decorator (fn ):
694+ def wrap_with_cast (g , input , other ):
695+ to_cast_func = globals ()['_cast_{}' .format (to_type )]
696+ from_cast_func = wrap_logical_op_with_cast_to (input .type ().scalarType ())(fn )
697+ return from_cast_func (g , to_cast_func (g , input , False ), to_cast_func (g , other , False ))
698+ return wrap_with_cast
699+ return decorator
688700
689701
690702def wrap_logical_op_with_negation (func ):
@@ -693,18 +705,18 @@ def wrap_with_not(g, input, other):
693705 return wrap_with_not
694706
695707
696- @wrap_logical_op_with_cast_to_uint8
708+ @wrap_logical_op_with_cast_to ( 'Byte' )
697709def eq (g , self , other ):
698710 return g .op ("Equal" , self , other )
699711
700712
701- @wrap_logical_op_with_cast_to_uint8
713+ @wrap_logical_op_with_cast_to ( 'Byte' )
702714@wrap_logical_op_with_negation
703715def ne (g , self , other ):
704716 return g .op ("Equal" , self , other )
705717
706718
707- @wrap_logical_op_with_cast_to_uint8
719+ @wrap_logical_op_with_cast_to ( 'Byte' )
708720def gt (g , input , other ):
709721 return gt_impl (g , input , other )
710722
@@ -714,7 +726,7 @@ def gt_impl(g, input, other):
714726 return g .op ("Greater" , input , sym_help ._if_scalar_type_as (g , other , input ))
715727
716728
717- @wrap_logical_op_with_cast_to_uint8
729+ @wrap_logical_op_with_cast_to ( 'Byte' )
718730def lt (g , input , other ):
719731 return lt_impl (g , input , other )
720732
@@ -724,20 +736,30 @@ def lt_impl(g, input, other):
724736 return g .op ("Less" , input , sym_help ._if_scalar_type_as (g , other , input ))
725737
726738
727- @wrap_logical_op_with_cast_to_uint8
739+ @wrap_logical_op_with_cast_to ( 'Byte' )
728740@wrap_logical_op_with_negation
729741def ge (g , input , other ):
730742 other = sym_help ._maybe_get_scalar (other )
731743 return lt_impl (g , input , sym_help ._if_scalar_type_as (g , other , input ))
732744
733745
734- @wrap_logical_op_with_cast_to_uint8
746+ @wrap_logical_op_with_cast_to ( 'Byte' )
735747@wrap_logical_op_with_negation
736748def le (g , input , other ):
737749 other = sym_help ._maybe_get_scalar (other )
738750 return gt_impl (g , input , sym_help ._if_scalar_type_as (g , other , input ))
739751
740752
753+ @wrap_logical_op_with_cast_to_and_from ('Bool' )
754+ def __and_ (g , input , other ):
755+ return g .op ('And' , input , other )
756+
757+
758+ @wrap_logical_op_with_cast_to_and_from ('Bool' )
759+ def __or_ (g , input , other ):
760+ return g .op ('Or' , input , other )
761+
762+
741763def where (g , condition , self , other ):
742764 return g .op ("Where" , condition , self , other )
743765
0 commit comments