Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
gh-108113: optimize ASTs in ast.parse/ast.literal_eval/compile(..., f…
…lags=ast.PyCF_ONLY_AST)
  • Loading branch information
iritkatriel committed Aug 18, 2023
commit c4911621eab5a6cc05adac63fc7d632955dd1d94
10 changes: 8 additions & 2 deletions Doc/library/ast.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2122,10 +2122,10 @@ Async and await
Apart from the node classes, the :mod:`ast` module defines these utility functions
and classes for traversing abstract syntax trees:

.. function:: parse(source, filename='<unknown>', mode='exec', *, type_comments=False, feature_version=None)
.. function:: parse(source, filename='<unknown>', mode='exec', *, type_comments=False, feature_version=None, optimize=-1)

Parse the source into an AST node. Equivalent to ``compile(source,
filename, mode, ast.PyCF_ONLY_AST)``.
filename, mode, flags=ast.PyCF_ONLY_AST, optimize=optimize)``.

If ``type_comments=True`` is given, the parser is modified to check
and return type comments as specified by :pep:`484` and :pep:`526`.
Expand Down Expand Up @@ -2172,6 +2172,10 @@ and classes for traversing abstract syntax trees:
.. versionchanged:: 3.13
The minimum supported version for feature_version is now (3,7)

The output AST is now optimized with constant folding.
The ``optimize`` argument was added to control additional
optimizations.


.. function:: unparse(ast_obj)

Expand Down Expand Up @@ -2229,6 +2233,8 @@ and classes for traversing abstract syntax trees:
.. versionchanged:: 3.10
For string inputs, leading spaces and tabs are now stripped.

.. versionchanged:: 3.13
This function now understands and collapses const expressions.

.. function:: get_docstring(node, clean=True)

Expand Down
14 changes: 14 additions & 0 deletions Doc/whatsnew/3.13.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ Other Language Changes
This change will affect tools using docstrings, like :mod:`doctest`.
(Contributed by Inada Naoki in :gh:`81283`.)

* The :func:`compile` built-in no longer ignores the ``optimize`` argument
when called with the ``ast.PyCF_ONLY_AST`` flag.
(Contributed by Irit Katriel in :gh:`108113`).

New Modules
===========

Expand All @@ -94,6 +98,16 @@ New Modules
Improved Modules
================

ast
---

* :func:`ast.parse` and :func:`ast.literal_eval` now perform constant folding
and other AST optimizations. This means that AST are more concise, and
:func:`ast.literal_eval` understands and collapses const expressions.
:func:`ast.parse` also accepts a new optional argument ``optimize``, which
it forwards to the :func:`compile` built-in.
(Contributed by Irit Katriel in :gh:`108113`).

array
-----

Expand Down
6 changes: 3 additions & 3 deletions Lib/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@


def parse(source, filename='<unknown>', mode='exec', *,
type_comments=False, feature_version=None):
type_comments=False, feature_version=None, optimize=-1):
"""
Parse the source into an AST node.
Equivalent to compile(source, filename, mode, PyCF_ONLY_AST).
Expand All @@ -50,7 +50,7 @@ def parse(source, filename='<unknown>', mode='exec', *,
feature_version = minor
# Else it should be an int giving the minor version for 3.x.
return compile(source, filename, mode, flags,
_feature_version=feature_version)
_feature_version=feature_version, optimize=optimize)


def literal_eval(node_or_string):
Expand All @@ -63,7 +63,7 @@ def literal_eval(node_or_string):
Caution: A complex expression can overflow the C stack and cause a crash.
"""
if isinstance(node_or_string, str):
node_or_string = parse(node_or_string.lstrip(" \t"), mode='eval')
node_or_string = parse(node_or_string.lstrip(" \t"), mode='eval', optimize=0)
if isinstance(node_or_string, Expression):
node_or_string = node_or_string.body
def _raise_malformed_node(node):
Expand Down
76 changes: 44 additions & 32 deletions Lib/test/test_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
def to_tuple(t):
if t is None or isinstance(t, (str, int, complex)) or t is Ellipsis:
return t
elif isinstance(t, list):
return [to_tuple(e) for e in t]
elif isinstance(t, (list, tuple)):
return type(t)([to_tuple(e) for e in t])
result = [t.__class__.__name__]
if hasattr(t, 'lineno') and hasattr(t, 'col_offset'):
result.append((t.lineno, t.col_offset))
Expand Down Expand Up @@ -274,7 +274,7 @@ def to_tuple(t):
# Tuple
"1,2,3",
# Tuple
"(1,2,3)",
"(1,x,3)",
# Empty tuple
"()",
# Combination
Expand Down Expand Up @@ -357,6 +357,15 @@ def test_ast_validation(self):
tree = ast.parse(snippet)
compile(tree, '<string>', 'exec')

