|
13 | 13 | import subprocess |
14 | 14 | import sys |
15 | 15 | import tempfile |
| 16 | +import json |
| 17 | +from typing import Dict, Optional, List, cast, Any |
16 | 18 |
|
17 | 19 | import torch |
18 | 20 | from torch.utils import cpp_extension |
|
25 | 27 | parser as common_parser, |
26 | 28 | ) |
27 | 29 | import torch.distributed as dist |
28 | | -from typing import Optional, List |
29 | 30 |
|
30 | 31 | REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent |
31 | 32 |
|
32 | 33 | try: |
33 | 34 | # using tools/ to optimize test run. |
34 | 35 | sys.path.append(str(REPO_ROOT)) |
35 | | - from tools.stats.import_test_stats import get_test_times |
| 36 | + from tools.stats.export_test_times import TEST_TIMES_FILE |
36 | 37 | from tools.testing.test_selections import ( |
37 | 38 | get_reordered_tests, |
38 | 39 | get_test_case_configs, |
@@ -72,6 +73,7 @@ def skip_test_p(name: str) -> bool: |
72 | 73 | rc += extra_tests |
73 | 74 | return sorted(rc) |
74 | 75 |
|
| 76 | + |
75 | 77 | TESTS = discover_tests( |
76 | 78 | blocklisted_patterns=[ |
77 | 79 | 'ao', |
@@ -268,9 +270,6 @@ def skip_test_p(name: str) -> bool: |
268 | 270 | "test_torch" |
269 | 271 | ] |
270 | 272 |
|
271 | | -# the JSON file to store the S3 test stats |
272 | | -TEST_TIMES_FILE = ".pytorch-test-times.json" |
273 | | - |
274 | 273 | # if a test file takes longer than 5 min, we add it to TARGET_DET_LIST |
275 | 274 | SLOW_TEST_THRESHOLD = 300 |
276 | 275 |
|
@@ -395,14 +394,14 @@ def test_cuda_primary_ctx(test_module, test_directory, options): |
395 | 394 | test_module, test_directory, options, extra_unittest_args=["--subprocess"] |
396 | 395 | ) |
397 | 396 |
|
| 397 | + |
398 | 398 | run_test_with_subprocess = functools.partial(run_test, extra_unittest_args=["--subprocess"]) |
399 | 399 |
|
400 | 400 |
|
401 | 401 | def get_run_test_with_subprocess_fn(): |
402 | 402 | return lambda test_module, test_directory, options: run_test_with_subprocess(test_module, test_directory, options) |
403 | 403 |
|
404 | 404 |
|
405 | | - |
406 | 405 | def _test_cpp_extensions_aot(test_directory, options, use_ninja): |
407 | 406 | if use_ninja: |
408 | 407 | try: |
@@ -570,6 +569,7 @@ def test_distributed(test_module, test_directory, options): |
570 | 569 | "distributed/rpc/cuda/test_tensorpipe_agent": get_run_test_with_subprocess_fn(), |
571 | 570 | } |
572 | 571 |
|
| 572 | + |
573 | 573 | def parse_test_module(test): |
574 | 574 | return test.split(".")[0] |
575 | 575 |
|
@@ -862,14 +862,21 @@ def get_selected_tests(options): |
862 | 862 | return selected_tests |
863 | 863 |
|
864 | 864 | # Download previous test times to make sharding decisions |
865 | | - test_file_times = get_test_times(str(REPO_ROOT), filename=TEST_TIMES_FILE) |
866 | | - if len(test_file_times) == 0: |
| 865 | + path = os.path.join(str(REPO_ROOT), TEST_TIMES_FILE) |
| 866 | + if os.path.exists(path): |
| 867 | + with open(path, "r") as f: |
| 868 | + test_file_times = cast(Dict[str, Any], json.load(f)) |
| 869 | + else: |
| 870 | + test_file_times = {} |
| 871 | + if os.environ["TEST_CONFIG"] not in test_file_times: |
867 | 872 | print( |
868 | | - "::warning:: Gathered no stats from S3. Proceeding with default sharding plan." |
| 873 | + "::warning:: Gathered no stats from artifacts. Proceeding with default sharding plan." |
869 | 874 | ) |
870 | 875 | selected_tests = selected_tests[which_shard - 1 :: num_shards] |
871 | 876 | else: |
872 | | - shards = calculate_shards(num_shards, selected_tests, test_file_times) |
| 877 | + print("Found test time stats from artifacts") |
| 878 | + test_file_times_config = test_file_times[os.environ["TEST_CONFIG"]] |
| 879 | + shards = calculate_shards(num_shards, selected_tests, test_file_times_config) |
873 | 880 | _, tests_from_shard = shards[which_shard - 1] |
874 | 881 | selected_tests = tests_from_shard |
875 | 882 |
|
|
0 commit comments