Skip to content
72 changes: 72 additions & 0 deletions Lib/test/test_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import weakref

from test import support
from test.support import findfile


def to_tuple(t):
if t is None or isinstance(t, (str, int, complex)):
Expand Down Expand Up @@ -432,6 +434,76 @@ def test_empty_yield_from(self):
self.assertIn("field value is required", str(cm.exception))


class ASTCompareTest(unittest.TestCase):
def test_normal_compare(self):
self.assertEqual(ast.parse('x = 10'), ast.parse('x = 10'))
self.assertNotEqual(ast.parse('x = 10'), ast.parse(''))
self.assertNotEqual(ast.parse('x = 10'), ast.parse('x'))
self.assertNotEqual(ast.parse('x = 10;y = 20'), ast.parse('class C:pass'))

def test_literals_compare(self):
self.assertEqual(ast.Num(), ast.Num())
self.assertEqual(ast.Num(-20), ast.Num(-20))
self.assertEqual(ast.Num(10), ast.Num(10))
self.assertEqual(ast.Num(2048), ast.Num(2048))
self.assertEqual(ast.Str(), ast.Str())
self.assertEqual(ast.Str("ABCD"), ast.Str("ABCD"))
self.assertEqual(ast.Str("中文字"), ast.Str("中文字"))

self.assertNotEqual(ast.Num(10), ast.Num(20))
self.assertNotEqual(ast.Num(-10), ast.Num(10))
self.assertNotEqual(ast.Str("AAAA"), ast.Str("BBBB"))
self.assertNotEqual(ast.Str("一二三"), ast.Str("中文字"))

self.assertNotEqual(ast.Num(10), ast.Num())
self.assertNotEqual(ast.Str("AB"), ast.Str())

def test_operator_compare(self):
self.assertEqual(ast.Add(), ast.Add())
self.assertEqual(ast.Sub(), ast.Sub())

self.assertNotEqual(ast.Add(), ast.Sub())
self.assertNotEqual(ast.Add(), ast.Num())

def test_complex_ast(self):
fps = [findfile('test_asyncgen.py'),
findfile('test_generators.py'),
findfile('test_unicode.py')]

for fp in fps:
with open(fp) as f:
try:
source = f.read()
except UnicodeDecodeError:
continue

a = ast.parse(source)
b = ast.parse(source)
self.assertEqual(a, b, "%s != %s" % (ast.dump(a), ast.dump(b)))
self.assertFalse(a != b)

def test_exec_compare(self):
for source in exec_tests:
a = ast.parse(source, mode='exec')
b = ast.parse(source, mode='exec')
self.assertEqual(a, b, "%s != %s" % (ast.dump(a), ast.dump(b)))
self.assertFalse(a != b)

def test_single_compare(self):
for source in single_tests:
a = ast.parse(source, mode='single')
b = ast.parse(source, mode='single')
self.assertEqual(a, b, "%s != %s" % (ast.dump(a), ast.dump(b)))
self.assertFalse(a != b)

def test_eval_compare(self):
for source in eval_tests:
a = ast.parse(source, mode='eval')
b = ast.parse(source, mode='eval')
self.assertEqual(a, b, "%s != %s" % (ast.dump(a), ast.dump(b)))
self.assertFalse(a != b)


class ASTHelpers_Test(unittest.TestCase):

def test_parse(self):
Expand Down
59 changes: 58 additions & 1 deletion Parser/asdl_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,63 @@ def visitModule(self, mod):
return Py_BuildValue("O()", Py_TYPE(self));
}