def test_optimization_levels(self):
cases = [(-1, __debug__), (0, True), (1, False), (2, False)]
for (optval, expected) in cases:
with self.subTest(optval=optval, expected=expected):
res = ast.parse("__debug__", optimize=optval)
self.assertIsInstance(res.body[0], ast.Expr)
self.assertIsInstance(res.body[0].value, ast.Constant)
self.assertEqual(res.body[0].value.value, expected)

def test_invalid_position_information(self):
invalid_linenos = [
(10, 1), (-10, -11), (10, -11), (-5, -2), (-5, 1)
Expand Down Expand Up @@ -948,7 +957,7 @@ def bad_normalize(*args):
self.assertRaises(TypeError, ast.parse, '\u03D5')

def test_issue18374_binop_col_offset(self):
tree = ast.parse('4+5+6+7')
tree = ast.parse('a+b+c+d')
parent_binop = tree.body[0].value
child_binop = parent_binop.left
grandchild_binop = child_binop.left
Expand All @@ -959,7 +968,7 @@ def test_issue18374_binop_col_offset(self):
self.assertEqual(grandchild_binop.col_offset, 0)
self.assertEqual(grandchild_binop.end_col_offset, 3)

tree = ast.parse('4+5-\\\n 6-7')
tree = ast.parse('a+b-\\\n c-d')
parent_binop = tree.body[0].value
child_binop = parent_binop.left
grandchild_binop = child_binop.left
Expand Down Expand Up @@ -1266,13 +1275,14 @@ def test_dump_incomplete(self):
)

def test_copy_location(self):
src = ast.parse('1 + 1', mode='eval')
src = ast.parse('x + 1', mode='eval')
src.body.right = ast.copy_location(ast.Constant(2), src.body.right)
self.assertEqual(ast.dump(src, include_attributes=True),
'Expression(body=BinOp(left=Constant(value=1, lineno=1, col_offset=0, '
'end_lineno=1, end_col_offset=1), op=Add(), right=Constant(value=2, '
'lineno=1, col_offset=4, end_lineno=1, end_col_offset=5), lineno=1, '
'col_offset=0, end_lineno=1, end_col_offset=5))'
"Expression(body=BinOp(left=Name(id='x', ctx=Load(), lineno=1, "
"col_offset=0, end_lineno=1, end_col_offset=1), op=Add(), "
"right=Constant(value=2, lineno=1, col_offset=4, end_lineno=1, "
"end_col_offset=5), lineno=1, col_offset=0, end_lineno=1, "
"end_col_offset=5))"
)
src = ast.Call(col_offset=1, lineno=1, end_lineno=1, end_col_offset=1)
new = ast.copy_location(src, ast.Call(col_offset=None, lineno=None))
Expand Down Expand Up @@ -1302,20 +1312,22 @@ def test_fix_missing_locations(self):
)

def test_increment_lineno(self):
src = ast.parse('1 + 1', mode='eval')
src = ast.parse('x + 1', mode='eval')
self.assertEqual(ast.increment_lineno(src, n=3), src)
self.assertEqual(ast.dump(src, include_attributes=True),
'Expression(body=BinOp(left=Constant(value=1, lineno=4, col_offset=0, '
'end_lineno=4, end_col_offset=1), op=Add(), right=Constant(value=1, '
'Expression(body=BinOp(left=Name(id=\'x\', ctx=Load(), '
'lineno=4, col_offset=0, end_lineno=4, end_col_offset=1), '
'op=Add(), right=Constant(value=1, '
'lineno=4, col_offset=4, end_lineno=4, end_col_offset=5), lineno=4, '
'col_offset=0, end_lineno=4, end_col_offset=5))'
)
# issue10869: do not increment lineno of root twice
src = ast.parse('1 + 1', mode='eval')
src = ast.parse('y + 2', mode='eval')
self.assertEqual(ast.increment_lineno(src.body, n=3), src.body)
self.assertEqual(ast.dump(src, include_attributes=True),
'Expression(body=BinOp(left=Constant(value=1, lineno=4, col_offset=0, '
'end_lineno=4, end_col_offset=1), op=Add(), right=Constant(value=1, '
'Expression(body=BinOp(left=Name(id=\'y\', ctx=Load(), '
'lineno=4, col_offset=0, end_lineno=4, end_col_offset=1), '
'op=Add(), right=Constant(value=2, '
'lineno=4, col_offset=4, end_lineno=4, end_col_offset=5), lineno=4, '
'col_offset=0, end_lineno=4, end_col_offset=5))'
)
Expand Down Expand Up @@ -1446,9 +1458,9 @@ def test_literal_eval(self):
self.assertEqual(ast.literal_eval('+3.25'), 3.25)
self.assertEqual(ast.literal_eval('-3.25'), -3.25)
self.assertEqual(repr(ast.literal_eval('-0.0')), '-0.0')
self.assertRaises(ValueError, ast.literal_eval, '++6')
self.assertRaises(ValueError, ast.literal_eval, '+True')
self.assertRaises(ValueError, ast.literal_eval, '2+3')
self.assertEqual(ast.literal_eval('++6'), 6)
self.assertEqual(ast.literal_eval('+True'), 1)
self.assertEqual(ast.literal_eval('2+3'), 5)

