Skip to content
149 changes: 80 additions & 69 deletions crates/codegen/src/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4030,7 +4030,9 @@ impl Compiler {
if let Some(ref guard) = m.guard {
// Compile guard and jump to end if false
self.compile_expression(guard)?;
emit!(self, Instruction::JumpIfFalseOrPop { target: end });
emit!(self, Instruction::CopyItem { index: 1_u32 });
emit!(self, Instruction::PopJumpIfFalse { target: end });
emit!(self, Instruction::PopTop);
}
self.compile_statements(&m.body)?;
}
Expand All @@ -4044,92 +4046,97 @@ impl Compiler {
Ok(())
}

fn compile_chained_comparison(
/// [CPython `compiler_addcompare`](https://github.com/python/cpython/blob/627894459a84be3488a1789919679c997056a03c/Python/compile.c#L2880-L2924)
fn compile_addcompare(&mut self, op: &CmpOp) {
use bytecode::ComparisonOperator::*;
match op {
CmpOp::Eq => emit!(self, Instruction::CompareOperation { op: Equal }),
CmpOp::NotEq => emit!(self, Instruction::CompareOperation { op: NotEqual }),
CmpOp::Lt => emit!(self, Instruction::CompareOperation { op: Less }),
CmpOp::LtE => emit!(self, Instruction::CompareOperation { op: LessOrEqual }),
CmpOp::Gt => emit!(self, Instruction::CompareOperation { op: Greater }),
CmpOp::GtE => {
emit!(self, Instruction::CompareOperation { op: GreaterOrEqual })
}
CmpOp::In => emit!(self, Instruction::ContainsOp(Invert::No)),
CmpOp::NotIn => emit!(self, Instruction::ContainsOp(Invert::Yes)),
CmpOp::Is => emit!(self, Instruction::IsOp(Invert::No)),
CmpOp::IsNot => emit!(self, Instruction::IsOp(Invert::Yes)),
}
}

/// Compile a chained comparison.
///
/// ```py
/// a == b == c == d
/// ```
///
/// Will compile into (pseudo code):
///
/// ```py
/// result = a == b
/// if result:
/// result = b == c
/// if result:
/// result = c == d
/// ```
///
/// # See Also
/// - [CPython `compiler_compare`](https://github.com/python/cpython/blob/627894459a84be3488a1789919679c997056a03c/Python/compile.c#L4678-L4717)
fn compile_compare(
&mut self,
left: &Expr,
ops: &[CmpOp],
exprs: &[Expr],
comparators: &[Expr],
) -> CompileResult<()> {
assert!(!ops.is_empty());
assert_eq!(exprs.len(), ops.len());
let (last_op, mid_ops) = ops.split_last().unwrap();
let (last_val, mid_exprs) = exprs.split_last().unwrap();

use bytecode::ComparisonOperator::*;
let compile_cmpop = |c: &mut Self, op: &CmpOp| match op {
CmpOp::Eq => emit!(c, Instruction::CompareOperation { op: Equal }),
CmpOp::NotEq => emit!(c, Instruction::CompareOperation { op: NotEqual }),
CmpOp::Lt => emit!(c, Instruction::CompareOperation { op: Less }),
CmpOp::LtE => emit!(c, Instruction::CompareOperation { op: LessOrEqual }),
CmpOp::Gt => emit!(c, Instruction::CompareOperation { op: Greater }),
CmpOp::GtE => {
emit!(c, Instruction::CompareOperation { op: GreaterOrEqual })
}
CmpOp::In => emit!(c, Instruction::ContainsOp(Invert::No)),
CmpOp::NotIn => emit!(c, Instruction::ContainsOp(Invert::Yes)),
CmpOp::Is => emit!(c, Instruction::IsOp(Invert::No)),
CmpOp::IsNot => emit!(c, Instruction::IsOp(Invert::Yes)),
};

// a == b == c == d
// compile into (pseudo code):
// result = a == b
// if result:
// result = b == c
// if result:
// result = c == d
let (last_comparator, mid_comparators) = comparators.split_last().unwrap();

// initialize lhs outside of loop
self.compile_expression(left)?;

let end_blocks = if mid_exprs.is_empty() {
None
} else {
let break_block = self.new_block();
let after_block = self.new_block();
Some((break_block, after_block))
};
if mid_comparators.is_empty() {
self.compile_expression(last_comparator)?;
self.compile_addcompare(last_op);

return Ok(());
}

let cleanup = self.new_block();

// for all comparisons except the last (as the last one doesn't need a conditional jump)
for (op, val) in mid_ops.iter().zip(mid_exprs) {
self.compile_expression(val)?;
for (op, comparator) in mid_ops.iter().zip(mid_comparators) {
self.compile_expression(comparator)?;

// store rhs for the next comparison in chain
emit!(self, Instruction::Swap { index: 2 });
emit!(self, Instruction::CopyItem { index: 2_u32 });
emit!(self, Instruction::CopyItem { index: 2 });

compile_cmpop(self, op);
self.compile_addcompare(op);

// if comparison result is false, we break with this value; if true, try the next one.
if let Some((break_block, _)) = end_blocks {
emit!(
self,
Instruction::JumpIfFalseOrPop {
target: break_block,
}
);
}
/*
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what was the blocker of this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's the only reason why both opcodes stayed:/

This change caused regression for

def test_no_wraparound_jump(self):
# See https://bugs.python.org/issue46724
def while_not_chained(a, b, c):
while not (a < b < c):
pass
for instr in dis.Bytecode(while_not_chained):
self.assertNotEqual(instr.opname, "EXTENDED_ARG")

Not because it found an EXTENDED_ARG, but because iterating over dis.Bytecode(...) gave an error

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if that's because our test_compile is an old one, replacing it will be also good.

emit!(self, Instruction::CopyItem { index: 1 });
// emit!(self, Instruction::ToBool); // TODO: Uncomment this
emit!(self, Instruction::PopJumpIfFalse { target: cleanup });
emit!(self, Instruction::PopTop);
*/

emit!(self, Instruction::JumpIfFalseOrPop { target: cleanup });
}

// handle the last comparison
self.compile_expression(last_val)?;
compile_cmpop(self, last_op);
self.compile_expression(last_comparator)?;
self.compile_addcompare(last_op);

if let Some((break_block, after_block)) = end_blocks {
emit!(
self,
Instruction::Jump {
target: after_block,
}
);

// early exit left us with stack: `rhs, comparison_result`. We need to clean up rhs.
self.switch_to_block(break_block);
emit!(self, Instruction::Swap { index: 2 });
emit!(self, Instruction::PopTop);
let end = self.new_block();
emit!(self, Instruction::Jump { target: end });

self.switch_to_block(after_block);
}
// early exit left us with stack: `rhs, comparison_result`. We need to clean up rhs.
self.switch_to_block(cleanup);
emit!(self, Instruction::Swap { index: 2 });
emit!(self, Instruction::PopTop);

self.switch_to_block(end);
Ok(())
}

Expand Down Expand Up @@ -4457,27 +4464,31 @@ impl Compiler {
let after_block = self.new_block();

let (last_value, values) = values.split_last().unwrap();

for value in values {
self.compile_expression(value)?;

emit!(self, Instruction::CopyItem { index: 1_u32 });
match op {
BoolOp::And => {
emit!(
self,
Instruction::JumpIfFalseOrPop {
Instruction::PopJumpIfFalse {
target: after_block,
}
);
}
BoolOp::Or => {
emit!(
self,
Instruction::JumpIfTrueOrPop {
Instruction::PopJumpIfTrue {
target: after_block,
}
);
}
}

emit!(self, Instruction::PopTop);
}

// If all values did not qualify, take the value of the last value:
Expand Down Expand Up @@ -4554,7 +4565,7 @@ impl Compiler {
comparators,
..
}) => {
self.compile_chained_comparison(left, ops, comparators)?;
self.compile_compare(left, ops, comparators)?;
}
// Expr::Constant(ExprConstant { value, .. }) => {
// self.emit_load_const(compile_constant(value));
Expand Down
Loading