|
32 | 32 | try: |
33 | 33 | # using tools/ to optimize test run. |
34 | 34 | sys.path.append(str(REPO_ROOT)) |
| 35 | + from tools.stats.import_test_stats import get_test_times |
35 | 36 | from tools.testing.test_selections import ( |
36 | | - export_S3_test_times, |
37 | | - get_shard_based_on_S3, |
38 | 37 | get_reordered_tests, |
39 | 38 | get_test_case_configs, |
| 39 | + calculate_shards, |
40 | 40 | ) |
41 | 41 | HAVE_TEST_SELECTION_TOOLS = True |
42 | 42 | except ImportError: |
@@ -677,13 +677,6 @@ def parse_args(): |
677 | 677 | help="additional arguments passed through to unittest, e.g., " |
678 | 678 | "python run_test.py -i sparse -- TestSparse.test_factory_size_check", |
679 | 679 | ) |
680 | | - parser.add_argument( |
681 | | - "--export-past-test-times", |
682 | | - nargs="?", |
683 | | - type=str, |
684 | | - const=TEST_TIMES_FILE, |
685 | | - help="dumps test times from previous S3 stats into a file, format JSON", |
686 | | - ) |
687 | 680 | parser.add_argument( |
688 | 681 | "--shard", |
689 | 682 | nargs=2, |
@@ -838,11 +831,21 @@ def get_selected_tests(options): |
838 | 831 | assert num_shards <= len( |
839 | 832 | selected_tests |
840 | 833 | ), f"Number of shards must be less than {len(selected_tests)}" |
841 | | - # TODO: fix this to use test_times_filename, but currently this is not working |
842 | | - # because setting the export arg immeidately halts the test execution. |
843 | | - selected_tests = get_shard_based_on_S3( |
844 | | - which_shard, num_shards, selected_tests, TEST_TIMES_FILE |
845 | | - ) |
| 834 | + |
| 835 | + if num_shards == 1: |
| 836 | + return selected_tests |
| 837 | + |
| 838 | + # Download previous test times to make sharding decisions |
| 839 | + test_file_times = get_test_times(str(REPO_ROOT), filename=TEST_TIMES_FILE) |
| 840 | + if len(test_file_times) == 0: |
| 841 | + print( |
| 842 | + "::warning:: Gathered no stats from S3. Proceeding with default sharding plan." |
| 843 | + ) |
| 844 | + selected_tests = selected_tests[which_shard - 1 :: num_shards] |
| 845 | + else: |
| 846 | + shards = calculate_shards(num_shards, selected_tests, test_file_times) |
| 847 | + _, tests_from_shard = shards[which_shard - 1] |
| 848 | + selected_tests = tests_from_shard |
846 | 849 |
|
847 | 850 | # skip all distributed tests if distributed package is not available. |
848 | 851 | if not dist.is_available(): |
@@ -882,15 +885,6 @@ def run_test_module(test: str, test_directory: str, options) -> Optional[str]: |
882 | 885 | def main(): |
883 | 886 | options = parse_args() |
884 | 887 |
|
885 | | - # TODO: move this export & download function in tools/ folder |
886 | | - test_times_filename = options.export_past_test_times |
887 | | - if test_times_filename: |
888 | | - print( |
889 | | - f"Exporting past test times from S3 to {test_times_filename}, no tests will be run." |
890 | | - ) |
891 | | - export_S3_test_times(test_times_filename) |
892 | | - return |
893 | | - |
894 | 888 | test_directory = str(REPO_ROOT / "test") |
895 | 889 | selected_tests = get_selected_tests(options) |
896 | 890 |
|
|
0 commit comments