-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Export tensor descriptor #8313
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
goldsborough
merged 21 commits into
pytorch:master
from
bstriner:export_TensorDescriptor
Jun 21, 2018
Merged
Export tensor descriptor #8313
Changes from all commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
668606e
Export TensorDescriptor
bstriner 5fffade
Export descriptors
bstriner 28edcb6
install cudnn_h
bstriner bf53069
Merge branch 'export_TensorDescriptor' of https://github.com/bstriner…
bstriner 01d0d51
Merge remote-tracking branch 'origin/master' into export_TensorDescri…
bstriner ac45bf4
Merge remote-tracking branch 'origin/master' into export_TensorDescri…
bstriner 7683c73
Add tests and with_cuda
bstriner 381b1b6
tab to space
bstriner c22a67b
forgot cpp
bstriner a848a20
fix flake
bstriner 3676040
ld flags
bstriner 3346411
flake
bstriner 7c1f61f
address comments
bstriner 287a89c
clang-format
bstriner 328bc60
Merge remote-tracking branch 'origin/master' into export_TensorDescri…
bstriner a929a78
fixtest
bstriner ec3e1f6
fix test
bstriner dd3b4e6
extra headers
bstriner 1a41d07
extra headers
bstriner 4cd515d
camelcasing
bstriner a81a2b5
Merge remote-tracking branch 'origin/master' into export_TensorDescri…
bstriner File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,80 @@ | ||
| /* | ||
| * CuDNN ReLU extension. Simple function but contains the general structure of | ||
| * most CuDNN extensions: | ||
| * 1) Check arguments. at::check* functions provide a standard way to validate | ||
| * input and provide pretty errors. | ||
| * 2) Create descriptors. Most CuDNN functions require creating and setting a | ||
| * variety of descriptors. | ||
| * 3) Apply the CuDNN function. | ||
| * 4) Destroy your descriptors. | ||
| * 5) Return something (optional). | ||
| */ | ||
|
|
||
| #include <torch/torch.h> | ||
|
|
||
| #include <ATen/cudnn/Descriptors.h> // for TensorDescriptor | ||
| #include <ATen/cudnn/Exceptions.h> // for CUDNN_CHECK | ||
| #include <ATen/cudnn/Handles.h> // for getCudnnHandle | ||
|
|
||
| // Name of function in python module and name used for error messages by | ||
| // at::check* functions. | ||
| const char* cudnn_relu_name = "cudnn_relu"; | ||
|
|
||
| // Check arguments to cudnn_relu | ||
| void cudnn_relu_check(const at::Tensor& inputs, const at::Tensor& outputs) { | ||
| // Create TensorArgs. These record the names and positions of each tensor as a | ||
| // parameter. | ||
| at::TensorArg arg_inputs(inputs, "inputs", 0); | ||
| at::TensorArg arg_outputs(outputs, "outputs", 1); | ||
| // Check arguments. No need to return anything. These functions with throw an | ||
| // error if they fail. Messages are populated using information from | ||
| // TensorArgs. | ||
| at::checkContiguous(cudnn_relu_name, arg_inputs); | ||
| at::checkScalarType(cudnn_relu_name, arg_inputs, at::kFloat); | ||
| at::checkBackend(cudnn_relu_name, arg_inputs.tensor, at::kCUDA); | ||
| at::checkContiguous(cudnn_relu_name, arg_outputs); | ||
| at::checkScalarType(cudnn_relu_name, arg_outputs, at::kFloat); | ||
| at::checkBackend(cudnn_relu_name, arg_outputs.tensor, at::kCUDA); | ||
| at::checkSameSize(cudnn_relu_name, arg_inputs, arg_outputs); | ||
| } | ||
|
|
||
| void cudnn_relu(const at::Tensor& inputs, const at::Tensor& outputs) { | ||
| // Most CuDNN extensions will follow a similar pattern. | ||
| // Step 1: Check inputs. This will throw an error if inputs are invalid, so no | ||
| // need to check return codes here. | ||
| cudnn_relu_check(inputs, outputs); | ||
| // Step 2: Create descriptors | ||
| cudnnHandle_t cuDnn = at::native::getCudnnHandle(); | ||
| // Note: 4 is minimum dim for a TensorDescriptor. Input and output are same | ||
| // size and type and contiguous, so one descriptor is sufficient. | ||
| at::native::TensorDescriptor input_tensor_desc(inputs, 4); | ||
| cudnnActivationDescriptor_t activationDesc; | ||
| // Note: Always check return value of cudnn functions using CUDNN_CHECK | ||
| at::native::CUDNN_CHECK(cudnnCreateActivationDescriptor(&activationDesc)); | ||
| at::native::CUDNN_CHECK(cudnnSetActivationDescriptor( | ||
| activationDesc, | ||
| /*mode=*/CUDNN_ACTIVATION_RELU, | ||
| /*reluNanOpt=*/CUDNN_PROPAGATE_NAN, | ||
| /*coef=*/1.)); | ||
| // Step 3: Apply CuDNN function | ||
| float alpha = 1.; | ||
| float beta = 0.; | ||
| at::native::CUDNN_CHECK(cudnnActivationForward( | ||
| cuDnn, | ||
| activationDesc, | ||
| &alpha, | ||
| input_tensor_desc.desc(), | ||
| inputs.data_ptr(), | ||
| &beta, | ||
| input_tensor_desc.desc(), // output descriptor same as input | ||
| outputs.data_ptr())); | ||
| // Step 4: Destroy descriptors | ||
| at::native::CUDNN_CHECK(cudnnDestroyActivationDescriptor(activationDesc)); | ||
| // Step 5: Return something (optional) | ||
| } | ||
|
|
||
| // Create the pybind11 module | ||
| PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||
| // Use the same name as the check functions so error messages make sense | ||
| m.def(cudnn_relu_name, &cudnn_relu, "CuDNN ReLU"); | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This comment was marked as off-topic.
Sorry, something went wrong.
Uh oh!
There was an error while loading. Please reload this page.