Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion aten/src/ATen/common_with_cwrap.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# this code should be common among cwrap and ATen preprocessing
# for now, I have put it in one place but right now is copied out of cwrap

import copy

def parse_arguments(args):
new_args = []
Expand Down Expand Up @@ -50,11 +51,16 @@ def set_declaration_defaults(declaration):
declaration['unqual_operator_name_with_overload'] = ''
# Simulate multiple dispatch, even if it's not necessary
if 'options' not in declaration:
declaration['options'] = [{'arguments': declaration['arguments']}]
declaration['options'] = [{
'arguments': copy.deepcopy(declaration['arguments']),
'schema_order_arguments': copy.deepcopy(declaration['schema_order_arguments']),
}]
del declaration['arguments']
del declaration['schema_order_arguments']
# Parse arguments (some of them can be strings)
for option in declaration['options']:
option['arguments'] = parse_arguments(option['arguments'])
option['schema_order_arguments'] = parse_arguments(option['schema_order_arguments'])
# Propagate defaults from declaration to options
for option in declaration['options']:
for k, v in declaration.items():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ using supported_primitive_arg_types = guts::typelist::typelist<
at::Tensor,
at::Scalar,
c10::QScheme,
c10::ScalarType
c10::ScalarType,
c10::Device,
c10::Layout,
c10::MemoryFormat
>;

