Skip to content

Commit b7bc09e

Browse files
committed
auto mark parent
1 parent 860094c commit b7bc09e

File tree

2 files changed

+79
-4
lines changed

2 files changed

+79
-4
lines changed

scripts/auto_mark_test.py

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def parse_results(result):
8383
test_results.stdout = result.stdout
8484
in_test_results = False
8585
for line in lines:
86-
if re.match(r"Run tests? sequentially", line):
86+
if re.search(r"Run \d+ tests? sequentially", line):
8787
in_test_results = True
8888
elif line.startswith("-----------"):
8989
in_test_results = False
@@ -161,6 +161,66 @@ def is_super_call_only(func_node: ast.FunctionDef | ast.AsyncFunctionDef) -> boo
161161
return True
162162

163163

164+
def build_inheritance_info(tree: ast.Module) -> tuple[dict, dict]:
165+
"""
166+
Build inheritance information from AST.
167+
168+
Returns:
169+
class_bases: dict[str, list[str]] - parent classes for each class (only those defined in the file)
170+
class_methods: dict[str, set[str]] - methods directly defined in each class
171+
"""
172+
all_classes = {
173+
node.name for node in ast.walk(tree) if isinstance(node, ast.ClassDef)
174+
}
175+
class_bases = {}
176+
class_methods = {}
177+
178+
for node in ast.walk(tree):
179+
if isinstance(node, ast.ClassDef):
180+
# Collect only parent classes defined in this file
181+
bases = [
182+
base.id
183+
for base in node.bases
184+
if isinstance(base, ast.Name) and base.id in all_classes
185+
]
186+
class_bases[node.name] = bases
187+
188+
# Collect directly defined methods
189+
methods = {
190+
item.name
191+
for item in node.body
192+
if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef))
193+
}
194+
class_methods[node.name] = methods
195+
196+
return class_bases, class_methods
197+
198+
199+
def find_method_definition(
200+
class_name: str, method_name: str, class_bases: dict, class_methods: dict
201+
) -> str | None:
202+
"""Find the class where a method is actually defined. Traverses inheritance chain (BFS)."""
203+
# Check current class first
204+
if method_name in class_methods.get(class_name, set()):
205+
return class_name
206+
207+
# Search parent classes
208+
visited = set()
209+
queue = list(class_bases.get(class_name, []))
210+
211+
while queue:
212+
current = queue.pop(0)
213+
if current in visited:
214+
continue
215+
visited.add(current)
216+
217+
if method_name in class_methods.get(current, set()):
218+
return current
219+
queue.extend(class_bases.get(current, []))
220+
221+
return None
222+
223+
164224
def remove_expected_failures(
165225
contents: str, tests_to_remove: set[tuple[str, str]]
166226
) -> str:
@@ -172,6 +232,18 @@ def remove_expected_failures(
172232
lines = contents.splitlines()
173233
lines_to_remove = set()
174234

235+
# Build inheritance information
236+
class_bases, class_methods = build_inheritance_info(tree)
237+
238+
# Resolve to actual defining classes
239+
resolved_tests = set()
240+
for class_name, method_name in tests_to_remove:
241+
defining_class = find_method_definition(
242+
class_name, method_name, class_bases, class_methods
243+
)
244+
if defining_class:
245+
resolved_tests.add((defining_class, method_name))
246+
175247
for node in ast.walk(tree):
176248
if not isinstance(node, ast.ClassDef):
177249
continue
@@ -180,7 +252,7 @@ def remove_expected_failures(
180252
if not isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
181253
continue
182254
method_name = item.name
183-
if (class_name, method_name) not in tests_to_remove:
255+
if (class_name, method_name) not in resolved_tests:
184256
continue
185257

186258
# Check if we should remove the entire method (super() call only)

scripts/lib_updater.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,11 +236,14 @@ def build_patch_dict(it: "Iterator[PatchEntry]") -> Patches:
236236

237237

238238
def iter_patch_lines(tree: ast.Module, patches: Patches) -> "Iterator[tuple[int, str]]":
239-
cache = {} # Used in phase 2. Stores the end line location of a class name.
239+
# Build cache of all classes (for Phase 2 to find classes without methods)
240+
cache = {}
241+
for node in tree.body:
242+
if isinstance(node, ast.ClassDef):
243+
cache[node.name] = node.end_lineno
240244

241245
# Phase 1: Iterate and mark existing tests
242246
for cls_node, fn_node in iter_tests(tree):
243-
cache[cls_node.name] = cls_node.end_lineno
244247
specs = patches.get(cls_node.name, {}).pop(fn_node.name, None)
245248
if not specs:
246249
continue

0 commit comments

Comments
 (0)