@@ -59,90 +59,10 @@ int decrement_nesting() {
5959 return --nesting;
6060}
6161
62- // Policies correspond to op categories that need code-divergent handling.
63- // Wrapper templates below are specialized based on a policy template parameter.
64- enum class CastPolicy : uint8_t {
65- fp16 = 0 , // Cast all inputs to at::kHalf before running the op.
66- fp32, // Cast all inputs to at::kFloat before running the op.
67- fp32_set_opt_dtype, // Treats functions (like softmax) that
68- // 1. we'd like to run in fp32 and
69- // 2. have a c10::optional<ScalarType> arg that controls the output type.
70- // fp32_set_opt_dtype wrappers' policy is: if the output type is already set,
71- // don't touch it, otherwise, set it to at::kFloat.
72- fp32_append_dtype, // Treats functions (like norm) that
73- // 1. we'd like to run in fp32 and
74- // 2. have some overloads that accept an output type and other overloads that don't.
75- // fp32_append_dtype wrappers wrap the overloads that don't have an output dtype.
76- // The wrapper policy is: append at::kFloat to the args, and redispatch to the
77- // type-aware overload.
78- promote, // Run in the widest dtype among several args.
79- };
80-
81- /* *******************************************************************
82- Logic to extract the promote type from any Tensor or TensorList args.
83- ********************************************************************/
84-
85- // Overload to catch Tensor args.
86- // If nextArg is floating-point, compare its scalar_type with our
87- // current best guess for the promote type, and update if necessary.
88- inline at::ScalarType prioritize (at::ScalarType current, const Tensor& nextArg) {
89- if (current == at::kDouble ) {
90- AT_ERROR (" promote type is double in at::autocast::prioritize" );
91- return current;
92- }
93- if (nextArg.is_cuda () && nextArg.is_floating_point ()) {
94- auto next = nextArg.scalar_type ();
95- if (next == at::kDouble ) {
96- return current; // ignores double tensors
97- } else if (current == at::kFloat || next == at::kFloat ) {
98- return at::kFloat ; // prioritizes float over half
99- } else if (current == at::kHalf && next == at::kHalf ) {
100- return at::kHalf ;
101- } else {
102- AT_ERROR (" Unexpected floating ScalarType in at::autocast::prioritize" );
103- return current;
104- }
105- } else {
106- return current;
107- }
108- }
109-
110- // Overload to catch TensorList args (for e.g. cat, stack).
111- // Reuses the overload above to process each Tensor in the list.
112- inline at::ScalarType prioritize (at::ScalarType current, const TensorList& list) {
113- for (const auto & tensor : list) {
114- current = prioritize (current, tensor);
115- }
116- return current;
117- }
118-
119- // Template to catch non-Tensor args (no-op that returns current best guess)
120- template <typename T>
121- inline at::ScalarType prioritize (at::ScalarType current, T nextArg) {
122- return current;
123- }
124-
125- // Overload for the tail case.
126- inline at::ScalarType promote_type (at::ScalarType current) {
127- return current;
128- }
129-
130- // Unpack args and determine if incoming float16 tensors need to be promoted to float32.
131- // Non-Tensor arguments are ignored.
132- template <typename Arg0, typename ... Args>
133- inline at::ScalarType promote_type (at::ScalarType current, Arg0 arg0, Args... args) {
134- auto new_current = prioritize (current, arg0);
135- return promote_type (new_current, args...);
136- }
137-
138- /* ***************************************************
139- Logic to apply cached casting to any Tensor argument.
140- ****************************************************/
141- inline bool is_eligible (const Tensor& arg) {
142- return (arg.is_cuda () && arg.is_floating_point () && (arg.scalar_type () != at::kDouble ));
143- }
144-
14562// Overload to catch Tensor args
63+ // TODO (possible optimization): Move cast_cache to an inline function in a header
64+ // (+ refactor the can_try_cache branch to call a small non-inline helper function.
65+ // can_try_cache branch is the only part that's hard to inline in other files).
14666Tensor cached_cast (at::ScalarType to_type, const Tensor& arg) {
14767 if (is_eligible (arg) && (arg.scalar_type () != to_type)) {
14868 // Heuristic: Do what Apex does, and cache fp16 casts of fp32 model weights (leaves).
@@ -165,61 +85,24 @@ Tensor cached_cast(at::ScalarType to_type, const Tensor& arg) {
16585 }
16686}
16787
168- // Overload to process optional<Tensor>
169- c10::optional<Tensor> cached_cast (at::ScalarType to_type, const c10::optional<Tensor>& arg) {
170- if (arg.has_value ()) {
171- return cached_cast (to_type, *arg);
172- } else {
173- return c10::nullopt ;
174- }
175- }
176-
177- // Overload to process TensorLists
178- std::vector<Tensor> cached_cast (at::ScalarType to_type, const TensorList& arg) {
179- std::vector<Tensor> vec;
180- vec.reserve (arg.size ());
181- for (const auto & t : arg) {
182- vec.push_back (cached_cast (to_type, t));
183- }
184- return vec;
185- }
186-
187- // Template to catch non-Tensor args.
188- template <typename T>
189- T cached_cast (at::ScalarType to_type, T arg) {
190- return arg;
191- }
192-
193- /* ******************************************************
194- Logic to flip an output dtype flag.
195- Keep it simple for now by assuming only one such flag is
196- present in the argument list. If I ever need a function
197- with more than flag I'll figure out something else.
198- The policy is:
199- If the user has explicity specified a dtype, respect it.
200- Otherwise, set it to the autocast type.
201- ********************************************************/
202-
203- // Overload to catch dtype flags
204- c10::optional<ScalarType> set_opt_dtype (at::ScalarType to_type, const c10::optional<ScalarType>& dtype) {
205- return dtype.has_value () ? dtype : to_type;
206- }
207-
208- // Template to catch other args
209- template <typename T>
210- inline T set_opt_dtype (at::ScalarType to_type, T arg) {
211- return arg;
212- }
213-
214- template <typename ... Args>
215- inline bool firstarg_is_eligible (const Tensor& arg, Args... args) {
216- return is_eligible (arg);
217- }
218-
219- template <typename ... Args>
220- inline at::ScalarType type_from_firstarg (at::ScalarType to_type, const Tensor& arg, Args... args) {
221- return (is_eligible (arg) ? to_type : arg.scalar_type ());
222- }
88+ // Policies correspond to op categories that need code-divergent handling.
89+ // Wrapper templates below are specialized based on a policy template parameter.
90+ enum class CastPolicy : uint8_t {
91+ fp16 = 0 , // Cast all inputs to at::kHalf before running the op.
92+ fp32, // Cast all inputs to at::kFloat before running the op.
93+ fp32_set_opt_dtype, // Treats functions (like softmax) that
94+ // 1. we'd like to run in fp32 and
95+ // 2. have a c10::optional<ScalarType> arg that controls the output type.
96+ // fp32_set_opt_dtype wrappers' policy is: if the output type is already set,
97+ // don't touch it, otherwise, set it to at::kFloat.
98+ fp32_append_dtype, // Treats functions (like norm) that
99+ // 1. we'd like to run in fp32 and
100+ // 2. have some overloads that accept an output type and other overloads that don't.
101+ // fp32_append_dtype wrappers wrap the overloads that don't have an output dtype.
102+ // The wrapper policy is: append at::kFloat to the args, and redispatch to the
103+ // type-aware overload.
104+ promote, // Run in the widest dtype among several args.
105+ };
223106
224107/* *******************************************************************************************************
225108Templates to provide wrapper functions
@@ -239,7 +122,7 @@ template<CastPolicy policy, class Redispatch, Redispatch* F, class Ret, class Ar
239122template <class Redispatch , Redispatch* F, class Ret , class ... Args>
240123struct WrapFunction_ <CastPolicy::fp16, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
241124 static Ret call (Args... args) {
242- c10::impl::ExcludeDispatchKeyGuard no_autocasting (DispatchKey::Autocast);
125+ c10::impl::ExcludeDispatchKeyGuard no_autocast (DispatchKey::Autocast);
243126 return (*F)(cached_cast (at::kHalf , args)...);
244127 }
245128};
@@ -248,7 +131,7 @@ struct WrapFunction_<CastPolicy::fp16, Redispatch, F, Ret, guts::typelist::typel
248131template <class Redispatch , Redispatch* F, class Ret , class ... Args>
249132struct WrapFunction_ <CastPolicy::fp32, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
250133 static Ret call (Args... args) {
251- c10::impl::ExcludeDispatchKeyGuard no_autocasting (DispatchKey::Autocast);
134+ c10::impl::ExcludeDispatchKeyGuard no_autocast (DispatchKey::Autocast);
252135 return (*F)(cached_cast (at::kFloat , args)...);
253136 }
254137};
@@ -257,7 +140,7 @@ struct WrapFunction_<CastPolicy::fp32, Redispatch, F, Ret, guts::typelist::typel
257140template <class Redispatch , Redispatch* F, class Ret , class ... Args>
258141struct WrapFunction_ <CastPolicy::fp32_set_opt_dtype, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
259142 static Ret call (Args... args) {
260- c10::impl::ExcludeDispatchKeyGuard no_autocasting (DispatchKey::Autocast);
143+ c10::impl::ExcludeDispatchKeyGuard no_autocast (DispatchKey::Autocast);
261144 if (firstarg_is_eligible (args...)) {
262145 return (*F)(set_opt_dtype (at::kFloat , args)...);
263146 } else {
@@ -272,7 +155,7 @@ struct WrapFunction_<CastPolicy::fp32_set_opt_dtype, Redispatch, F, Ret, guts::t
272155template <class Redispatch , Redispatch* F, class Ret , class ... Args>
273156struct WrapFunction_ <CastPolicy::fp32_append_dtype, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
274157 static Ret call (Args... args) {
275- c10::impl::ExcludeDispatchKeyGuard no_autocasting (DispatchKey::Autocast);
158+ c10::impl::ExcludeDispatchKeyGuard no_autocast (DispatchKey::Autocast);
276159 at::ScalarType out_type = type_from_firstarg (at::kFloat , args...);
277160 return (*F)(args..., out_type);
278161 }
@@ -282,7 +165,7 @@ struct WrapFunction_<CastPolicy::fp32_append_dtype, Redispatch, F, Ret, guts::ty
282165template <class Redispatch , Redispatch* F, class Ret , class ... Args>
283166struct WrapFunction_ <CastPolicy::promote, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
284167 static Ret call (Args... args) {
285- c10::impl::ExcludeDispatchKeyGuard no_autocasting (DispatchKey::Autocast);
168+ c10::impl::ExcludeDispatchKeyGuard no_autocast (DispatchKey::Autocast);
286169 auto to_type = promote_type (at::kHalf , args...);
287170 return (*F)(cached_cast (to_type, args)...);
288171 }
@@ -319,6 +202,7 @@ Tensor binary_cross_entropy_banned(const Tensor &, const Tensor &, const c10::op
319202 " safe to autocast." );
320203}
321204
205+
322206#ifndef USE_STATIC_DISPATCH
323207namespace {
324208/* ****************************************************************************************************************
@@ -422,7 +306,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
422306 KERNEL (ADD_NS (layer_norm), " layer_norm" , Tensor (const Tensor &, IntArrayRef, const c10::optional<Tensor>&, const c10::optional<Tensor>&, double , bool ), fp32)
423307 // The macro doesn't like this one so I had to write it out manually.
424308 m.impl (" native_layer_norm" ,
425- TORCH_FN ((&WrapFunction<CastPolicy::fp32, std::tuple<Tensor,Tensor,Tensor> (const Tensor &, const c10::optional<Tensor>&, const c10::optional<Tensor>&, int64_t , int64_t , double ), std::tuple<Tensor,Tensor,Tensor> (const Tensor &, const c10::optional<Tensor>&, const c10::optional<Tensor>&, int64_t , int64_t , double ), &ADD_NS (native_layer_norm)>::type::call)));
309+ TORCH_FN ((&WrapFunction<CastPolicy::fp32, std::tuple<Tensor,Tensor,Tensor> (const Tensor &, const c10::optional<Tensor>&, const c10::optional<Tensor>&, int64_t , int64_t , double ), std::tuple<Tensor,Tensor,Tensor> (const Tensor &, const c10::optional<Tensor>&, const c10::optional<Tensor>&, int64_t , int64_t , double ), &ADD_NS (native_layer_norm)>::type::call)));
426310 KERNEL (ADD_NS (group_norm), " group_norm" , Tensor (const Tensor &, int64_t , const c10::optional<Tensor>&, const c10::optional<Tensor>&, double , bool ), fp32)
427311 KERNEL (ADD_NS (frobenius_norm), " frobenius_norm" , Tensor (const Tensor &), fp32)
428312 KERNEL (ADD_NS (frobenius_norm), " frobenius_norm.dim" , Tensor (const Tensor &, IntArrayRef, bool ), fp32)
@@ -490,7 +374,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
490374 KERNEL (ADD_NS (tensordot), " tensordot" , Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef), promote)
491375
492376 m.impl (" binary_cross_entropy" ,
493- TORCH_FN ((&at::autocast::binary_cross_entropy_banned)));
377+ TORCH_FN ((&at::autocast::binary_cross_entropy_banned)));
494378}
495379
496380}
0 commit comments