Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
3ec6148
Replace Type dispatch with ATenDispatch
Jun 4, 2019
799330a
Update on "Replace Type dispatch with ATenDispatch"
Jun 4, 2019
897abf3
Update on "[wip] Replace Type dispatch with ATenDispatch"
Jun 4, 2019
5c15cdb
Update on "[wip] Replace Type dispatch with ATenDispatch"
Jun 4, 2019
75b66cd
Update on "[wip] Replace Type dispatch with ATenDispatch"
Jun 5, 2019
3e0aaad
Update on "[BC-BREAKING] Replace Type dispatch with ATenDispatch"
Jun 6, 2019
7472cdc
Update on "[BC-BREAKING] Replace Type dispatch with ATenDispatch"
Jun 6, 2019
11ccf9a
Update on "[BC-BREAKING] Replace Type dispatch with ATenDispatch"
Jun 6, 2019
2babd58
Update on "[BC-BREAKING] Replace Type dispatch with ATenDispatch"
Jun 7, 2019
cef4290
Update on "[BC-BREAKING] Replace Type dispatch with ATenDispatch"
Jun 7, 2019
33fb37e
Update on "[BC-BREAKING] Replace Type dispatch with ATenDispatch"
Jun 7, 2019
579dd3f
Update on "[BC-BREAKING] Replace Type dispatch with ATenDispatch"
Jun 7, 2019
6840d9b
Update on "[BC-BREAKING] Replace Type dispatch with ATenDispatch"
Jun 7, 2019
f495d97
Update on "[BC-BREAKING] Replace Type dispatch with ATenDispatch"
Jun 10, 2019
98fb2ee
Update on "[BC-BREAKING] Replace Type dispatch with ATenDispatch"
Jun 11, 2019
f1da595
Update on "[BC-BREAKING] Replace Type dispatch with ATenDispatch"
Jun 12, 2019
d098e00
Update on "[BC-BREAKING] Replace Type dispatch with ATenDispatch"
Jun 12, 2019
afcf125
Update on "[BC-BREAKING] Replace Type dispatch with ATenDispatch"
Jun 12, 2019
d7f144d
Update on "[BC-BREAKING] Replace Type dispatch with ATenDispatch"
Jun 14, 2019
bd29609
Update on "[BC-BREAKING] Replace Type dispatch with ATenDispatch"
Jun 18, 2019
629371b
Update on "[BC-BREAKING] Replace Type dispatch with ATenDispatch"
Jun 18, 2019
97ed318
Update on "[BC-BREAKING] Replace Type dispatch with ATenDispatch"
Jun 19, 2019
c168c0f
Update on "[BC-BREAKING] Replace Type dispatch with ATenDispatch"
Jun 19, 2019
a1e8382
Update on "[BC-BREAKING] Replace Type dispatch with ATenDispatch"
Jun 19, 2019
cf1523d
Update on "[BC-BREAKING] Replace Type dispatch with ATenDispatch"
Jun 19, 2019
68f6d9d
Update on "[BC-BREAKING] Replace Type dispatch with ATenDispatch"
Jun 19, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aten/src/ATen/ATen.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@
#include <c10/core/Storage.h>
#include <c10/core/TensorOptions.h>
#include <c10/util/Exception.h>
#include <ATen/core/ATenDispatch.h>
10 changes: 10 additions & 0 deletions aten/src/ATen/core/ATenDispatch.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#include <ATen/core/ATenDispatch.h>

namespace at {

ATenDispatch & globalATenDispatch() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registering to the global table is not thread safe, right? Gotta be careful: library loads can happen in different threads.

static ATenDispatch singleton;
return singleton;
}

} // namespace at
106 changes: 106 additions & 0 deletions aten/src/ATen/core/ATenDispatch.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#pragma once

#include <c10/core/Backend.h>
#include <unordered_map>
#include <c10/util/C++17.h>
#include <memory>
#include <mutex>

// This dispatch class serves as a replacement for our previous dispatch
// mechanism, in which all functions were members of a Type class. A derived
// class existed for each backend (and Variable), and the vtable was used to
// dispatch to the correct implementation. This class is to be replaced by
// the c10 dispatcher when it supports all argument and return types.
// This implementation opts to store implementations in a table of void*.

