Skip to content

Conversation

@li-roy
Copy link
Contributor

@li-roy li-roy commented Jun 19, 2019

Stack from ghstack:

Removes many usages of type, in preparation for deletion:

  • Use of Type in python binding for constructors. replaced with TensorTypeId
  • Use of Type to save python default tensor. replaced with TensorTypeId

Differential Revision: D15893331

@pytorchbot pytorchbot added module: autograd Related to torch.autograd, and the autograd engine in general module: internals Related to internal abstractions in c10 and ATen module: operators module: pybind Related to our Python bindings / interactions with other Python libraries labels Jun 19, 2019
Remove many usages of Type

gh-metadata: pytorch pytorch 21941 gh/li-roy/30/head
royboy added 2 commits June 19, 2019 10:46
Remove many usages of Type

gh-metadata: pytorch pytorch 21941 gh/li-roy/30/head
Remove many usages of Type

gh-metadata: pytorch pytorch 21941 gh/li-roy/30/head
for (auto p : backends) {
auto baseType = context.getNonVariableTypeRaw(static_cast<Backend>(p), ScalarType::Undefined);
if (baseType) {
res.emplace_back(VariableType::getVariableTypeFromBaseType(*baseType));
Copy link
Contributor

@ezyang ezyang Jun 20, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weird, how come we didn't previously loop? EDIT: Oh, that's because you got rid of scalar type from Type. Makes sense


static std::vector<PyTensorType> tensor_types;

void set_default_tensor_type(PyTensorType* type) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This got moved, are there any substantive changes?

for (int s = 0; s < static_cast<int>(ScalarType::NumOptions); s++) {
cpu_map.emplace(type_to_string(*type, static_cast<ScalarType>(s)),
std::make_pair(type, static_cast<ScalarType>(s)));
cpu_map.emplace(type_to_string(*type), type);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good!

royboy added 2 commits June 20, 2019 14:05
Remove many usages of Type

gh-metadata: pytorch pytorch 21941 gh/li-roy/30/head
Remove many usages of Type

gh-metadata: pytorch pytorch 21941 gh/li-roy/30/head
@pytorchbot pytorchbot added the module: nn Related to torch.nn label Jun 20, 2019
Remove many usages of Type

gh-metadata: pytorch pytorch 21941 gh/li-roy/30/head
static bool isVariableType(const at::Type& type);
static std::vector<at::Type*> allCUDATypes();
static std::vector<at::Type*> allCPUTypes();
static std::vector<at::DeprecatedTypeProperties*> allCUDATypes();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a plan to not use DeprecatedTypeProperties here?

Copy link
Contributor Author

@li-roy li-roy Jun 21, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will soon only be useful for python types. Because we have the concept of FloatTensor and cuda.FloatTensor, we need these to be coupled in some way to convert those strings back and forth. If we get rid of those constructs, we don't need this at all. But we can also get rid of this by just passing around a tuple of TensorTypeId/Backend and ScalarType.

Remove many usages of Type

gh-metadata: pytorch pytorch 21941 gh/li-roy/30/head
PyObject *THPModule_getDefaultDtype(PyObject *_unused, PyObject *arg) {
HANDLE_TH_ERRORS
auto scalar_type = torch::tensors::get_default_scalar_type();
auto scalar_type = at::typeMetaToScalarType(torch::tensors::get_default_tensor_options().dtype());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should do this.

There is no such thing as default tensor options. There's a default dtype, and a default device type for constructors, but that's it. Using tensor options for this implies there are dimensions of the default which are absolutely not applicable: currently layout, optionality, etc. But soon there will also be dimension names and that doesn't make sense in this context.

The right way to think about this, I think, is that TensorOptions are a mechanism for C++ users to specify kwarg-like options to our factory function. Every other use of them has been a mistake (Type.options(), hacking up our parsing in multiple places to pass TensorOptions instead of just the actual parameters, etc.)

Copy link
Contributor Author

@li-roy li-roy Jun 21, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I'll replace this with get_default_tensor_type_id, which is better than the current get_default_tensor_type because we are still having a default backend, but we just don't have scalar_type info passed along with it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the future we should probably have a default device type instead.

royboy added 3 commits June 21, 2019 15:27
Remove many usages of Type

gh-metadata: pytorch pytorch 21941 gh/li-roy/30/head
Remove many usages of Type

gh-metadata: pytorch pytorch 21941 gh/li-roy/30/head
Remove many usages of Type

gh-metadata: pytorch pytorch 21941 gh/li-roy/30/head
royboy added 5 commits June 21, 2019 15:47
Remove many usages of Type

gh-metadata: pytorch pytorch 21941 gh/li-roy/30/head
Remove many usages of Type

gh-metadata: pytorch pytorch 21941 gh/li-roy/30/head
Remove many usages of Type

gh-metadata: pytorch pytorch 21941 gh/li-roy/30/head
Remove many usages of Type

gh-metadata: pytorch pytorch 21941 gh/li-roy/30/head
Remove many usages of Type

gh-metadata: pytorch pytorch 21941 gh/li-roy/30/head
// torch.cuda. is prefix of str
std::call_once(cuda_once, []() {
for (auto type : autograd::VariableType::allCUDATypes()) {
for (int s = 0; s < static_cast<int>(ScalarType::NumOptions); s++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait, do you still need to loop over ScalarTypes here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we still need to return Backend, ScalarType pair.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but you are never using s?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see what you mean

} else {
std::call_once(cpu_once, []() {
for (auto type : autograd::VariableType::allCPUTypes()) {
for (int s = 0; s < static_cast<int>(ScalarType::NumOptions); s++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here.

with freeze_rng_state():
output = test_case._forward(module, input)
grad_output = output.new(output.shape).normal_()
grad_output = output.new(output.shape, device=output.device).normal_()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this necessary? If so, that's broken.

"PyTorch splits its backend into two shared libraries: a CPU library "
"and a CUDA library; this error has occurred because you are trying "
"to use some CUDA functionality, but the CUDA library has not been "
"to use some CUDA functionality, but the CUDA library does not exist."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't really parse this change. What are you trying to get across here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change isn't needed anymore, I'll revert.

Initially, this PR made it so that this was being hit in some cases when pytorch was being built without CUDA, and I wanted to include that possibility in the error message. But the most recent changes make it so it won't be hit anymore.

}
}

static inline Backend backendToBackendOfDeviceType(Backend b, DeviceType d) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since this only seems to be used in one file (and let's be honest, it's not really clear what it does), can we just move it there instead of having it publicly available?

self.assertRaisesRegex(AssertionError, msg, lambda: torch.tensor([1], device="cuda"))
self.assertRaisesRegex(AssertionError, msg, lambda: torch.tensor([1]).cuda())
self.assertRaisesRegex(AssertionError, msg, lambda: torch.cuda.FloatTensor())
self.assertRaisesRegex(RuntimeError, msg, lambda: torch.cuda.FloatTensor())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this one different?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's hitting a different error now. After most recent changes, it hits an error saying that cuda.FloatTensor isn't available.

PyObject *THPModule_isDefaultTypeCuda(PyObject *_unused, PyObject *arg) {
HANDLE_TH_ERRORS
if (torch::tensors::get_default_tensor_type().is_cuda()) {
if (backendToDeviceType(tensorTypeIdToBackend(torch::tensors::get_default_tensor_type_id())) == at::kCUDA) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not your fault but...this is so crazy, lol.

}
if (!isVariableType(t.dispatch_type())) {
AT_ERROR("Expected object of type Variable but found type ", t.dispatch_type().toString(), " at position #", i, " "
if (!t.type().is_variable()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it work to do t.is_variable() ?

}
auto scalar_type = static_cast<ScalarType>(tensor_type.scalar_type);
return THPVariable_Wrap(torch::utils::legacy_tensor_ctor(*aten_type, scalar_type, args, kwargs));
return THPVariable_Wrap(torch::utils::legacy_tensor_ctor(tensor_type.get_type_id(), tensor_type.get_scalar_type(), args, kwargs));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think your change caused this, but:

>>> torch.randn(2,3).to_mkldnn().new((2,3))
tensor([2., 3.])

is wrong. Do you know what's going on there?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This occurs before my change too, right? If so, I don't have MKLDNN built, but my guess is that new calls legacy_new_from_sequence which calls internal_new_from_data, which uses to, so it doesn't support layouts.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it does. I'll file an issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

throw unavailable_type(*type);
}
set_default_tensor_type(*aten_type, scalar_type);
set_default_tensor_type(type);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if this is due to this specific change or somewhere else, but this seems worse:

Before you patch (without CUDA):

torch.set_default_tensor_type(torch.cuda.FloatTensor)
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-2-876b829d82b3> in <module>()
----> 1 torch.set_default_tensor_type(torch.cuda.FloatTensor)

/data/users/gchanan/_pytorch11/torch/__init__.py in set_default_tensor_type(t)
    151     if isinstance(t, _string_classes):
    152         t = _import_dotted_name(t)
--> 153     _C._set_default_tensor_type(t)
    154
    155

/data/users/gchanan/_pytorch11/torch/cuda/__init__.py in _lazy_init()
    176         raise RuntimeError(
    177             "Cannot re-initialize CUDA in forked subprocess. " + msg)
--> 178     _check_driver()
    179     torch._C._cuda_init()
    180     _cudart = _load_cudart()

/data/users/gchanan/_pytorch11/torch/cuda/__init__.py in _check_driver()
     90 def _check_driver():
     91     if not hasattr(torch._C, '_cuda_isDriverSufficient'):
---> 92         raise AssertionError("Torch not compiled with CUDA enabled")
     93     if not torch._C._cuda_isDriverSufficient():
     94         if torch._C._cuda_getDriverVersion() == 0:

AssertionError: Torch not compiled with CUDA enabled

with your patch:

In [4]: torch.set_default_tensor_type(torch.cuda.FloatTensor)

In [5]: torch.randn(2,3)
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-5-1dc2fcf8de67> in <module>()
----> 1 torch.randn(2,3)

/data/users/gchanan/_pytorch8/torch/cuda/__init__.py in _lazy_init()
    176         raise RuntimeError(
    177             "Cannot re-initialize CUDA in forked subprocess. " + msg)
--> 178     _check_driver()
    179     torch._C._cuda_init()
    180     _cudart = _load_cudart()

/data/users/gchanan/_pytorch8/torch/cuda/__init__.py in _check_driver()
     90 def _check_driver():
     91     if not hasattr(torch._C, '_cuda_isDriverSufficient'):
---> 92         raise AssertionError("Torch not compiled with CUDA enabled")
     93     if not torch._C._cuda_isDriverSufficient():
     94         if torch._C._cuda_getDriverVersion() == 0:

AssertionError: Torch not compiled with CUDA enabled

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed and added a test for it.

maybe_initialize_cuda(type_id);
AutoNoGIL no_gil;
return torch::zeros(sizes, type.options(scalar_type, std::move(device)));
if (device.has_value()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for all of these that condition based on whether the device exists...can we just build that into your options function?

Remove many usages of Type

gh-metadata: pytorch pytorch 21941 gh/li-roy/30/head
@li-roy li-roy mentioned this pull request Jun 24, 2019
royboy added 3 commits June 24, 2019 12:54
Remove many usages of Type

gh-metadata: pytorch pytorch 21941 gh/li-roy/30/head
Remove many usages of Type

gh-metadata: pytorch pytorch 21941 gh/li-roy/30/head
Remove many usages of Type

gh-metadata: pytorch pytorch 21941 gh/li-roy/30/head
@li-roy li-roy requested a review from gchanan June 25, 2019 01:55
royboy added 7 commits June 25, 2019 01:23
Remove many usages of Type

gh-metadata: pytorch pytorch 21941 gh/li-roy/30/head
Remove many usages of Type

gh-metadata: pytorch pytorch 21941 gh/li-roy/30/head
Remove many usages of Type

gh-metadata: pytorch pytorch 21941 gh/li-roy/30/head
Remove many usages of Type

gh-metadata: pytorch pytorch 21941 gh/li-roy/30/head
Remove many usages of Type

gh-metadata: pytorch pytorch 21941 gh/li-roy/30/head
Remove many usages of Type

gh-metadata: pytorch pytorch 21941 gh/li-roy/30/head
Remove many usages of Type

gh-metadata: pytorch pytorch 21941 gh/li-roy/30/head
@zou3519 zou3519 deleted the gh/li-roy/30/head branch June 30, 2019 11:14
@facebook-github-bot
Copy link
Contributor

@li-roy merged this pull request in 9c8f9f0.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Jun 30, 2019
Summary:
Pull Request resolved: pytorch/pytorch#21941
ghimport-source-id: f20cca6229daba9eb8652adb3d959266ae081ef1

Test Plan: Imported from OSS

Differential Revision: D15893331

Pulled By: li-roy

fbshipit-source-id: c988b16008ff0e2725a88c6025afd4aabdaca45a
xzhu1900 pushed a commit to xzhu1900/pytorch that referenced this pull request Jul 5, 2019
Summary:
Pull Request resolved: pytorch#21941
ghimport-source-id: f20cca6

Test Plan: Imported from OSS

Differential Revision: D15893331

Pulled By: li-roy

fbshipit-source-id: c988b16008ff0e2725a88c6025afd4aabdaca45a
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: autograd Related to torch.autograd, and the autograd engine in general module: internals Related to internal abstractions in c10 and ATen module: nn Related to torch.nn module: pybind Related to our Python bindings / interactions with other Python libraries

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants