Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 142 additions & 1 deletion Lib/test/test_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ def test_constant_as_name(self):
for constant in "True", "False", "None":
expr = ast.Expression(ast.Name(constant, ast.Load()))
ast.fix_missing_locations(expr)
with self.assertRaisesRegex(ValueError, f"Name node can't be used with '{constant}' constant"):
with self.assertRaisesRegex(ValueError, f"identifier field can't represent '{constant}' constant"):
compile(expr, "<test>", "eval")


Expand Down Expand Up @@ -1476,6 +1476,147 @@ def test_stdlib_validates(self):
mod = ast.parse(source, fn)
compile(mod, fn, "exec")

constant_1 = ast.Constant(1)
pattern_1 = ast.MatchValue(constant_1)

constant_x = ast.Constant('x')
pattern_x = ast.MatchValue(constant_x)

constant_true = ast.Constant(True)
pattern_true = ast.MatchSingleton(True)

name_carter = ast.Name('carter', ast.Load())

_MATCH_PATTERNS = [
ast.MatchValue(
ast.Attribute(
ast.Attribute(
ast.Name('x', ast.Store()),
'y', ast.Load()
),
'z', ast.Load()
)
),
ast.MatchValue(
ast.Attribute(
ast.Attribute(
ast.Name('x', ast.Load()),
'y', ast.Store()
),
'z', ast.Load()
)
),
ast.MatchValue(
ast.Constant(...)
),
ast.MatchValue(
ast.Constant(True)
),
ast.MatchValue(
ast.Constant((1,2,3))
),
ast.MatchSingleton('string'),
ast.MatchSequence([
ast.MatchSingleton('string')
]),
ast.MatchSequence(
[
ast.MatchSequence(
[
ast.MatchSingleton('string')
]
)
]
),
ast.MatchMapping(
[constant_1, constant_true],
[pattern_x]
),
ast.MatchMapping(
[constant_true, constant_1],
[pattern_x, pattern_1],
rest='True'
),
ast.MatchMapping(
[constant_true, ast.Starred(ast.Name('lol', ast.Load()), ast.Load())],
[pattern_x, pattern_1],
rest='legit'
),
ast.MatchClass(
ast.Attribute(
ast.Attribute(
constant_x,
'y', ast.Load()),
'z', ast.Load()),
patterns=[], kwd_attrs=[], kwd_patterns=[]
),
ast.MatchClass(
name_carter,
patterns=[],
kwd_attrs=['True'],
kwd_patterns=[pattern_1]
),
ast.MatchClass(
name_carter,
patterns=[],
kwd_attrs=[],
kwd_patterns=[pattern_1]
),
ast.MatchClass(
name_carter,
patterns=[ast.MatchSingleton('string')],
kwd_attrs=[],
kwd_patterns=[]
),
ast.MatchClass(
name_carter,
patterns=[ast.MatchStar()],
kwd_attrs=[],
kwd_patterns=[]
),
ast.MatchClass(
name_carter,
patterns=[],
kwd_attrs=[],
kwd_patterns=[ast.MatchStar()]
),
ast.MatchSequence(
[
ast.MatchStar("True")
]
),
ast.MatchAs(
name='False'
),
ast.MatchOr(
[]
),
ast.MatchOr(
[pattern_1]
),
ast.MatchOr(
[pattern_1, pattern_x, ast.MatchSingleton('xxx')]
)
]

def test_match_validation_pattern(self):
name_x = ast.Name('x', ast.Load())
for pattern in self._MATCH_PATTERNS:
with self.subTest(ast.dump(pattern, indent=4)):
node = ast.Match(
subject=name_x,
cases = [
ast.match_case(
pattern=pattern,
body = [ast.Pass()]
)
]
)
node = ast.fix_missing_locations(node)
module = ast.Module([node], [])
with self.assertRaises(ValueError):
compile(module, "<test>", "exec")


class ConstantTests(unittest.TestCase):
"""Tests on the ast.Constant node type."""
Expand Down
154 changes: 123 additions & 31 deletions Python/ast.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ struct validator {
};

static int validate_stmts(struct validator *, asdl_stmt_seq *);
static int validate_exprs(struct validator *, asdl_expr_seq*, expr_context_ty, int);
static int validate_exprs(struct validator *, asdl_expr_seq *, expr_context_ty, int);
static int validate_patterns(struct validator *, asdl_pattern_seq *, int);
static int _validate_nonempty_seq(asdl_seq *, const char *, const char *);
static int validate_stmt(struct validator *, stmt_ty);
static int validate_expr(struct validator *, expr_ty, expr_context_ty);
Expand All @@ -33,7 +34,7 @@ validate_name(PyObject *name)
};
for (int i = 0; forbidden[i] != NULL; i++) {
if (_PyUnicode_EqualToASCIIString(name, forbidden[i])) {
PyErr_Format(PyExc_ValueError, "Name node can't be used with '%s' constant", forbidden[i]);
PyErr_Format(PyExc_ValueError, "identifier field can't represent '%s' constant", forbidden[i]);
return 0;
}
}
Expand Down Expand Up @@ -448,6 +449,21 @@ validate_pattern_match_value(struct validator *state, expr_ty exp)
switch (exp->kind)
{
case Constant_kind:
/* Ellipsis and immutable sequences are not allowed.
For True, False and None, MatchSingleton() should
be used */
if (!validate_expr(state, exp, Load)) {
return 0;
}
PyObject *literal = exp->v.Constant.value;
if (PyLong_CheckExact(literal) || PyFloat_CheckExact(literal) ||
PyBytes_CheckExact(literal) || PyComplex_CheckExact(literal) ||
PyUnicode_CheckExact(literal)) {
return 1;
}
PyErr_SetString(PyExc_ValueError,
"unexpected constant inside of a literal pattern");
return 0;
case Attribute_kind:
// Constants and attribute lookups are always permitted
return 1;
Expand All @@ -465,11 +481,14 @@ validate_pattern_match_value(struct validator *state, expr_ty exp)
return 1;
}
break;
case JoinedStr_kind:
// Handled in the later stages
return 1;
default:
break;
}
PyErr_SetString(PyExc_SyntaxError,
"patterns may only match literals and attribute lookups");
PyErr_SetString(PyExc_ValueError,
"patterns may only match literals and attribute lookups");
return 0;
}