def test_literal_eval_str_int_limit(self):
with support.adjust_int_max_str_digits(4000):
Expand All @@ -1473,11 +1485,11 @@ def test_literal_eval_complex(self):
self.assertEqual(ast.literal_eval('3.25-6.75j'), 3.25-6.75j)
self.assertEqual(ast.literal_eval('-3.25-6.75j'), -3.25-6.75j)
self.assertEqual(ast.literal_eval('(3+6j)'), 3+6j)
self.assertRaises(ValueError, ast.literal_eval, '-6j+3')
self.assertRaises(ValueError, ast.literal_eval, '-6j+3j')
self.assertRaises(ValueError, ast.literal_eval, '3+-6j')
self.assertRaises(ValueError, ast.literal_eval, '3+(0+6j)')
self.assertRaises(ValueError, ast.literal_eval, '-(3+6j)')
self.assertEqual(ast.literal_eval('-6j+3'), 3-6j)
self.assertEqual(ast.literal_eval('-6j+3j'), -3j)
self.assertEqual(ast.literal_eval('3+-6j'), 3-6j)
self.assertEqual(ast.literal_eval('3+(0+6j)'), 3+6j)
self.assertEqual(ast.literal_eval('-(3+6j)'), -3-6j)

def test_literal_eval_malformed_dict_nodes(self):
malformed = ast.Dict(keys=[ast.Constant(1), ast.Constant(2)], values=[ast.Constant(3)])
Expand All @@ -1494,7 +1506,7 @@ def test_literal_eval_trailing_ws(self):
def test_literal_eval_malformed_lineno(self):
msg = r'malformed node or string on line 3:'
with self.assertRaisesRegex(ValueError, msg):
ast.literal_eval("{'a': 1,\n'b':2,\n'c':++3,\n'd':4}")
ast.literal_eval("{'a': 1,\n'b':2,\n'c':++x,\n'd':4}")

node = ast.UnaryOp(
ast.UAdd(), ast.UnaryOp(ast.UAdd(), ast.Constant(6)))
Expand Down Expand Up @@ -2265,7 +2277,7 @@ def test_load_const(self):
consts)

def test_literal_eval(self):
tree = ast.parse("1 + 2")
tree = ast.parse("x + 2")
binop = tree.body[0].value

new_left = ast.Constant(value=10)
Expand Down Expand Up @@ -2479,14 +2491,14 @@ def test_slices(self):

def test_binop(self):
s = dedent('''
(1 * 2 + (3 ) +
(1 * x + (3 ) +
4
)
''').strip()
binop = self._parse_value(s)
self._check_end_pos(binop, 2, 6)
self._check_content(s, binop.right, '4')
self._check_content(s, binop.left, '1 * 2 + (3 )')
self._check_content(s, binop.left, '1 * x + (3 )')
self._check_content(s, binop.left.right, '3')

