Skip to content

Commit 87b1623

Browse files
committed
Merge branch 'master' of github.com:pytorch/pytorch into bfloat-i0
2 parents 3450bba + a011b86 commit 87b1623

35 files changed

+609
-337
lines changed

.jenkins/pytorch/test.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,8 @@ test_benchmarks() {
340340
mkdir -p ${BENCHMARK_DATA}
341341
pytest benchmarks/fastrnns/test_bench.py --benchmark-sort=Name --benchmark-json=${BENCHMARK_DATA}/fastrnns_legacy_old.json --fuser=old --executor=legacy
342342
python benchmarks/upload_scribe.py --pytest_bench_json ${BENCHMARK_DATA}/fastrnns_legacy_old.json
343+
pytest benchmarks/fastrnns/test_bench.py --benchmark-sort=Name --benchmark-json=${BENCHMARK_DATA}/fastrnns_profiling_te.json --fuser=te --executor=profiling
344+
python benchmarks/upload_scribe.py --pytest_bench_json ${BENCHMARK_DATA}/fastrnns_profiling_te.json
343345
assert_git_not_dirty
344346
fi
345347
}

CODEOWNERS

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,6 @@
44
/docs/cpp @goldsborough @ebetica @yf225
55
/torch/csrc/api/ @ebetica @goldsborough @yf225
66
/test/cpp/api/ @ebetica @goldsborough @yf225
7-
/torch/lib/c10d/ @pietern @mrshenli @zhaojuanmao
8-
/torch/csrc/distributed/ @pietern @mrshenli @zhaojuanmao
9-
/torch/distributed/ @apaszke @pietern @mrshenli @zhaojuanmao
10-
/test/test_c10d.py @pietern @mrshenli @zhaojuanmao
117
/torch/utils/cpp_extension.py @goldsborough @fmassa @soumith @ezyang
128

139
# Not there to strictly require the approval, but to be tagged as a reviewer
@@ -20,17 +16,19 @@
2016
/torch/jit/ @apaszke
2117
/torch/utils/data/ @apaszke
2218

23-
# Distributed RPC Framework.
24-
/torch/csrc/distributed/rpc @mrshenli @pritamdamania87 @zhaojuanmao
25-
/torch/csrc/distributed/autograd @mrshenli @pritamdamania87 @zhaojuanmao
26-
/torch/distributed/rpc @mrshenli @pritamdamania87 @zhaojuanmao
27-
/torch/distributed/autograd @mrshenli @pritamdamania87 @zhaojuanmao
28-
/torch/distributed/optim @mrshenli @pritamdamania87 @zhaojuanmao @aazzolini
29-
3019
# Tensorpipe RPC Agent.
3120
/torch/csrc/distributed/rpc/tensorpipe_agent.cpp @jiayisuse @osalpekar @lw @beauby
3221
/torch/csrc/distributed/rpc/tensorpipe_agent.h @jiayisuse @osalpekar @lw @beauby
3322

23+
# Distributed package
24+
# This list is mostly if you'd like to be tagged as reviewer, feel free to add
25+
# or remove yourself from it.
26+
/torch/lib/c10d/ @pietern @mrshenli @zhaojuanmao @pritamdamania87 @rohan-varma
27+
/torch/csrc/distributed/ @pietern @mrshenli @zhaojuanmao @pritamdamania87 @rohan-varma
28+
/torch/distributed/ @apaszke @pietern @mrshenli @zhaojuanmao @pritamdamania87 @rohan-varma
29+
3430
# Distributed tests
35-
/test/distributed @mrshenli @pritamdamania87 @zhaojuanmao
36-
/torch/testing/_internal/distributed @mrshenli @pritamdamania87 @zhaojuanmao
31+
# This list is mostly if you'd like to be tagged as reviewer, feel free to add
32+
# or remove yourself from it.
33+
/test/distributed @mrshenli @pritamdamania87 @zhaojuanmao @rohan-varma
34+
/torch/testing/_internal/distributed @mrshenli @pritamdamania87 @zhaojuanmao @rohan-varma

CONTRIBUTING.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -825,8 +825,9 @@ static_assert(std::is_same(A*, decltype(A::singleton()))::value, "hmm");
825825
826826
[Clang-Tidy](https://clang.llvm.org/extra/clang-tidy/index.html) is a C++
827827
linter and static analysis tool based on the clang compiler. We run clang-tidy
828-
in our CI to make sure that new C++ code is safe, sane and efficient. See our
829-
[.travis.yml](https://github.com/pytorch/pytorch/blob/master/.travis.yml) file
828+
in our CI to make sure that new C++ code is safe, sane and efficient. See the
829+
[`clang-tidy` job in our GitHub Workflow's
830+
lint.yml file](https://github.com/pytorch/pytorch/blob/master/.github/workflows/lint.yml)
830831
for the simple commands we use for this.
831832
832833
To run clang-tidy locally, follow these steps:

aten/src/ATen/native/cuda/CompareEQKernel.cu

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@ namespace at { namespace native {
1212

1313
void eq_kernel_cuda(TensorIterator& iter) {
1414
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBFloat16, kBool, iter.common_dtype(), "eq_cuda", [&]() {
15-
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "eq_cuda", [&] {
16-
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
17-
return a == b;
18-
});
15+
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
16+
return a == b;
1917
});
2018
});
2119
}

aten/src/ATen/native/cuda/CompareGEKernel.cu

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@ namespace at { namespace native {
1212

1313
void ge_kernel_cuda(TensorIterator& iter) {
1414
AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBFloat16, kBool, iter.common_dtype(), "ge_cuda", [&]() {
15-
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "ge_cuda", [&] {
16-
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
17-
return a >= b;
18-
});
15+
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
16+
return a >= b;
1917
});
2018
});
2119
}

aten/src/ATen/native/cuda/CompareGTKernel.cu

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@ namespace at { namespace native {
1212

1313
void gt_kernel_cuda(TensorIterator& iter) {
1414
AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBFloat16, kBool, iter.common_dtype(), "gt_cuda", [&]() {
15-
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "gt_cuda", [&] {
16-
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
17-
return a > b;
18-
});
15+
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
16+
return a > b;
1917
});
2018
});
2119
}

aten/src/ATen/native/cuda/CompareLEKernel.cu

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@ namespace at { namespace native {
1212

1313
void le_kernel_cuda(TensorIterator& iter) {
1414
AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBFloat16, kBool, iter.common_dtype(), "le_cuda", [&]() {
15-
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "le_cuda", [&] {
16-
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
17-
return a <= b;
18-
});
15+
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
16+
return a <= b;
1917
});
2018
});
2119
}

aten/src/ATen/native/cuda/CompareLTKernel.cu

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@ namespace at { namespace native {
1212

1313
void lt_kernel_cuda(TensorIterator& iter) {
1414
AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBFloat16, kBool, iter.common_dtype(), "lt_cuda", [&]() {
15-
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "lt_cuda", [&] {
16-
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
17-
return a < b;
18-
});
15+
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
16+
return a < b;
1917
});
2018
});
2119
}

