Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aten/src/ATen/ATen.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <ATen/core/Scalar.h>
#include <c10/core/Storage.h>
#include <c10/core/TensorOptions.h>
#include <ATen/core/Reduction.h>
#include <c10/util/Exception.h>
#include <ATen/core/ATenDispatch.h>
#include <ATen/core/UnsafeFromTH.h>
7 changes: 0 additions & 7 deletions aten/src/ATen/CPUTypeDefault.cpp

This file was deleted.

11 changes: 0 additions & 11 deletions aten/src/ATen/CPUTypeDefault.h

This file was deleted.

56 changes: 2 additions & 54 deletions aten/src/ATen/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,16 @@
#include <string>
#include <stdexcept>

#include <ATen/RegisterCPU.h>
#include <ATen/Tensor.h>
#include <ATen/cpu/FlushDenormal.h>

#include <TH/TH.h> // for USE_LAPACK

namespace at {

static inline void errorHandler(const char * msg, void * data) {
throw std::runtime_error(msg);
}
static inline void argErrorHandler(int arg, const char * msg, void * data) {
std::stringstream new_error;
new_error << "invalid argument " << arg << ": " << msg;
throw std::runtime_error(new_error.str());
}

Context::Context()
: next_id(static_cast<size_t>(TypeID::NumOptions))
, thc_state(nullptr, [](THCState* p){ /* no-op */ } )
, thh_state(nullptr, [](THHState* p){ /* no-op */ } )
{

THSetDefaultErrorHandler(errorHandler,nullptr);
THSetDefaultArgErrorHandler(argErrorHandler,nullptr);
register_cpu_types(this);
}
: thc_state(nullptr, [](THCState* p){ /* no-op */ } )
, thh_state(nullptr, [](THHState* p){ /* no-op */ } ) {}

// TODO: This could be bad juju if someone calls globalContext() in the
// destructor of an object with static lifetime.
Expand Down Expand Up @@ -108,38 +91,6 @@ bool Context::setFlushDenormal(bool on) {
return at::cpu::set_flush_denormal(on);
}

// NOTE: We also check `at::NonVariableTypeMode`, and if it's enabled we always
// return non-Variable type in this function.
// See NOTE [ Treating Variables as non-Variables in type dispatch ]
TypeExtendedInterface& getType(TensorOptions options) {
return globalContext().getType(
options.backend(), typeMetaToScalarType(options.dtype()), options.is_variable() && !at::NonVariableTypeMode::is_enabled());
}

// NOTE: We also check `at::NonVariableTypeMode`, and if it's enabled we always
// return non-Variable type in this function.
// See NOTE [ Treating Variables as non-Variables in type dispatch ]
TypeExtendedInterface& getType(const TensorImpl* impl) {
Backend backend = tensorTypeIdToBackend(impl->type_id());
return globalContext().getType(
backend, typeMetaToScalarType(impl->dtype()), impl->is_variable());
}

TypeExtendedInterface& getType(const Tensor& t) {
return getType(t.unsafeGetTensorImpl());
}

LegacyTHDispatcher& getLegacyTHDispatcher(TensorOptions options) {
return globalContext().getLegacyTHDispatcher(
options.backend(), typeMetaToScalarType(options.dtype()));
}

LegacyTHDispatcher& getLegacyTHDispatcher(const TensorImpl* impl) {
Backend backend = tensorTypeIdToBackend(impl->type_id());
return globalContext().getLegacyTHDispatcher(
backend, typeMetaToScalarType(impl->dtype()));
}

Allocator* getCPUAllocator() {
return getTHDefaultAllocator();
}
Expand All @@ -155,9 +106,6 @@ struct LegacyDeviceTypeInit : public LegacyDeviceTypeInitInterface {
void initHIP() const override {
globalContext().lazyInitHIP();
}
void initComplex() const override {
globalContext().lazyInitComplex();
}
};
REGISTER_LEGACY_TYPE_INIT(LegacyDeviceTypeInit);

Expand Down
63 changes: 0 additions & 63 deletions aten/src/ATen/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,13 @@

#include <ATen/core/ATenGeneral.h>
#include <ATen/Tensor.h>
#include <ATen/TypeExtendedInterface.h>
#include <ATen/Utils.h>
#include <ATen/LegacyTHDispatch.h>
#include <ATen/LegacyTHDispatcher.h>
#include <ATen/core/ATenGeneral.h>
#include <ATen/core/Generator.h>
#include <ATen/CPUGenerator.h>
#include <ATen/core/LegacyTypeDispatch.h>
#include <ATen/core/VariableHooksInterface.h>
#include <ATen/detail/CUDAHooksInterface.h>
#include <ATen/detail/HIPHooksInterface.h>
#include <ATen/detail/ComplexHooksInterface.h>
#include <c10/util/Exception.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>

Expand All @@ -28,35 +23,6 @@ class Tensor;
class CAFFE2_API Context {
public:
Context();
TypeExtendedInterface* getNonVariableTypeRaw(Backend p, ScalarType s) {
return static_cast<TypeExtendedInterface*>(globalLegacyTypeDispatch().getNonVariableTypeRaw(p, s));
}
TypeExtendedInterface * getNonVariableTypeOpt(Backend p, ScalarType s) {
return static_cast<TypeExtendedInterface*>(globalLegacyTypeDispatch().getNonVariableTypeOpt(p, s));
}
TypeExtendedInterface & getNonVariableType(Backend p, ScalarType s) {
return static_cast<TypeExtendedInterface&>(globalLegacyTypeDispatch().getNonVariableType(p, s));
}
TypeExtendedInterface & getVariableType(Backend p, ScalarType s) {
return static_cast<TypeExtendedInterface&>(globalLegacyTypeDispatch().getVariableType(p, s));
}
TypeExtendedInterface & getType(Backend p, ScalarType s, bool is_variable) {
return static_cast<TypeExtendedInterface&>(globalLegacyTypeDispatch().getType(p, s, is_variable));
}
LegacyTHDispatcher& getLegacyTHDispatcher(Backend p, ScalarType s) {
return globalLegacyTHDispatch().getLegacyTHDispatcher(p, s);
}
// The passed in Type must be delete'able
// TODO: Just make it take a unique_ptr
void registerType(Backend b, Type* t) {
globalLegacyTypeDispatch().registerType(b,
LegacyTypeDispatch::TypeUniquePtr{t, LegacyTypeDeleter([](Type* p) { delete p; }) });
}

void registerLegacyTHDispatcher(Backend b, ScalarType s, LegacyTHDispatcher* t) {
globalLegacyTHDispatch().registerDispatcher(b, s,
LegacyTHDispatch::LegacyTHDispatcherUniquePtr{t, LegacyTHDispatcherDeleter([](LegacyTHDispatcher* p) { delete p; }) });
}

Generator & defaultGenerator(Device device) {
DeviceType device_type = device.type();
Expand Down Expand Up @@ -102,22 +68,15 @@ class CAFFE2_API Context {
THCState* lazyInitCUDA() {
std::call_once(thc_init,[&] {
thc_state = detail::getCUDAHooks().initCUDA();
detail::getCUDAHooks().registerCUDATypes(this);
});
return thc_state.get();
}
THHState* lazyInitHIP() {
std::call_once(thh_init,[&] {
thh_state = detail::getHIPHooks().initHIP();
detail::getHIPHooks().registerHIPTypes(this);
});
return thh_state.get();
}
void lazyInitComplex() {
std::call_once(complex_init_, [&] {
detail::getComplexHooks().registerComplexTypes(this);
});
}

THCState* getTHCState() {
// AT_ASSERT(thc_state);
Expand All @@ -127,9 +86,6 @@ class CAFFE2_API Context {
return thh_state.get();
}

size_t freshTypeID() {
return next_id++;
}
bool setFlushDenormal(bool on);

// NB: This method is *purely* whether or not a user requested
Expand All @@ -153,21 +109,13 @@ class CAFFE2_API Context {
lazyInitHIP();
}
}
void initComplexIfNeeded(ScalarType s) {
if (isComplexType(s)) {
lazyInitComplex();
}
}
std::once_flag thc_init;
std::once_flag thh_init;
std::once_flag complex_init_;
bool enabled_cudnn = true;
bool deterministic_cudnn = false;
bool benchmark_cudnn = false;
std::atomic<size_t> next_id;
std::unique_ptr<THCState, void(*)(THCState*)> thc_state;
std::unique_ptr<THHState, void(*)(THHState*)> thh_state;
friend struct Type;
};

CAFFE2_API Context& globalContext();
Expand All @@ -176,14 +124,6 @@ static inline void init() {
globalContext();
}

static inline TypeExtendedInterface& getNonVariableType(Backend p, ScalarType s) {
return globalContext().getNonVariableType(p, s);
}

CAFFE2_API TypeExtendedInterface& getType(TensorOptions options);
CAFFE2_API TypeExtendedInterface& getType(const TensorImpl*);
CAFFE2_API TypeExtendedInterface& getType(const Tensor&);

CAFFE2_API Allocator* getCPUAllocator();

static inline DeprecatedTypeProperties& getNonVariableDeprecatedTypeProperties(Backend p, ScalarType s) {
Expand All @@ -206,9 +146,6 @@ static inline DeprecatedTypeProperties& HIP(ScalarType s) {
Backend::HIP, s, /*is_variable*/false);
}

CAFFE2_API LegacyTHDispatcher& getLegacyTHDispatcher(TensorOptions options);
CAFFE2_API LegacyTHDispatcher& getLegacyTHDispatcher(const Tensor&);

static inline bool hasCUDA() {
return globalContext().hasCUDA();
}
Expand Down
12 changes: 0 additions & 12 deletions aten/src/ATen/LegacyTHDispatch.cpp

This file was deleted.

127 changes: 0 additions & 127 deletions aten/src/ATen/LegacyTHDispatch.h

This file was deleted.

2 changes: 1 addition & 1 deletion aten/src/ATen/TensorGeometry.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#pragma once

#include <ATen/Type.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/core/Tensor.h>

namespace at {

Expand Down
Loading