Skip to content

Commit 6d2d257

Browse files
committed
Update on "[FSDP][optim_state_dict][9/N] Rewrite the all-gather flow of optimizer state to support older GPUs"
[ghstack-poisoned]
2 parents 14e7993 + 5cc58a7 commit 6d2d257

File tree

161 files changed

+1962
-1325
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

161 files changed

+1962
-1325
lines changed

.github/ci_commit_pins/xla.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
d43a14a9adb89e6280b311f0513da2d3f3b0c618
1+
2ffe4a04df9d498e250153d931cadf9d92268510

.github/merge_rules.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,17 @@
350350
- Lint
351351
- pull
352352

353+
- name: ROCm
354+
patterns:
355+
- '**rocm**'
356+
- '**hip**'
357+
approved_by:
358+
- jeffdaily
359+
mandatory_checks_name:
360+
- EasyCLA
361+
- Lint
362+
- pull
363+
353364
- name: superuser
354365
patterns:
355366
- '*'

.github/scripts/get_workflow_job_id.py

Lines changed: 97 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,74 @@
22
# workflow. GitHub does not provide this information to workflow runs, so we
33
# need to figure it out based on what they *do* provide.
44

5-
import requests
6-
import os
75
import argparse
6+
import json
7+
import os
8+
import re
9+
import sys
10+
import urllib
11+
import urllib.parse
12+
13+
from typing import Any, Callable, Dict, List, Tuple, Optional
14+
from urllib.request import Request, urlopen
15+
16+
def parse_json_and_links(conn: Any) -> Tuple[Any, Dict[str, Dict[str, str]]]:
17+
links = {}
18+
# Extract links which GH uses for pagination
19+
# see https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Link
20+
if "Link" in conn.headers:
21+
for elem in re.split(", *<", conn.headers["Link"]):
22+
try:
23+
url, params_ = elem.split(";", 1)
24+
except ValueError:
25+
continue
26+
url = urllib.parse.unquote(url.strip("<> "))
27+
qparams = urllib.parse.parse_qs(params_.strip(), separator=";")
28+
params = {k: v[0].strip('"') for k, v in qparams.items() if type(v) is list and len(v) > 0}
29+
params["url"] = url
30+
if "rel" in params:
31+
links[params["rel"]] = params
32+
33+
return json.load(conn), links
834

9-
def handle_bad_status(response: requests.Response) -> None:
10-
if response.status_code != 200:
35+
def fetch_url(url: str, *,
36+
headers: Optional[Dict[str, str]] = None,
37+
reader: Callable[[Any], Any] = lambda x: x.read()) -> Any:
38+
if headers is None:
39+
headers = {}
40+
try:
41+
with urlopen(Request(url, headers=headers)) as conn:
42+
return reader(conn)
43+
except urllib.error.HTTPError as err:
1144
exception_message = (
1245
"Is github alright?",
13-
f"Recieved status code '{response.status_code}' when attempting to retrieve runs:\n",
14-
f"{response.content.decode()}"
46+
f"Recieved status code '{err.code}' when attempting to retrieve {url}:\n",
47+
f"{err.reason}\n\nheaders={err.headers}"
1548
)
16-
raise RuntimeError(exception_message)
49+
raise RuntimeError(exception_message) from err
50+
51+
def parse_args() -> Any:
52+
parser = argparse.ArgumentParser()
53+
parser.add_argument(
54+
"workflow_run_id", help="The id of the workflow run, should be GITHUB_RUN_ID"
55+
)
56+
parser.add_argument(
57+
"runner_name",
58+
help="The name of the runner to retrieve the job id, should be RUNNER_NAME",
59+
)
60+
61+
return parser.parse_args()
62+
63+
64+
def fetch_jobs(url: str, headers: Dict[str, str]) -> List[Dict[str, str]]:
65+
response, links = fetch_url(url, headers=headers, reader=parse_json_and_links)
66+
jobs = response["jobs"]
67+
assert type(jobs) is list
68+
while "next" in links.keys():
69+
response, links = fetch_url(links["next"]["url"], headers=headers, reader=parse_json_and_links)
70+
jobs.extend(response["jobs"])
71+
72+
return jobs
1773

1874