Expand All @@ -489,51 +508,101 @@ validate_pattern(struct validator *state, pattern_ty p)
ret = validate_pattern_match_value(state, p->v.MatchValue.value);
break;
case MatchSingleton_kind:
// TODO: Check constant is specifically None, True, or False
ret = validate_constant(state, p->v.MatchSingleton.value);
ret = p->v.MatchSingleton.value == Py_None || PyBool_Check(p->v.MatchSingleton.value);
if (!ret) {
PyErr_SetString(PyExc_ValueError,
"MatchSingleton can only contain True, False and None");
}
break;
case MatchSequence_kind:
// TODO: Validate all subpatterns
// return validate_patterns(state, p->v.MatchSequence.patterns);
ret = 1;
ret = validate_patterns(state, p->v.MatchSequence.patterns, /*star_ok=*/1);
break;
case MatchMapping_kind:
// TODO: check "rest" target name is valid
if (asdl_seq_LEN(p->v.MatchMapping.keys) != asdl_seq_LEN(p->v.MatchMapping.patterns)) {
PyErr_SetString(PyExc_ValueError,
"MatchMapping doesn't have the same number of keys as patterns");
return 0;
ret = 0;
break;
}
// null_ok=0 for key expressions, as rest-of-mapping is captured in "rest"
// TODO: replace with more restrictive expression validator, as per MatchValue above
if (!validate_exprs(state, p->v.MatchMapping.keys, Load, /*null_ok=*/ 0)) {
return 0;

if (p->v.MatchMapping.rest && !validate_name(p->v.MatchMapping.rest)) {
ret = 0;
break;
}
// TODO: Validate all subpatterns
// ret = validate_patterns(state, p->v.MatchMapping.patterns);
ret = 1;

asdl_expr_seq *keys = p->v.MatchMapping.keys;
for (Py_ssize_t i = 0; i < asdl_seq_LEN(keys); i++) {
expr_ty key = asdl_seq_GET(keys, i);
if (key->kind == Constant_kind) {
PyObject *literal = key->v.Constant.value;
if (literal == Py_None || PyBool_Check(literal)) {
/* validate_pattern_match_value will ensure the key
doesn't contain True, False and None but it is
syntactically valid, so we will pass those on in
a special case. */
continue;
}
}
if (!validate_pattern_match_value(state, key)) {
ret = 0;
break;
}
}

ret = validate_patterns(state, p->v.MatchMapping.patterns, /*star_ok=*/0);
break;
case MatchClass_kind:
if (asdl_seq_LEN(p->v.MatchClass.kwd_attrs) != asdl_seq_LEN(p->v.MatchClass.kwd_patterns)) {
PyErr_SetString(PyExc_ValueError,
"MatchClass doesn't have the same number of keyword attributes as patterns");
return 0;
ret = 0;
break;
}
// TODO: Restrict cls lookup to being a name or attribute
if (!validate_expr(state, p->v.MatchClass.cls, Load)) {
return 0;
ret = 0;
break;
}
// TODO: Validate all subpatterns
// return validate_patterns(state, p->v.MatchClass.patterns) &&
// validate_patterns(state, p->v.MatchClass.kwd_patterns);
ret = 1;

expr_ty cls = p->v.MatchClass.cls;
while (1) {
if (cls->kind == Name_kind) {
break;
}
else if (cls->kind == Attribute_kind) {
cls = cls->v.Attribute.value;
continue;
}
else {
PyErr_SetString(PyExc_ValueError,
"MatchClass cls field can only contain Name or Attribute nodes.");
state->recursion_depth--;
return 0;
}
}

for (Py_ssize_t i = 0; i < asdl_seq_LEN(p->v.MatchClass.kwd_attrs); i++) {
PyObject *identifier = asdl_seq_GET(p->v.MatchClass.kwd_attrs, i);
if (!validate_name(identifier)) {
state->recursion_depth--;
return 0;
}
}

if (!validate_patterns(state, p->v.MatchClass.patterns, /*star_ok=*/0)) {
ret = 0;
break;
}

ret = validate_patterns(state, p->v.MatchClass.kwd_patterns, /*star_ok=*/0);
break;
case MatchStar_kind:
// TODO: check target name is valid
ret = 1;
ret = p->v.MatchStar.name == NULL || validate_name(p->v.MatchStar.name);
break;
case MatchAs_kind:
// TODO: check target name is valid
if (p->v.MatchAs.name && !validate_name(p->v.MatchAs.name)) {
ret = 0;
break;
}
if (p->v.MatchAs.pattern == NULL) {
ret = 1;
}
Expand All @@ -547,9 +616,13 @@ validate_pattern(struct validator *state, pattern_ty p)
}
break;
case MatchOr_kind:
// TODO: Validate all subpatterns
// return validate_patterns(state, p->v.MatchOr.patterns);
ret = 1;
if (asdl_seq_LEN(p->v.MatchOr.patterns) < 2) {
PyErr_SetString(PyExc_ValueError,
"MatchOr requires at least 2 patterns");
ret = 0;
break;
}
ret = validate_patterns(state, p->v.MatchOr.patterns, /*star_ok=*/0);
break;
// No default case, so the compiler will emit a warning if new pattern
// kinds are added without being handled here
Expand Down Expand Up @@ -815,6 +888,25 @@ validate_exprs(struct validator *state, asdl_expr_seq *exprs, expr_context_ty ct
return 1;
}

static int
validate_patterns(struct validator *state, asdl_pattern_seq *patterns, int star_ok)
{
Py_ssize_t i;
for (i = 0; i < asdl_seq_LEN(patterns); i++) {
pattern_ty pattern = asdl_seq_GET(patterns, i);
if (pattern->kind == MatchStar_kind && !star_ok) {
PyErr_SetString(PyExc_ValueError,
"Can't use MatchStar within this sequence of patterns");
return 0;
}
if (!validate_pattern(state, pattern)) {
return 0;
}
}
return 1;
}


/* See comments in symtable.c. */
#define COMPILER_STACK_FRAME_SCALE 3

Expand Down