Skip to content

Commit 901b43c

Browse files
committed
test: Add tests for recursive model equality to prevent infinite recursion
1 parent 2e2d973 commit 901b43c

File tree

1 file changed

+76
-0
lines changed

1 file changed

+76
-0
lines changed

tests/test_recursive_eq.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from typing import Optional
2+
3+
from pydantic import BaseModel
4+
5+
6+
def test_recursive_model_equality():
7+
"""Test that comparing models with self-references doesn't cause infinite recursion."""
8+
9+
class RecursiveModel(BaseModel):
10+
value: int
11+
parent: Optional['RecursiveModel'] = None
12+
13+
# Create a model with a reference to itself
14+
model = RecursiveModel(value=1)
15+
model.parent = model
16+
17+
# This should not cause infinite recursion
18+
assert model == model
19+
20+
# Create another model with the same structure
21+
model2 = RecursiveModel(value=1)
22+
model2.parent = model2
23+
24+
# These models should be equal
25+
assert model == model2
26+
27+
# Create a model with a different value
28+
model3 = RecursiveModel(value=2)
29+
model3.parent = model3
30+
31+
# These models should not be equal
32+
assert model != model3
33+
34+
35+
def test_recursive_model_complex_cycle():
36+
"""Test that comparing models with complex reference cycles doesn't cause infinite recursion."""
37+
38+
class Node(BaseModel):
39+
value: int
40+
children: list['Node'] = []
41+
42+
# Create a cycle: root -> child1 -> child2 -> root
43+
root = Node(value=1)
44+
child1 = Node(value=2)
45+
child2 = Node(value=3)
46+
47+
root.children = [child1]
48+
child1.children = [child2]
49+
child2.children = [root]
50+
51+
# This should not cause infinite recursion
52+
assert root == root
53+
54+
# Create another identical structure
55+
root2 = Node(value=1)
56+
child1_2 = Node(value=2)
57+
child2_2 = Node(value=3)
58+
59+
root2.children = [child1_2]
60+
child1_2.children = [child2_2]
61+
child2_2.children = [root2]
62+
63+
# These should be equal
64+
assert root == root2
65+
66+
# Create a structure with a different value
67+
root3 = Node(value=4)
68+
child1_3 = Node(value=2)
69+
child2_3 = Node(value=3)
70+
71+
root3.children = [child1_3]
72+
child1_3.children = [child2_3]
73+
child2_3.children = [root3]
74+
75+
# These should not be equal
76+
assert root != root3

0 commit comments

Comments
 (0)