1975
# Our strategy is to retrieve the parent workflow run, then filter its jobs on
@@ -29,46 +85,37 @@ def handle_bad_status(response: requests.Response) -> None:
2985
# since only one job can be scheduled on a runner at a time, we know that
3086
# looking for RUNNER_NAME will uniquely identify the job we're currently
3187
# running.
32-
parser = argparse.ArgumentParser()
33-
parser.add_argument(
34-
"workflow_run_id", help="The id of the workflow run, should be GITHUB_RUN_ID"
35-
)
36-
parser.add_argument(
37-
"runner_name",
38-
help="The name of the runner to retrieve the job id, should be RUNNER_NAME",
39-
)
40-
41-
args = parser.parse_args()
42-
43-
44-
# From https://docs.github.com/en/actions/learn-github-actions/environment-variables
45-
PYTORCH_REPO = os.environ.get("GITHUB_REPOSITORY", "pytorch/pytorch")
46-
PYTORCH_GITHUB_API = f"https://api.github.com/repos/{PYTORCH_REPO}"
47-
GITHUB_TOKEN = os.environ["GITHUB_TOKEN"]
48-
REQUEST_HEADERS = {
49-
"Accept": "application/vnd.github.v3+json",
50-
"Authorization": "token " + GITHUB_TOKEN,
51-
}
52-
53-
response = requests.get(
54-
f"{PYTORCH_GITHUB_API}/actions/runs/{args.workflow_run_id}/jobs?per_page=100",
55-
headers=REQUEST_HEADERS,
56-
)
57-
handle_bad_status(response)
58-
59-
jobs = response.json()["jobs"]
60-
while "next" in response.links.keys():
61-
response = requests.get(response.links["next"]["url"], headers=REQUEST_HEADERS)
62-
handle_bad_status(response)
63-
jobs.extend(response.json()["jobs"])
64-
65-
# Sort the jobs list by start time, in descending order. We want to get the most
66-
# recently scheduled job on the runner.
67-
jobs.sort(key=lambda job: job["started_at"], reverse=True)
68-
69-
for job in jobs:
70-
if job["runner_name"] == args.runner_name:
71-
print(job["id"])
72-
exit(0)
73-
74-
exit(1)
88+
89+
def find_job_id(args: Any) -> str:
90+
# From https://docs.github.com/en/actions/learn-github-actions/environment-variables
91+
PYTORCH_REPO = os.environ.get("GITHUB_REPOSITORY", "pytorch/pytorch")
92+
PYTORCH_GITHUB_API = f"https://api.github.com/repos/{PYTORCH_REPO}"
93+
GITHUB_TOKEN = os.environ["GITHUB_TOKEN"]
94+
REQUEST_HEADERS = {
95+
"Accept": "application/vnd.github.v3+json",
96+
"Authorization": "token " + GITHUB_TOKEN,
97+
}
98+
99+
url = f"{PYTORCH_GITHUB_API}/actions/runs/{args.workflow_run_id}/jobs?per_page=100"
100+
jobs = fetch_jobs(url, REQUEST_HEADERS)
101+
102+
# Sort the jobs list by start time, in descending order. We want to get the most
103+
# recently scheduled job on the runner.
104+
jobs.sort(key=lambda job: job["started_at"], reverse=True)
105+
106+
for job in jobs:
107+
if job["runner_name"] == args.runner_name:
108+
return job["id"]
109+
110+
raise RuntimeError(f"Can't find job id for runner {args.runner_name}")
111+
112+
def main() -> None:
113+
args = parse_args()
114+
try:
115+
print(find_job_id(args))
116+
except Exception as e:
117+
print(repr(e), file=sys.stderr)
118+
print(f"workflow-{args.workflow_run_id}")
119+
120+
if __name__ == "__main__":
121+
main()

.github/workflows/_rocm-test.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@ jobs:
118118
SCCACHE_BUCKET: ossci-compiler-cache-circleci-v2
119119
DOCKER_IMAGE: ${{ inputs.docker-image }}
120120
XLA_CLANG_CACHE_S3_BUCKET_NAME: ossci-compiler-clang-cache-circleci-xla
121-
PYTORCH_JIT_ENABLE_NVFUSER: 1
122121
PYTORCH_TEST_CUDA_MEM_LEAK_CHECK: ${{ matrix.mem_leak_check && '1' || '0' }}
123122
PYTORCH_TEST_RERUN_DISABLED_TESTS: ${{ matrix.rerun_disabled_tests && '1' || '0' }}
124123
timeout-minutes: 270

