Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -803,9 +803,9 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenFallbackKernelWith
EXPECT_EQ(4, outputs[0].toInt());
}

c10::optional<Tensor> called_arg2;
c10::optional<int64_t> called_arg3;
c10::optional<std::string> called_arg4;
c10::optional<Tensor> called_arg2 = c10::nullopt;
c10::optional<int64_t> called_arg3 = c10::nullopt;
c10::optional<std::string> called_arg4 = c10::nullopt;

void kernelWithOptInputWithoutOutput(Tensor arg1, const c10::optional<Tensor>& arg2, c10::optional<int64_t> arg3, c10::optional<std::string> arg4) {
called = true;
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/core/op_registration/kernel_function_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -526,9 +526,9 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenFallbackKernelWithoutTen
EXPECT_EQ(4, outputs[0].toInt());
}

c10::optional<Tensor> called_arg2;
c10::optional<int64_t> called_arg3;
c10::optional<std::string> called_arg4;
c10::optional<Tensor> called_arg2 = c10::nullopt;
c10::optional<int64_t> called_arg3 = c10::nullopt;
c10::optional<std::string> called_arg4 = c10::nullopt;

void kernelWithOptInputWithoutOutput(Tensor arg1, const c10::optional<Tensor>& arg2, c10::optional<int64_t> arg3, c10::optional<std::string> arg4) {
called = true;
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/core/op_registration/kernel_functor_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -675,9 +675,9 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenFallbackKernelWithoutTens
EXPECT_EQ(4, outputs[0].toInt());
}

c10::optional<Tensor> called_arg2;
c10::optional<int64_t> called_arg3;
c10::optional<std::string> called_arg4;
c10::optional<Tensor> called_arg2 = c10::nullopt;
c10::optional<int64_t> called_arg3 = c10::nullopt;
c10::optional<std::string> called_arg4 = c10::nullopt;