namespace at {

// ATenOpTable stores the implementations for each backend, in addition to
// an implementation for variables.
class CAFFE2_API ATenOpTable {
public:
ATenOpTable(std::string schema)
: schema_(std::move(schema)) {}

template<class FuncType>
FuncType* getOp(Backend backend, bool is_variable) const {
if (is_variable) {
return reinterpret_cast<FuncType*>(getVariableOp());
}
return reinterpret_cast<FuncType*>(getBaseOp(backend));
}
private:
void registerOp(Backend backend, void* fn) {
TORCH_CHECK(function_table_[static_cast<int64_t>(backend)] == nullptr,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Internal asserts here? (I guess if you're going to expose directly to extensions an internal assert is not appropriate, but for internal use, this isn't a public API right)

Copy link
Contributor

@ezyang ezyang Jun 14, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no type checking in the registration API right now; as long as we don't have any type checking I have a pretty strong preference of NOT allowing external parties to poke this API. I don't know what the situation is with XLA and this diff at the moment. (The reason for this preference is that if XLA uses this directly, and we start changing the types of functions, they'll start segfaulting. Ick!)

"Attempting to register variable function for schema ", schema_,
" and backend ", toString(backend),
" but there is already a function registered");
function_table_[static_cast<int64_t>(backend)] = fn;
}

void registerVariableOp(void* fn) {
TORCH_CHECK(variable_function_ == nullptr,
"Attempting to register variable function for schema ", schema_,
" but there is already a function registered");
variable_function_ = fn;
}

void* getBaseOp(Backend backend) const {
if (function_table_[static_cast<int64_t>(backend)] == nullptr) {
TORCH_CHECK(function_table_[static_cast<int64_t>(Backend::Undefined)] != nullptr,
"No function is registered for schema ", schema_, " on backend ", toString(backend));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re error message: not only is no function not registered, but there is no undefined implementation registered either.

Actually, why are we dispatching to Undefined at all, at this point? Can't we just error out straight up at this point?

return function_table_[static_cast<int64_t>(Backend::Undefined)];
}
return function_table_[static_cast<int64_t>(backend)];
}

void* getVariableOp() const {
TORCH_CHECK(variable_function_ != nullptr,
"No variable function registered for ", schema_);
return variable_function_;
}

friend class ATenDispatch;

std::string schema_;
void* function_table_[static_cast<int64_t>(Backend::NumOptions)] = {nullptr};
void* variable_function_ = nullptr;
};

class CAFFE2_API ATenDispatch {
public:
template<class FuncType>
ATenDispatch& registerOp(Backend backend, const char* schema, FuncType* fn) {
std::lock_guard<std::mutex> lock(mutex_);
if (op_tables_.find(schema) == op_tables_.end()) {
op_tables_.insert(std::make_pair(schema, ATenOpTable(schema)));
}
op_tables_.at(schema).registerOp(backend, reinterpret_cast<void*>(fn));
return *this;
}

template <class FuncType>
ATenDispatch& registerVariableOp(const char* schema, FuncType* fn) {
std::lock_guard<std::mutex> lock(mutex_);
if (op_tables_.find(schema) == op_tables_.end()) {
op_tables_.insert(std::make_pair(schema, ATenOpTable(schema)));
}
op_tables_.at(schema).registerVariableOp(reinterpret_cast<void*>(fn));
return *this;
}

const ATenOpTable* getOpTable(const char* schema) const {
auto iter = op_tables_.find(schema);
TORCH_CHECK(iter != op_tables_.end(),
"No functions are registered for schema ", schema);
return &iter->second;
}

private:
std::unordered_map<std::string, ATenOpTable> op_tables_;
std::mutex mutex_;
};

CAFFE2_API ATenDispatch& globalATenDispatch();

} // namespace at
18 changes: 7 additions & 11 deletions aten/src/ATen/core/LegacyTypeDispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ class CAFFE2_API LegacyTypeDispatch {
}
Type * getNonVariableTypeOpt(Backend p, ScalarType s) {
if (p != Backend::Undefined) {
initForDeviceType(backendToDeviceType(p));
initForScalarType(s);
initForBackend(p);
}
auto type = getNonVariableTypeRaw(p, s);

Expand Down Expand Up @@ -103,10 +102,11 @@ class CAFFE2_API LegacyTypeDispatch {
type_registry[static_cast<int>(b)] = std::move(t);
detail::getVariableHooks().registerVariableTypeFor(this, b);
}
private:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why'd this become public?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, I'm calling this directly from Functions.h for all factory functions.

void initForDeviceType(DeviceType p) {
void initForBackend(Backend b) {
auto p = backendToDeviceType(b);
static std::once_flag cpu_once;
static std::once_flag cuda_once;
static std::once_flag complex_once;
if (p == DeviceType::CPU) {
std::call_once(cpu_once, [] {
getLegacyDeviceTypeInit().initCPU();
Expand All @@ -120,17 +120,13 @@ class CAFFE2_API LegacyTypeDispatch {
getLegacyDeviceTypeInit().initHIP();
});
}
}
void initForScalarType(ScalarType s) {
static std::once_flag once;
// Only complex may need initialization
if (isComplexType(s)) {
std::call_once(once, [] {
if (b == Backend::ComplexCPU || b == Backend::ComplexCUDA) {
std::call_once(complex_once, [] {
getLegacyDeviceTypeInit().initComplex();
});
}
}

private:
// NB: type_registry has nullptr for all CUDA backends until
// CUDA initialization has occurred
TypeUniquePtr type_registry
Expand Down
Loading