Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
270 changes: 172 additions & 98 deletions scripts/update_lib/patch_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@
- Applying patches to test files (JSON -> file)
"""

from __future__ import annotations

import ast
import collections
import contextlib
import enum
import re
import textwrap
import typing

Expand Down Expand Up @@ -91,33 +95,109 @@ def as_decorator(self) -> str:

return f"@{unparsed}"

@classmethod
def try_from_ast_node(
cls, node: ast.Attribute | ast.Call, lines: list[str]
) -> typing.Self | None:
if isinstance(node, ast.Attribute):
attr_node = node
elif isinstance(node, ast.Call):
attr_node = node.func
else:
return

if (
isinstance(attr_node, ast.Name)
or getattr(attr_node.value, "id", None) != UT
):
return

cond = None
try:
ut_method = UtMethod(attr_node.attr)
except ValueError:
return

# If our ut_method has args then,
# we need to search for a constant that contains our `COMMENT`.
# Otherwise we need to search it in the raw source code :/
if ut_method.has_args():
reason = next(
(
inner_node.value
for inner_node in ast.walk(node)
if isinstance(inner_node, ast.Constant)
and isinstance(inner_node.value, str)
and COMMENT in inner_node.value
),
None,
)

def _single_to_double_quotes(s: str) -> str:
"""Convert single-quoted strings to double-quoted strings.
# If we didn't find a constant containing <COMMENT>,
# then we didn't put this decorator
if not reason:
return

if ut_method.has_cond():
cond = ast.unparse(node.args[0])
else:
Comment on lines +141 to +143

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Handle keyword-based conditional decorators without crashing.

Line 142 assumes a positional condition (node.args[0]). Valid forms like @unittest.skipIf(condition=..., reason=...) make node.args empty and raise IndexError during extraction.

Proposed fix
-            if ut_method.has_cond():
-                cond = ast.unparse(node.args[0])
+            if ut_method.has_cond():
+                if not isinstance(node, ast.Call):
+                    return
+                if node.args:
+                    cond_node = node.args[0]
+                else:
+                    cond_node = next(
+                        (kw.value for kw in node.keywords if kw.arg in {"condition", "cond"}),
+                        None,
+                    )
+                    if cond_node is None:
+                        return
+                cond = ast.unparse(cond_node)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@scripts/update_lib/patch_spec.py` around lines 141 - 143, When extracting the
decorator condition in patch_spec.py (inside the block guarded by
ut_method.has_cond()), avoid indexing node.args directly; instead if node.args
is non-empty keep using node.args[0], otherwise search node.keywords for the
keyword named "condition" (i.e. iterate node.keywords and match kw.arg ==
"condition") and set cond = ast.unparse(kw.value) if found; if neither exists
set cond = None or skip extraction so the code using cond won't raise
IndexError. Reference: ut_method.has_cond(), cond variable, and node.args /
node.keywords.

pattern = re.compile(rf"{COMMENT}.?(.*)")
dec_lineno = node.lineno

curr_line = lines[dec_lineno - 1]
prev_line = lines[dec_lineno - 2]

Comment on lines +145 to +149

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Guard previous-line lookup at file start.

Line 148 reads lines[dec_lineno - 2] unconditionally. If the decorator is on the first line, this becomes lines[-1] and can incorrectly treat EOF comments as decorator metadata.

Proposed fix
-            prev_line = lines[dec_lineno - 2]
+            prev_line = lines[dec_lineno - 2] if dec_lineno > 1 else ""
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@scripts/update_lib/patch_spec.py` around lines 145 - 149, The code reads
prev_line = lines[dec_lineno - 2] without checking for file start, which makes
lines[-1] get used when a decorator is on the first line; change the logic
around dec_lineno (from node.lineno) so you only index dec_lineno - 2 when
dec_lineno > 1, otherwise set prev_line to an empty string or None; update any
downstream checks that read prev_line (in the same function handling decorator
parsing) to handle the empty/None value to avoid treating EOF/comments as
decorator metadata.

# If we see our comment at the decorator line, take it
if found := pattern.search(curr_line):
reason = found.group()
elif prev_line.strip().startswith("#") and (
found := pattern.search(prev_line)
):
# Search the previous line of the decorator,
# only take the comment if the line starts with a `#`
reason = found.group()
else:
# Didn't find our `COMMENT`, so the patch isn't ours :)
return

Falls back to original if conversion breaks the AST equivalence.
"""
import re
reason = reason.removeprefix(COMMENT).strip(";:, ")
return cls(ut_method, cond, reason)

def replace_string(match: re.Match) -> str:
content = match.group(1)
# Unescape single quotes and escape double quotes
content = content.replace("\\'", "'").replace('"', '\\"')
return f'"{content}"'

# Match single-quoted strings (handles escaped single quotes inside)
converted = re.sub(r"'((?:[^'\\]|\\.)*)'", replace_string, s)
class PatchEntryVisitor(ast.NodeVisitor):
def __init__(self, lines: list[str]):
self.current_class = None
self.patches = []
self.lines = lines

# Verify: parse converted and unparse should equal original
try:
converted_ast = ast.parse(converted, mode="eval")
if ast.unparse(converted_ast) == s:
return converted
except SyntaxError:
pass
def patches_from_node(
self, node: ast.FunctionDef | ast.AsyncFunctionDef
) -> Iterator[PatchEntry]:
for dec_node in node.decorator_list:
spec = PatchSpec.try_from_ast_node(dec_node, self.lines)

# Fall back to original if conversion failed
return s
if spec is None:
continue

yield PatchEntry(self.current_class, node.name, spec)

def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
self.patches.extend(self.patches_from_node(node))
# TODO: Support nested classes/methods
# self.generic_visit(node)

def visit_FunctionDef(self, node: ast.FunctionDef):
self.patches.extend(self.patches_from_node(node))
# TODO: Support nested classes/methods
# self.generic_visit(node)

def visit_ClassDef(self, node: ast.ClassDef):
with temp_attr(self, "current_class", node.name):
for patch in self.patches_from_node(node):
patch = patch._replace(test_name="__self__")
self.patches.append(patch)

self.generic_visit(node)


class PatchEntry(typing.NamedTuple):
Expand All @@ -142,76 +222,10 @@ class PatchEntry(typing.NamedTuple):
def iter_patch_entries(
cls, tree: ast.Module, lines: list[str]
) -> "Iterator[typing.Self]":
import re
import sys

for cls_node, fn_node in iter_tests(tree):
parent_class = cls_node.name
for dec_node in fn_node.decorator_list:
if not isinstance(dec_node, (ast.Attribute, ast.Call)):
continue

attr_node = (
dec_node if isinstance(dec_node, ast.Attribute) else dec_node.func
)

if (
isinstance(attr_node, ast.Name)
or getattr(attr_node.value, "id", None) != UT
):
continue

cond = None
try:
ut_method = UtMethod(attr_node.attr)
except ValueError:
continue

# If our ut_method has args then,
# we need to search for a constant that contains our `COMMENT`.
# Otherwise we need to search it in the raw source code :/
if ut_method.has_args():
reason = next(
(
node.value
for node in ast.walk(dec_node)
if isinstance(node, ast.Constant)
and isinstance(node.value, str)
and COMMENT in node.value
),
None,
)

# If we didn't find a constant containing <COMMENT>,
# then we didn't put this decorator
if not reason:
continue

if ut_method.has_cond():
cond = ast.unparse(dec_node.args[0])
else:
pattern = re.compile(rf"{COMMENT}.?(.*)")
dec_lineno = dec_node.lineno

curr_line = lines[dec_lineno - 1]
prev_line = lines[dec_lineno - 2]

# If we see our comment at the decorator line, take it
if found := pattern.search(curr_line):
reason = found.group()
elif prev_line.strip().startswith("#") and (
found := pattern.search(prev_line)
):
# Search the previous line of the decorator,
# only take the comment if the line starts with a `#`
reason = found.group()
else:
# Didn't find our `COMMENT`, so the patch isn't ours :)
continue

reason = reason.removeprefix(COMMENT).strip(";:, ")
spec = PatchSpec(ut_method, cond, reason)
yield cls(parent_class, fn_node.name, spec)
visitor = PatchEntryVisitor(lines)
visitor.visit(tree)
yield from visitor.patches


def iter_tests(
Expand Down Expand Up @@ -251,6 +265,15 @@ def extract_patches(contents: str) -> Patches:
return build_patch_dict(iter_patches(contents))


def modification_from_node_specs(node, specs):
lineno = min(
(dec_node.lineno for dec_node in node.decorator_list), default=node.lineno
)
indent = " " * node.col_offset
patch_lines = "\n".join(spec.as_decorator() for spec in specs)
return (lineno - 1, textwrap.indent(patch_lines, indent))


def _iter_patch_lines(
tree: ast.Module, patches: Patches
) -> "Iterator[tuple[int, str]]":
Expand All @@ -262,7 +285,15 @@ def _iter_patch_lines(
async_methods: dict[str, set[str]] = {}
# Track class bases for inherited async method lookup
class_bases: dict[str, list[str]] = {}
all_classes = {node.name for node in tree.body if isinstance(node, ast.ClassDef)}
all_classes = set()
all_class_nodes = []
for node in tree.body:
if not isinstance(node, ast.ClassDef):
continue

all_classes.add(node.name)
all_class_nodes.append(node)

for node in tree.body:
if isinstance(node, ast.ClassDef):
cache[node.name] = node.end_lineno
Expand All @@ -284,13 +315,7 @@ def _iter_patch_lines(
if not specs:
continue

lineno = min(
(dec_node.lineno for dec_node in fn_node.decorator_list),
default=fn_node.lineno,
)
indent = " " * fn_node.col_offset
patch_lines = "\n".join(spec.as_decorator() for spec in specs)
yield (lineno - 1, textwrap.indent(patch_lines, indent))
yield modification_from_node_specs(fn_node, specs)

# Phase 2: Iterate and mark inherited tests
for cls_name, tests in sorted(patches.items()):
Expand All @@ -300,6 +325,10 @@ def _iter_patch_lines(
continue

for test_name, specs in sorted(tests.items()):
if test_name == "__self__":
# Yielding modifications for the class itself should be done during phase 3
continue

decorators = "\n".join(spec.as_decorator() for spec in specs)
# Check current class and ancestors for async method
is_async = False
Expand All @@ -314,6 +343,7 @@ def _iter_patch_lines(
is_async = True
break
queue.extend(class_bases.get(cur, []))

if is_async:
patch_lines = f"""
{decorators}
Expand All @@ -328,6 +358,11 @@ def {test_name}(self):
""".rstrip()
yield (lineno, textwrap.indent(patch_lines, DEFAULT_INDENT))

# Phase 3: Mark the class itself
for cls_node in all_class_nodes:
if cls_specs := patches.get(cls_node.name, {}).pop("__self__", None):
yield modification_from_node_specs(cls_node, cls_specs)


def _has_unittest_import(tree: ast.Module) -> bool:
"""Check if 'import unittest' is already present in the file."""
Expand Down Expand Up @@ -406,3 +441,42 @@ def patches_from_json(data: dict) -> Patches:
}
for cls_name, tests in data.items()
}


def _single_to_double_quotes(s: str) -> str:
"""
Convert single-quoted strings to double-quoted strings.

Falls back to original if conversion breaks the AST equivalence.
"""
import re

def replace_string(match: re.Match) -> str:
content = match.group(1)
# Unescape single quotes and escape double quotes
content = content.replace("\\'", "'").replace('"', '\\"')
return f'"{content}"'

# Match single-quoted strings (handles escaped single quotes inside)
converted = re.sub(r"'((?:[^'\\]|\\.)*)'", replace_string, s)

# Verify: parse converted and unparse should equal original
try:
converted_ast = ast.parse(converted, mode="eval")
if ast.unparse(converted_ast) == s:
return converted
except SyntaxError:
pass

# Fall back to original if conversion failed
return s


@contextlib.contextmanager
def temp_attr(obj: object, attr: str, value: object):
old = getattr(obj, attr, None)
setattr(obj, attr, value)
try:
yield obj
finally:
setattr(obj, attr, old)
26 changes: 26 additions & 0 deletions scripts/update_lib/tests/test_patch_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,32 @@ def test_one(self):
self.assertIn("@unittest.expectedFailure", result)
self.assertIn(COMMENT, result)

def test_round_trip_with_patches_on_class(self):
"""Test that extracted patches can be re-applied."""
original = f"""import unittest

@unittest.skipIf(a == b, "{COMMENT}")
@unittest.expectedFailure # {COMMENT}
class TestFoo(unittest.TestCase):
...
"""
# Extract patches
patches = extract_patches(original)

# Apply to clean code
clean = """import unittest

class TestFoo(unittest.TestCase):
def test_one(self):
pass
"""
result = apply_patches(clean, patches)

# Should have the decorator
self.assertIn("@unittest.expectedFailure", result)
self.assertIn("@unittest.skipIf", result)
self.assertIn(COMMENT, result)

Comment on lines +348 to +373

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

This test addition violates the repository’s test-file change policy.

Lines 348-373 introduce new test logic/data/assertions in a *test*.py file, which is outside the allowed modification scope in this repo policy.

As per coding guidelines, **/*test*.py says: “NEVER modify test assertions, test logic, or test data” and limits acceptable edits to expectedFailure/TODO handling. Based on learnings, AGENTS guidance repeats the same restriction.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@scripts/update_lib/tests/test_patch_spec.py` around lines 348 - 373, The new
test method test_round_trip_with_patches_on_class violates the repo policy
against modifying test files; remove this added test (the function
test_round_trip_with_patches_on_class and its assertions referencing
extract_patches, apply_patches, and COMMENT) from the patch so no new test logic
or assertions are introduced in *test*.py files, or if needed, move equivalent
validation into an allowed non-test location (e.g., a new integration/example
script) and keep the test file unchanged.

Sources: Coding guidelines, Learnings


class TestFindImportInsertLine(unittest.TestCase):
"""Tests for _find_import_insert_line function."""
Expand Down
Loading