Skip to content

Commit fbf274f

Browse files
Michael Carillifacebook-github-bot
authored andcommitted
Autocast support for cudnn RNNs (#42385)
Summary: Should close #36428. The cudnn RNN API expects weights to occupy a flat buffer in memory with a particular layout. This PR implements a "speed of light" fix: [`_cudnn_rnn_cast_reflatten`](https://github.com/pytorch/pytorch/pull/42385/files#diff-9ef93b6a4fb5a06a37c562b83737ac6aR327) (the autocast wrapper assigned to `_cudnn_rnn`) copies weights to the right slices of a flat FP16 buffer with a single read/write per weight (as opposed to casting them to FP16 individually then reflattening the individual FP16 weights, which would require 2 read/writes per weight). It isn't pretty but IMO it doesn't make rnn bindings much more tortuous than they already are. The [test](https://github.com/pytorch/pytorch/pull/42385/files#diff-e68a7bc6ba14f212e5e7eb3727394b40R2683) tries a forward under autocast and a backward for the full cross product of RNN options and input/weight/hidden dtypes. As for all FP16list autocast tests, forward output and backward grads are checked against a control where inputs (including RNN module weights in this case) are precasted to FP16 on the python side. Not sure who to ask for review, tagging ezyang and ngimel because Ed wrote this file (almost 2 years ago) and Natalia did the most recent major [surgery](#12600). Side quests discovered: - Should we update [persistent RNN heuristics](https://github.com/pytorch/pytorch/blob/dbdd28207c5cf6c4a35ceb1de0811c4812e8882c/aten/src/ATen/native/cudnn/RNN.cpp#L584) to include compute capability 8.0? Could be another PR but seems easy enough to include. - Many (maybe all?!) the raw cudnn API calls in [RNN.cpp](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp) are deprecated in cudnn 8. I don't mind taking the AI to update them since my mental cache is full of rnn stuff, but that would be a substantial separate PR. Pull Request resolved: #42385 Reviewed By: zhangguanheng66 Differential Revision: D23077782 Pulled By: ezyang fbshipit-source-id: a2afb1bdab33ba0442879a703df13dc87f03ec2e
1 parent 0a9c35a commit fbf274f

File tree

12 files changed

+531
-229
lines changed

12 files changed

+531
-229
lines changed

BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,7 @@ filegroup(
340340
"aten/src/ATen/cuda/CublasHandlePool.cpp",
341341
"aten/src/ATen/cuda/PinnedMemoryAllocator.cpp",
342342
"aten/src/ATen/cuda/detail/CUDAHooks.cpp",
343+
"aten/src/ATen/cudnn/AutocastRNN.cpp",
343344
"aten/src/ATen/cudnn/Descriptors.cpp",
344345
"aten/src/ATen/cudnn/Handle.cpp",
345346
"aten/src/ATen/cudnn/Types.cpp",

aten/src/ATen/autocast_mode.cpp

Lines changed: 29 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -59,90 +59,10 @@ int decrement_nesting() {
5959
return --nesting;
6060
}
6161

62-
// Policies correspond to op categories that need code-divergent handling.
63-
// Wrapper templates below are specialized based on a policy template parameter.
64-
enum class CastPolicy : uint8_t {
65-
fp16 = 0, // Cast all inputs to at::kHalf before running the op.
66-
fp32, // Cast all inputs to at::kFloat before running the op.
67-
fp32_set_opt_dtype, // Treats functions (like softmax) that
68-
// 1. we'd like to run in fp32 and
69-
// 2. have a c10::optional<ScalarType> arg that controls the output type.
70-
// fp32_set_opt_dtype wrappers' policy is: if the output type is already set,
71-
// don't touch it, otherwise, set it to at::kFloat.
72-
fp32_append_dtype, // Treats functions (like norm) that
73-
// 1. we'd like to run in fp32 and
74-
// 2. have some overloads that accept an output type and other overloads that don't.
75-
// fp32_append_dtype wrappers wrap the overloads that don't have an output dtype.
76-
// The wrapper policy is: append at::kFloat to the args, and redispatch to the
77-
// type-aware overload.
78-
promote, // Run in the widest dtype among several args.
79-
};
80-
81-
/********************************************************************
82-
Logic to extract the promote type from any Tensor or TensorList args.
83-
********************************************************************/
84-
85-
// Overload to catch Tensor args.
86-
// If nextArg is floating-point, compare its scalar_type with our
87-
// current best guess for the promote type, and update if necessary.
88-
inline at::ScalarType prioritize(at::ScalarType current, const Tensor& nextArg) {
89-
if (current == at::kDouble) {
90-
AT_ERROR("promote type is double in at::autocast::prioritize");
91-
return current;
92-
}
93-
if (nextArg.is_cuda() && nextArg.is_floating_point()) {
94-
auto next = nextArg.scalar_type();
95-
if (next == at::kDouble) {
96-
return current; // ignores double tensors
97-
} else if (current == at::kFloat || next == at::kFloat) {
98-
return at::kFloat; // prioritizes float over half
99-
} else if (current == at::kHalf && next == at::kHalf) {
100-
return at::kHalf;
101-
} else {
102-
AT_ERROR("Unexpected floating ScalarType in at::autocast::prioritize");
103-
return current;
104-
}
105-
} else {
106-
return current;
107-
}
108-
}
109-
110-
// Overload to catch TensorList args (for e.g. cat, stack).
111-
// Reuses the overload above to process each Tensor in the list.
112-
inline at::ScalarType prioritize(at::ScalarType current, const TensorList& list) {
113-
for (const auto& tensor : list) {
114-
current = prioritize(current, tensor);
115-
}
116-
return current;
117-
}
118-
119-
// Template to catch non-Tensor args (no-op that returns current best guess)
120-
template<typename T>
121-
inline at::ScalarType prioritize(at::ScalarType current, T nextArg) {
122-
return current;
123-
}
124-
125-
// Overload for the tail case.
126-
inline at::ScalarType promote_type(at::ScalarType current) {
127-
return current;
128-
}
129-
130-
// Unpack args and determine if incoming float16 tensors need to be promoted to float32.
131-
// Non-Tensor arguments are ignored.
132-
template<typename Arg0, typename... Args>
133-
inline at::ScalarType promote_type(at::ScalarType current, Arg0 arg0, Args... args) {
134-
auto new_current = prioritize(current, arg0);
135-
return promote_type(new_current, args...);
136-
}
137-
138-
/****************************************************
139-
Logic to apply cached casting to any Tensor argument.
140-
****************************************************/
141-
inline bool is_eligible(const Tensor& arg) {
142-
return (arg.is_cuda() && arg.is_floating_point() && (arg.scalar_type() != at::kDouble));
143-
}
144-
14562
// Overload to catch Tensor args
63+
// TODO (possible optimization): Move cast_cache to an inline function in a header
64+
// (+ refactor the can_try_cache branch to call a small non-inline helper function.
65+
// can_try_cache branch is the only part that's hard to inline in other files).
14666
Tensor cached_cast(at::ScalarType to_type, const Tensor& arg) {
14767
if (is_eligible(arg) && (arg.scalar_type() != to_type)) {
14868
// Heuristic: Do what Apex does, and cache fp16 casts of fp32 model weights (leaves).
@@ -165,61 +85,24 @@ Tensor cached_cast(at::ScalarType to_type, const Tensor& arg) {
16585
}
16686
}
16787

168-
// Overload to process optional<Tensor>
169-
c10::optional<Tensor> cached_cast(at::ScalarType to_type, const c10::optional<Tensor>& arg) {
170-
if (arg.has_value()) {
171-
return cached_cast(to_type, *arg);
172-
} else {
173-
return c10::nullopt;
174-
}
175-
}
176-
177-
// Overload to process TensorLists
178-
std::vector<Tensor> cached_cast(at::ScalarType to_type, const TensorList& arg) {
179-
std::vector<Tensor> vec;
180-
vec.reserve(arg.size());
181-
for (const auto& t : arg) {
182-
vec.push_back(cached_cast(to_type, t));
183-
}
184-
return vec;
185-
}
186-
187-
// Template to catch non-Tensor args.
188-
template<typename T>
189-
T cached_cast(at::ScalarType to_type, T arg) {
190-
return arg;
191-
}
192-
193-
/*******************************************************
194-
Logic to flip an output dtype flag.
195-
Keep it simple for now by assuming only one such flag is
196-
present in the argument list. If I ever need a function
197-
with more than flag I'll figure out something else.
198-
The policy is:
199-
If the user has explicity specified a dtype, respect it.
200-
Otherwise, set it to the autocast type.
201-
********************************************************/
202-
203-
// Overload to catch dtype flags
204-
c10::optional<ScalarType> set_opt_dtype(at::ScalarType to_type, const c10::optional<ScalarType>& dtype) {
205-
return dtype.has_value() ? dtype : to_type;
206-
}
207-
208-
// Template to catch other args
209-
template<typename T>
210-
inline T set_opt_dtype(at::ScalarType to_type, T arg) {
211-
return arg;
212-
}
213-
214-
template<typename... Args>
215-
inline bool firstarg_is_eligible(const Tensor& arg, Args... args) {
216-
return is_eligible(arg);
217-
}
218-
219-
template<typename... Args>
220-
inline at::ScalarType type_from_firstarg(at::ScalarType to_type, const Tensor& arg, Args... args) {
221-
return (is_eligible(arg) ? to_type : arg.scalar_type());
222-
}
88+
// Policies correspond to op categories that need code-divergent handling.
89+
// Wrapper templates below are specialized based on a policy template parameter.
90+
enum class CastPolicy : uint8_t {
91+
fp16 = 0, // Cast all inputs to at::kHalf before running the op.
92+
fp32, // Cast all inputs to at::kFloat before running the op.
93+
fp32_set_opt_dtype, // Treats functions (like softmax) that
94+
// 1. we'd like to run in fp32 and
95+
// 2. have a c10::optional<ScalarType> arg that controls the output type.
96+
// fp32_set_opt_dtype wrappers' policy is: if the output type is already set,
97+
// don't touch it, otherwise, set it to at::kFloat.
98+
fp32_append_dtype, // Treats functions (like norm) that
99+
// 1. we'd like to run in fp32 and
100+
// 2. have some overloads that accept an output type and other overloads that don't.
101+
// fp32_append_dtype wrappers wrap the overloads that don't have an output dtype.
102+
// The wrapper policy is: append at::kFloat to the args, and redispatch to the
103+
// type-aware overload.
104+
promote, // Run in the widest dtype among several args.
105+
};
223106

224107
/********************************************************************************************************
225108
Templates to provide wrapper functions
@@ -239,7 +122,7 @@ template<CastPolicy policy, class Redispatch, Redispatch* F, class Ret, class Ar
239122
template<class Redispatch, Redispatch* F, class Ret, class... Args>
240123
struct WrapFunction_<CastPolicy::fp16, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
241124
static Ret call(Args... args) {
242-
c10::impl::ExcludeDispatchKeyGuard no_autocasting(DispatchKey::Autocast);
125+
c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast);
243126
return (*F)(cached_cast(at::kHalf, args)...);
244127
}
245128
};
@@ -248,7 +131,7 @@ struct WrapFunction_<CastPolicy::fp16, Redispatch, F, Ret, guts::typelist::typel
248131
template<class Redispatch, Redispatch* F, class Ret, class... Args>
249132
struct WrapFunction_<CastPolicy::fp32, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
250133
static Ret call(Args... args) {
251-
c10::impl::ExcludeDispatchKeyGuard no_autocasting(DispatchKey::Autocast);
134+
c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast);
252135
return (*F)(cached_cast(at::kFloat, args)...);
253136
}
254137
};
@@ -257,7 +140,7 @@ struct WrapFunction_<CastPolicy::fp32, Redispatch, F, Ret, guts::typelist::typel
257140
template<class Redispatch, Redispatch* F, class Ret, class... Args>
258141
struct WrapFunction_<CastPolicy::fp32_set_opt_dtype, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
259142
static Ret call(Args... args) {
260-
c10::impl::ExcludeDispatchKeyGuard no_autocasting(DispatchKey::Autocast);
143+
c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast);
261144
if (firstarg_is_eligible(args...)) {
262145
return (*F)(set_opt_dtype(at::kFloat, args)...);
263146
} else {
@@ -272,7 +155,7 @@ struct WrapFunction_<CastPolicy::fp32_set_opt_dtype, Redispatch, F, Ret, guts::t
272155
template<class Redispatch, Redispatch* F, class Ret, class... Args>
273156
struct WrapFunction_<CastPolicy::fp32_append_dtype, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
274157
static Ret call(Args... args) {
275-
c10::impl::ExcludeDispatchKeyGuard no_autocasting(DispatchKey::Autocast);
158+
c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast);
276159
at::ScalarType out_type = type_from_firstarg(at::kFloat, args...);
277160
return (*F)(args..., out_type);
278161
}
@@ -282,7 +165,7 @@ struct WrapFunction_<CastPolicy::fp32_append_dtype, Redispatch, F, Ret, guts::ty
282165
template<class Redispatch, Redispatch* F, class Ret, class... Args>
283166
struct WrapFunction_<CastPolicy::promote, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
284167
static Ret call(Args... args) {
285-
c10::impl::ExcludeDispatchKeyGuard no_autocasting(DispatchKey::Autocast);
168+
c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast);
286169
auto to_type = promote_type(at::kHalf, args...);
287170
return (*F)(cached_cast(to_type, args)...);
288171
}
@@ -319,6 +202,7 @@ Tensor binary_cross_entropy_banned(const Tensor &, const Tensor &, const c10::op
319202
"safe to autocast.");
320203
}
321204

205+
322206
#ifndef USE_STATIC_DISPATCH
323207
namespace {
324208
/*****************************************************************************************************************
@@ -422,7 +306,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
422306
KERNEL(ADD_NS(layer_norm), "layer_norm", Tensor (const Tensor &, IntArrayRef, const c10::optional<Tensor>&, const c10::optional<Tensor>&, double, bool), fp32)
423307
// The macro doesn't like this one so I had to write it out manually.
424308
m.impl("native_layer_norm",
425-
TORCH_FN((&WrapFunction<CastPolicy::fp32, std::tuple<Tensor,Tensor,Tensor> (const Tensor &, const c10::optional<Tensor>&, const c10::optional<Tensor>&, int64_t, int64_t, double), std::tuple<Tensor,Tensor,Tensor> (const Tensor &, const c10::optional<Tensor>&, const c10::optional<Tensor>&, int64_t, int64_t, double), &ADD_NS(native_layer_norm)>::type::call)));
309+
TORCH_FN((&WrapFunction<CastPolicy::fp32, std::tuple<Tensor,Tensor,Tensor> (const Tensor &, const c10::optional<Tensor>&, const c10::optional<Tensor>&, int64_t, int64_t, double), std::tuple<Tensor,Tensor,Tensor> (const Tensor &, const c10::optional<Tensor>&, const c10::optional<Tensor>&, int64_t, int64_t, double), &ADD_NS(native_layer_norm)>::type::call)));
426310
KERNEL(ADD_NS(group_norm), "group_norm", Tensor (const Tensor &, int64_t, const c10::optional<Tensor>&, const c10::optional<Tensor>&, double, bool), fp32)
427311
KERNEL(ADD_NS(frobenius_norm), "frobenius_norm", Tensor (const Tensor &), fp32)
428312
KERNEL(ADD_NS(frobenius_norm), "frobenius_norm.dim", Tensor (const Tensor &, IntArrayRef, bool), fp32)
@@ -490,7 +374,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
490374
KERNEL(ADD_NS(tensordot), "tensordot", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef), promote)
491375

492376
m.impl("binary_cross_entropy",
493-
TORCH_FN((&at::autocast::binary_cross_entropy_banned)));
377+
TORCH_FN((&at::autocast::binary_cross_entropy_banned)));
494378
}
495379

496380
}

aten/src/ATen/autocast_mode.h

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,128 @@ TORCH_API void clear_cache();
99
TORCH_API int increment_nesting();
1010
TORCH_API int decrement_nesting();
1111

12+
/********************************************************************
13+
Logic to extract the promote type from any Tensor or TensorList args.
14+
********************************************************************/
15+
16+
// Overload to catch Tensor args.
17+
// If nextArg is floating-point, compare its scalar_type with our
18+
// current best guess for the promote type, and update if necessary.
19+
inline at::ScalarType prioritize(at::ScalarType current, const Tensor& nextArg) {
20+
if (current == at::kDouble) {
21+
AT_ERROR("promote type is double in at::autocast::prioritize");
22+
return current;
23+
}
24+
if (nextArg.is_cuda() && nextArg.is_floating_point()) {
25+
auto next = nextArg.scalar_type();
26+
if (next == at::kDouble) {
27+
return current; // ignores double tensors
28+
} else if (current == at::kFloat || next == at::kFloat) {
29+
return at::kFloat; // prioritizes float over half
30+
} else if (current == at::kHalf && next == at::kHalf) {
31+
return at::kHalf;
32+
} else {
33+
AT_ERROR("Unexpected floating ScalarType in at::autocast::prioritize");
34+
return current;
35+
}
36+
} else {
37+
return current;
38+
}
39+
}
40+
41+
// Overload to catch TensorList args (for e.g. cat, stack).
42+
// Reuses the overload above to process each Tensor in the list.
43+
inline at::ScalarType prioritize(at::ScalarType current, const TensorList& list) {
44+
for (const auto& tensor : list) {
45+
current = prioritize(current, tensor);
46+
}
47+
return current;
48+
}
49+
50+
// Template to catch non-Tensor args (no-op that returns current best guess)
51+
template<typename T>
52+
inline at::ScalarType prioritize(at::ScalarType current, T nextArg) {
53+
return current;
54+
}
55+
56+
// Overload for the tail case.
57+
inline at::ScalarType promote_type(at::ScalarType current) {
58+
return current;
59+
}
60+
61+
// Unpack args and determine if incoming float16 tensors need to be promoted to float32.
62+
// Non-Tensor arguments are ignored.
63+
template<typename Arg0, typename... Args>
64+
inline at::ScalarType promote_type(at::ScalarType current, Arg0 arg0, Args... args) {
65+
auto new_current = prioritize(current, arg0);
66+
return promote_type(new_current, args...);
67+
}
68+
69+
/****************************************************
70+
Logic to apply cached casting to any Tensor argument.
71+
****************************************************/
72+
inline bool is_eligible(const Tensor& arg) {
73+
return (arg.defined() && arg.is_cuda() && arg.is_floating_point() && (arg.scalar_type() != at::kDouble));
74+
}
75+
76+
// Overload to catch Tensor args
77+
TORCH_API Tensor cached_cast(at::ScalarType to_type, const Tensor& arg);
78+
79+
// Overload to process optional<Tensor>
80+
inline c10::optional<Tensor> cached_cast(at::ScalarType to_type, const c10::optional<Tensor>& arg) {
81+
if (arg.has_value()) {
82+
return cached_cast(to_type, *arg);
83+
} else {
84+
return c10::nullopt;
85+
}
86+
}
87+
88+
// Overload to process TensorLists
89+
inline std::vector<Tensor> cached_cast(at::ScalarType to_type, const TensorList& arg) {
90+
std::vector<Tensor> vec;
91+
vec.reserve(arg.size());
92+
for (const auto& t : arg) {
93+
vec.push_back(cached_cast(to_type, t));
94+
}
95+
return vec;
96+
}
97+
98+
// Template to catch non-Tensor args.
99+
template<typename T>
100+
inline T cached_cast(at::ScalarType to_type, T arg) {
101+
return arg;
102+
}
103+
104+
/*******************************************************
105+
Logic to flip an output dtype flag.
106+
Keep it simple for now by assuming only one such flag is
107+
present in the argument list. If I ever need a function
108+
with more than flag I'll figure out something else.
109+
The policy is:
110+
If the user has explicity specified a dtype, respect it.
111+
Otherwise, set it to the autocast type.
112+
********************************************************/
113+
114+
// Overload to catch dtype flags
115+
c10::optional<ScalarType> inline set_opt_dtype(at::ScalarType to_type, const c10::optional<ScalarType>& dtype) {
116+
return dtype.has_value() ? dtype : to_type;
117+
}
118+
119+
// Template to catch other args
120+
template<typename T>
121+
inline T set_opt_dtype(at::ScalarType to_type, T arg) {
122+
return arg;
123+
}
124+
125+
template<typename... Args>
126+
inline bool firstarg_is_eligible(const Tensor& arg, Args... args) {
127+
return is_eligible(arg);
128+
}
129+
130+
template<typename... Args>
131+
inline at::ScalarType type_from_firstarg(at::ScalarType to_type, const Tensor& arg, Args... args) {
132+
return (is_eligible(arg) ? to_type : arg.scalar_type());
133+
}
134+
12135
} // namespace autocast
13136
} // namespace at

0 commit comments

Comments
 (0)