forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
THCTensorTypeUtils.cuh
83 lines (72 loc) · 2.6 KB
/
THCTensorTypeUtils.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#ifndef THC_TENSOR_TYPE_UTILS_INC
#define THC_TENSOR_TYPE_UTILS_INC
#include <cuda.h>
#include <assert.h>
#include <THC/THCGeneral.h>
#include <TH/THHalf.h>
#include <THC/THCTensor.hpp>
#include <THC/THCTensorInfo.cuh>
#include <THC/THCTensor.hpp>
/// A utility for accessing THCuda*Tensor types in a generic manner
/// Equivalent to C++11's type_traits std::is_same; used for comparing
/// equality of types. Don't assume the existence of C++11
template <typename T, typename U>
struct SameType {
static const bool same = false;
};
template <typename T>
struct SameType<T, T> {
static const bool same = true;
};
template <typename T, typename U>
bool isSameType() {
return SameType<T, U>::same;
}
// Utility function for constructing TensorInfo structs. In this case, the
// two template parameters are:
//
// 1. The TensorType, e.g. THCTensor in generic functions, or THCudaTensor,
// THCudaLongTensor etc.
//
// 2. The IndexType. This is always going to be an unsigned integral value,
// but depending on the size of the Tensor you may select uint16_t
// uint32_t, uint64_t etc.
//
// Internally we use the TensorUtils static functions to get the necessary
// dims, sizes, stride etc.
//
// For example, suppose we have a THCudaTensor t, with dim = 2, size = [3, 4],
// stride = [4, 1], offset = 8, and we set our index type to be unsigned int.
// Then we yield a TensorInfo struct templatized with float, unsigned int and
// the following fields:
//
// data is a float* to the underlying storage at position 8
// dims is 2
// sizes is a MAX_CUTORCH_DIMS element array with [3, 4] in its first two positions
// strides is a MAX_CUTORCH_DIMS element array with [4, 1] in its first two positions
//
// TensorInfos can then be passed to CUDA kernels, but we can use the static functions
// defined above to perform Tensor Operations that are appropriate for each
// TensorType.
template <typename ScalarType, typename TensorType, typename IndexType>
TensorInfo<ScalarType, IndexType>
getTensorInfo(THCState* state, TensorType* t) {
IndexType sz[MAX_CUTORCH_DIMS];
IndexType st[MAX_CUTORCH_DIMS];
int dims = THCTensor_nDimensionLegacyNoScalars(state, t);
for (int i = 0; i < dims; ++i) {
sz[i] = THTensor_sizeLegacyNoScalars(t, i);
st[i] = THTensor_strideLegacyNoScalars(t, i);
}
return TensorInfo<ScalarType, IndexType>(
t->template data<ScalarType>(), dims, sz, st);
}
template <typename T>
struct ScalarNegate {
static __host__ __device__ T to(const T v) { return -v; }
};
template <typename T>
struct ScalarInv {
static __host__ __device__ T to(const T v) { return ((T) 1) / v; }
};
#endif // THC_TENSOR_TYPE_UTILS_INC