Skip to content

Conversation

@syed-ahmed
Copy link
Collaborator

Summary: This PR is a followup of @mruberry's #9318. It tries to achieve the following:

  • Specializing std common math functions for at::Half type.
  • Create CUDANumerics.cuh to contain necessary parts from THCNumerics.cuh.
  • Update THCNumerics.cuh with new usage and comments to demonstrate the best practice for developers and hence, making way for its deprecation.
  • Remove legacy/redundant code path.
  • Remove unused CUDA HALF macros (see separate PR Cuda half macros cleanup #10147)

Comments: CUDANumerics.cuh contains mathematical functions that are either not in the std namespace or are specialized for compilation with CUDA NVCC or CUDA NVRTC. This header is derived from the legacy THCNumerics.cuh. Following are some rationale behind why some functions were kept while others were removed:

  • All arithmetic can now be done in ATen using binary cuda kernel or CUDA tensor pointwise apply (check Implement add, sub, mul, div using TensorIterator #8919 and CUDAApplyUtils). at::Half comparisons rely on implicit conversion to float.
  • Functions that are c/c++ standard compliant, have been specialized for user defined types, for instance, the std namespace has been opened up for at::Half, that defines math function definitions for at::Half. Check Half-inl.h
  • Some standard compliant functions are specialized here for performance reasons. For instance, powi is used for pow calculation on integral types. Moreover, abs, isinf, isnan are specialized to save one API call vs when used with std. Although this is subject to change, depending on if we really care about saving one API call.
  • Numeric limits such as max/min is removed since they call standard defines. Moreover, numeric limits for
    at::Half is present in Half-inl.h. I understood that HIP has some issue with std::numeric_limits and this the related github issue I found: std::numeric_limits<_T>::infinity() compilation problem ROCm/hip#374. @AlexVlx mentions that the issue can be avoided by launching std::numeric_limits in __device__. Since, we are launching lambdas with device contexts, I don't see an issue why std::numeric_limits won't compile on HIP if launched with device context within a kernel, unless I am not aware of the real reason why max/min was there in THCNumerics in the first place. (Haven't ever tried a build with HIP).

Here are some reference PRs that was handy in refactoring TH into ATen:

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ezyang
Copy link
Contributor

ezyang commented Aug 7, 2018

ROCm failure is not directly your fault, but it is probably real.

// compatibility

inline AT_HOSTDEVICE at::Half lgamma(at::Half a) {
return (at::Half)lgammaf((float)a);

This comment was marked as off-topic.

@ezyang
Copy link
Contributor

ezyang commented Aug 7, 2018

To fix the ROCm error, I'd advise just disabling the Half related functionality when the ROCm build is enabled. Half seems to cause a lot of problems for the ROCm toolchain.

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty reasonable to me, but build failures will have to be worked around somehow.

// at all. However, keeping all the function definitions to provide backward
// compatibility

inline AT_HOSTDEVICE at::Half lgamma(at::Half a) {

This comment was marked as off-topic.

}

template <>
struct numerics<uint8_t> {

This comment was marked as off-topic.

};

template <typename scalar_t>
static inline __host__ __device__ scalar_t powi(scalar_t a, scalar_t b) {

This comment was marked as off-topic.

// find the max
accscalar_t threadMax = ilpReduce<MaxFloat, ILP, scalar_t, accscalar_t>(
input, classes, MaxFloat<scalar_t, accscalar_t>(), -THCNumerics<accscalar_t>::max());
input, classes, MaxFloat<scalar_t, accscalar_t>(), -std::numeric_limits<accscalar_t>::max());

This comment was marked as off-topic.

This comment was marked as off-topic.

@syed-ahmed
Copy link
Collaborator Author

@ezyang @colesbury Thanks for the review. Will follow up with fixes.

@syed-ahmed
Copy link
Collaborator Author

syed-ahmed commented Aug 8, 2018

Change list:

  • Got rid of half overloads for math functions
  • Used implicit conversions to show current ATen usage
  • Created at::numeric_limits for HIP
  • Added the CUDA_DEVICE_DEBUG flag for device debugging using cuda-gdb
  • Update function call and usage in THCAtomics

Copy link
Member

@colesbury colesbury left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty good. Thanks for cleaning up these headers. Just a few small comments.

#include <cuda.h>
#include <limits.h>

// CUDANumerics.cuh is a holder for mathematical functions that are either

This comment was marked as off-topic.

template <>
struct numeric_limits<uint8_t> {
static inline __host__ __device__ uint8_t lowest() { return 0; }
static inline __host__ __device__ uint8_t max() { return UCHAR_MAX; }

This comment was marked as off-topic.

// half API for the common mathematical functions.
// Note: When calling std math functions from device, don't
// use the std namespace, but just "::" so that the function
// gets resolved from nvcc math_functions.hpp

This comment was marked as off-topic.

This comment was marked as off-topic.

@syed-ahmed
Copy link
Collaborator Author

Looks like the two failing builds are not related to this PR? Are there any more changes needed for this PR?

@ezyang
Copy link
Contributor

ezyang commented Aug 15, 2018

@pytorchbot retest this please

@soumith soumith dismissed colesbury’s stale review August 15, 2018 18:20

comments have been addressed

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

soumith has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

soumith has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ailzhang
Copy link
Contributor

@pytorchbot retest this please

@syed-ahmed
Copy link
Collaborator Author

How is the internal CI looking for this PR?

@colesbury
Copy link
Member

Looks good, I'll reimport and land.

@syed-ahmed syed-ahmed deleted the thcnumerics-refactor branch August 24, 2018 23:47
zdevito pushed a commit to zdevito/ATen that referenced this pull request Aug 25, 2018
Summary:
**Summary**: This PR is a followup of mruberry's pytorch/pytorch#9318. It tries to achieve the following:
- Specializing std common math functions for `at::Half` type.
- Create `CUDANumerics.cuh` to contain necessary parts from `THCNumerics.cuh`.
- Update `THCNumerics.cuh` with new usage and comments to  demonstrate the best practice for developers and hence, making way for its deprecation.
- Remove legacy/redundant code path.
- Remove unused CUDA HALF macros (see separate PR pytorch/pytorch#10147)

**Comments**: `CUDANumerics.cuh` contains mathematical functions that are either not in the std namespace or are specialized for compilation with CUDA NVCC or CUDA NVRTC. This header is derived from the legacy `THCNumerics.cuh`. Following are some rationale behind why some functions were kept while others were removed:
- All arithmetic can now be done in ATen using binary cuda kernel  or CUDA tensor pointwise apply (check pytorch/pytorch#8919 and `CUDAApplyUtils`). `at::Half` comparisons rely on implicit conversion to float.
- Functions that are c/c++ standard compliant, have been specialized for user defined types, for instance, the std namespace has been opened up for `at::Half`, that defines math function definitions for `at::Half`. Check `Half-inl.h`
- Some standard compliant functions are specialized here for performance reasons. For instance, `powi` is used for `pow` calculation on integral types. Moreover, `abs`, `isinf`, `isnan` are specialized to save one API call vs when used with std. Although this is subject to change, depending on if we really care about saving one API call.
- Numeric limits such as `max/min` is removed since they call standard defines. Moreover, numeric limits for
`at::Half` is present in `Half-inl.h`. I understood that HIP has some issue with `std::numeric_limits` and this the related github issue I found: ROCm/hip#374. AlexVlx mentions that the issue can be avoided by launching `std::numeric_limits` in `__device__`. Since, we are launching lambdas with device contexts, I don't see an issue why `std::numeric_limits` won't compile on HIP if launched with device context within a kernel, unless I am not aware of the real reason why max/min was there in THCNumerics in the first place. (Haven't ever tried a build with HIP).

Here are some reference PRs that was handy in refactoring TH into ATen:
- pytorch/pytorch#6786
- pytorch/pytorch#5475
- pytorch/pytorch#9401
- pytorch/pytorch#8689
- pytorch/pytorch#8919
Pull Request resolved: pytorch/pytorch#10301

Differential Revision: D9204758

Pulled By: soumith

fbshipit-source-id: 09f489c1656458c02367b6cd31c3eeeca5acdc8a
petrex pushed a commit to petrex/pytorch that referenced this pull request Aug 27, 2018
* upstream/master: (89 commits)
  move HeatmapMaxKeypointOp unittest to oss
  fix xfails involving literals (pytorch#10905)
  Bag of Distributions doc fixes (pytorch#10894)
  Remove FIXME_zerol() from test_jit.py (pytorch#10900)
  Increase BC for PackedSequence ctor (pytorch#9864)
  Remove ability of Scalars to hold Tensors.
  Begin a bestiary of MSVC/NVCC bugs. (pytorch#10883)
  Prevent JIT from overspecializing to every single size configuration (pytorch#10844)
  Handling failing test on ROCm.
  Update mobile predictor caller's interface
  Cache isContiguous and numel
  Create class constant for string literal 'blob_names'
  Conv BN fusion for 3D conv (pytorch#10239)
  Stop using symbolic override for tracing RNNs (pytorch#10638)
  Add registry to pybind_state (pytorch#10759)
  Remove the nanopb submodule
  Create at::linear (pytorch#10799)
  Refactor THCNumerics and add common math functions for at::Half (pytorch#10301)
  Remove Tensor constructor of Scalar. (pytorch#10852)
  Revert D9492561: [pytorch][PR] Moving the operator argument to the front for kernelPointwiseApply.
  ...
@varunagrawal
Copy link
Contributor

varunagrawal commented Aug 28, 2018

I am getting the same issue as @soumith. However, this is when building a CUDA extension.

Error log:

/python3.7/site-packages/torch/lib/include/THC/THCNumerics.cuh(163): error: more than one operator "<" matches these operands:
            built-in operator "arithmetic < arithmetic"
            function "operator<(const __half &, const __half &)"
            operand types are: at::Half < at::Half

nvcc similarly complains about <=, > and >=.

@varunagrawal
Copy link
Contributor

varunagrawal commented Aug 28, 2018

Can confirm. Checked out 87a7840 and compiled everything again. Extension compilation succeeds without a hitch.

@syed-ahmed
Copy link
Collaborator Author

Hi @varunagrawal. Could you please post a snippet of the code you are compiling, especially the line which does the "<" comparison? From the error, it says you are comparing at::Half with at::Half, which shouldn't need THCNumerics.cuh. Would be happy to reproduce on my side and help. :)

@varunagrawal
Copy link
Contributor

The compiler isn't quite showing the file or code snippet that's causing the issue. This is what I get:

gcc -pthread -B /home/varun/anaconda3/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -DWITH_CUDA -I/home/varun/projects/torchvision/torchvision/csrc -I/home/varun/anaconda3/lib/python3.7/site-packages/torch/lib/include -I/home/varun/anaconda3/lib/python3.7/site-packages/torch/lib/include/TH -I/home/varun/anaconda3/lib/python3.7/site-packages/torch/lib/include/THC -I/usr/local/cuda/include -I/home/varun/anaconda3/include/python3.7m -c /home/varun/projects/torchvision/torchvision/csrc/cpu/ROIPool_cpu.cpp -o build/temp.linux-x86_64-3.7/home/varun/projects/torchvision/torchvision/csrc/cpu/ROIPool_cpu.o -DTORCH_EXTENSION_NAME=_C -std=c++11
cc1plus: warning: command line option ‘-Wstrict-prototypes’ is valid for C/ObjC but not for C++
/usr/local/cuda/bin/nvcc -DWITH_CUDA -I/home/varun/projects/torchvision/torchvision/csrc -I/home/varun/anaconda3/lib/python3.7/site-packages/torch/lib/include -I/home/varun/anaconda3/lib/python3.7/site-packages/torch/lib/include/TH -I/home/varun/anaconda3/lib/python3.7/site-packages/torch/lib/include/THC -I/usr/local/cuda/include -I/home/varun/anaconda3/include/python3.7m -c /home/varun/projects/torchvision/torchvision/csrc/cuda/ROIPool_cuda.cu -o build/temp.linux-x86_64-3.7/home/varun/projects/torchvision/torchvision/csrc/cuda/ROIPool_cuda.o -DTORCH_EXTENSION_NAME=_C --compiler-options '-fPIC' -std=c++11
/home/varun/anaconda3/lib/python3.7/site-packages/torch/lib/include/THC/THCNumerics.cuh(163): error: more than one operator "<" matches these operands:
            built-in operator "arithmetic < arithmetic"
            function "operator<(const __half &, const __half &)"
            operand types are: at::Half < at::Half

/home/varun/anaconda3/lib/python3.7/site-packages/torch/lib/include/THC/THCNumerics.cuh(167): error: more than one operator "<=" matches these operands:
            built-in operator "arithmetic <= arithmetic"
            function "operator<=(const __half &, const __half &)"
            operand types are: at::Half <= at::Half

/home/varun/anaconda3/lib/python3.7/site-packages/torch/lib/include/THC/THCNumerics.cuh(171): error: more than one operator ">" matches these operands:
            built-in operator "arithmetic > arithmetic"
            function "operator>(const __half &, const __half &)"
            operand types are: at::Half > at::Half

/home/varun/anaconda3/lib/python3.7/site-packages/torch/lib/include/THC/THCNumerics.cuh(175): error: more than one operator ">=" matches these operands:
            built-in operator "arithmetic >= arithmetic"
            function "operator>=(const __half &, const __half &)"
            operand types are: at::Half >= at::Half

4 errors detected in the compilation of "/tmp/tmpxft_00000b86_00000000-4_ROIPool_cuda.cpp4.ii".
error: command '/usr/local/cuda/bin/nvcc' failed with exit status 2

I honestly can't decipher the issue other than it seems to have problems linking with some code in THCNumerics.cuh.

@syed-ahmed
Copy link
Collaborator Author

Could you please post the code for /home/varun/projects/torchvision/torchvision/csrc/cuda/ROIPool_cuda.cu? Also if it's better to communicate on slack, I'm available there: https://pytorch.slack.com

@varunagrawal
Copy link
Contributor

varunagrawal commented Aug 28, 2018

@syed-ahmed I am not on that slack workspace and unfortunately I don't fit the criteria for creating an account there.

You can find the code for ROIPool_cuda.cu here.

@syed-ahmed
Copy link
Collaborator Author

Is this the latest version of the ROIPool_cuda.cu code? I don't see a dispatch of Half (AT_DISPATCH_FLOATING_TYPES_AND_HALF for instance) which could touch half comparisons.

@varunagrawal
Copy link
Contributor

It is the latest. Doesn't AT_DISPATCH_FLOATING_TYPES cover Half types as well?

@syed-ahmed
Copy link
Collaborator Author

syed-ahmed commented Aug 28, 2018

From what I can see here: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h#L13, AT_DISPATCH_FLOATING_TYPES would throw an error if you passed a half type to this macro. Instead, for including dispatch of half types, these are the macros that cover it: AT_DISPATCH_FLOATING_TYPES_AND_HALF and AT_DISPATCH_ALL_TYPES_AND_HALF

@varunagrawal
Copy link
Contributor

Huh alright. Though it seems weird that AT_DISPATCH_ALL_TYPES_AND_HALF doesn't include HALF by default despite saying ALL_TYPES. 😕

@syed-ahmed
Copy link
Collaborator Author

I was able to reproduce the error. What's happening is when comparing at::Half with at::Half, the compiler is getting confused with which conversion to do since there are multiple operator overloading (while this is a problem, currently it is avoided by providing half related NVCC flags in Dependencies.cmake). That is, at::Half can implicitly convert to float and the comparison can be float > float for instance. Or at::Half can implicitly convert to __half and the comparison can be __half > __half. There exist operator overloading for comparisons in cuda_fp16.hpp and THCHalfAutonumerics. THCHalfAutonumerics is only used in THCUNN. So in your case, we need to remove half operators and conversions coming from cuda_fp16 header somehow such that the operator overloading from there is not compiled against

Hence, the fix for your extension is when building using setuptools, add the flags in there as shown below. Currently you'll see, PyTorch top of tree adds these flags in Dependencies.cmake - '-DCUDA_HAS_FP16=1','-D__CUDA_NO_HALF_OPERATORS__','-D__CUDA_NO_HALF_CONVERSIONS__','-D__CUDA_NO_HALF2_OPERATORS__', which says, don't use any of the operator overloading from cuda_fp16.hpp. This error might not have surfaced before since you might be the first one to provide at::Half support in an extension.

def get_extensions():
    this_dir = os.path.dirname(os.path.abspath(__file__))
    extensions_dir = os.path.join(this_dir, 'torchvision', 'csrc')

    main_file = glob.glob(os.path.join(extensions_dir, '*.cpp'))
    source_cpu = glob.glob(os.path.join(extensions_dir, 'cpu', '*.cpp'))
    source_cuda = glob.glob(os.path.join(extensions_dir, 'cuda', '*.cu'))

    sources = main_file + source_cpu
    extension = CppExtension

    extra_cflags = []
    extra_compile_args = {'cxx':[]}
    define_macros = []

    if torch.cuda.is_available() and CUDA_HOME is not None:
        extension = CUDAExtension
        sources += source_cuda
        define_macros += [('WITH_CUDA', None)]
        extra_compile_args['nvcc'] = ['-DCUDA_HAS_FP16=1','-D__CUDA_NO_HALF_OPERATORS__','-D__CUDA_NO_HALF_CONVERSIONS__','-D__CUDA_NO_HALF2_OPERATORS__']

    sources = [os.path.join(extensions_dir, s) for s in sources]

    include_dirs = [extensions_dir]

    ext_modules = [
        extension(
            'torchvision._C',
            sources,
            include_dirs=include_dirs,
            define_macros=define_macros,
            extra_compile_args=extra_compile_args,
        )
    ]

    return ext_modules

syed-ahmed added a commit to syed-ahmed/tutorials that referenced this pull request Aug 29, 2018
When building CUDA extensions, we have to pass `extra_compile_args` now to avoid cuda header collisions with half operator overloading (that happens through implicit half conversions) details: pytorch/pytorch#10301 (comment).
facebook-github-bot pushed a commit that referenced this pull request Sep 10, 2018
Summary:
The controller you requested could not be found.  found there are some issues when using comparison operators for half types when certain THC header are included. I was able to reproduce and added a test. I also fix the issue by adding the proper definitions to avoid this issue.

Reported in #10301 (comment)
Related: pytorch/tutorials#292

soumith fmassa
Pull Request resolved: #11395

Differential Revision: D9725102

Pulled By: goldsborough

fbshipit-source-id: 630425829046bbebea3409bb792a9d62c91f41ad
PenghuiCheng pushed a commit to PenghuiCheng/pytorch that referenced this pull request Sep 11, 2018
…rch#10301)

Summary:
**Summary**: This PR is a followup of mruberry's pytorch#9318. It tries to achieve the following:
- Specializing std common math functions for `at::Half` type.
- Create `CUDANumerics.cuh` to contain necessary parts from `THCNumerics.cuh`.
- Update `THCNumerics.cuh` with new usage and comments to  demonstrate the best practice for developers and hence, making way for its deprecation.
- Remove legacy/redundant code path.
- Remove unused CUDA HALF macros (see separate PR pytorch#10147)

**Comments**: `CUDANumerics.cuh` contains mathematical functions that are either not in the std namespace or are specialized for compilation with CUDA NVCC or CUDA NVRTC. This header is derived from the legacy `THCNumerics.cuh`. Following are some rationale behind why some functions were kept while others were removed:
- All arithmetic can now be done in ATen using binary cuda kernel  or CUDA tensor pointwise apply (check pytorch#8919 and `CUDAApplyUtils`). `at::Half` comparisons rely on implicit conversion to float.
- Functions that are c/c++ standard compliant, have been specialized for user defined types, for instance, the std namespace has been opened up for `at::Half`, that defines math function definitions for `at::Half`. Check `Half-inl.h`
- Some standard compliant functions are specialized here for performance reasons. For instance, `powi` is used for `pow` calculation on integral types. Moreover, `abs`, `isinf`, `isnan` are specialized to save one API call vs when used with std. Although this is subject to change, depending on if we really care about saving one API call.
- Numeric limits such as `max/min` is removed since they call standard defines. Moreover, numeric limits for
`at::Half` is present in `Half-inl.h`. I understood that HIP has some issue with `std::numeric_limits` and this the related github issue I found: ROCm/hip#374. AlexVlx mentions that the issue can be avoided by launching `std::numeric_limits` in `__device__`. Since, we are launching lambdas with device contexts, I don't see an issue why `std::numeric_limits` won't compile on HIP if launched with device context within a kernel, unless I am not aware of the real reason why max/min was there in THCNumerics in the first place. (Haven't ever tried a build with HIP).

Here are some reference PRs that was handy in refactoring TH into ATen:
- pytorch#6786
- pytorch#5475
- pytorch#9401
- pytorch#8689
- pytorch#8919
Pull Request resolved: pytorch#10301

Differential Revision: D9204758

Pulled By: soumith

fbshipit-source-id: 09f489c1656458c02367b6cd31c3eeeca5acdc8a
PenghuiCheng pushed a commit to PenghuiCheng/pytorch that referenced this pull request Sep 11, 2018
Summary:
The controller you requested could not be found.  found there are some issues when using comparison operators for half types when certain THC header are included. I was able to reproduce and added a test. I also fix the issue by adding the proper definitions to avoid this issue.

Reported in pytorch#10301 (comment)
Related: pytorch/tutorials#292

soumith fmassa
Pull Request resolved: pytorch#11395

Differential Revision: D9725102

Pulled By: goldsborough

fbshipit-source-id: 630425829046bbebea3409bb792a9d62c91f41ad
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants