2222 LoadStatic , MethodCall , PrimitiveOp , OpDescription , RegisterOp , CallC , Truncate ,
2323 RaiseStandardError , Unreachable , LoadErrorValue , LoadGlobal ,
2424 NAMESPACE_TYPE , NAMESPACE_MODULE , NAMESPACE_STATIC , BinaryIntOp , GetElementPtr ,
25- LoadMem , LoadAddress
25+ LoadMem , ComparisonOp , LoadAddress
2626)
2727from mypyc .ir .rtypes import (
2828 RType , RUnion , RInstance , optional_value_type , int_rprimitive , float_rprimitive ,
5555from mypyc .primitives .misc_ops import (
5656 none_object_op , fast_isinstance_op , bool_op , type_is_op
5757)
58- from mypyc .primitives .int_ops import int_logical_op_mapping
58+ from mypyc .primitives .int_ops import int_comparison_op_mapping
5959from mypyc .rt_subtype import is_runtime_subtype
6060from mypyc .subtype import is_subtype
6161from mypyc .sametype import is_same_type
@@ -559,7 +559,11 @@ def binary_op(self,
559559 if value is not None :
560560 return value
561561
562- if is_tagged (lreg .type ) and is_tagged (rreg .type ) and expr_op in int_logical_op_mapping :
562+ # Special case 'is' and 'is not'
563+ if expr_op in ('is' , 'is not' ):
564+ return self .translate_is_op (lreg , rreg , expr_op , line )
565+
566+ if is_tagged (lreg .type ) and is_tagged (rreg .type ) and expr_op in int_comparison_op_mapping :
563567 return self .compare_tagged (lreg , rreg , expr_op , line )
564568
565569 call_c_ops_candidates = c_binary_ops .get (expr_op , [])
@@ -577,16 +581,15 @@ def check_tagged_short_int(self, val: Value, line: int) -> Value:
577581 bitwise_and = self .binary_int_op (c_pyssize_t_rprimitive , val ,
578582 int_tag , BinaryIntOp .AND , line )
579583 zero = self .add (LoadInt (0 , line , rtype = c_pyssize_t_rprimitive ))
580- check = self .binary_int_op ( bool_rprimitive , bitwise_and , zero , BinaryIntOp .EQ , line )
584+ check = self .comparison_op ( bitwise_and , zero , ComparisonOp .EQ , line )
581585 return check
582586
583587 def compare_tagged (self , lhs : Value , rhs : Value , op : str , line : int ) -> Value :
584588 """Compare two tagged integers using given op"""
585589 # generate fast binary logic ops on short ints
586590 if is_short_int_rprimitive (lhs .type ) and is_short_int_rprimitive (rhs .type ):
587- return self .binary_int_op (bool_rprimitive , lhs , rhs ,
588- int_logical_op_mapping [op ][0 ], line )
589- op_type , c_func_desc , negate_result , swap_op = int_logical_op_mapping [op ]
591+ return self .comparison_op (lhs , rhs , int_comparison_op_mapping [op ][0 ], line )
592+ op_type , c_func_desc , negate_result , swap_op = int_comparison_op_mapping [op ]
590593 result = self .alloc_temp (bool_rprimitive )
591594 short_int_block , int_block , out = BasicBlock (), BasicBlock (), BasicBlock ()
592595 check_lhs = self .check_tagged_short_int (lhs , line )
@@ -601,7 +604,7 @@ def compare_tagged(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
601604 branch .negated = False
602605 self .add (branch )
603606 self .activate_block (short_int_block )
604- eq = self .binary_int_op ( bool_rprimitive , lhs , rhs , op_type , line )
607+ eq = self .comparison_op ( lhs , rhs , op_type , line )
605608 self .add (Assign (result , eq , line ))
606609 self .goto (out )
607610 self .activate_block (int_block )
@@ -725,7 +728,7 @@ def add_bool_branch(self, value: Value, true: BasicBlock, false: BasicBlock) ->
725728 else :
726729 value_type = optional_value_type (value .type )
727730 if value_type is not None :
728- is_none = self .binary_op (value , self .none_object (), 'is not' , value .line )
731+ is_none = self .translate_is_op (value , self .none_object (), 'is not' , value .line )
729732 branch = Branch (is_none , true , false , Branch .BOOL_EXPR )
730733 self .add (branch )
731734 always_truthy = False
@@ -822,6 +825,9 @@ def matching_call_c(self,
822825 def binary_int_op (self , type : RType , lhs : Value , rhs : Value , op : int , line : int ) -> Value :
823826 return self .add (BinaryIntOp (type , lhs , rhs , op , line ))
824827
828+ def comparison_op (self , lhs : Value , rhs : Value , op : int , line : int ) -> Value :
829+ return self .add (ComparisonOp (lhs , rhs , op , line ))
830+
825831 def builtin_len (self , val : Value , line : int ) -> Value :
826832 typ = val .type
827833 if is_list_rprimitive (typ ) or is_tuple_rprimitive (typ ):
@@ -974,7 +980,7 @@ def translate_eq_cmp(self,
974980 if not class_ir .has_method ('__eq__' ):
975981 # There's no __eq__ defined, so just use object identity.
976982 identity_ref_op = 'is' if expr_op == '==' else 'is not'
977- return self .binary_op (lreg , rreg , identity_ref_op , line )
983+ return self .translate_is_op (lreg , rreg , identity_ref_op , line )
978984
979985 return self .gen_method_call (
980986 lreg ,
@@ -984,6 +990,21 @@ def translate_eq_cmp(self,
984990 line
985991 )
986992
993+ def translate_is_op (self ,
994+ lreg : Value ,
995+ rreg : Value ,
996+ expr_op : str ,
997+ line : int ) -> Value :
998+ """Create equality comparison operation between object identities
999+
1000+ Args:
1001+ expr_op: either 'is' or 'is not'
1002+ """
1003+ op = ComparisonOp .EQ if expr_op == 'is' else ComparisonOp .NEQ
1004+ lhs = self .coerce (lreg , object_rprimitive , line )
1005+ rhs = self .coerce (rreg , object_rprimitive , line )
1006+ return self .add (ComparisonOp (lhs , rhs , op , line ))
1007+
9871008 def _create_dict (self ,
9881009 keys : List [Value ],
9891010 values : List [Value ],
0 commit comments