def test_boolop(self):
Expand Down Expand Up @@ -3039,7 +3051,7 @@ def main():
('Module', [('FunctionDef', (1, 0, 1, 38), 'f', ('arguments', [], [], None, [], [], None, []), [('Pass', (1, 34, 1, 38))], [], None, None, [('TypeVar', (1, 6, 1, 19), 'T', ('Tuple', (1, 9, 1, 19), [('Name', (1, 10, 1, 13), 'int', ('Load',)), ('Name', (1, 15, 1, 18), 'str', ('Load',))], ('Load',))), ('TypeVarTuple', (1, 21, 1, 24), 'Ts'), ('ParamSpec', (1, 26, 1, 29), 'P')])], []),
]
single_results = [
('Interactive', [('Expr', (1, 0, 1, 3), ('BinOp', (1, 0, 1, 3), ('Constant', (1, 0, 1, 1), 1, None), ('Add',), ('Constant', (1, 2, 1, 3), 2, None)))]),
('Interactive', [('Expr', (1, 0, 1, 3), ('Constant', (1, 0, 1, 3), 3, None))]),
]
eval_results = [
('Expression', ('Constant', (1, 0, 1, 4), None, None)),
Expand Down Expand Up @@ -3073,9 +3085,9 @@ def main():
('Expression', ('Name', (1, 0, 1, 1), 'v', ('Load',))),
('Expression', ('List', (1, 0, 1, 7), [('Constant', (1, 1, 1, 2), 1, None), ('Constant', (1, 3, 1, 4), 2, None), ('Constant', (1, 5, 1, 6), 3, None)], ('Load',))),
('Expression', ('List', (1, 0, 1, 2), [], ('Load',))),
('Expression', ('Tuple', (1, 0, 1, 5), [('Constant', (1, 0, 1, 1), 1, None), ('Constant', (1, 2, 1, 3), 2, None), ('Constant', (1, 4, 1, 5), 3, None)], ('Load',))),
('Expression', ('Tuple', (1, 0, 1, 7), [('Constant', (1, 1, 1, 2), 1, None), ('Constant', (1, 3, 1, 4), 2, None), ('Constant', (1, 5, 1, 6), 3, None)], ('Load',))),
('Expression', ('Tuple', (1, 0, 1, 2), [], ('Load',))),
('Expression', ('Constant', (1, 0, 1, 5), (1, 2, 3), None)),
('Expression', ('Tuple', (1, 0, 1, 7), [('Constant', (1, 1, 1, 2), 1, None), ('Name', (1, 3, 1, 4), 'x', ('Load',)), ('Constant', (1, 5, 1, 6), 3, None)], ('Load',))),
('Expression', ('Constant', (1, 0, 1, 2), (), None)),
('Expression', ('Call', (1, 0, 1, 17), ('Attribute', (1, 0, 1, 7), ('Attribute', (1, 0, 1, 5), ('Attribute', (1, 0, 1, 3), ('Name', (1, 0, 1, 1), 'a', ('Load',)), 'b', ('Load',)), 'c', ('Load',)), 'd', ('Load',)), [('Subscript', (1, 8, 1, 16), ('Attribute', (1, 8, 1, 11), ('Name', (1, 8, 1, 9), 'a', ('Load',)), 'b', ('Load',)), ('Slice', (1, 12, 1, 15), ('Constant', (1, 12, 1, 13), 1, None), ('Constant', (1, 14, 1, 15), 2, None), None), ('Load',))], [])),
]
main()
2 changes: 1 addition & 1 deletion Lib/test/test_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def f(): """doc"""
# test both direct compilation and compilation via AST
codeobjs = []
codeobjs.append(compile(codestr, "<test>", "exec", optimize=optval))
tree = ast.parse(codestr)
tree = ast.parse(codestr, optimize=optval)
codeobjs.append(compile(tree, "<test>", "exec", optimize=optval))
for code in codeobjs:
ns = {}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
The :func:`compile` built-in no longer ignores the ``optimize`` argument
when called with the ``ast.PyCF_ONLY_AST`` flag. The :func:`ast.parse`
function now accepts an optional argument ``optimize``, which it forwards to
:func:`compile`. :func:`ast.parse` and :func:`ast.literal_eval` perform
const folding, so ASTs are more concise and :func:`ast.literal_eval`
accepts const expressions.
23 changes: 23 additions & 0 deletions Python/pythonrun.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "pycore_pyerrors.h" // _PyErr_GetRaisedException, _Py_Offer_Suggestions
#include "pycore_pylifecycle.h" // _Py_UnhandledKeyboardInterrupt
#include "pycore_pystate.h" // _PyInterpreterState_GET()
#include "pycore_symtable.h" // _PyFuture_FromAST()
#include "pycore_sysmodule.h" // _PySys_Audit()
#include "pycore_traceback.h" // _PyTraceBack_Print_Indented()

Expand Down Expand Up @@ -1790,6 +1791,24 @@ run_pyc_file(FILE *fp, PyObject *globals, PyObject *locals,
return NULL;
}

static int
call_ast_optimize(mod_ty mod, PyObject *filename, PyCompilerFlags *cf,
int optimize, PyArena *arena)
{
PyFutureFeatures future;
if (!_PyFuture_FromAST(mod, filename, &future)) {
return -1;
}
int flags = future.ff_features | cf->cf_flags;
if (optimize == -1) {
optimize = _Py_GetConfig()->optimization_level;
}
if (!_PyAST_Optimize(mod, arena, optimize, flags)) {
return -1;
}
return 0;
}

PyObject *
Py_CompileStringObject(const char *str, PyObject *filename, int start,
PyCompilerFlags *flags, int optimize)
Expand All @@ -1806,6 +1825,10 @@ Py_CompileStringObject(const char *str, PyObject *filename, int start,
return NULL;
}
if (flags && (flags->cf_flags & PyCF_ONLY_AST)) {
if (call_ast_optimize(mod, filename, flags, optimize, arena) < 0) {
_PyArena_Free(arena);
return NULL;
}
PyObject *result = PyAST_mod2obj(mod);
_PyArena_Free(arena);
return result;
Expand Down