@@ -83,7 +83,7 @@ def parse_results(result):
8383 test_results .stdout = result .stdout
8484 in_test_results = False
8585 for line in lines :
86- if re .match (r"Run tests? sequentially" , line ):
86+ if re .search (r"Run \d+ tests? sequentially" , line ):
8787 in_test_results = True
8888 elif line .startswith ("-----------" ):
8989 in_test_results = False
@@ -161,6 +161,66 @@ def is_super_call_only(func_node: ast.FunctionDef | ast.AsyncFunctionDef) -> boo
161161 return True
162162
163163
164+ def build_inheritance_info (tree : ast .Module ) -> tuple [dict , dict ]:
165+ """
166+ Build inheritance information from AST.
167+
168+ Returns:
169+ class_bases: dict[str, list[str]] - parent classes for each class (only those defined in the file)
170+ class_methods: dict[str, set[str]] - methods directly defined in each class
171+ """
172+ all_classes = {
173+ node .name for node in ast .walk (tree ) if isinstance (node , ast .ClassDef )
174+ }
175+ class_bases = {}
176+ class_methods = {}
177+
178+ for node in ast .walk (tree ):
179+ if isinstance (node , ast .ClassDef ):
180+ # Collect only parent classes defined in this file
181+ bases = [
182+ base .id
183+ for base in node .bases
184+ if isinstance (base , ast .Name ) and base .id in all_classes
185+ ]
186+ class_bases [node .name ] = bases
187+
188+ # Collect directly defined methods
189+ methods = {
190+ item .name
191+ for item in node .body
192+ if isinstance (item , (ast .FunctionDef , ast .AsyncFunctionDef ))
193+ }
194+ class_methods [node .name ] = methods
195+
196+ return class_bases , class_methods
197+
198+
199+ def find_method_definition (
200+ class_name : str , method_name : str , class_bases : dict , class_methods : dict
201+ ) -> str | None :
202+ """Find the class where a method is actually defined. Traverses inheritance chain (BFS)."""
203+ # Check current class first
204+ if method_name in class_methods .get (class_name , set ()):
205+ return class_name
206+
207+ # Search parent classes
208+ visited = set ()
209+ queue = list (class_bases .get (class_name , []))
210+
211+ while queue :
212+ current = queue .pop (0 )
213+ if current in visited :
214+ continue
215+ visited .add (current )
216+
217+ if method_name in class_methods .get (current , set ()):
218+ return current
219+ queue .extend (class_bases .get (current , []))
220+
221+ return None
222+
223+
164224def remove_expected_failures (
165225 contents : str , tests_to_remove : set [tuple [str , str ]]
166226) -> str :
@@ -172,6 +232,18 @@ def remove_expected_failures(
172232 lines = contents .splitlines ()
173233 lines_to_remove = set ()
174234
235+ # Build inheritance information
236+ class_bases , class_methods = build_inheritance_info (tree )
237+
238+ # Resolve to actual defining classes
239+ resolved_tests = set ()
240+ for class_name , method_name in tests_to_remove :
241+ defining_class = find_method_definition (
242+ class_name , method_name , class_bases , class_methods
243+ )
244+ if defining_class :
245+ resolved_tests .add ((defining_class , method_name ))
246+
175247 for node in ast .walk (tree ):
176248 if not isinstance (node , ast .ClassDef ):
177249 continue
@@ -180,7 +252,7 @@ def remove_expected_failures(
180252 if not isinstance (item , (ast .FunctionDef , ast .AsyncFunctionDef )):
181253 continue
182254 method_name = item .name
183- if (class_name , method_name ) not in tests_to_remove :
255+ if (class_name , method_name ) not in resolved_tests :
184256 continue
185257
186258 # Check if we should remove the entire method (super() call only)
0 commit comments