Skip to content

Commit 2d761d0

Browse files
authored
Support defining __bool__ (mypyc/mypyc#373)
To support this, refactor emitclass dunder handling to be even *more* data driven. The main reason I am not throwing in support for ~all of the other dunder methods is that I think we should add tests for them when we add support. Supporting __bool__ without always deoptimizing truthiness checks against Optional types requires adding a very simple "whole-program analysis" for whether a subclass defines __bool__.
1 parent c86a37b commit 2d761d0

File tree

7 files changed

+115
-20
lines changed

7 files changed

+115
-20
lines changed

mypyc/emitclass.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from mypyc.emitfunc import native_function_header
1111
from mypyc.emitwrapper import (
1212
generate_dunder_wrapper, generate_hash_wrapper, generate_richcompare_wrapper,
13+
generate_bool_wrapper,
1314
)
1415
from mypyc.ops import (
1516
ClassIR, FuncIR, FuncDecl, RType, RTuple, Environment, object_rprimitive, FuncSignature,
@@ -31,6 +32,7 @@ def wrapper_slot(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
3132
# We maintain a table from dunder function names to struct slots they
3233
# correspond to and functions that generate a wrapper (if necessary)
3334
# and return the function name to stick in the slot.
35+
# TODO: Add remaining dunder methods
3436
SlotGenerator = Callable[[ClassIR, FuncIR, Emitter], str]
3537
SlotTable = Mapping[str, Tuple[str, SlotGenerator]]
3638

@@ -44,6 +46,19 @@ def wrapper_slot(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
4446
'__hash__': ('tp_hash', generate_hash_wrapper),
4547
}
4648

49+
AS_MAPPING_SLOT_DEFS = {
50+
'__getitem__': ('mp_subscript', generate_dunder_wrapper),
51+
}
52+
53+
AS_NUMBER_SLOT_DEFS = {
54+
'__bool__': ('nb_bool', generate_bool_wrapper),
55+
}
56+
57+
SIDE_TABLES = [
58+
('as_mapping', 'PyMappingMethods', AS_MAPPING_SLOT_DEFS),
59+
('as_number', 'PyNumberMethods', AS_NUMBER_SLOT_DEFS),
60+
]
61+
4762

4863
def generate_slots(cl: ClassIR, table: SlotTable, emitter: Emitter) -> Dict[str, str]:
4964
fields = OrderedDict() # type: Dict[str, str]
@@ -103,12 +118,14 @@ def emit_line() -> None:
103118
init_fn = cl.get_method('__init__')
104119

105120
# Fill out slots in the type object from dunder methods.
106-
# TODO: Add remaining dunder methods
107121
fields.update(generate_slots(cl, SLOT_DEFS, emitter))
108122

109-
as_mapping_name = generate_as_mapping_for_class(cl, emitter)
110-
if as_mapping_name:
111-
fields['tp_as_mapping'] = '&{}'.format(as_mapping_name)
123+
# Fill out dunder methods that live in tables hanging off the side.
124+
for table_name, type, slot_defs in SIDE_TABLES:
125+
slots = generate_slots(cl, slot_defs, emitter)
126+
if slots:
127+
table_struct_name = generate_side_table_for_class(cl, table_name, type, slots, emitter)
128+
fields['tp_{}'.format(table_name)] = '&{}'.format(table_struct_name)
112129

113130
richcompare_name = generate_richcompare_wrapper(cl, emitter)
114131
if richcompare_name:
@@ -457,18 +474,13 @@ def generate_methods_table(cl: ClassIR,
457474
emitter.emit_line('};')
458475

459476

460-
AS_MAPPING_SLOT_DEFS = {
461-
'__getitem__': ('mp_subscript', generate_dunder_wrapper),
462-
}
463-
464-
465-
def generate_as_mapping_for_class(cl: ClassIR,
477+
def generate_side_table_for_class(cl: ClassIR,
478+
name: str,
479+
type: str,
480+
slots: Dict[str, str],
466481
emitter: Emitter) -> Optional[str]:
467-
slots = generate_slots(cl, AS_MAPPING_SLOT_DEFS, emitter)
468-
if not slots:
469-
return None
470-
name = '{}_as_mapping'.format(cl.name_prefix(emitter.names))
471-
emitter.emit_line('static PyMappingMethods {} = {{'.format(name))
482+
name = '{}_{}'.format(cl.name_prefix(emitter.names), name)
483+
emitter.emit_line('static {} {} = {{'.format(type, name))
472484
for field, value in slots.items():
473485
emitter.emit_line(".{} = {},".format(field, value))
474486
emitter.emit_line("};")

mypyc/emitwrapper.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from mypyc.emit import Emitter
55
from mypyc.ops import (
66
ClassIR, FuncIR, RType, RuntimeArg,
7-
is_object_rprimitive, is_int_rprimitive,
7+
is_object_rprimitive, is_int_rprimitive, is_bool_rprimitive,
8+
bool_rprimitive,
89
FUNC_STATICMETHOD,
910
)
1011
from mypyc.namegen import NameGenerator
@@ -131,6 +132,26 @@ def generate_hash_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
131132
return name
132133

133134

135+
def generate_bool_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
136+
"""Generates a wrapper for native __bool__ methods."""
137+
name = '{}{}{}'.format(DUNDER_PREFIX, fn.name, cl.name_prefix(emitter.names))
138+
emitter.emit_line('static int {name}(PyObject *self) {{'.format(
139+
name=name
140+
))
141+
emitter.emit_line('{}val = {}{}(self);'.format(emitter.ctype_spaced(fn.ret_type),
142+
NATIVE_PREFIX,
143+
fn.cname(emitter.names)))
144+
emitter.emit_error_check('val', fn.ret_type, 'return -1;')
145+
# This wouldn't be that hard to fix but it seems unimportant and
146+
# getting error handling and unboxing right would be fiddly. (And
147+
# way easier to do in IR!)
148+
assert is_bool_rprimitive(fn.ret_type), "Only bool return supported for __bool__"
149+
emitter.emit_line('return val;')
150+
emitter.emit_line('}')
151+
152+
return name
153+
154+
134155
def generate_wrapper_core(fn: FuncIR, emitter: Emitter,
135156
optional_args: List[RuntimeArg] = [],
136157
arg_names: Optional[List[str]] = None) -> None:

mypyc/genops.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,12 @@ def prepare_class_def(module_name: str, cdef: ClassDef, mapper: Mapper) -> None:
380380
ir.mro = mro
381381
ir.base_mro = base_mro
382382

383+
# We need to know whether any children of a class have a __bool__
384+
# method in order to know whether we can assume it is always true.
385+
if ir.has_method('__bool__'):
386+
for base in ir.mro:
387+
base.has_bool = True
388+
383389

384390
class FuncInfo(object):
385391
"""Contains information about functions as they are generated."""
@@ -2329,9 +2335,8 @@ def add_bool_branch(self, value: Value, true: BasicBlock, false: BasicBlock) ->
23292335
is_none = self.binary_op(value, self.none(), 'is not', value.line)
23302336
branch = Branch(is_none, true, false, Branch.BOOL_EXPR)
23312337
self.add(branch)
2332-
if isinstance(value_type, RInstance):
2338+
if isinstance(value_type, RInstance) and not value_type.class_ir.has_bool:
23332339
# Optional[X] where X is always truthy
2334-
# TODO: Support __bool__
23352340
pass
23362341
else:
23372342
# Optional[X] where X may be falsey and requires a check

mypyc/ops.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1431,6 +1431,9 @@ def __init__(self, name: str, module_name: str, is_trait: bool = False,
14311431
# base_mro is the chain of concrete (non-trait) ancestors
14321432
self.base_mro = [self] # type: List[ClassIR]
14331433

1434+
# Does this class or any subclass have a __bool__method
1435+
self.has_bool = False
1436+
14341437
def real_base(self) -> Optional['ClassIR']:
14351438
"""Return the actual concrete base class, if there is one."""
14361439
if len(self.mro) > 1 and not self.mro[1].is_trait:

test-data/fixtures/ir.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@ def __eq__(self, x:object) -> bool:pass
6464
def __ne__(self, x: object) -> bool: pass
6565
def join(self, x: Iterable[object]) -> bytes: pass
6666

67-
class bool: pass
67+
class bool:
68+
def __init__(self, o: object = ...) -> None: ...
69+
6870

6971
class tuple(Generic[T], Sized):
7072
def __init__(self, i: Iterable[T]) -> None: pass

test-data/genops-optional.test

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,49 @@ L2:
2424
r3 = 2
2525
return r3
2626

27+
[case testIsTruthyOverride]
28+
from typing import Optional
29+
30+
class A: pass
31+
32+
class B(A):
33+
def __bool__(self) -> bool:
34+
return False
35+
36+
37+
def f(x: Optional[A]) -> int:
38+
if x:
39+
return 1
40+
return 2
41+
[out]
42+
def B.__bool__(self):
43+
self :: B
44+
r0 :: bool
45+
L0:
46+
r0 = False
47+
return r0
48+
def f(x):
49+
x :: union[A, None]
50+
r0 :: None
51+
r1 :: bool
52+
r2 :: A
53+
r3 :: bool
54+
r4, r5 :: int
55+
L0:
56+
r0 = None
57+
r1 = x is not r0
58+
if r1 goto L1 else goto L3 :: bool
59+
L1:
60+
r2 = cast(A, x)
61+
r3 = bool r2 :: object
62+
if r3 goto L2 else goto L3 :: bool
63+
L2:
64+
r4 = 1
65+
return r4
66+
L3:
67+
r5 = 2
68+
return r5
69+
2770
[case testIsNotNone]
2871
from typing import Optional
2972

test-data/run.test

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2377,7 +2377,8 @@ class Item:
23772377
return self.value < x.value
23782378

23792379
class Subclass1(Item):
2380-
pass
2380+
def __bool__(self) -> bool:
2381+
return bool(self.value)
23812382

23822383
class NonBoxedThing:
23832384
def __getitem__(self, index: Item) -> Item:
@@ -2400,6 +2401,9 @@ def internal_index_into() -> None:
24002401
z = Item("3")
24012402
print(y[z].value)
24022403

2404+
def is_truthy(x: Item) -> bool:
2405+
return True if x else False
2406+
24032407
[file driver.py]
24042408
from native import *
24052409
x = BoxedThing()
@@ -2427,6 +2431,11 @@ assert i2 < i3
24272431
assert not i1 < i2
24282432
assert i1 == Subclass1('lolol')
24292433

2434+
assert is_truthy(Item(''))
2435+
assert is_truthy(Item('a'))
2436+
assert not is_truthy(Subclass1(''))
2437+
assert is_truthy(Subclass1('a'))
2438+
24302439
internal_index_into()
24312440
[out]
24322441
7 7

0 commit comments

Comments
 (0)