xref: /aosp_15_r20/external/tensorflow/tensorflow/core/framework/tensor_types.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_TYPES_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_TENSOR_TYPES_H_
18 
19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20 #include "tensorflow/core/platform/logging.h"
21 
22 namespace tensorflow {
23 
24 // Helper to define Tensor types given that the scalar is of type T.
25 template <typename T, int NDIMS = 1, typename IndexType = Eigen::DenseIndex>
26 struct TTypes {
27   // Rank-<NDIMS> tensor of scalar type T.
28   typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor, IndexType>,
29                            Eigen::Aligned>
30       Tensor;
31   typedef Eigen::TensorMap<
32       Eigen::Tensor<const T, NDIMS, Eigen::RowMajor, IndexType>, Eigen::Aligned>
33       ConstTensor;
34 
35   // Unaligned Rank-<NDIMS> tensor of scalar type T.
36   typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor, IndexType> >
37       UnalignedTensor;
38   typedef Eigen::TensorMap<
39       Eigen::Tensor<const T, NDIMS, Eigen::RowMajor, IndexType> >
40       UnalignedConstTensor;
41 
42   typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor, int>,
43                            Eigen::Aligned>
44       Tensor32Bit;
45 
46   // Scalar tensor (implemented as a rank-0 tensor) of scalar type T.
47   typedef Eigen::TensorMap<
48       Eigen::TensorFixedSize<T, Eigen::Sizes<>, Eigen::RowMajor, IndexType>,
49       Eigen::Aligned>
50       Scalar;
51   typedef Eigen::TensorMap<Eigen::TensorFixedSize<const T, Eigen::Sizes<>,
52                                                   Eigen::RowMajor, IndexType>,
53                            Eigen::Aligned>
54       ConstScalar;
55 
56   // Unaligned Scalar tensor of scalar type T.
57   typedef Eigen::TensorMap<
58       Eigen::TensorFixedSize<T, Eigen::Sizes<>, Eigen::RowMajor, IndexType> >
59       UnalignedScalar;
60   typedef Eigen::TensorMap<Eigen::TensorFixedSize<const T, Eigen::Sizes<>,
61                                                   Eigen::RowMajor, IndexType> >
62       UnalignedConstScalar;
63 
64   // Rank-1 tensor (vector) of scalar type T.
65   typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
66                            Eigen::Aligned>
67       Flat;
68   typedef Eigen::TensorMap<
69       Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>, Eigen::Aligned>
70       ConstFlat;
71   typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
72                            Eigen::Aligned>
73       Vec;
74   typedef Eigen::TensorMap<
75       Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>, Eigen::Aligned>
76       ConstVec;
77 
78   // Unaligned Rank-1 tensor (vector) of scalar type T.
79   typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType> >
80       UnalignedFlat;
81   typedef Eigen::TensorMap<
82       Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType> >
83       UnalignedConstFlat;
84   typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType> >
85       UnalignedVec;
86   typedef Eigen::TensorMap<
87       Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType> >
88       UnalignedConstVec;
89 
90   // Rank-2 tensor (matrix) of scalar type T.
91   typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, IndexType>,
92                            Eigen::Aligned>
93       Matrix;
94   typedef Eigen::TensorMap<
95       Eigen::Tensor<const T, 2, Eigen::RowMajor, IndexType>, Eigen::Aligned>
96       ConstMatrix;
97 
98   // Unaligned Rank-2 tensor (matrix) of scalar type T.
99   typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, IndexType> >
100       UnalignedMatrix;
101   typedef Eigen::TensorMap<
102       Eigen::Tensor<const T, 2, Eigen::RowMajor, IndexType> >
103       UnalignedConstMatrix;
104 };
105 
106 typedef typename TTypes<float, 1>::Tensor32Bit::Index Index32;
107 
108 template <typename Index, int NumDims>
SafeFor32BitIndexing(const Eigen::DSizes<Index,NumDims> & in)109 bool SafeFor32BitIndexing(const Eigen::DSizes<Index, NumDims>& in) {
110   for (int i = 0; i < NumDims; ++i) {
111     if (in[i] > std::numeric_limits<Index32>::max()) return false;
112   }
113   return true;
114 }
115 
116 template <typename Index, size_t NumDims>
SafeFor32BitIndexing(const Eigen::array<Index,NumDims> & in)117 bool SafeFor32BitIndexing(const Eigen::array<Index, NumDims>& in) {
118   for (size_t i = 0; i < NumDims; ++i) {
119     if (in[i] > std::numeric_limits<Index32>::max()) return false;
120   }
121   return true;
122 }
123 
124 template <typename TensorType,
125           typename Enable = typename TTypes<
126               typename TensorType::Scalar, TensorType::NumIndices>::Tensor32Bit>
SafeFor32BitIndexing(TensorType in)127 bool SafeFor32BitIndexing(TensorType in) {
128   return in.size() <= std::numeric_limits<Index32>::max();
129 }
130 
131 template <typename Index, int NumDims>
To32Bit(const Eigen::DSizes<Index,NumDims> & in)132 Eigen::DSizes<Index32, NumDims> To32Bit(
133     const Eigen::DSizes<Index, NumDims>& in) {
134   DCHECK(SafeFor32BitIndexing(in));
135   Eigen::DSizes<Index32, NumDims> out;
136   for (int i = 0; i < NumDims; ++i) {
137     out[i] = static_cast<Index32>(in[i]);
138   }
139   return out;
140 }
141 
142 template <typename Index, size_t NumDims>
To32Bit(const Eigen::array<Index,NumDims> & in)143 Eigen::array<Index32, NumDims> To32Bit(const Eigen::array<Index, NumDims>& in) {
144   DCHECK(SafeFor32BitIndexing(in));
145   Eigen::array<Index32, NumDims> out;
146   for (size_t i = 0; i < NumDims; ++i) {
147     out[i] = static_cast<Index32>(in[i]);
148   }
149   return out;
150 }
151 
152 template <typename TensorType>
153 typename TTypes<typename TensorType::Scalar,
154                 TensorType::NumIndices>::Tensor32Bit
To32Bit(TensorType in)155 To32Bit(TensorType in) {
156   typedef typename TTypes<typename TensorType::Scalar,
157                           TensorType::NumIndices>::Tensor32Bit RetType;
158   DCHECK(SafeFor32BitIndexing(in));
159   return RetType(in.data(), To32Bit(in.dimensions()));
160 }
161 
162 namespace internal {
163 
164 template <typename Device>
165 struct MaybeWith32BitIndexingImpl {
166   template <typename Func, typename... Args>
operatorMaybeWith32BitIndexingImpl167   void operator()(Func func, Args&&... args) const {
168     func(std::forward<Args>(args)...);
169   }
170 };
171 
172 template <>
173 struct MaybeWith32BitIndexingImpl<Eigen::GpuDevice> {
174   template <typename Func, typename... Args>
175   void operator()(Func func, Args&&... args) const {
176     auto all = [](const auto&... bool_vals) {
177       for (bool b : {bool_vals...}) {
178         if (!b) return false;
179       }
180       return true;
181     };
182     if (all(SafeFor32BitIndexing(std::forward<Args>(args))...)) {
183       func(To32Bit(std::forward<Args>(args))...);
184     } else {
185       func(std::forward<Args>(args)...);
186     }
187   }
188 };
189 
190 }  // namespace internal
191 
192 template <typename Device, typename Func, typename... Args>
193 void MaybeWith32BitIndexing(Func func, Args&&... args) {
194   return internal::MaybeWith32BitIndexingImpl<Device>()(
195       func, std::forward<Args>(args)...);
196 }
197 
198 }  // namespace tensorflow
199 #endif  // TENSORFLOW_CORE_FRAMEWORK_TENSOR_TYPES_H_
200