Skip to content

Conversation

@goldsborough
Copy link
Contributor

This PR:

  1. Simplifies type dispatch options to consistent use of macros (not macros here and functions there),
  2. Adds the dispatch header to ATen/ATen.h so that users (e.g. writing extensions) can dispatch too.

@zdevito @ezyang @apaszke

@goldsborough goldsborough force-pushed the simplify-aten-dispatch branch from a990374 to 169ba53 Compare February 28, 2018 22:23
} default: \
runtime_error("%s not implemented for '%s'", NAME, the_type.toString()); \
} \
#define AT_DISPATCH_FLOATING_TYPES(TYPE, NAME, function) \

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@goldsborough goldsborough mentioned this pull request Mar 1, 2018
Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me. One small nit.

DISPATCH_ALL_FLOATING_TYPES(self.type(), "embedding_backward", [&]() {
using accscalar_t = cuda::acc_type<scalar_t>;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "embedding_backward", [&] {
using cuda_scalar_t = to_cuda_type<scalar_t>::type;

This comment was marked as off-topic.

at::runtime_error( \
"%s not implemented for '%s'", (NAME), the_type.toString()); \
} \
}()

This comment was marked as off-topic.

This comment was marked as off-topic.

const at::Tensor& condition,
const at::Tensor& self,
const at::Tensor& other) {
at::CPU_tensor_apply4<scalar_t, uint8_t, scalar_t, scalar_t>(

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

}()

}
#define AT_DISPATCH_ALL_TYPES_AND_HALF(TYPE, NAME, ...) \

This comment was marked as off-topic.

This comment was marked as off-topic.

@colesbury
Copy link
Member

LGTM too fwiw

