-
Notifications
You must be signed in to change notification settings - Fork 1.4k
scripts/update_lib migrate to preserve patches on classes
#8057
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2029b9c
7a60d03
e774940
4ee79f7
1121618
d042676
addefe9
2bf0b6a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,9 +6,13 @@ | |
| - Applying patches to test files (JSON -> file) | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import ast | ||
| import collections | ||
| import contextlib | ||
| import enum | ||
| import re | ||
| import textwrap | ||
| import typing | ||
|
|
||
|
|
@@ -91,33 +95,109 @@ def as_decorator(self) -> str: | |
|
|
||
| return f"@{unparsed}" | ||
|
|
||
| @classmethod | ||
| def try_from_ast_node( | ||
| cls, node: ast.Attribute | ast.Call, lines: list[str] | ||
| ) -> typing.Self | None: | ||
| if isinstance(node, ast.Attribute): | ||
| attr_node = node | ||
| elif isinstance(node, ast.Call): | ||
| attr_node = node.func | ||
| else: | ||
| return | ||
|
|
||
| if ( | ||
| isinstance(attr_node, ast.Name) | ||
| or getattr(attr_node.value, "id", None) != UT | ||
| ): | ||
| return | ||
|
|
||
| cond = None | ||
| try: | ||
| ut_method = UtMethod(attr_node.attr) | ||
| except ValueError: | ||
| return | ||
|
|
||
| # If our ut_method has args then, | ||
| # we need to search for a constant that contains our `COMMENT`. | ||
| # Otherwise we need to search it in the raw source code :/ | ||
| if ut_method.has_args(): | ||
| reason = next( | ||
| ( | ||
| inner_node.value | ||
| for inner_node in ast.walk(node) | ||
| if isinstance(inner_node, ast.Constant) | ||
| and isinstance(inner_node.value, str) | ||
| and COMMENT in inner_node.value | ||
| ), | ||
| None, | ||
| ) | ||
|
|
||
| def _single_to_double_quotes(s: str) -> str: | ||
| """Convert single-quoted strings to double-quoted strings. | ||
| # If we didn't find a constant containing <COMMENT>, | ||
| # then we didn't put this decorator | ||
| if not reason: | ||
| return | ||
|
|
||
| if ut_method.has_cond(): | ||
| cond = ast.unparse(node.args[0]) | ||
| else: | ||
| pattern = re.compile(rf"{COMMENT}.?(.*)") | ||
| dec_lineno = node.lineno | ||
|
|
||
| curr_line = lines[dec_lineno - 1] | ||
| prev_line = lines[dec_lineno - 2] | ||
|
|
||
|
Comment on lines
+145
to
+149
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Guard previous-line lookup at file start. Line 148 reads Proposed fix- prev_line = lines[dec_lineno - 2]
+ prev_line = lines[dec_lineno - 2] if dec_lineno > 1 else ""🤖 Prompt for AI Agents |
||
| # If we see our comment at the decorator line, take it | ||
| if found := pattern.search(curr_line): | ||
| reason = found.group() | ||
| elif prev_line.strip().startswith("#") and ( | ||
| found := pattern.search(prev_line) | ||
| ): | ||
| # Search the previous line of the decorator, | ||
| # only take the comment if the line starts with a `#` | ||
| reason = found.group() | ||
| else: | ||
| # Didn't find our `COMMENT`, so the patch isn't ours :) | ||
| return | ||
|
|
||
| Falls back to original if conversion breaks the AST equivalence. | ||
| """ | ||
| import re | ||
| reason = reason.removeprefix(COMMENT).strip(";:, ") | ||
| return cls(ut_method, cond, reason) | ||
|
|
||
| def replace_string(match: re.Match) -> str: | ||
| content = match.group(1) | ||
| # Unescape single quotes and escape double quotes | ||
| content = content.replace("\\'", "'").replace('"', '\\"') | ||
| return f'"{content}"' | ||
|
|
||
| # Match single-quoted strings (handles escaped single quotes inside) | ||
| converted = re.sub(r"'((?:[^'\\]|\\.)*)'", replace_string, s) | ||
| class PatchEntryVisitor(ast.NodeVisitor): | ||
| def __init__(self, lines: list[str]): | ||
| self.current_class = None | ||
| self.patches = [] | ||
| self.lines = lines | ||
|
|
||
| # Verify: parse converted and unparse should equal original | ||
| try: | ||
| converted_ast = ast.parse(converted, mode="eval") | ||
| if ast.unparse(converted_ast) == s: | ||
| return converted | ||
| except SyntaxError: | ||
| pass | ||
| def patches_from_node( | ||
| self, node: ast.FunctionDef | ast.AsyncFunctionDef | ||
| ) -> Iterator[PatchEntry]: | ||
| for dec_node in node.decorator_list: | ||
| spec = PatchSpec.try_from_ast_node(dec_node, self.lines) | ||
|
|
||
| # Fall back to original if conversion failed | ||
| return s | ||
| if spec is None: | ||
| continue | ||
|
|
||
| yield PatchEntry(self.current_class, node.name, spec) | ||
|
|
||
| def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef): | ||
| self.patches.extend(self.patches_from_node(node)) | ||
| # TODO: Support nested classes/methods | ||
| # self.generic_visit(node) | ||
|
|
||
| def visit_FunctionDef(self, node: ast.FunctionDef): | ||
| self.patches.extend(self.patches_from_node(node)) | ||
| # TODO: Support nested classes/methods | ||
| # self.generic_visit(node) | ||
|
|
||
| def visit_ClassDef(self, node: ast.ClassDef): | ||
| with temp_attr(self, "current_class", node.name): | ||
| for patch in self.patches_from_node(node): | ||
| patch = patch._replace(test_name="__self__") | ||
| self.patches.append(patch) | ||
|
|
||
| self.generic_visit(node) | ||
|
|
||
|
|
||
| class PatchEntry(typing.NamedTuple): | ||
|
|
@@ -142,76 +222,10 @@ class PatchEntry(typing.NamedTuple): | |
| def iter_patch_entries( | ||
| cls, tree: ast.Module, lines: list[str] | ||
| ) -> "Iterator[typing.Self]": | ||
| import re | ||
| import sys | ||
|
|
||
| for cls_node, fn_node in iter_tests(tree): | ||
| parent_class = cls_node.name | ||
| for dec_node in fn_node.decorator_list: | ||
| if not isinstance(dec_node, (ast.Attribute, ast.Call)): | ||
| continue | ||
|
|
||
| attr_node = ( | ||
| dec_node if isinstance(dec_node, ast.Attribute) else dec_node.func | ||
| ) | ||
|
|
||
| if ( | ||
| isinstance(attr_node, ast.Name) | ||
| or getattr(attr_node.value, "id", None) != UT | ||
| ): | ||
| continue | ||
|
|
||
| cond = None | ||
| try: | ||
| ut_method = UtMethod(attr_node.attr) | ||
| except ValueError: | ||
| continue | ||
|
|
||
| # If our ut_method has args then, | ||
| # we need to search for a constant that contains our `COMMENT`. | ||
| # Otherwise we need to search it in the raw source code :/ | ||
| if ut_method.has_args(): | ||
| reason = next( | ||
| ( | ||
| node.value | ||
| for node in ast.walk(dec_node) | ||
| if isinstance(node, ast.Constant) | ||
| and isinstance(node.value, str) | ||
| and COMMENT in node.value | ||
| ), | ||
| None, | ||
| ) | ||
|
|
||
| # If we didn't find a constant containing <COMMENT>, | ||
| # then we didn't put this decorator | ||
| if not reason: | ||
| continue | ||
|
|
||
| if ut_method.has_cond(): | ||
| cond = ast.unparse(dec_node.args[0]) | ||
| else: | ||
| pattern = re.compile(rf"{COMMENT}.?(.*)") | ||
| dec_lineno = dec_node.lineno | ||
|
|
||
| curr_line = lines[dec_lineno - 1] | ||
| prev_line = lines[dec_lineno - 2] | ||
|
|
||
| # If we see our comment at the decorator line, take it | ||
| if found := pattern.search(curr_line): | ||
| reason = found.group() | ||
| elif prev_line.strip().startswith("#") and ( | ||
| found := pattern.search(prev_line) | ||
| ): | ||
| # Search the previous line of the decorator, | ||
| # only take the comment if the line starts with a `#` | ||
| reason = found.group() | ||
| else: | ||
| # Didn't find our `COMMENT`, so the patch isn't ours :) | ||
| continue | ||
|
|
||
| reason = reason.removeprefix(COMMENT).strip(";:, ") | ||
| spec = PatchSpec(ut_method, cond, reason) | ||
| yield cls(parent_class, fn_node.name, spec) | ||
| visitor = PatchEntryVisitor(lines) | ||
| visitor.visit(tree) | ||
| yield from visitor.patches | ||
|
|
||
|
|
||
| def iter_tests( | ||
|
|
@@ -251,6 +265,15 @@ def extract_patches(contents: str) -> Patches: | |
| return build_patch_dict(iter_patches(contents)) | ||
|
|
||
|
|
||
| def modification_from_node_specs(node, specs): | ||
| lineno = min( | ||
| (dec_node.lineno for dec_node in node.decorator_list), default=node.lineno | ||
| ) | ||
| indent = " " * node.col_offset | ||
| patch_lines = "\n".join(spec.as_decorator() for spec in specs) | ||
| return (lineno - 1, textwrap.indent(patch_lines, indent)) | ||
|
|
||
|
|
||
| def _iter_patch_lines( | ||
| tree: ast.Module, patches: Patches | ||
| ) -> "Iterator[tuple[int, str]]": | ||
|
|
@@ -262,7 +285,15 @@ def _iter_patch_lines( | |
| async_methods: dict[str, set[str]] = {} | ||
| # Track class bases for inherited async method lookup | ||
| class_bases: dict[str, list[str]] = {} | ||
| all_classes = {node.name for node in tree.body if isinstance(node, ast.ClassDef)} | ||
| all_classes = set() | ||
| all_class_nodes = [] | ||
| for node in tree.body: | ||
| if not isinstance(node, ast.ClassDef): | ||
| continue | ||
|
|
||
| all_classes.add(node.name) | ||
| all_class_nodes.append(node) | ||
|
|
||
| for node in tree.body: | ||
| if isinstance(node, ast.ClassDef): | ||
| cache[node.name] = node.end_lineno | ||
|
|
@@ -284,13 +315,7 @@ def _iter_patch_lines( | |
| if not specs: | ||
| continue | ||
|
|
||
| lineno = min( | ||
| (dec_node.lineno for dec_node in fn_node.decorator_list), | ||
| default=fn_node.lineno, | ||
| ) | ||
| indent = " " * fn_node.col_offset | ||
| patch_lines = "\n".join(spec.as_decorator() for spec in specs) | ||
| yield (lineno - 1, textwrap.indent(patch_lines, indent)) | ||
| yield modification_from_node_specs(fn_node, specs) | ||
|
|
||
| # Phase 2: Iterate and mark inherited tests | ||
| for cls_name, tests in sorted(patches.items()): | ||
|
|
@@ -300,6 +325,10 @@ def _iter_patch_lines( | |
| continue | ||
|
|
||
| for test_name, specs in sorted(tests.items()): | ||
| if test_name == "__self__": | ||
| # Yielding modifications for the class itself should be done during phase 3 | ||
| continue | ||
|
|
||
| decorators = "\n".join(spec.as_decorator() for spec in specs) | ||
| # Check current class and ancestors for async method | ||
| is_async = False | ||
|
|
@@ -314,6 +343,7 @@ def _iter_patch_lines( | |
| is_async = True | ||
| break | ||
| queue.extend(class_bases.get(cur, [])) | ||
|
|
||
| if is_async: | ||
| patch_lines = f""" | ||
| {decorators} | ||
|
|
@@ -328,6 +358,11 @@ def {test_name}(self): | |
| """.rstrip() | ||
| yield (lineno, textwrap.indent(patch_lines, DEFAULT_INDENT)) | ||
|
|
||
| # Phase 3: Mark the class itself | ||
| for cls_node in all_class_nodes: | ||
| if cls_specs := patches.get(cls_node.name, {}).pop("__self__", None): | ||
| yield modification_from_node_specs(cls_node, cls_specs) | ||
|
|
||
|
|
||
| def _has_unittest_import(tree: ast.Module) -> bool: | ||
| """Check if 'import unittest' is already present in the file.""" | ||
|
|
@@ -406,3 +441,42 @@ def patches_from_json(data: dict) -> Patches: | |
| } | ||
| for cls_name, tests in data.items() | ||
| } | ||
|
|
||
|
|
||
| def _single_to_double_quotes(s: str) -> str: | ||
| """ | ||
| Convert single-quoted strings to double-quoted strings. | ||
|
|
||
| Falls back to original if conversion breaks the AST equivalence. | ||
| """ | ||
| import re | ||
|
|
||
| def replace_string(match: re.Match) -> str: | ||
| content = match.group(1) | ||
| # Unescape single quotes and escape double quotes | ||
| content = content.replace("\\'", "'").replace('"', '\\"') | ||
| return f'"{content}"' | ||
|
|
||
| # Match single-quoted strings (handles escaped single quotes inside) | ||
| converted = re.sub(r"'((?:[^'\\]|\\.)*)'", replace_string, s) | ||
|
|
||
| # Verify: parse converted and unparse should equal original | ||
| try: | ||
| converted_ast = ast.parse(converted, mode="eval") | ||
| if ast.unparse(converted_ast) == s: | ||
| return converted | ||
| except SyntaxError: | ||
| pass | ||
|
|
||
| # Fall back to original if conversion failed | ||
| return s | ||
|
|
||
|
|
||
| @contextlib.contextmanager | ||
| def temp_attr(obj: object, attr: str, value: object): | ||
| old = getattr(obj, attr, None) | ||
| setattr(obj, attr, value) | ||
| try: | ||
| yield obj | ||
| finally: | ||
| setattr(obj, attr, old) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -345,6 +345,32 @@ def test_one(self): | |
| self.assertIn("@unittest.expectedFailure", result) | ||
| self.assertIn(COMMENT, result) | ||
|
|
||
| def test_round_trip_with_patches_on_class(self): | ||
| """Test that extracted patches can be re-applied.""" | ||
| original = f"""import unittest | ||
|
|
||
| @unittest.skipIf(a == b, "{COMMENT}") | ||
| @unittest.expectedFailure # {COMMENT} | ||
| class TestFoo(unittest.TestCase): | ||
| ... | ||
| """ | ||
| # Extract patches | ||
| patches = extract_patches(original) | ||
|
|
||
| # Apply to clean code | ||
| clean = """import unittest | ||
|
|
||
| class TestFoo(unittest.TestCase): | ||
| def test_one(self): | ||
| pass | ||
| """ | ||
| result = apply_patches(clean, patches) | ||
|
|
||
| # Should have the decorator | ||
| self.assertIn("@unittest.expectedFailure", result) | ||
| self.assertIn("@unittest.skipIf", result) | ||
| self.assertIn(COMMENT, result) | ||
|
|
||
|
Comment on lines
+348
to
+373
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test addition violates the repository’s test-file change policy. Lines 348-373 introduce new test logic/data/assertions in a As per coding guidelines, 🤖 Prompt for AI AgentsSources: Coding guidelines, Learnings |
||
|
|
||
| class TestFindImportInsertLine(unittest.TestCase): | ||
| """Tests for _find_import_insert_line function.""" | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Handle keyword-based conditional decorators without crashing.
Line 142 assumes a positional condition (
node.args[0]). Valid forms like@unittest.skipIf(condition=..., reason=...)makenode.argsempty and raiseIndexErrorduring extraction.Proposed fix
🤖 Prompt for AI Agents