aten/src/ATen/native/cuda/PowKernel.cu

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,8 @@ void pow_tensor_tensor_kernel(TensorIterator& iter) {
110110
});
111111
} else if (isFloatingType(iter.dtype())) {
112112
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "pow_cuda", [&]() {
113-
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "pow_cuda", [&] {
114-
gpu_kernel(iter, []GPU_LAMBDA(scalar_t base, scalar_t exp) -> scalar_t {
115-
return pow_(base, exp);
116-
});
113+
gpu_kernel(iter, []GPU_LAMBDA(scalar_t base, scalar_t exp) -> scalar_t {
114+
return pow_(base, exp);
117115
});
118116
});
119117
} else {
@@ -170,10 +168,8 @@ void pow_tensor_scalar_kernel(TensorIterator& iter, Scalar exp_scalar) {
170168
});
171169
} else if (isFloatingType(iter.dtype()) || exp_scalar.isIntegral(false)) {
172170
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "pow_cuda", [&]() {
173-
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "pow_cuda", [&] {
174-
const auto exp = exp_scalar.to<scalar_t>();
175-
pow_tensor_scalar_kernel_impl<scalar_t>(iter, exp);
176-
});
171+
const auto exp = exp_scalar.to<scalar_t>();
172+
pow_tensor_scalar_kernel_impl<scalar_t>(iter, exp);
177173
});
178174
} else {
179175
const auto exp = exp_scalar.to<float>();

aten/src/ATen/record_function.cpp

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,14 @@ class CallbackManager {
9292
bool found_needs_ids = false;
9393
auto init_handles = [
9494
scope, &found_active_cb, &found_needs_inputs, &found_needs_ids](
95-
CallbackHandles& handles, RecordFunctionCallbacks& cbs) {
95+
CallbackHandles& handles, RecordFunctionCallbacks& cbs, ObserverContextList& ctx_list) {
9696
handles.clear();
97+
98+
size_t num_callbacks = 0;
9799
for (const auto& cb : cbs) {
98100
if (cb.first.shouldRun(scope)) {
99101
handles.push_back(cb.second);
102+
++num_callbacks;
100103
found_active_cb = true;
101104
if (cb.first.needsInputs()) {
102105
found_needs_inputs = true;
@@ -106,10 +109,12 @@ class CallbackManager {
106109
}
107110
}
108111
}
112+
// Pre-allocate observer context list with nullptr.
113+
ctx_list.resize(num_callbacks);
109114
};
110115

111-
init_handles(rec_fn.sorted_active_tls_handles_, sorted_tls_callbacks_);
112-
init_handles(rec_fn.sorted_active_global_handles_, sorted_global_callbacks_);
116+
init_handles(rec_fn.sorted_active_tls_handles_, sorted_tls_callbacks_, rec_fn.tls_ctx_);
117+
init_handles(rec_fn.sorted_active_global_handles_, sorted_global_callbacks_, rec_fn.global_ctx_);
113118
rec_fn.active = found_active_cb;
114119
rec_fn.needs_inputs = found_needs_inputs;
115120
if (found_needs_ids && found_active_cb) {
@@ -121,11 +126,13 @@ class CallbackManager {
121126
mergeRunCallbacks(
122127
sorted_global_callbacks_,
123128
rf.sorted_active_global_handles_,
129+
rf.global_ctx_,
124130
/* is_start */ true,
125131
rf);
126132
mergeRunCallbacks(
127133
sorted_tls_callbacks_,
128134
rf.sorted_active_tls_handles_,
135+
rf.tls_ctx_,
129136
/* is_start */ true,
130137
rf);
131138
rf.called_start_callbacks_ = true;
@@ -135,21 +142,30 @@ class CallbackManager {
135142
mergeRunCallbacks(
136143
sorted_global_callbacks_,
137144
rf.sorted_active_global_handles_,
145+
rf.global_ctx_,
138146
/* is_start */ false,
139147
rf);
140148
mergeRunCallbacks(
141149
sorted_tls_callbacks_,
142150
rf.sorted_active_tls_handles_,
151+
rf.tls_ctx_,
143152
/* is_start */ false,
144153
rf);
145154
}
146155

147156
private:
148157
bool tryRunCallback(
149-
const std::function<void(const RecordFunction&)>& fn,
150-
RecordFunction& rf) {
158+
const RecordFunctionCallback& rfcb,
159+
RecordFunction& rf,
160+
std::unique_ptr<ObserverContext>& ctx,
161+
bool is_start) {
151162
try {
152-
fn(rf);
163+
if (is_start) {
164+
ctx = rfcb.start()(rf);
165+
}
166+
else {
167+
rfcb.end()(rf, ctx.get());
168+
}
153169
return true;
154170
} catch (const std::exception &e) {
155171
LOG(WARNING) << "Exception in RecordFunction callback: "
@@ -165,11 +181,12 @@ class CallbackManager {
165181
void mergeRunCallbacks(
166182
const RecordFunctionCallbacks& sorted_callbacks,
167183
const CallbackHandles& sorted_handles,
184+
ObserverContextList& ctx_list,
168185
bool is_start,
169186
RecordFunction& rf) {
170187
size_t num_executed = 0;
171188
size_t idx_c = 0;
172-
for (size_t idx_h = 0; idx_h < sorted_handles.size(); ++idx_h) {
189+
for (size_t idx_h = 0; idx_h < sorted_handles.size() && idx_h < ctx_list.size(); ++idx_h) {
173190
while (idx_c < sorted_callbacks.size() &&
174191
sorted_callbacks[idx_c].second < sorted_handles[idx_h]) {
175192
++idx_c;
@@ -178,11 +195,7 @@ class CallbackManager {
178195
break;
179196
}
180197
if (sorted_callbacks[idx_c].second == sorted_handles[idx_h]) {
181-
if (is_start) {
182-
tryRunCallback(sorted_callbacks[idx_c].first.start(), rf);
183-
} else {
184-
tryRunCallback(sorted_callbacks[idx_c].first.end(), rf);
185-
}
198+
tryRunCallback(sorted_callbacks[idx_c].first, rf, ctx_list[idx_h], is_start);
186199
++num_executed;
187200
}
188201
}

0 commit comments

Comments
 (0)