@ezyang ezyang merged commit 5e188e4 into pytorch:master Mar 1, 2018
facebook-github-bot pushed a commit that referenced this pull request Aug 24, 2018
Summary:
**Summary**: This PR is a followup of mruberry's #9318. It tries to achieve the following:
- Specializing std common math functions for `at::Half` type.
- Create `CUDANumerics.cuh` to contain necessary parts from `THCNumerics.cuh`.
- Update `THCNumerics.cuh` with new usage and comments to  demonstrate the best practice for developers and hence, making way for its deprecation.
- Remove legacy/redundant code path.
- Remove unused CUDA HALF macros (see separate PR #10147)

**Comments**: `CUDANumerics.cuh` contains mathematical functions that are either not in the std namespace or are specialized for compilation with CUDA NVCC or CUDA NVRTC. This header is derived from the legacy `THCNumerics.cuh`. Following are some rationale behind why some functions were kept while others were removed:
- All arithmetic can now be done in ATen using binary cuda kernel  or CUDA tensor pointwise apply (check #8919 and `CUDAApplyUtils`). `at::Half` comparisons rely on implicit conversion to float.
- Functions that are c/c++ standard compliant, have been specialized for user defined types, for instance, the std namespace has been opened up for `at::Half`, that defines math function definitions for `at::Half`. Check `Half-inl.h`
- Some standard compliant functions are specialized here for performance reasons. For instance, `powi` is used for `pow` calculation on integral types. Moreover, `abs`, `isinf`, `isnan` are specialized to save one API call vs when used with std. Although this is subject to change, depending on if we really care about saving one API call.
- Numeric limits such as `max/min` is removed since they call standard defines. Moreover, numeric limits for
`at::Half` is present in `Half-inl.h`. I understood that HIP has some issue with `std::numeric_limits` and this the related github issue I found: ROCm/hip#374. AlexVlx mentions that the issue can be avoided by launching `std::numeric_limits` in `__device__`. Since, we are launching lambdas with device contexts, I don't see an issue why `std::numeric_limits` won't compile on HIP if launched with device context within a kernel, unless I am not aware of the real reason why max/min was there in THCNumerics in the first place. (Haven't ever tried a build with HIP).

Here are some reference PRs that was handy in refactoring TH into ATen:
- #6786
- #5475
- #9401
- #8689
- #8919
Pull Request resolved: #10301

Differential Revision: D9204758

Pulled By: soumith

fbshipit-source-id: 09f489c1656458c02367b6cd31c3eeeca5acdc8a
zdevito pushed a commit to zdevito/ATen that referenced this pull request Aug 25, 2018
Summary:
**Summary**: This PR is a followup of mruberry's pytorch/pytorch#9318. It tries to achieve the following:
- Specializing std common math functions for `at::Half` type.
- Create `CUDANumerics.cuh` to contain necessary parts from `THCNumerics.cuh`.
- Update `THCNumerics.cuh` with new usage and comments to  demonstrate the best practice for developers and hence, making way for its deprecation.
- Remove legacy/redundant code path.
- Remove unused CUDA HALF macros (see separate PR pytorch/pytorch#10147)

**Comments**: `CUDANumerics.cuh` contains mathematical functions that are either not in the std namespace or are specialized for compilation with CUDA NVCC or CUDA NVRTC. This header is derived from the legacy `THCNumerics.cuh`. Following are some rationale behind why some functions were kept while others were removed:
- All arithmetic can now be done in ATen using binary cuda kernel  or CUDA tensor pointwise apply (check pytorch/pytorch#8919 and `CUDAApplyUtils`). `at::Half` comparisons rely on implicit conversion to float.
- Functions that are c/c++ standard compliant, have been specialized for user defined types, for instance, the std namespace has been opened up for `at::Half`, that defines math function definitions for `at::Half`. Check `Half-inl.h`
- Some standard compliant functions are specialized here for performance reasons. For instance, `powi` is used for `pow` calculation on integral types. Moreover, `abs`, `isinf`, `isnan` are specialized to save one API call vs when used with std. Although this is subject to change, depending on if we really care about saving one API call.
- Numeric limits such as `max/min` is removed since they call standard defines. Moreover, numeric limits for
`at::Half` is present in `Half-inl.h`. I understood that HIP has some issue with `std::numeric_limits` and this the related github issue I found: ROCm/hip#374. AlexVlx mentions that the issue can be avoided by launching `std::numeric_limits` in `__device__`. Since, we are launching lambdas with device contexts, I don't see an issue why `std::numeric_limits` won't compile on HIP if launched with device context within a kernel, unless I am not aware of the real reason why max/min was there in THCNumerics in the first place. (Haven't ever tried a build with HIP).

Here are some reference PRs that was handy in refactoring TH into ATen:
- pytorch/pytorch#6786
- pytorch/pytorch#5475
- pytorch/pytorch#9401
- pytorch/pytorch#8689
- pytorch/pytorch#8919
Pull Request resolved: pytorch/pytorch#10301

Differential Revision: D9204758

Pulled By: soumith

fbshipit-source-id: 09f489c1656458c02367b6cd31c3eeeca5acdc8a
PenghuiCheng pushed a commit to PenghuiCheng/pytorch that referenced this pull request Sep 11, 2018
…rch#10301)

Summary:
**Summary**: This PR is a followup of mruberry's pytorch#9318. It tries to achieve the following:
- Specializing std common math functions for `at::Half` type.
- Create `CUDANumerics.cuh` to contain necessary parts from `THCNumerics.cuh`.
- Update `THCNumerics.cuh` with new usage and comments to  demonstrate the best practice for developers and hence, making way for its deprecation.
- Remove legacy/redundant code path.
- Remove unused CUDA HALF macros (see separate PR pytorch#10147)

**Comments**: `CUDANumerics.cuh` contains mathematical functions that are either not in the std namespace or are specialized for compilation with CUDA NVCC or CUDA NVRTC. This header is derived from the legacy `THCNumerics.cuh`. Following are some rationale behind why some functions were kept while others were removed:
- All arithmetic can now be done in ATen using binary cuda kernel  or CUDA tensor pointwise apply (check pytorch#8919 and `CUDAApplyUtils`). `at::Half` comparisons rely on implicit conversion to float.
- Functions that are c/c++ standard compliant, have been specialized for user defined types, for instance, the std namespace has been opened up for `at::Half`, that defines math function definitions for `at::Half`. Check `Half-inl.h`
- Some standard compliant functions are specialized here for performance reasons. For instance, `powi` is used for `pow` calculation on integral types. Moreover, `abs`, `isinf`, `isnan` are specialized to save one API call vs when used with std. Although this is subject to change, depending on if we really care about saving one API call.
- Numeric limits such as `max/min` is removed since they call standard defines. Moreover, numeric limits for
`at::Half` is present in `Half-inl.h`. I understood that HIP has some issue with `std::numeric_limits` and this the related github issue I found: ROCm/hip#374. AlexVlx mentions that the issue can be avoided by launching `std::numeric_limits` in `__device__`. Since, we are launching lambdas with device contexts, I don't see an issue why `std::numeric_limits` won't compile on HIP if launched with device context within a kernel, unless I am not aware of the real reason why max/min was there in THCNumerics in the first place. (Haven't ever tried a build with HIP).

Here are some reference PRs that was handy in refactoring TH into ATen:
- pytorch#6786
- pytorch#5475
- pytorch#9401
- pytorch#8689
- pytorch#8919
Pull Request resolved: pytorch#10301

Differential Revision: D9204758

Pulled By: soumith

fbshipit-source-id: 09f489c1656458c02367b6cd31c3eeeca5acdc8a
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants