|
28 | 28 |
|
29 | 29 | import fnmatch |
30 | 30 | import imp |
| 31 | +import itertools |
31 | 32 | import os |
32 | 33 | from contextlib import contextmanager |
33 | 34 |
|
@@ -88,6 +89,7 @@ def __init__(self, suite, test_class, test_config, test_root): |
88 | 89 | self.test_class = test_class |
89 | 90 | self.test_config = test_config |
90 | 91 | self.test_root = test_root |
| 92 | + self.test_count_estimation = len(list(self._list_test_filenames())) |
91 | 93 |
|
92 | 94 | def _list_test_filenames(self): |
93 | 95 | """Implemented by the subclassed TestLoaders to list filenames. |
@@ -199,6 +201,34 @@ def extension(self): |
199 | 201 | return ".js" |
200 | 202 |
|
201 | 203 |
|
| 204 | +class TestGenerator(object): |
| 205 | + def __init__(self, test_count_estimate, slow_tests, fast_tests): |
| 206 | + self.test_count_estimate = test_count_estimate |
| 207 | + self.slow_tests = slow_tests |
| 208 | + self.fast_tests = fast_tests |
| 209 | + self._rebuild_iterator() |
| 210 | + |
| 211 | + def _rebuild_iterator(self): |
| 212 | + self._iterator = itertools.chain(self.slow_tests, self.fast_tests) |
| 213 | + |
| 214 | + def __iter__(self): |
| 215 | + return self |
| 216 | + |
| 217 | + def __next__(self): |
| 218 | + return self.next() |
| 219 | + |
| 220 | + def next(self): |
| 221 | + return next(self._iterator) |
| 222 | + |
| 223 | + def merge(self, test_generator): |
| 224 | + self.test_count_estimate += test_generator.test_count_estimate |
| 225 | + self.slow_tests = itertools.chain( |
| 226 | + self.slow_tests, test_generator.slow_tests) |
| 227 | + self.fast_tests = itertools.chain( |
| 228 | + self.fast_tests, test_generator.fast_tests) |
| 229 | + self._rebuild_iterator() |
| 230 | + |
| 231 | + |
202 | 232 | @contextmanager |
203 | 233 | def _load_testsuite_module(name, root): |
204 | 234 | f = None |
@@ -236,14 +266,22 @@ def _test_loader_class(self): |
236 | 266 | def ListTests(self): |
237 | 267 | return self._test_loader.list_tests() |
238 | 268 |
|
| 269 | + def __initialize_test_count_estimation(self): |
| 270 | + # Retrieves a single test to initialize the test generator. |
| 271 | + next(iter(self.ListTests())) |
| 272 | + |
| 273 | + def __calculate_test_count(self): |
| 274 | + self.__initialize_test_count_estimation() |
| 275 | + return self._test_loader.test_count_estimation |
| 276 | + |
239 | 277 | def load_tests_from_disk(self, statusfile_variables): |
240 | 278 | self.statusfile = statusfile.StatusFile( |
241 | 279 | self.status_file(), statusfile_variables) |
242 | 280 |
|
| 281 | + test_count = self.__calculate_test_count() |
243 | 282 | slow_tests = (test for test in self.ListTests() if test.is_slow) |
244 | 283 | fast_tests = (test for test in self.ListTests() if not test.is_slow) |
245 | | - |
246 | | - return slow_tests, fast_tests |
| 284 | + return TestGenerator(test_count, slow_tests, fast_tests) |
247 | 285 |
|
248 | 286 | def get_variants_gen(self, variants): |
249 | 287 | return self._variants_gen_class()(variants) |
|
0 commit comments