static PyObject *
ast_richcompare(PyObject *self, PyObject *other, int op)
{
int i, len;
PyObject *fields, *key, *a = Py_None, *b = Py_None;

/* Check operator */
if ((op != Py_EQ && op != Py_NE) ||
!PyAST_Check(self) ||
!PyAST_Check(other)) {
Py_RETURN_NOTIMPLEMENTED;
}

/* Compare types */
if (Py_TYPE(self) != Py_TYPE(other)) {
if (op == Py_EQ)
Py_RETURN_FALSE;
else
Py_RETURN_TRUE;
}

/* Compare fields */
fields = PyObject_GetAttrString(self, "_fields");
len = PySequence_Size(fields);
for (i = 0; i < len; ++i) {
key = PySequence_GetItem(fields, i);

if (PyObject_HasAttr(self, key))
a = PyObject_GetAttr(self, key);
if (PyObject_HasAttr(other, key))
b = PyObject_GetAttr(other, key);

/* Check filed value type */
if (Py_TYPE(a) != Py_TYPE(b)) {
if (op == Py_EQ) {
Py_RETURN_FALSE;
}
}

if (op == Py_EQ) {
if (!PyObject_RichCompareBool(a, b, Py_EQ)) {
Py_RETURN_FALSE;
}
}
else if (op == Py_NE) {
if (PyObject_RichCompareBool(a, b, Py_NE)) {
Py_RETURN_TRUE;
}
}
}

if (op == Py_EQ)
Py_RETURN_TRUE;
else
Py_RETURN_FALSE;
}

static PyMethodDef ast_type_methods[] = {
{"__reduce__", ast_type_reduce, METH_NOARGS, NULL},
{NULL}
Expand Down Expand Up @@ -754,7 +811,7 @@ def visitModule(self, mod):
0, /* tp_doc */
(traverseproc)ast_traverse, /* tp_traverse */
(inquiry)ast_clear, /* tp_clear */
0, /* tp_richcompare */
ast_richcompare, /* tp_richcompare */
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
Expand Down
59 changes: 58 additions & 1 deletion Python/Python-ast.c
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,63 @@ ast_type_reduce(PyObject *self, PyObject *unused)
return Py_BuildValue("O()", Py_TYPE(self));
}

static PyObject *
ast_richcompare(PyObject *self, PyObject *other, int op)
{
int i, len;
PyObject *fields, *key, *a = Py_None, *b = Py_None;

/* Check operator */
if ((op != Py_EQ && op != Py_NE) ||
!PyAST_Check(self) ||
!PyAST_Check(other)) {
Py_RETURN_NOTIMPLEMENTED;
}

/* Compare types */
if (Py_TYPE(self) != Py_TYPE(other)) {
if (op == Py_EQ)
Py_RETURN_FALSE;
else
Py_RETURN_TRUE;
}

/* Compare fields */
fields = PyObject_GetAttrString(self, "_fields");
len = PySequence_Size(fields);
for (i = 0; i < len; ++i) {
key = PySequence_GetItem(fields, i);

if (PyObject_HasAttr(self, key))
a = PyObject_GetAttr(self, key);
if (PyObject_HasAttr(other, key))
b = PyObject_GetAttr(other, key);

/* Check filed value type */
if (Py_TYPE(a) != Py_TYPE(b)) {
if (op == Py_EQ) {
Py_RETURN_FALSE;
}
}

if (op == Py_EQ) {
if (!PyObject_RichCompareBool(a, b, Py_EQ)) {
Py_RETURN_FALSE;
}
}
else if (op == Py_NE) {
if (PyObject_RichCompareBool(a, b, Py_NE)) {
Py_RETURN_TRUE;
}
}
}

if (op == Py_EQ)
Py_RETURN_TRUE;
else
Py_RETURN_FALSE;
}

static PyMethodDef ast_type_methods[] = {
{"__reduce__", ast_type_reduce, METH_NOARGS, NULL},
{NULL}
Expand Down Expand Up @@ -641,7 +698,7 @@ static PyTypeObject AST_type = {
0, /* tp_doc */
(traverseproc)ast_traverse, /* tp_traverse */
(inquiry)ast_clear, /* tp_clear */
0, /* tp_richcompare */
ast_richcompare, /* tp_richcompare */
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
Expand Down