1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_KERNELS_LINALG_LINALG_OPS_COMMON_H_ 16 #define TENSORFLOW_CORE_KERNELS_LINALG_LINALG_OPS_COMMON_H_ 17 18 // Classes to support linear algebra functionality, similar to the numpy.linalg 19 // module. Supports batch computation on several matrices at once, sharding the 20 // computations across different threads if necessary. 21 #include <algorithm> 22 23 #include "third_party/eigen3/Eigen/Core" 24 #include "tensorflow/core/framework/kernel_def_builder.h" 25 #include "tensorflow/core/framework/op_kernel.h" 26 #include "tensorflow/core/framework/tensor.h" 27 #include "tensorflow/core/framework/tensor_shape.h" 28 #include "tensorflow/core/framework/tensor_types.h" 29 #include "tensorflow/core/framework/types.h" 30 #include "tensorflow/core/lib/core/errors.h" 31 #include "tensorflow/core/lib/gtl/inlined_vector.h" 32 #include "tensorflow/core/platform/types.h" 33 #include "tensorflow/core/util/work_sharder.h" 34 35 namespace tensorflow { 36 37 // Base class for linear algebra operators. 38 template <class InputScalar, class OutputScalar = InputScalar> 39 class LinearAlgebraOp : public OpKernel { 40 public: LinearAlgebraOp(OpKernelConstruction * context)41 explicit LinearAlgebraOp(OpKernelConstruction* context) : OpKernel(context) {} 42 43 void Compute(OpKernelContext* context) override; 44 45 protected: 46 using TensorShapes = gtl::InlinedVector<TensorShape, 4>; 47 // Returns the number of leading inputs that are to be treated as matrix 48 // inputs. By default this is all the inputs. Derived classes can override 49 // this to tell the base class to ignore one or more trailing inputs. NumMatrixInputs(const OpKernelContext * context)50 virtual int NumMatrixInputs(const OpKernelContext* context) const { 51 return context->num_inputs(); 52 } 53 54 // Returns true if the number of inputs and their shapes are as expected. 55 // Many ops take a single square input matrix, so we provide that as a default 56 // implementation for convenience. ValidateInputMatrixShapes(OpKernelContext * context,const TensorShapes & input_matrix_shapes)57 virtual void ValidateInputMatrixShapes( 58 OpKernelContext* context, const TensorShapes& input_matrix_shapes) const { 59 ValidateSingleSquareMatrix(context, input_matrix_shapes); 60 } 61 62 // Convenience validators for common cases: 63 // 64 // Validate op taking a single matrix A. 65 static void ValidateSingleMatrix(OpKernelContext* context, 66 const TensorShapes& input_matrix_shapes); 67 // Validate op taking a single square matrix A. 68 static void ValidateSingleSquareMatrix( 69 OpKernelContext* context, const TensorShapes& input_matrix_shapes); 70 // Validate op taking two matrices A and B that have the same number of rows. 71 static void ValidateSolver(OpKernelContext* context, 72 const TensorShapes& input_matrix_shapes); 73 // Validate op taking two matrices A and B that have the same number of rows 74 // and A is square. 75 static void ValidateSquareSolver(OpKernelContext* context, 76 const TensorShapes& input_matrix_shapes); 77 78 // Returns the output shapes of each individual matrix operation. Output 79 // matrices shapes must be rank 0, 1, or 2. Scalar outputs are rank 0. 80 // 81 // The derived class may return a number of shapes (N) less than 82 // context->num_outputs() (M) to indicate that a only leading subset of 83 // the outputs will be populated. In this case, a dummy scalar tensor with 84 // value zero will be return for the last M-N outputs. 85 // 86 // For many ops, the output dimensions are the same as the input dimensions, 87 // so we provide that as a default implementation for convenience. GetOutputMatrixShapes(const TensorShapes & input_matrix_shapes)88 virtual TensorShapes GetOutputMatrixShapes( 89 const TensorShapes& input_matrix_shapes) const { 90 return input_matrix_shapes; 91 } 92 93 // Returns the cost per matrix operation. This is used to determine the 94 // number of threads to use for parallelizing calls to ComputeMatrix in 95 // batch mode. Cost per unit is assumed to be roughly 1ns, based on comments 96 // in core/util/work_sharder.cc. Many linear algebra ops take roughly max(m,n) 97 // * min(m,n)^2, where the first input matrix is m-by-n. We provide that as a 98 // default implementation for convenience. GetCostPerUnit(const TensorShapes & input_matrix_shapes)99 virtual int64_t GetCostPerUnit( 100 const TensorShapes& input_matrix_shapes) const { 101 double m = static_cast<double>(input_matrix_shapes[0].dim_size(0)); 102 double n = static_cast<double>(input_matrix_shapes[0].dim_size(1)); 103 double cost = std::max(m, n) * std::min(m, n) * std::min(m, n); 104 return cost >= static_cast<double>(kint64max) ? kint64max 105 : static_cast<int64_t>(cost); 106 } 107 108 // Returns true if it is safe to forward (alias) input to output buffer 109 // and expect the kernel to perform the computation inplace. EnableInputForwarding()110 virtual bool EnableInputForwarding() const { return true; } 111 112 using InputMatrix = Eigen::Matrix<InputScalar, Eigen::Dynamic, Eigen::Dynamic, 113 Eigen::RowMajor>; 114 using InputConstMatrixMap = Eigen::Map<const InputMatrix>; 115 using InputMatrixMap = Eigen::Map<InputMatrix>; 116 using InputConstVectorMap = 117 Eigen::Map<const Eigen::Matrix<InputScalar, 1, Eigen::Dynamic>>; 118 using InputConstMatrixMaps = gtl::InlinedVector<InputConstMatrixMap, 4>; 119 using InputMatrixMaps = gtl::InlinedVector<InputMatrixMap, 4>; 120 using InputRealScalar = typename Eigen::NumTraits<InputScalar>::Real; 121 122 using OutputMatrix = Eigen::Matrix<OutputScalar, Eigen::Dynamic, 123 Eigen::Dynamic, Eigen::RowMajor>; 124 using OutputConstMatrixMap = Eigen::Map<const OutputMatrix>; 125 using OutputMatrixMap = Eigen::Map<OutputMatrix>; 126 using OutputConstVectorMap = 127 Eigen::Map<const Eigen::Matrix<OutputScalar, 1, Eigen::Dynamic>>; 128 using OutputConstMatrixMaps = gtl::InlinedVector<OutputConstMatrixMap, 4>; 129 using OutputMatrixMaps = gtl::InlinedVector<OutputMatrixMap, 4>; 130 using OutputRealScalar = typename Eigen::NumTraits<OutputScalar>::Real; 131 132 // backward compatibility 133 using Scalar = OutputScalar; 134 using Matrix = 135 Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>; 136 using ConstMatrixMap = Eigen::Map<const Matrix>; 137 using MatrixMap = Eigen::Map<Matrix>; 138 using ConstVectorMap = 139 Eigen::Map<const Eigen::Matrix<Scalar, 1, Eigen::Dynamic>>; 140 using ConstMatrixMaps = gtl::InlinedVector<ConstMatrixMap, 4>; 141 using MatrixMaps = gtl::InlinedVector<MatrixMap, 4>; 142 using RealScalar = typename Eigen::NumTraits<Scalar>::Real; 143 144 // Performs a single matrix computation given input matrices, and 145 // stores the result in outputs. For batch operations, this will be called 146 // repeatedly for a single call to Compute() when multiple matrices exist in 147 // input Tensors with rank > 2. In this case the calls to ComputeMatrix are 148 // parallelized. The number of threads used is determined by a cost model from 149 // the value returned by GetCostPerUnit(). 150 virtual void ComputeMatrix(OpKernelContext* context, 151 const InputConstMatrixMaps& inputs, 152 OutputMatrixMaps* outputs) = 0; 153 154 private: 155 using TensorInputs = gtl::InlinedVector<const Tensor*, 4>; 156 using TensorOutputs = gtl::InlinedVector<Tensor*, 4>; 157 // This function maps 2-d slices (matrices) of the input and output tensors 158 // using Eigen::Map and calls ComputeMatrix implemented in terms of the 159 // Eigen::MatrixBase API by the derived class. 160 // 161 // The 'matrix_index' parameter specifies the index of the matrix to be used 162 // from each input tensor, and the index of the matrix to be written to each 163 // output tensor. The input matrices are in row major order, and located at 164 // the memory addresses 165 // inputs[i].flat<Scalar>().data() + 166 // matrix_index * input_matrix_shapes[i].num_elements() 167 // for i in 0...inputs.size()-1. 168 // The output matrices are in row major order, and located at the memory 169 // address 170 // outputs[i]->flat<Scalar>().data() + 171 // matrix_index * output_matrix_shapes[i].num_elements(). 172 // for i in 0...outputs.size()-1. 173 // 174 void ComputeTensorSlice(OpKernelContext* context, int64_t matrix_index, 175 const TensorInputs& inputs, 176 const TensorShapes& input_matrix_shapes, 177 const TensorOutputs& outputs, 178 const TensorShapes& output_matrix_shapes); 179 180 void AnalyzeInputs(OpKernelContext* context, TensorInputs* inputs, 181 TensorShapes* input_matrix_shapes, 182 TensorShape* batch_shape); 183 184 void PrepareOutputs(OpKernelContext* context, 185 const TensorShapes& input_matrix_shapes, 186 const TensorShape& batch_shape, TensorOutputs* outputs, 187 TensorShapes* output_matrix_shapes); 188 }; 189 190 // Declare LinearAlgebraOp, which is explicitly instantiated in 191 // linalg_ops_common.cc for half,float, double, complex64, and complex128. 192 extern template class LinearAlgebraOp<Eigen::half>; 193 extern template class LinearAlgebraOp<float>; 194 extern template class LinearAlgebraOp<double>; 195 extern template class LinearAlgebraOp<complex64>; 196 extern template class LinearAlgebraOp<complex128>; 197 198 } // namespace tensorflow 199 200 #define INHERIT_LINALG_TYPEDEFS(Scalar) \ 201 typedef LinearAlgebraOp<Scalar> Base; \ 202 using RealScalar = typename Eigen::NumTraits<Scalar>::Real; \ 203 using Matrix = typename Base::Matrix; \ 204 using MatrixMap = typename Base::MatrixMap; \ 205 using MatrixMaps = typename Base::MatrixMaps; \ 206 using ConstMatrixMap = typename Base::ConstMatrixMap; \ 207 using ConstMatrixMaps = typename Base::ConstMatrixMaps; \ 208 using ConstVectorMap = typename Base::ConstVectorMap; \ 209 using TensorShapes = typename Base::TensorShapes; 210 211 #define REGISTER_LINALG_OP_CPU(OpName, OpClass, Scalar) \ 212 REGISTER_KERNEL_BUILDER( \ 213 Name(OpName).Device(DEVICE_CPU).TypeConstraint<Scalar>("T"), OpClass) 214 215 #define REGISTER_LINALG_OP_GPU(OpName, OpClass, Scalar) \ 216 REGISTER_KERNEL_BUILDER( \ 217 Name(OpName).Device(DEVICE_GPU).TypeConstraint<Scalar>("T"), OpClass) 218 219 // Deprecated, use one of the device-specific macros above. 220 #define REGISTER_LINALG_OP(OpName, OpClass, Scalar) \ 221 REGISTER_LINALG_OP_CPU(OpName, OpClass, Scalar) 222 223 #endif // TENSORFLOW_CORE_KERNELS_LINALG_LINALG_OPS_COMMON_H_ 224