template<class T, bool AllowDeprecatedTypes, class Enable = void> struct assert_is_valid_input_type {
Expand Down
27 changes: 24 additions & 3 deletions aten/src/ATen/core/jit_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -1517,9 +1517,12 @@ namespace detail {
template <typename T>
struct getTypePtr_ final {
static TypePtr call() {
if (!isCustomClassRegistered<T>()) {
throw c10::Error("Type could not be converted to any of the known types.", "");
}
TORCH_CHECK(
isCustomClassRegistered<T>(),
"Type ",
c10::util::get_fully_qualified_type_name<T>(),
" could not be converted to any of the known types."
);
auto res = getCustomClassType<T>();
return std::dynamic_pointer_cast<Type>(std::move(res));
}
Expand Down Expand Up @@ -1557,6 +1560,24 @@ struct getTypePtr_<c10::ScalarType> final {
}
};
template <>
struct getTypePtr_<c10::Device> final {
static TypePtr call() {
return DeviceObjType::get();
}
};
template <>
struct getTypePtr_<c10::Layout> final {
static TypePtr call() {
return IntType::get();
}
};
template <>
struct getTypePtr_<c10::MemoryFormat> final {
static TypePtr call() {
return IntType::get();
}
};
template <>
struct getTypePtr_<bool> final {
static TypePtr call() {
return BoolType::get();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#pragma once

#include <c10/util/Metaprogramming.h>
#include <c10/util/TypeList.h>
#include <c10/core/TensorOptions.h>
#include <c10/core/CompileTimeFunctionPointer.h>

// This file defines hacky_wrapper_for_legacy_signatures, which takes a kernel written in a legacy way
// (e.g. with TensorOptions packed) and wraps it into a kernel with the signature expected by
// the PyTorch operator library. The intention is to ultimately rewrite kernels to take the new signature
// and then delete this file. This transition process can happen kernel-by-kernel, since this wrapper
// is a no-op for kernels that already have a non-legacy signature.

namespace c10 {
namespace impl {

inline c10::optional<MemoryFormat> process_memory_format(const TensorOptions& options, c10::optional<MemoryFormat> memory_format) {
TORCH_CHECK(
!(options.has_memory_format() && memory_format.has_value()),
"Cannot set memory_format both in TensorOptions and explicit argument; please delete "
"the redundant setter.");
if (memory_format.has_value()) {
return memory_format;
} else {
return options.memory_format_opt();
}
}

namespace detail {

// with_scattered_tensor_options takes a function pointer that potentially takes a TensorOptions argument.
// If it does, then it creates a new function pointer that takes scattered arguments, internally
// gathers those arguments, and then calls the underlying function pointer. If the underlying
// function pointer does not take a TensorOptions argument, it is passed through unmodified.

template<class Type, class Enable = void> struct is_tensoroptions_arg : std::false_type {};
template<class Type> struct is_tensoroptions_arg<Type, std::enable_if_t<std::is_same<TensorOptions, std::decay_t<Type>>::value>> : std::true_type {};
template<class Type>
using is_tensoroptions_arg_t = typename is_tensoroptions_arg<Type>::type;

template<class FuncType>
inline constexpr bool has_tensoroptions_arg() {
using parameter_types = typename guts::infer_function_traits_t<FuncType>::parameter_types;
constexpr size_t num_tensoroptions_args = guts::typelist::count_if<is_tensoroptions_arg_t, parameter_types>::value;
static_assert(num_tensoroptions_args <= 1, "Function has multiple TensorOptions parameters. We support at most one.");
return num_tensoroptions_args > 0;
}

// sanity checks
static_assert(has_tensoroptions_arg<int (int64_t, const TensorOptions&)>(), "");
static_assert(has_tensoroptions_arg<int (int64_t, TensorOptions)>(), "");
static_assert(!has_tensoroptions_arg<int (int64_t, std::string)>(), "");

template<class FuncPtr, class ParametersBeforeTensorOptions, class ParametersAfterTensorOptions> struct with_scattered_tensor_options_;

template<class FuncPtr, class Enable = void>
struct with_scattered_tensor_options final {};

template<class UnderlyingFuncPtr>
struct with_scattered_tensor_options<UnderlyingFuncPtr, std::enable_if_t<!has_tensoroptions_arg<typename UnderlyingFuncPtr::FuncType>()>> final {
// FuncType does not have TensorOptions arguments.
// Don't wrap anything but just return the base pointer.
using FuncPtr = UnderlyingFuncPtr;
};

template<class UnderlyingFuncPtr>
struct with_scattered_tensor_options<UnderlyingFuncPtr, std::enable_if_t<has_tensoroptions_arg<typename UnderlyingFuncPtr::FuncType>()>> final {
private:
// FuncType has TensorOptions arguments.
// Return a function pointer to a wrapper function that replaces those with expanded arguments.
using gathered_parameter_types = typename guts::infer_function_traits_t<typename UnderlyingFuncPtr::FuncType>::parameter_types;
static constexpr size_t tensoroptions_arg_index =
guts::typelist::find_if<
gathered_parameter_types,
is_tensoroptions_arg_t
>::value;

using parameters_before_tensoroptions =
guts::typelist::take_t<gathered_parameter_types, tensoroptions_arg_index>;
using parameters_after_tensoroptions =
guts::typelist::drop_t<gathered_parameter_types, tensoroptions_arg_index + 1>;

using wrapper = with_scattered_tensor_options_<UnderlyingFuncPtr, parameters_before_tensoroptions, parameters_after_tensoroptions>;
public:
using FuncPtr = TORCH_FN_TYPE(&wrapper::wrapper);
};

template<class FuncPtr, class... ParametersBeforeTensorOptions, class... ParametersAfterTensorOptions>
struct with_scattered_tensor_options_<FuncPtr, guts::typelist::typelist<ParametersBeforeTensorOptions...>, guts::typelist::typelist<ParametersAfterTensorOptions...>> final {
static decltype(auto) wrapper(
ParametersBeforeTensorOptions... parameters_before,
optional<ScalarType> scalar_type,
optional<Layout> layout,
optional<Device> device,
optional<bool> pin_memory,
ParametersAfterTensorOptions... parameters_after) {
return (*FuncPtr::func_ptr())(
std::forward<ParametersBeforeTensorOptions>(parameters_before)...,
TensorOptions().dtype(scalar_type).device(device).layout(layout).pinned_memory(pin_memory),
std::forward<ParametersAfterTensorOptions>(parameters_after)...
);
}
};

}

template<class FuncPtr>
constexpr auto hacky_wrapper_for_legacy_signatures(FuncPtr) {
return typename detail::with_scattered_tensor_options<FuncPtr>::FuncPtr();
};

}
}
11 changes: 11 additions & 0 deletions aten/src/ATen/cwrap_parser.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import yaml
import copy

try:
# use faster C loader if available
from yaml import CLoader as Loader
Expand All @@ -24,4 +26,13 @@ def parse(filename):
declarations.append(declaration)
elif in_declaration:
declaration_lines.append(line)
declarations = [process_declaration(declaration) for declaration in declarations]
return declarations

def process_declaration(declaration):
declaration = copy.deepcopy(declaration)
if "arguments" in declaration:
declaration["schema_order_arguments"] = copy.deepcopy(declaration["arguments"])
if "options" in declaration:
declaration["options"] = [process_declaration(option) for option in declaration["options"]]
return declaration
Loading