1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_TYPES_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_TENSOR_TYPES_H_
18
19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20 #include "tensorflow/core/platform/logging.h"
21
22 namespace tensorflow {
23
24 // Helper to define Tensor types given that the scalar is of type T.
25 template <typename T, int NDIMS = 1, typename IndexType = Eigen::DenseIndex>
26 struct TTypes {
27 // Rank-<NDIMS> tensor of scalar type T.
28 typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor, IndexType>,
29 Eigen::Aligned>
30 Tensor;
31 typedef Eigen::TensorMap<
32 Eigen::Tensor<const T, NDIMS, Eigen::RowMajor, IndexType>, Eigen::Aligned>
33 ConstTensor;
34
35 // Unaligned Rank-<NDIMS> tensor of scalar type T.
36 typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor, IndexType> >
37 UnalignedTensor;
38 typedef Eigen::TensorMap<
39 Eigen::Tensor<const T, NDIMS, Eigen::RowMajor, IndexType> >
40 UnalignedConstTensor;
41
42 typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor, int>,
43 Eigen::Aligned>
44 Tensor32Bit;
45
46 // Scalar tensor (implemented as a rank-0 tensor) of scalar type T.
47 typedef Eigen::TensorMap<
48 Eigen::TensorFixedSize<T, Eigen::Sizes<>, Eigen::RowMajor, IndexType>,
49 Eigen::Aligned>
50 Scalar;
51 typedef Eigen::TensorMap<Eigen::TensorFixedSize<const T, Eigen::Sizes<>,
52 Eigen::RowMajor, IndexType>,
53 Eigen::Aligned>
54 ConstScalar;
55
56 // Unaligned Scalar tensor of scalar type T.
57 typedef Eigen::TensorMap<
58 Eigen::TensorFixedSize<T, Eigen::Sizes<>, Eigen::RowMajor, IndexType> >
59 UnalignedScalar;
60 typedef Eigen::TensorMap<Eigen::TensorFixedSize<const T, Eigen::Sizes<>,
61 Eigen::RowMajor, IndexType> >
62 UnalignedConstScalar;
63
64 // Rank-1 tensor (vector) of scalar type T.
65 typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
66 Eigen::Aligned>
67 Flat;
68 typedef Eigen::TensorMap<
69 Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>, Eigen::Aligned>
70 ConstFlat;
71 typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
72 Eigen::Aligned>
73 Vec;
74 typedef Eigen::TensorMap<
75 Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>, Eigen::Aligned>
76 ConstVec;
77
78 // Unaligned Rank-1 tensor (vector) of scalar type T.
79 typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType> >
80 UnalignedFlat;
81 typedef Eigen::TensorMap<
82 Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType> >
83 UnalignedConstFlat;
84 typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType> >
85 UnalignedVec;
86 typedef Eigen::TensorMap<
87 Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType> >
88 UnalignedConstVec;
89
90 // Rank-2 tensor (matrix) of scalar type T.
91 typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, IndexType>,
92 Eigen::Aligned>
93 Matrix;
94 typedef Eigen::TensorMap<
95 Eigen::Tensor<const T, 2, Eigen::RowMajor, IndexType>, Eigen::Aligned>
96 ConstMatrix;
97
98 // Unaligned Rank-2 tensor (matrix) of scalar type T.
99 typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, IndexType> >
100 UnalignedMatrix;
101 typedef Eigen::TensorMap<
102 Eigen::Tensor<const T, 2, Eigen::RowMajor, IndexType> >
103 UnalignedConstMatrix;
104 };
105
106 typedef typename TTypes<float, 1>::Tensor32Bit::Index Index32;
107
108 template <typename Index, int NumDims>
SafeFor32BitIndexing(const Eigen::DSizes<Index,NumDims> & in)109 bool SafeFor32BitIndexing(const Eigen::DSizes<Index, NumDims>& in) {
110 for (int i = 0; i < NumDims; ++i) {
111 if (in[i] > std::numeric_limits<Index32>::max()) return false;
112 }
113 return true;
114 }
115
116 template <typename Index, size_t NumDims>
SafeFor32BitIndexing(const Eigen::array<Index,NumDims> & in)117 bool SafeFor32BitIndexing(const Eigen::array<Index, NumDims>& in) {
118 for (size_t i = 0; i < NumDims; ++i) {
119 if (in[i] > std::numeric_limits<Index32>::max()) return false;
120 }
121 return true;
122 }
123
124 template <typename TensorType,
125 typename Enable = typename TTypes<
126 typename TensorType::Scalar, TensorType::NumIndices>::Tensor32Bit>
SafeFor32BitIndexing(TensorType in)127 bool SafeFor32BitIndexing(TensorType in) {
128 return in.size() <= std::numeric_limits<Index32>::max();
129 }
130
131 template <typename Index, int NumDims>
To32Bit(const Eigen::DSizes<Index,NumDims> & in)132 Eigen::DSizes<Index32, NumDims> To32Bit(
133 const Eigen::DSizes<Index, NumDims>& in) {
134 DCHECK(SafeFor32BitIndexing(in));
135 Eigen::DSizes<Index32, NumDims> out;
136 for (int i = 0; i < NumDims; ++i) {
137 out[i] = static_cast<Index32>(in[i]);
138 }
139 return out;
140 }
141
142 template <typename Index, size_t NumDims>
To32Bit(const Eigen::array<Index,NumDims> & in)143 Eigen::array<Index32, NumDims> To32Bit(const Eigen::array<Index, NumDims>& in) {
144 DCHECK(SafeFor32BitIndexing(in));
145 Eigen::array<Index32, NumDims> out;
146 for (size_t i = 0; i < NumDims; ++i) {
147 out[i] = static_cast<Index32>(in[i]);
148 }
149 return out;
150 }
151
152 template <typename TensorType>
153 typename TTypes<typename TensorType::Scalar,
154 TensorType::NumIndices>::Tensor32Bit
To32Bit(TensorType in)155 To32Bit(TensorType in) {
156 typedef typename TTypes<typename TensorType::Scalar,
157 TensorType::NumIndices>::Tensor32Bit RetType;
158 DCHECK(SafeFor32BitIndexing(in));
159 return RetType(in.data(), To32Bit(in.dimensions()));
160 }
161
162 namespace internal {
163
164 template <typename Device>
165 struct MaybeWith32BitIndexingImpl {
166 template <typename Func, typename... Args>
operatorMaybeWith32BitIndexingImpl167 void operator()(Func func, Args&&... args) const {
168 func(std::forward<Args>(args)...);
169 }
170 };
171
172 template <>
173 struct MaybeWith32BitIndexingImpl<Eigen::GpuDevice> {
174 template <typename Func, typename... Args>
175 void operator()(Func func, Args&&... args) const {
176 auto all = [](const auto&... bool_vals) {
177 for (bool b : {bool_vals...}) {
178 if (!b) return false;
179 }
180 return true;
181 };
182 if (all(SafeFor32BitIndexing(std::forward<Args>(args))...)) {
183 func(To32Bit(std::forward<Args>(args))...);
184 } else {
185 func(std::forward<Args>(args)...);
186 }
187 }
188 };
189
190 } // namespace internal
191
192 template <typename Device, typename Func, typename... Args>
193 void MaybeWith32BitIndexing(Func func, Args&&... args) {
194 return internal::MaybeWith32BitIndexingImpl<Device>()(
195 func, std::forward<Args>(args)...);
196 }
197
198 } // namespace tensorflow
199 #endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_TYPES_H_
200