Skip to content

Commit

Permalink
Refactor THCNumerics and add common math functions for at::Half (pyto…
Browse files Browse the repository at this point in the history
…rch#10301)

Summary:
**Summary**: This PR is a followup of mruberry's pytorch#9318. It tries to achieve the following:
- Specializing std common math functions for `at::Half` type.
- Create `CUDANumerics.cuh` to contain necessary parts from `THCNumerics.cuh`.
- Update `THCNumerics.cuh` with new usage and comments to  demonstrate the best practice for developers and hence, making way for its deprecation.
- Remove legacy/redundant code path.
- Remove unused CUDA HALF macros (see separate PR pytorch#10147)

**Comments**: `CUDANumerics.cuh` contains mathematical functions that are either not in the std namespace or are specialized for compilation with CUDA NVCC or CUDA NVRTC. This header is derived from the legacy `THCNumerics.cuh`. Following are some rationale behind why some functions were kept while others were removed:
- All arithmetic can now be done in ATen using binary cuda kernel  or CUDA tensor pointwise apply (check pytorch#8919 and `CUDAApplyUtils`). `at::Half` comparisons rely on implicit conversion to float.
- Functions that are c/c++ standard compliant, have been specialized for user defined types, for instance, the std namespace has been opened up for `at::Half`, that defines math function definitions for `at::Half`. Check `Half-inl.h`
- Some standard compliant functions are specialized here for performance reasons. For instance, `powi` is used for `pow` calculation on integral types. Moreover, `abs`, `isinf`, `isnan` are specialized to save one API call vs when used with std. Although this is subject to change, depending on if we really care about saving one API call.
- Numeric limits such as `max/min` is removed since they call standard defines. Moreover, numeric limits for
`at::Half` is present in `Half-inl.h`. I understood that HIP has some issue with `std::numeric_limits` and this the related github issue I found: ROCm/HIP#374. AlexVlx mentions that the issue can be avoided by launching `std::numeric_limits` in `__device__`. Since, we are launching lambdas with device contexts, I don't see an issue why `std::numeric_limits` won't compile on HIP if launched with device context within a kernel, unless I am not aware of the real reason why max/min was there in THCNumerics in the first place. (Haven't ever tried a build with HIP).

Here are some reference PRs that was handy in refactoring TH into ATen:
- pytorch#6786
- pytorch#5475
- pytorch#9401
- pytorch#8689
- pytorch#8919
Pull Request resolved: pytorch#10301

Differential Revision: D9204758

Pulled By: soumith

fbshipit-source-id: 09f489c1656458c02367b6cd31c3eeeca5acdc8a
  • Loading branch information
syed-ahmed authored and PenghuiCheng committed Sep 11, 2018
1 parent bd5cda5 commit 979bc58
Show file tree
Hide file tree
Showing 13 changed files with 340 additions and 383 deletions.
75 changes: 75 additions & 0 deletions aten/src/ATen/cuda/NumericLimits.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#pragma once

#include <cuda.h>
#include <limits.h>

// NumericLimits.cuh is a holder for numeric limits definitions of commonly used
// types. This header is very specific to ROCm HIP and may be removed in the future.
// This header is derived from the legacy THCNumerics.cuh.

namespace at{

template <typename T>
struct numeric_limits {
};

// WARNING: the following at::numeric_limits definitions are there only to support
// HIP compilation for the moment. Use std::numeric_limits if you are not
// compiling for ROCm.
// from @colesbury: "The functions on numeric_limits aren't marked with
// __device__ which is why they don't work with ROCm. CUDA allows them
// because they're constexpr."
template <>
struct numeric_limits<uint8_t> {
static inline __host__ __device__ uint8_t lowest() { return 0; }
static inline __host__ __device__ uint8_t max() { return UINT8_MAX; }
};

template <>
struct numeric_limits<int8_t> {
static inline __host__ __device__ int8_t lowest() { return INT8_MIN; }
static inline __host__ __device__ int8_t max() { return INT8_MAX; }
};

template <>
struct numeric_limits<int16_t> {
static inline __host__ __device__ int16_t lowest() { return INT16_MIN; }
static inline __host__ __device__ int16_t max() { return INT16_MAX; }
};

template <>
struct numeric_limits<int32_t> {
static inline __host__ __device__ int32_t lowest() { return INT32_MIN; }
static inline __host__ __device__ int32_t max() { return INT32_MAX; }
};

template <>
struct numeric_limits<int64_t> {
#ifdef _MSC_VER
static inline __host__ __device__ int64_t lowest() { return _I64_MIN; }
static inline __host__ __device__ int64_t max() { return _I64_MAX; }
#else
static inline __host__ __device__ int64_t lowest() { return INT64_MIN; }
static inline __host__ __device__ int64_t max() { return INT64_MAX; }
#endif
};

template <>
struct numeric_limits<at::Half> {
static inline __host__ __device__ at::Half lowest() { return at::Half(0xFBFF, at::Half::from_bits); }
static inline __host__ __device__ at::Half max() { return at::Half(0x7BFF, at::Half::from_bits); }
};

template <>
struct numeric_limits<float> {
static inline __host__ __device__ float lowest() { return -FLT_MAX; }
static inline __host__ __device__ float max() { return FLT_MAX; }
};

template <>
struct numeric_limits<double> {
static inline __host__ __device__ double lowest() { return -DBL_MAX; }
static inline __host__ __device__ double max() { return DBL_MAX; }
};

} // namespace at
10 changes: 5 additions & 5 deletions aten/src/ATen/native/cuda/SoftMax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
#include <THC/THCTensorMathReduce.cuh>
#include <THC/THCTensorSort.cuh>
#include <THC/THCThrustAllocator.cuh>
#include <THC/THCNumerics.cuh>

#include "ATen/AccumulateType.h"
#include "ATen/cuda/NumericLimits.cuh"


namespace at {
Expand Down Expand Up @@ -200,7 +200,7 @@ __global__ void cunn_SpatialSoftMaxForward(
////////////////////////////////////////////////////////////

if (blockDim.x > 1) {
accscalar_t max_input = THCNumerics<accscalar_t>::min();
accscalar_t max_input = at::numeric_limits<accscalar_t>::lowest();
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) {
const accscalar_t value = static_cast<accscalar_t>(input[data_offset + d * dim_stride]);
max_input = Max<accscalar_t>()(max_input, value);
Expand All @@ -217,7 +217,7 @@ __global__ void cunn_SpatialSoftMaxForward(
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x)
output[data_offset + d * dim_stride] = epilogue(input[data_offset + d * dim_stride]);
} else {
accscalar_t max_input = THCNumerics<accscalar_t>::min();
accscalar_t max_input = at::numeric_limits<accscalar_t>::lowest();
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) {
const accscalar_t value = static_cast<accscalar_t>(input[data_offset + d * dim_stride]);
max_input = Max<accscalar_t>()(max_input, value);
Expand Down Expand Up @@ -403,9 +403,9 @@ cunn_SoftMaxForward(scalar_t *output, scalar_t *input, int classes)

// find the max
accscalar_t threadMax = ilpReduce<MaxFloat, ILP, scalar_t, accscalar_t>(
input, classes, MaxFloat<scalar_t, accscalar_t>(), -THCNumerics<accscalar_t>::max());
input, classes, MaxFloat<scalar_t, accscalar_t>(), -at::numeric_limits<accscalar_t>::max());
accscalar_t max_k = blockReduce<Max, accscalar_t>(
sdata, threadMax, Max<accscalar_t>(), -THCNumerics<accscalar_t>::max());
sdata, threadMax, Max<accscalar_t>(), -at::numeric_limits<accscalar_t>::max());

// reduce all values
accscalar_t threadExp = ilpReduce<SumExpFloat, ILP, scalar_t, accscalar_t>(
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ list(APPEND ATen_CUDA_TEST_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/integer_divider_test.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda_rng_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/apply_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/stream_test.cpp)
${CMAKE_CURRENT_SOURCE_DIR}/stream_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cuda_half_test.cu)
if (CUDNN_FOUND)
list(APPEND ATen_CUDA_TEST_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cudnn_test.cpp)
Expand Down
90 changes: 90 additions & 0 deletions aten/src/ATen/test/cuda_half_test.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#define CATCH_CONFIG_MAIN
#include "catch.hpp"

#include "ATen/ATen.h"
#include "ATen/cuda/NumericLimits.cuh"
#include "cuda.h"
#include "cuda_fp16.h"
#include "cuda_runtime.h"

#include <assert.h>

using namespace at;

__device__ void test(){

// test half construction and implicit conversions in device
assert(Half(3) == Half(3.0f));
assert(static_cast<Half>(3.0f) == Half(3.0f));
// there is no float <=> __half implicit conversion
assert(static_cast<Half>(3.0f) == 3.0f);

__half a = __float2half(3.0f);
__half b = __float2half(2.0f);
__half c = a - Half(b);
assert(static_cast<Half>(c) == Half(1.0));

// asserting if the functions used on
// half types give almost equivalent results when using
// functions on double.
// The purpose of these asserts are to test the device side
// half API for the common mathematical functions.
// Note: When calling std math functions from device, don't
// use the std namespace, but just "::" so that the function
// gets resolved from nvcc math_functions.hpp

float threshold = 0.00001;
assert(::abs(::lgamma(Half(10.0)) - ::lgamma(10.0f)) <= threshold);
assert(::abs(::exp(Half(1.0)) - ::exp(1.0f)) <= threshold);
assert(::abs(::log(Half(1.0)) - ::log(1.0f)) <= threshold);
assert(::abs(::log10(Half(1000.0)) - ::log10(1000.0f)) <= threshold);
assert(::abs(::log1p(Half(0.0)) - ::log1p(0.0f)) <= threshold);
assert(::abs(::log2(Half(1000.0)) - ::log2(1000.0f)) <= threshold);
assert(::abs(::expm1(Half(1.0)) - ::expm1(1.0f)) <= threshold);
assert(::abs(::cos(Half(0.0)) - ::cos(0.0f)) <= threshold);
assert(::abs(::sin(Half(0.0)) - ::sin(0.0f)) <= threshold);
assert(::abs(::sqrt(Half(100.0)) - ::sqrt(100.0f)) <= threshold);
assert(::abs(::ceil(Half(2.4)) - ::ceil(2.4f)) <= threshold);
assert(::abs(::floor(Half(2.7)) - ::floor(2.7f)) <= threshold);
assert(::abs(::trunc(Half(2.7)) - ::trunc(2.7f)) <= threshold);
assert(::abs(::acos(Half(-1.0)) - ::acos(-1.0f)) <= threshold);
assert(::abs(::cosh(Half(1.0)) - ::cosh(1.0f)) <= threshold);
assert(::abs(::acosh(Half(1.0)) - ::acosh(1.0f)) <= threshold);
assert(::abs(::asin(Half(1.0)) - ::asin(1.0f)) <= threshold);
assert(::abs(::sinh(Half(1.0)) - ::sinh(1.0f)) <= threshold);
assert(::abs(::asinh(Half(1.0)) - ::asinh(1.0f)) <= threshold);
assert(::abs(::tan(Half(0.0)) - ::tan(0.0f)) <= threshold);
assert(::abs(::atan(Half(1.0)) - ::atan(1.0f)) <= threshold);
assert(::abs(::tanh(Half(1.0)) - ::tanh(1.0f)) <= threshold);
assert(::abs(::erf(Half(10.0)) - ::erf(10.0f)) <= threshold);
assert(::abs(::erfc(Half(10.0)) - ::erfc(10.0f)) <= threshold);
assert(::abs(::abs(Half(-3.0)) - ::abs(-3.0f)) <= threshold);
assert(::abs(::round(Half(2.3)) - ::round(2.3f)) <= threshold);
assert(::abs(::pow(Half(2.0), Half(10.0)) - ::pow(2.0f, 10.0f)) <= threshold);
assert(::abs(::atan2(Half(7.0), Half(0.0)) - ::atan2(7.0f, 0.0f)) <= threshold);
// note: can't use namespace on isnan and isinf in device code
#ifdef _MSC_VER
// Windows requires this explicit conversion. The reason is unclear
// related issue with clang: https://reviews.llvm.org/D37906
assert(::abs(::isnan((float)Half(0.0)) - ::isnan(0.0f)) <= threshold);
assert(::abs(::isinf((float)Half(0.0)) - ::isinf(0.0f)) <= threshold);
#else
assert(::abs(::isnan(Half(0.0)) - ::isnan(0.0f)) <= threshold);
assert(::abs(::isinf(Half(0.0)) - ::isinf(0.0f)) <= threshold);
#endif
}

__global__ void kernel(){
test();
}

void launch_function(){
kernel<<<1,1>>>();
}

TEST_CASE( "half common math functions tests in device", "[cuda]" ) {
launch_function();
cudaError_t err = cudaDeviceSynchronize();
REQUIRE(err == cudaSuccess);
}

43 changes: 43 additions & 0 deletions aten/src/ATen/test/half_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
#include <iostream>
#include <limits>
#include <sstream>
#include <cmath>
#include <type_traits>
#include "test_seed.h"
#include "test_assert.h"

using namespace at;

Expand Down Expand Up @@ -115,3 +118,43 @@ ASSERT_SAME_TYPE(max_exponent);
ASSERT_SAME_TYPE(max_exponent10);
ASSERT_SAME_TYPE(traps);
ASSERT_SAME_TYPE(tinyness_before);

TEST_CASE( "half common math functions test", "[]" ) {
float threshold = 0.00001;
assert(std::abs(std::lgamma(Half(10.0)) - std::lgamma(10.0f)) <= threshold);
assert(std::abs(std::exp(Half(1.0)) - std::exp(1.0f)) <= threshold);
assert(std::abs(std::log(Half(1.0)) - std::log(1.0f)) <= threshold);
assert(std::abs(std::log10(Half(1000.0)) - std::log10(1000.0f)) <= threshold);
assert(std::abs(std::log1p(Half(0.0)) - std::log1p(0.0f)) <= threshold);
assert(std::abs(std::log2(Half(1000.0)) - std::log2(1000.0f)) <= threshold);
assert(std::abs(std::expm1(Half(1.0)) - std::expm1(1.0f)) <= threshold);
assert(std::abs(std::cos(Half(0.0)) - std::cos(0.0f)) <= threshold);
assert(std::abs(std::sin(Half(0.0)) - std::sin(0.0f)) <= threshold);
assert(std::abs(std::sqrt(Half(100.0)) - std::sqrt(100.0f)) <= threshold);
assert(std::abs(std::ceil(Half(2.4)) - std::ceil(2.4f)) <= threshold);
assert(std::abs(std::floor(Half(2.7)) - std::floor(2.7f)) <= threshold);
assert(std::abs(std::trunc(Half(2.7)) - std::trunc(2.7f)) <= threshold);
assert(std::abs(std::acos(Half(-1.0)) - std::acos(-1.0f)) <= threshold);
assert(std::abs(std::cosh(Half(1.0)) - std::cosh(1.0f)) <= threshold);
assert(std::abs(std::acosh(Half(1.0)) - std::acosh(1.0f)) <= threshold);
assert(std::abs(std::asin(Half(1.0)) - std::asin(1.0f)) <= threshold);
assert(std::abs(std::sinh(Half(1.0)) - std::sinh(1.0f)) <= threshold);
assert(std::abs(std::asinh(Half(1.0)) - std::asinh(1.0f)) <= threshold);
assert(std::abs(std::tan(Half(0.0)) - std::tan(0.0f)) <= threshold);
assert(std::abs(std::atan(Half(1.0)) - std::atan(1.0f)) <= threshold);
assert(std::abs(std::tanh(Half(1.0)) - std::tanh(1.0f)) <= threshold);
assert(std::abs(std::erf(Half(10.0)) - std::erf(10.0f)) <= threshold);
assert(std::abs(std::erfc(Half(10.0)) - std::erfc(10.0f)) <= threshold);
assert(std::abs(std::abs(Half(-3.0)) - std::abs(-3.0f)) <= threshold);
assert(std::abs(std::round(Half(2.3)) - std::round(2.3f)) <= threshold);
assert(std::abs(std::pow(Half(2.0), Half(10.0)) - std::pow(2.0f, 10.0f)) <= threshold);
assert(std::abs(std::atan2(Half(7.0), Half(0.0)) - std::atan2(7.0f, 0.0f)) <= threshold);
#ifdef __APPLE__
// @TODO: can macos do implicit conversion of Half?
assert(std::abs(std::isnan(static_cast<float>(Half(0.0))) - std::isnan(0.0f)) <= threshold);
assert(std::abs(std::isinf(static_cast<float>(Half(0.0))) - std::isinf(0.0f)) <= threshold);
#else
assert(std::abs(std::isnan(Half(0.0)) - std::isnan(0.0f)) <= threshold);
assert(std::abs(std::isinf(Half(0.0)) - std::isinf(0.0f)) <= threshold);
#endif
}
4 changes: 0 additions & 4 deletions aten/src/THC/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@ foreach(THC_TYPE Byte Char Short Int Long Half Float Double)
endforeach()
endforeach()

IF(CUDA_HAS_FP16 OR NOT ${CUDA_VERSION} LESS 7.5)
LIST(APPEND extra_src ${CMAKE_CURRENT_SOURCE_DIR}/THCHalf.cu)
ENDIF()

set(ATen_CUDA_SRCS ${ATen_CUDA_SRCS}
${CMAKE_CURRENT_SOURCE_DIR}/THCCachingAllocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/THCCachingHostAllocator.cpp
Expand Down
7 changes: 3 additions & 4 deletions aten/src/THC/THCAtomics.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
#include "THC.h"
#include "THCHalf.h"
#include "THCNumerics.cuh"

namespace at { struct Half; }
#include "ATen/ATen.h"

template <typename T, size_t n>
struct AtomicAddIntegerImpl;
Expand Down Expand Up @@ -118,8 +117,8 @@ static inline __device__ void atomicAdd(half *address, half val) {
old = atomicCAS(address_as_ui, assumed, old);
} while (assumed != old);
}
static inline __device__ void atomicAdd(at::Half *address, half val) {
return atomicAdd(reinterpret_cast<half*>(address), val);
static inline __device__ void atomicAdd(at::Half *address, at::Half val) {
atomicAdd(reinterpret_cast<half*>(address), val);
}

#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
Expand Down
51 changes: 0 additions & 51 deletions aten/src/THC/THCHalf.cu

This file was deleted.

8 changes: 0 additions & 8 deletions aten/src/THC/THCHalf.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,7 @@ typedef __half_raw half;
#endif
#endif

THC_EXTERNC void THCFloat2Half(THCState *state, half *out, float *in, ptrdiff_t len);
THC_EXTERNC void THCHalf2Float(THCState *state, float *out, half *in, ptrdiff_t len);
THC_API half THC_float2half(float a);
THC_API float THC_half2float(half a);

/* Check for native fp16 support on the current device (CC 5.3+) */
THC_API int THC_nativeHalfInstructions(THCState *state);

/* Check for performant native fp16 support on the current device */
THC_API int THC_fastHalfInstructions(THCState *state);

#endif
Loading

0 comments on commit 979bc58

Please sign in to comment.