xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/linalg/linalg_ops_common.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_KERNELS_LINALG_LINALG_OPS_COMMON_H_
16 #define TENSORFLOW_CORE_KERNELS_LINALG_LINALG_OPS_COMMON_H_
17 
18 // Classes to support linear algebra functionality, similar to the numpy.linalg
19 // module. Supports batch computation on several matrices at once, sharding the
20 // computations across different threads if necessary.
21 #include <algorithm>
22 
23 #include "third_party/eigen3/Eigen/Core"
24 #include "tensorflow/core/framework/kernel_def_builder.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/framework/tensor_types.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/gtl/inlined_vector.h"
32 #include "tensorflow/core/platform/types.h"
33 #include "tensorflow/core/util/work_sharder.h"
34 
35 namespace tensorflow {
36 
37 // Base class for linear algebra operators.
38 template <class InputScalar, class OutputScalar = InputScalar>
39 class LinearAlgebraOp : public OpKernel {
40  public:
LinearAlgebraOp(OpKernelConstruction * context)41   explicit LinearAlgebraOp(OpKernelConstruction* context) : OpKernel(context) {}
42 
43   void Compute(OpKernelContext* context) override;
44 
45  protected:
46   using TensorShapes = gtl::InlinedVector<TensorShape, 4>;
47   // Returns the number of leading inputs that are to be treated as matrix
48   // inputs. By default this is all the inputs. Derived classes can override
49   // this to tell the base class to ignore one or more trailing inputs.
NumMatrixInputs(const OpKernelContext * context)50   virtual int NumMatrixInputs(const OpKernelContext* context) const {
51     return context->num_inputs();
52   }
53 
54   // Returns true if the number of inputs and their shapes are as expected.
55   // Many ops take a single square input matrix, so we provide that as a default
56   // implementation for convenience.
ValidateInputMatrixShapes(OpKernelContext * context,const TensorShapes & input_matrix_shapes)57   virtual void ValidateInputMatrixShapes(
58       OpKernelContext* context, const TensorShapes& input_matrix_shapes) const {
59     ValidateSingleSquareMatrix(context, input_matrix_shapes);
60   }
61 
62   // Convenience validators for common cases:
63   //
64   // Validate op taking a single matrix A.
65   static void ValidateSingleMatrix(OpKernelContext* context,
66                                    const TensorShapes& input_matrix_shapes);
67   // Validate op taking a single square matrix A.
68   static void ValidateSingleSquareMatrix(
69       OpKernelContext* context, const TensorShapes& input_matrix_shapes);
70   // Validate op taking two matrices A and B that have the same number of rows.
71   static void ValidateSolver(OpKernelContext* context,
72                              const TensorShapes& input_matrix_shapes);
73   // Validate op taking two matrices A and B that have the same number of rows
74   // and A is square.
75   static void ValidateSquareSolver(OpKernelContext* context,
76                                    const TensorShapes& input_matrix_shapes);
77 
78   // Returns the output shapes of each individual matrix operation. Output
79   // matrices shapes must be rank 0, 1, or 2. Scalar outputs are rank 0.
80   //
81   // The derived class may return a number of shapes (N) less than
82   // context->num_outputs() (M) to indicate that a only leading subset of
83   // the outputs will be populated. In this case, a dummy scalar tensor with
84   // value zero will be return for the last M-N outputs.
85   //
86   // For many ops, the output dimensions are the same as the input dimensions,
87   // so we provide that as a default implementation for convenience.
GetOutputMatrixShapes(const TensorShapes & input_matrix_shapes)88   virtual TensorShapes GetOutputMatrixShapes(
89       const TensorShapes& input_matrix_shapes) const {
90     return input_matrix_shapes;
91   }
92 
93   // Returns the cost per matrix operation. This is used to determine the
94   // number of threads to use for parallelizing calls to ComputeMatrix in
95   // batch mode. Cost per unit is assumed to be roughly 1ns, based on comments
96   // in core/util/work_sharder.cc. Many linear algebra ops take roughly max(m,n)
97   // * min(m,n)^2, where the first input matrix is m-by-n. We provide that as a
98   // default implementation for convenience.
GetCostPerUnit(const TensorShapes & input_matrix_shapes)99   virtual int64_t GetCostPerUnit(
100       const TensorShapes& input_matrix_shapes) const {
101     double m = static_cast<double>(input_matrix_shapes[0].dim_size(0));
102     double n = static_cast<double>(input_matrix_shapes[0].dim_size(1));
103     double cost = std::max(m, n) * std::min(m, n) * std::min(m, n);
104     return cost >= static_cast<double>(kint64max) ? kint64max
105                                                   : static_cast<int64_t>(cost);
106   }
107 
108   // Returns true if it is safe to forward (alias) input to output buffer
109   // and expect the kernel to perform the computation inplace.
EnableInputForwarding()110   virtual bool EnableInputForwarding() const { return true; }
111 
112   using InputMatrix = Eigen::Matrix<InputScalar, Eigen::Dynamic, Eigen::Dynamic,
113                                     Eigen::RowMajor>;
114   using InputConstMatrixMap = Eigen::Map<const InputMatrix>;
115   using InputMatrixMap = Eigen::Map<InputMatrix>;
116   using InputConstVectorMap =
117       Eigen::Map<const Eigen::Matrix<InputScalar, 1, Eigen::Dynamic>>;
118   using InputConstMatrixMaps = gtl::InlinedVector<InputConstMatrixMap, 4>;
119   using InputMatrixMaps = gtl::InlinedVector<InputMatrixMap, 4>;
120   using InputRealScalar = typename Eigen::NumTraits<InputScalar>::Real;
121 
122   using OutputMatrix = Eigen::Matrix<OutputScalar, Eigen::Dynamic,
123                                      Eigen::Dynamic, Eigen::RowMajor>;
124   using OutputConstMatrixMap = Eigen::Map<const OutputMatrix>;
125   using OutputMatrixMap = Eigen::Map<OutputMatrix>;
126   using OutputConstVectorMap =
127       Eigen::Map<const Eigen::Matrix<OutputScalar, 1, Eigen::Dynamic>>;
128   using OutputConstMatrixMaps = gtl::InlinedVector<OutputConstMatrixMap, 4>;
129   using OutputMatrixMaps = gtl::InlinedVector<OutputMatrixMap, 4>;
130   using OutputRealScalar = typename Eigen::NumTraits<OutputScalar>::Real;
131 
132   // backward compatibility
133   using Scalar = OutputScalar;
134   using Matrix =
135       Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
136   using ConstMatrixMap = Eigen::Map<const Matrix>;
137   using MatrixMap = Eigen::Map<Matrix>;
138   using ConstVectorMap =
139       Eigen::Map<const Eigen::Matrix<Scalar, 1, Eigen::Dynamic>>;
140   using ConstMatrixMaps = gtl::InlinedVector<ConstMatrixMap, 4>;
141   using MatrixMaps = gtl::InlinedVector<MatrixMap, 4>;
142   using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
143 
144   // Performs a single matrix computation given input matrices, and
145   // stores the result in outputs. For batch operations, this will be called
146   // repeatedly for a single call to Compute() when multiple matrices exist in
147   // input Tensors with rank > 2. In this case the calls to ComputeMatrix are
148   // parallelized. The number of threads used is determined by a cost model from
149   // the value returned by GetCostPerUnit().
150   virtual void ComputeMatrix(OpKernelContext* context,
151                              const InputConstMatrixMaps& inputs,
152                              OutputMatrixMaps* outputs) = 0;
153 
154  private:
155   using TensorInputs = gtl::InlinedVector<const Tensor*, 4>;
156   using TensorOutputs = gtl::InlinedVector<Tensor*, 4>;
157   // This function maps 2-d slices (matrices) of the input and output tensors
158   // using Eigen::Map and calls ComputeMatrix implemented in terms of the
159   // Eigen::MatrixBase API by the derived class.
160   //
161   // The 'matrix_index' parameter specifies the index of the matrix to be used
162   // from each input tensor, and the index of the matrix to be written to each
163   // output tensor. The input matrices are in row major order, and located at
164   // the memory addresses
165   //   inputs[i].flat<Scalar>().data() +
166   //   matrix_index * input_matrix_shapes[i].num_elements()
167   // for i in 0...inputs.size()-1.
168   // The output matrices are in row major order, and located at the memory
169   // address
170   //   outputs[i]->flat<Scalar>().data() +
171   //   matrix_index * output_matrix_shapes[i].num_elements().
172   // for i in 0...outputs.size()-1.
173   //
174   void ComputeTensorSlice(OpKernelContext* context, int64_t matrix_index,
175                           const TensorInputs& inputs,
176                           const TensorShapes& input_matrix_shapes,
177                           const TensorOutputs& outputs,
178                           const TensorShapes& output_matrix_shapes);
179 
180   void AnalyzeInputs(OpKernelContext* context, TensorInputs* inputs,
181                      TensorShapes* input_matrix_shapes,
182                      TensorShape* batch_shape);
183 
184   void PrepareOutputs(OpKernelContext* context,
185                       const TensorShapes& input_matrix_shapes,
186                       const TensorShape& batch_shape, TensorOutputs* outputs,
187                       TensorShapes* output_matrix_shapes);
188 };
189 
190 // Declare LinearAlgebraOp, which is explicitly instantiated in
191 // linalg_ops_common.cc for half,float, double, complex64, and complex128.
192 extern template class LinearAlgebraOp<Eigen::half>;
193 extern template class LinearAlgebraOp<float>;
194 extern template class LinearAlgebraOp<double>;
195 extern template class LinearAlgebraOp<complex64>;
196 extern template class LinearAlgebraOp<complex128>;
197 
198 }  // namespace tensorflow
199 
200 #define INHERIT_LINALG_TYPEDEFS(Scalar)                       \
201   typedef LinearAlgebraOp<Scalar> Base;                       \
202   using RealScalar = typename Eigen::NumTraits<Scalar>::Real; \
203   using Matrix = typename Base::Matrix;                       \
204   using MatrixMap = typename Base::MatrixMap;                 \
205   using MatrixMaps = typename Base::MatrixMaps;               \
206   using ConstMatrixMap = typename Base::ConstMatrixMap;       \
207   using ConstMatrixMaps = typename Base::ConstMatrixMaps;     \
208   using ConstVectorMap = typename Base::ConstVectorMap;       \
209   using TensorShapes = typename Base::TensorShapes;
210 
211 #define REGISTER_LINALG_OP_CPU(OpName, OpClass, Scalar) \
212   REGISTER_KERNEL_BUILDER(                              \
213       Name(OpName).Device(DEVICE_CPU).TypeConstraint<Scalar>("T"), OpClass)
214 
215 #define REGISTER_LINALG_OP_GPU(OpName, OpClass, Scalar) \
216   REGISTER_KERNEL_BUILDER(                              \
217       Name(OpName).Device(DEVICE_GPU).TypeConstraint<Scalar>("T"), OpClass)
218 
219 // Deprecated, use one of the device-specific macros above.
220 #define REGISTER_LINALG_OP(OpName, OpClass, Scalar) \
221   REGISTER_LINALG_OP_CPU(OpName, OpClass, Scalar)
222 
223 #endif  // TENSORFLOW_CORE_KERNELS_LINALG_LINALG_OPS_COMMON_H_
224