Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 140 additions & 11 deletions scripts/fix_test.py → scripts/auto_mark_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,17 @@
"""

import argparse
import ast
import shutil
import sys
from pathlib import Path

from lib_updater import PatchSpec, UtMethod, apply_patches
from lib_updater import (
COMMENT,
PatchSpec,
UtMethod,
apply_patches,
)


def parse_args():
Expand Down Expand Up @@ -61,15 +67,18 @@ def __str__(self):
class TestResult:
tests_result: str = ""
tests = []
unexpected_successes = [] # Tests that passed but were marked as expectedFailure
stdout = ""

def __str__(self):
return f"TestResult(tests_result={self.tests_result},tests={len(self.tests)})"
return f"TestResult(tests_result={self.tests_result},tests={len(self.tests)},unexpected_successes={len(self.unexpected_successes)})"


def parse_results(result):
lines = result.stdout.splitlines()
test_results = TestResult()
test_results.tests = []
test_results.unexpected_successes = []
test_results.stdout = result.stdout
in_test_results = False
for line in lines:
Expand Down Expand Up @@ -107,6 +116,19 @@ def parse_results(result):
res = line.split("== Tests result: ")[1]
res = res.split(" ")[0]
test_results.tests_result = res
# Parse: "UNEXPECTED SUCCESS: test_name (path)"
elif line.startswith("UNEXPECTED SUCCESS: "):
rest = line[len("UNEXPECTED SUCCESS: ") :]
# Format: "test_name (path)"
first_space = rest.find(" ")
if first_space > 0:
test = Test()
test.name = rest[:first_space]
path_part = rest[first_space:].strip()
if path_part.startswith("(") and path_part.endswith(")"):
test.path = path_part[1:-1]
test.result = "unexpected_success"
test_results.unexpected_successes.append(test)
return test_results


Expand All @@ -117,6 +139,95 @@ def path_to_test(path) -> list[str]:
return parts[-2:] # Get class name and method name


def is_super_call_only(func_node: ast.FunctionDef | ast.AsyncFunctionDef) -> bool:
"""Check if the method body is just 'return super().method_name()'."""
if len(func_node.body) != 1:
return False
stmt = func_node.body[0]
if not isinstance(stmt, ast.Return) or stmt.value is None:
return False
# Check for super().method_name() pattern
call = stmt.value
if not isinstance(call, ast.Call):
return False
if not isinstance(call.func, ast.Attribute):
return False
super_call = call.func.value
if not isinstance(super_call, ast.Call):
return False
if not isinstance(super_call.func, ast.Name) or super_call.func.id != "super":
return False
return True


def remove_expected_failures(
contents: str, tests_to_remove: set[tuple[str, str]]
) -> str:
"""Remove @unittest.expectedFailure decorators from tests that now pass."""
if not tests_to_remove:
return contents

tree = ast.parse(contents)
lines = contents.splitlines()
lines_to_remove = set()

for node in ast.walk(tree):
if not isinstance(node, ast.ClassDef):
continue
class_name = node.name
for item in node.body:
if not isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
method_name = item.name
if (class_name, method_name) not in tests_to_remove:
continue

# Check if we should remove the entire method (super() call only)
remove_entire_method = is_super_call_only(item)

if remove_entire_method:
# Remove entire method including decorators and any preceding comment
first_line = item.lineno - 1 # 0-indexed, def line
if item.decorator_list:
first_line = item.decorator_list[0].lineno - 1
# Check for TODO comment before first decorator/def
if first_line > 0:
prev_line = lines[first_line - 1].strip()
if prev_line.startswith("#") and COMMENT in prev_line:
first_line -= 1
# Remove from first_line to end_lineno (inclusive)
for i in range(first_line, item.end_lineno):
lines_to_remove.add(i)
else:
# Only remove the expectedFailure decorator
for dec in item.decorator_list:
dec_line = dec.lineno - 1 # 0-indexed
line_content = lines[dec_line]

# Check if it's @unittest.expectedFailure
if "expectedFailure" not in line_content:
continue

# Check if TODO: RUSTPYTHON is on the same line or the line before
has_comment_on_line = COMMENT in line_content
has_comment_before = (
dec_line > 0
and lines[dec_line - 1].strip().startswith("#")
and COMMENT in lines[dec_line - 1]
)

if has_comment_on_line or has_comment_before:
lines_to_remove.add(dec_line)
if has_comment_before:
lines_to_remove.add(dec_line - 1)

# Remove lines in reverse order to maintain line numbers
for line_idx in sorted(lines_to_remove, reverse=True):
del lines[line_idx]

return "\n".join(lines) + "\n" if lines else ""


def build_patches(test_parts_set: set[tuple[str, str]]) -> dict:
"""Convert failing tests to lib_updater patch format."""
patches = {}
Expand Down Expand Up @@ -190,20 +301,38 @@ def run_test(test_name):
f = test_path.read_text(encoding="utf-8")

# Collect failing tests (with deduplication for subtests)
seen_tests = set() # Track (class_name, method_name) to avoid duplicates
failing_tests = set() # Track (class_name, method_name) to avoid duplicates
for test in tests.tests:
if test.result == "fail" or test.result == "error":
test_parts = path_to_test(test.path)
if len(test_parts) == 2:
test_key = tuple(test_parts)
if test_key not in seen_tests:
seen_tests.add(test_key)
print(f"Marking test: {test_parts[0]}.{test_parts[1]}")

# Apply patches using lib_updater
if seen_tests:
patches = build_patches(seen_tests)
if test_key not in failing_tests:
failing_tests.add(test_key)
print(f"Marking as failing: {test_parts[0]}.{test_parts[1]}")

# Collect unexpected successes (tests that now pass but have expectedFailure)
unexpected_successes = set()
for test in tests.unexpected_successes:
test_parts = path_to_test(test.path)
if len(test_parts) == 2:
test_key = tuple(test_parts)
if test_key not in unexpected_successes:
unexpected_successes.add(test_key)
print(f"Removing expectedFailure: {test_parts[0]}.{test_parts[1]}")

# Remove expectedFailure from tests that now pass
if unexpected_successes:
f = remove_expected_failures(f, unexpected_successes)

# Apply patches for failing tests
if failing_tests:
patches = build_patches(failing_tests)
f = apply_patches(f, patches)

# Write changes if any modifications were made
if failing_tests or unexpected_successes:
test_path.write_text(f, encoding="utf-8")

print(f"Modified {len(seen_tests)} tests")
print(f"Added expectedFailure to {len(failing_tests)} tests")
print(f"Removed expectedFailure from {len(unexpected_successes)} tests")
Loading