.github/workflows/pull.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ jobs:
217217
docker-image-name: xla_base
218218
test-matrix: |
219219
{ include: [
220-
{ config: "xla", shard: 1, num_shards: 1, runner: "linux.2xlarge" },
220+
{ config: "xla", shard: 1, num_shards: 1, runner: "linux.4xlarge" },
221221
]}
222222
223223
linux-bionic-py3_7-clang8-xla-test:

.github/workflows/unstable.yml

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
name: unstable
2+
3+
on:
4+
push:
5+
branches:
6+
- master
7+
- main
8+
tags:
9+
- ciflow/unstable/*
10+
workflow_dispatch:
11+
12+
concurrency:
13+
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
14+
cancel-in-progress: true
15+
16+
jobs:
17+
# There must be at least one job here to satisfy GitHub action workflow syntax
18+
introduction:
19+
runs-on: ubuntu-latest
20+
continue-on-error: true
21+
steps:
22+
- name: Introduce PyTorch unstable workflow
23+
run: |
24+
echo "PyTorch unstable workflow is used to host experimental or flaky jobs"
25+
echo " that needs to be run for every commit, but doesn't block PR merging"
26+
echo " as part of the stable pull or trunk workflows."
27+
echo
28+
echo "In addition, a new label called ciflow/unstable can be attached to the"
29+
echo " PR to trigger this workflow. That can be done either manually or"
30+
echo " automatically using PyTorch auto-label bot."
31+
echo
32+
echo "Once the jobs are deemed stable enough (% red signal < 20% and TTS < 3h),"
33+
echo " they can graduate and move back to pull or trunk."

.jenkins/pytorch/multigpu-test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,5 @@ time python test/run_test.py --verbose -i distributed/_shard/test_partial_tensor
4545
time python test/run_test.py --verbose -i distributed/_shard/test_replicated_tensor
4646
# Other tests
4747
time python test/run_test.py --verbose -i test_cuda_primary_ctx
48+
time python test/run_test.py --verbose -i test_optim -- -k optimizers_with_varying_tensors
4849
assert_git_not_dirty

BUILD.bazel

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,6 @@ cc_library(
407407
"@cuda//:cusolver",
408408
"@cuda//:nvrtc",
409409
"@cudnn",
410-
"@cudnn_frontend",
411410
],
412411
alwayslink = True,
413412
)

CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,9 @@ cmake_dependent_option(
195195
cmake_dependent_option(
196196
BUILD_NVFUSER_BENCHMARK "Build C++ binaries for nvfuser benchmarks" OFF
197197
"USE_CUDA" OFF)
198+
cmake_dependent_option(
199+
USE_EXPERIMENTAL_CUDNN_V8_API "Use experimental cuDNN v8 API" ON
200+
"USE_CUDNN" OFF)
198201
option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" ON)
199202
option(USE_KINETO "Use Kineto profiling library" ON)
200203
option(USE_CUPTI_SO "Use CUPTI as a shared library" ON)
@@ -802,6 +805,7 @@ if(NOT MSVC)
802805
append_cxx_flag_if_supported("-Werror=non-virtual-dtor" CMAKE_CXX_FLAGS)
803806
append_cxx_flag_if_supported("-Werror=braced-scalar-init" CMAKE_CXX_FLAGS)
804807
append_cxx_flag_if_supported("-Werror=range-loop-construct" CMAKE_CXX_FLAGS)
808+
append_cxx_flag_if_supported("-Werror=bool-operation" CMAKE_CXX_FLAGS)
805809
append_cxx_flag_if_supported("-Winconsistent-missing-override" CMAKE_CXX_FLAGS)
806810
append_cxx_flag_if_supported("-Wnarrowing" CMAKE_CXX_FLAGS)
807811
append_cxx_flag_if_supported("-Wno-missing-field-initializers" CMAKE_CXX_FLAGS)

WORKSPACE

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -203,12 +203,6 @@ new_local_repository(
203203
path = "/usr/",
204204
)
205205

206-
new_local_repository(
207-
name = "cudnn_frontend",
208-
build_file = "@//third_party:cudnn_frontend.BUILD",
209-
path = "third_party/cudnn_frontend/",
210-
)
211-
212206
local_repository(
213207
name = "com_github_google_flatbuffers",
214208
path = "third_party/flatbuffers",

0 commit comments

Comments
 (0)