Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 57 additions & 18 deletions scripts/update_lib/cmd_auto_mark.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,53 +71,89 @@ def run_test(test_name: str, skip_build: bool = False) -> TestResult:
return parse_results(result)


def _try_parse_test_info(test_info: str) -> tuple[str, str] | None:
"""Try to extract (name, path) from 'test_name (path)' or 'test_name (path) [subtest]'."""
first_space = test_info.find(" ")
if first_space > 0:
name = test_info[:first_space]
rest = test_info[first_space:].strip()
if rest.startswith("("):
end_paren = rest.find(")")
if end_paren > 0:
return name, rest[1:end_paren]
return None


def parse_results(result: subprocess.CompletedProcess) -> TestResult:
"""Parse subprocess result into TestResult."""
lines = result.stdout.splitlines()
test_results = TestResult()
test_results.stdout = result.stdout
in_test_results = False
# For multiline format: "test_name (path)\ndocstring ... RESULT"
pending_test_info = None

for line in lines:
if re.search(r"Run \d+ tests? sequentially", line):
in_test_results = True
elif line.startswith("-----------"):
elif "== Tests result: " in line:
in_test_results = False

if in_test_results and " ... " in line:
line = line.strip()
stripped = line.strip()
# Skip lines that don't look like test results
if line.startswith("tests") or line.startswith("["):
if stripped.startswith("tests") or stripped.startswith("["):
pending_test_info = None
continue
# Parse: "test_name (path) [subtest] ... RESULT"
parts = line.split(" ... ")
parts = stripped.split(" ... ")
if len(parts) >= 2:
test_info = parts[0]
result_str = parts[-1].lower()
# Only process FAIL or ERROR
if result_str not in ("fail", "error"):
pending_test_info = None
continue
# Extract test name (first word)
first_space = test_info.find(" ")
if first_space > 0:
# Try parsing from this line (single-line format)
parsed = _try_parse_test_info(test_info)
if not parsed and pending_test_info:
# Multiline format: previous line had test_name (path)
parsed = _try_parse_test_info(pending_test_info)
if parsed:
test = Test()
test.name = test_info[:first_space]
# Extract path from (path)
rest = test_info[first_space:].strip()
if rest.startswith("("):
end_paren = rest.find(")")
if end_paren > 0:
test.path = rest[1:end_paren]
test.result = result_str
test_results.tests.append(test)
test.name, test.path = parsed
test.result = result_str
test_results.tests.append(test)
pending_test_info = None

elif in_test_results:
# Track test info for multiline format:
# test_name (path)
# docstring ... RESULT
stripped = line.strip()
if (
stripped
and "(" in stripped
and stripped.endswith(")")
and ":" not in stripped.split("(")[0]
):
pending_test_info = stripped
else:
pending_test_info = None

# Also check for Tests result on non-" ... " lines
if "== Tests result: " in line:
res = line.split("== Tests result: ")[1]
res = res.split(" ")[0]
test_results.tests_result = res

elif "== Tests result: " in line:
res = line.split("== Tests result: ")[1]
res = res.split(" ")[0]
test_results.tests_result = res