struct KernelWithOptInputWithoutOutput final : OperatorKernel {
void operator()(Tensor arg1, const c10::optional<Tensor>& arg2, c10::optional<int64_t> arg3, c10::optional<std::string> arg4) {
Expand Down
18 changes: 9 additions & 9 deletions aten/src/ATen/core/op_registration/kernel_lambda_legacy_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -729,9 +729,9 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenFallbackKernelWithou

TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithOptionalInputs_withoutOutput_whenRegistered_thenCanBeCalled) {
bool called;
c10::optional<Tensor> called_arg2;
c10::optional<int64_t> called_arg3;
c10::optional<std::string> called_arg4;
c10::optional<Tensor> called_arg2 = c10::nullopt;
c10::optional<int64_t> called_arg3 = c10::nullopt;
c10::optional<std::string> called_arg4 = c10::nullopt;

auto registrar = RegisterOperators().op(
"_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> ()",
Expand Down Expand Up @@ -768,9 +768,9 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithOptionalIn

TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithOptionalInputs_withOutput_whenRegistered_thenCanBeCalled) {
bool called;
c10::optional<Tensor> called_arg2;
c10::optional<int64_t> called_arg3;
c10::optional<std::string> called_arg4;
c10::optional<Tensor> called_arg2 = c10::nullopt;
c10::optional<int64_t> called_arg3 = c10::nullopt;
c10::optional<std::string> called_arg4 = c10::nullopt;

auto registrar = RegisterOperators().op(
"_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> Tensor?",
Expand Down Expand Up @@ -810,9 +810,9 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithOptionalIn

TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithOptionalInputs_withMultipleOutputs_whenRegistered_thenCanBeCalled) {
bool called;
c10::optional<Tensor> called_arg2;
c10::optional<int64_t> called_arg3;
c10::optional<std::string> called_arg4;
c10::optional<Tensor> called_arg2 = c10::nullopt;
c10::optional<int64_t> called_arg3 = c10::nullopt;
c10::optional<std::string> called_arg4 = c10::nullopt;

auto registrar = RegisterOperators().op(
"_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> (Tensor?, int?, str?)",
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/core/op_registration/kernel_lambda_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -461,9 +461,9 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenFallbackKernelWithoutTenso
EXPECT_EQ(4, outputs[0].toInt());
}

c10::optional<Tensor> called_arg2;
c10::optional<int64_t> called_arg3;
c10::optional<std::string> called_arg4;
c10::optional<Tensor> called_arg2 = c10::nullopt;
c10::optional<int64_t> called_arg3 = c10::nullopt;
c10::optional<std::string> called_arg4 = c10::nullopt;

TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithOptionalInputs_withoutOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators().op(
Expand Down
7 changes: 6 additions & 1 deletion aten/src/ATen/core/op_registration/op_registration.h
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,12 @@ class CAFFE2_API RegisterOperators final {
);
}

Options() = default;
Options()
: schemaOrName_(c10::nullopt)
, legacyATenSchema_(c10::nullopt)
, kernels()
, aliasAnalysisKind_(c10::nullopt)
{}

// KernelRegistrationConfig accumulates all information from the config
// parameters passed to a RegisterOperators::op() call into one object.
Expand Down
24 changes: 12 additions & 12 deletions aten/src/ATen/core/op_registration/op_registration_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -818,28 +818,28 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) {

// optional types (with has_value() == false)
testArgTypes<c10::optional<double>>::test(
c10::optional<double>(), [] (const c10::optional<double>& v) {EXPECT_FALSE(v.has_value());},
c10::optional<double>(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
c10::optional<double>(c10::nullopt), [] (const c10::optional<double>& v) {EXPECT_FALSE(v.has_value());},
c10::optional<double>(c10::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
"(float? a) -> float?");
testArgTypes<c10::optional<int64_t>>::test(
c10::optional<int64_t>(), [] (const c10::optional<int64_t>& v) {EXPECT_FALSE(v.has_value());},
c10::optional<int64_t>(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
c10::optional<int64_t>(c10::nullopt), [] (const c10::optional<int64_t>& v) {EXPECT_FALSE(v.has_value());},
c10::optional<int64_t>(c10::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
"(int? a) -> int?");
testArgTypes<c10::optional<bool>>::test(
c10::optional<bool>(), [] (const c10::optional<bool>& v) {EXPECT_FALSE(v.has_value());},
c10::optional<bool>(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
c10::optional<bool>(c10::nullopt), [] (const c10::optional<bool>& v) {EXPECT_FALSE(v.has_value());},
c10::optional<bool>(c10::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
"(bool? a) -> bool?");
testArgTypes<c10::optional<bool>>::test(
c10::optional<bool>(), [] (const c10::optional<bool>& v) {EXPECT_FALSE(v.has_value());},
c10::optional<bool>(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
c10::optional<bool>(c10::nullopt), [] (const c10::optional<bool>& v) {EXPECT_FALSE(v.has_value());},
c10::optional<bool>(c10::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
"(bool? a) -> bool?");
testArgTypes<c10::optional<std::string>>::test(
c10::optional<std::string>(), [] (const c10::optional<std::string>& v) {EXPECT_FALSE(v.has_value());},
c10::optional<std::string>(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
c10::optional<std::string>(c10::nullopt), [] (const c10::optional<std::string>& v) {EXPECT_FALSE(v.has_value());},
c10::optional<std::string>(c10::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
"(str? a) -> str?");
testArgTypes<c10::optional<Tensor>>::test(
c10::optional<Tensor>(), [] (const c10::optional<Tensor>& v) {EXPECT_FALSE(v.has_value());},
c10::optional<Tensor>(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
c10::optional<Tensor>(c10::nullopt), [] (const c10::optional<Tensor>& v) {EXPECT_FALSE(v.has_value());},
c10::optional<Tensor>(c10::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
"(Tensor? a) -> Tensor?");


Expand Down