# Parse: "UNEXPECTED SUCCESS: test_name (path)"
elif line.startswith("UNEXPECTED SUCCESS: "):
if line.startswith("UNEXPECTED SUCCESS: "):
rest = line[len("UNEXPECTED SUCCESS: ") :]
# Format: "test_name (path)"
first_space = rest.find(" ")
Expand Down Expand Up @@ -232,13 +268,16 @@ def build_patches(


def _is_super_call_only(func_node: ast.FunctionDef | ast.AsyncFunctionDef) -> bool:
"""Check if the method body is just 'return super().method_name()'."""
"""Check if the method body is just 'return super().method_name()' or 'return await super().method_name()'."""
if len(func_node.body) != 1:
return False
stmt = func_node.body[0]
if not isinstance(stmt, ast.Return) or stmt.value is None:
return False
call = stmt.value
# Unwrap await for async methods
if isinstance(call, ast.Await):
call = call.value
if not isinstance(call, ast.Call):
return False
if not isinstance(call.func, ast.Attribute):
Expand Down
18 changes: 17 additions & 1 deletion scripts/update_lib/patch_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,17 @@ def _iter_patch_lines(

# Build cache of all classes (for Phase 2 to find classes without methods)
cache = {}
# Build per-class set of async method names (for Phase 2 to generate correct override)
async_methods: dict[str, set[str]] = {}
for node in tree.body:
if isinstance(node, ast.ClassDef):
cache[node.name] = node.end_lineno
cls_async: set[str] = set()
for item in node.body:
if isinstance(item, ast.AsyncFunctionDef):
cls_async.add(item.name)
if cls_async:
async_methods[node.name] = cls_async

# Phase 1: Iterate and mark existing tests
for cls_node, fn_node in iter_tests(tree):
Expand All @@ -274,7 +282,15 @@ def _iter_patch_lines(

for test_name, specs in tests.items():
decorators = "\n".join(spec.as_decorator() for spec in specs)
patch_lines = f"""
is_async = test_name in async_methods.get(cls_name, set())
if is_async:
patch_lines = f"""
{decorators}
async def {test_name}(self):
{DEFAULT_INDENT}return await super().{test_name}()
""".rstrip()
else:
patch_lines = f"""
{decorators}
def {test_name}(self):
{DEFAULT_INDENT}return super().{test_name}()
Expand Down
166 changes: 166 additions & 0 deletions scripts/update_lib/tests/test_auto_mark.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,76 @@ def test_parse_error_message(self):
self.assertEqual(len(result.tests), 1)
self.assertEqual(result.tests[0].error_message, "AssertionError: 1 != 2")

def test_parse_directory_test_multiple_submodules(self):
"""Test parsing directory test output with multiple submodules.

When running a directory test (e.g., test_asyncio), the output contains
multiple submodules separated by '------' lines. Failures in submodules
after the first one must still be detected.
"""
stdout = """\
Run 3 tests sequentially
0:00:00 [ 1/3] test_asyncio.test_buffered_proto
test_ok (test.test_asyncio.test_buffered_proto.TestProto.test_ok) ... ok

----------------------------------------------------------------------
Ran 1 tests in 0.1s

OK

0:00:01 [ 2/3] test_asyncio.test_events
test_create (test.test_asyncio.test_events.TestEvents.test_create) ... FAIL

----------------------------------------------------------------------
Ran 1 tests in 0.2s

FAILED (failures=1)

0:00:02 [ 3/3] test_asyncio.test_tasks
test_gather (test.test_asyncio.test_tasks.TestTasks.test_gather) ... ERROR

----------------------------------------------------------------------
Ran 1 tests in 0.3s

FAILED (errors=1)

== Tests result: FAILURE ==
"""
result = parse_results(self._make_result(stdout))
self.assertEqual(len(result.tests), 2)
names = {t.name for t in result.tests}
self.assertIn("test_create", names)
self.assertIn("test_gather", names)
# Verify results
test_create = next(t for t in result.tests if t.name == "test_create")
test_gather = next(t for t in result.tests if t.name == "test_gather")
self.assertEqual(test_create.result, "fail")
self.assertEqual(test_gather.result, "error")
self.assertEqual(result.tests_result, "FAILURE")

def test_parse_multiline_test_with_docstring(self):
"""Test parsing tests where docstring appears on a separate line.

Some tests have docstrings that cause the output to span two lines:
test_name (path)
docstring ... ERROR
"""
stdout = """\
Run 3 tests sequentially
test_ok (test.test_example.TestClass.test_ok) ... ok
test_with_doc (test.test_example.TestClass.test_with_doc)
Test that something works ... ERROR
test_normal_fail (test.test_example.TestClass.test_normal_fail) ... FAIL
"""
result = parse_results(self._make_result(stdout))
self.assertEqual(len(result.tests), 2)
names = {t.name for t in result.tests}
self.assertIn("test_with_doc", names)
self.assertIn("test_normal_fail", names)
test_doc = next(t for t in result.tests if t.name == "test_with_doc")
self.assertEqual(test_doc.path, "test.test_example.TestClass.test_with_doc")
self.assertEqual(test_doc.result, "error")

def test_parse_multiple_error_messages(self):
"""Test parsing multiple error messages."""
stdout = """
Expand Down Expand Up @@ -644,6 +714,102 @@ def test_one(self):
method = self._parse_method(code)
self.assertFalse(_is_super_call_only(method))

def test_async_await_super_call(self):
"""Test async method that awaits super().same_name()."""
code = """
class Foo:
async def test_one(self):
return await super().test_one()
"""
method = self._parse_method(code)
self.assertTrue(_is_super_call_only(method))

def test_async_await_mismatched_super_call(self):
"""Test async method that awaits super().different_name()."""
code = """
class Foo:
async def test_one(self):
return await super().test_two()
"""
method = self._parse_method(code)
self.assertFalse(_is_super_call_only(method))

def test_async_without_await(self):
"""Test async method that calls super() without await (sync super call in async method)."""
code = """
class Foo:
async def test_one(self):
return super().test_one()
"""
method = self._parse_method(code)
self.assertTrue(_is_super_call_only(method))


class TestAsyncInheritedOverride(unittest.TestCase):
"""Tests for async inherited method override generation."""

def test_inherited_async_method_generates_async_override(self):
"""Test that inherited async methods get async def + await override."""
code = """import unittest

class BaseTest:
async def test_async_one(self):
pass

class TestChild(BaseTest, unittest.TestCase):
pass
"""
failing = {("TestChild", "test_async_one")}
result = apply_test_changes(code, failing, set())

self.assertIn("async def test_async_one(self):", result)
self.assertIn("return await super().test_async_one()", result)
self.assertIn("@unittest.expectedFailure", result)

def test_inherited_sync_method_generates_sync_override(self):
"""Test that inherited sync methods get sync def override."""
code = """import unittest

class BaseTest:
def test_sync_one(self):
pass

class TestChild(BaseTest, unittest.TestCase):
pass
"""
failing = {("TestChild", "test_sync_one")}
result = apply_test_changes(code, failing, set())

self.assertIn("def test_sync_one(self):", result)
self.assertIn("return super().test_sync_one()", result)
self.assertNotIn("async def test_sync_one", result)
self.assertNotIn("await", result)

def test_remove_async_super_call_override(self):
"""Test removing async super call override on unexpected success."""
code = f"""import unittest

class BaseTest:
async def test_async_one(self):
pass

class TestChild(BaseTest, unittest.TestCase):
# {COMMENT}
@unittest.expectedFailure
async def test_async_one(self):
return await super().test_async_one()
"""
successes = {("TestChild", "test_async_one")}
result = apply_test_changes(code, set(), successes)

# The override in TestChild should be removed; base class method remains
self.assertNotIn("return await super().test_async_one()", result)
self.assertNotIn("@unittest.expectedFailure", result)
self.assertIn("class TestChild", result)
# Base class method should still be present
self.assertIn("class BaseTest", result)
self.assertIn("async def test_async_one(self):", result)


if __name__ == "__main__":
unittest.main()
Loading