xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/linalg/linalg_ops_common.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
17 
18 #include <initializer_list>
19 #include <utility>
20 
21 #include "third_party/eigen3/Eigen/Core"
22 #include "tensorflow/core/framework/device_base.h"
23 #include "tensorflow/core/framework/kernel_def_builder.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/tensor_shape.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/platform/errors.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/types.h"
31 
32 namespace tensorflow {
33 
34 // static
35 template <class InputScalar, class OutputScalar>
ValidateSingleMatrix(OpKernelContext * context,const TensorShapes & input_matrix_shapes)36 void LinearAlgebraOp<InputScalar, OutputScalar>::ValidateSingleMatrix(
37     OpKernelContext* context, const TensorShapes& input_matrix_shapes) {
38   OP_REQUIRES(context, input_matrix_shapes.size() == 1,
39               errors::InvalidArgument("Expected a single input matrix, got %d.",
40                                       input_matrix_shapes.size()));
41   OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_matrix_shapes[0]),
42               errors::InvalidArgument("Input must be a matrix."));
43 }
44 
45 // static
46 template <class InputScalar, class OutputScalar>
ValidateSingleSquareMatrix(OpKernelContext * context,const TensorShapes & input_matrix_shapes)47 void LinearAlgebraOp<InputScalar, OutputScalar>::ValidateSingleSquareMatrix(
48     OpKernelContext* context, const TensorShapes& input_matrix_shapes) {
49   OP_REQUIRES(context, input_matrix_shapes.size() == 1,
50               errors::InvalidArgument("Expected a single input matrix, got %d.",
51                                       input_matrix_shapes.size()));
52   OP_REQUIRES(context, TensorShapeUtils::IsSquareMatrix(input_matrix_shapes[0]),
53               errors::InvalidArgument("Input matrix must be square."));
54 }
55 
56 // static
57 template <class InputScalar, class OutputScalar>
ValidateSolver(OpKernelContext * context,const TensorShapes & input_matrix_shapes)58 void LinearAlgebraOp<InputScalar, OutputScalar>::ValidateSolver(
59     OpKernelContext* context, const TensorShapes& input_matrix_shapes) {
60   OP_REQUIRES(context, input_matrix_shapes.size() == 2,
61               errors::InvalidArgument("Expected two input matrices, got %d.",
62                                       input_matrix_shapes.size()));
63   OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_matrix_shapes[0]),
64               errors::InvalidArgument("First input (lhs) must be a matrix."));
65   OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_matrix_shapes[1]),
66               errors::InvalidArgument("Second input (rhs) must be a matrix."));
67   OP_REQUIRES(
68       context,
69       input_matrix_shapes[0].dim_size(0) == input_matrix_shapes[1].dim_size(0),
70       errors::InvalidArgument("Input matrix and rhs are incompatible."));
71 }
72 
73 // static
74 template <class InputScalar, class OutputScalar>
ValidateSquareSolver(OpKernelContext * context,const TensorShapes & input_matrix_shapes)75 void LinearAlgebraOp<InputScalar, OutputScalar>::ValidateSquareSolver(
76     OpKernelContext* context, const TensorShapes& input_matrix_shapes) {
77   OP_REQUIRES(context, input_matrix_shapes.size() == 2,
78               errors::InvalidArgument("Expected two input matrices, got %d.",
79                                       input_matrix_shapes.size()));
80   OP_REQUIRES(
81       context, TensorShapeUtils::IsSquareMatrix(input_matrix_shapes[0]),
82       errors::InvalidArgument("First input (lhs) must be a square matrix."));
83   OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_matrix_shapes[1]),
84               errors::InvalidArgument("Second input (rhs) must be a matrix."));
85   OP_REQUIRES(
86       context,
87       input_matrix_shapes[0].dim_size(0) == input_matrix_shapes[1].dim_size(0),
88       errors::InvalidArgument("Input matrix and rhs are incompatible."));
89 }
90 
91 template <class InputScalar, class OutputScalar>
Compute(OpKernelContext * context)92 void LinearAlgebraOp<InputScalar, OutputScalar>::Compute(
93     OpKernelContext* context) {
94   TensorInputs inputs;
95   TensorShapes input_matrix_shapes;
96   TensorShape batch_shape;
97   AnalyzeInputs(context, &inputs, &input_matrix_shapes, &batch_shape);
98   if (!context->status().ok()) return;
99 
100   TensorShapes output_matrix_shapes;
101   TensorOutputs outputs;
102   PrepareOutputs(context, input_matrix_shapes, batch_shape, &outputs,
103                  &output_matrix_shapes);
104   if (!context->status().ok()) return;
105 
106   // Process the individual matrix problems in parallel using a threadpool.
107   auto shard = [this, &inputs, &input_matrix_shapes, &outputs,
108                 &output_matrix_shapes, context](int64_t begin, int64_t end) {
109     for (int64_t i = begin; i < end; ++i) {
110       ComputeTensorSlice(context, i, inputs, input_matrix_shapes, outputs,
111                          output_matrix_shapes);
112     }
113   };
114   auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
115   Shard(worker_threads.num_threads, worker_threads.workers,
116         batch_shape.num_elements(), GetCostPerUnit(input_matrix_shapes), shard);
117 }
118 
119 template <class InputScalar, class OutputScalar>
AnalyzeInputs(OpKernelContext * context,TensorInputs * inputs,TensorShapes * input_matrix_shapes,TensorShape * batch_shape)120 void LinearAlgebraOp<InputScalar, OutputScalar>::AnalyzeInputs(
121     OpKernelContext* context, TensorInputs* inputs,
122     TensorShapes* input_matrix_shapes, TensorShape* batch_shape) {
123   int input_rank = -1;
124   for (int i = 0; i < NumMatrixInputs(context); ++i) {
125     const Tensor& in = context->input(i);
126     if (i == 0) {
127       input_rank = in.dims();
128       OP_REQUIRES(
129           context, input_rank >= 2,
130           errors::InvalidArgument("Input tensor ", i,
131                                   " must have rank >= 2, got ", input_rank));
132       // If the tensor rank is greater than 2, we consider the inner-most
133       // dimensions as matrices, and loop over all the other outer ("batch")
134       // dimensions to compute the results.
135       for (int dim = 0; dim < input_rank - 2; ++dim) {
136         batch_shape->AddDim(in.dim_size(dim));
137       }
138     } else {
139       // Make sure that all inputs have the same rank and outer dimensions.
140       OP_REQUIRES(context, input_rank == in.dims(),
141                   errors::InvalidArgument(
142                       "All input tensors must have the same rank."));
143       for (int dim = 0; dim < input_rank - 2; ++dim) {
144         OP_REQUIRES(
145             context, in.dim_size(dim) == batch_shape->dim_size(dim),
146             errors::InvalidArgument(
147                 "All input tensors must have the same outer dimensions."));
148       }
149     }
150 
151     const int row_dimension = input_rank - 2;
152     const int col_dimension = input_rank - 1;
153     const int64_t num_rows = in.dim_size(row_dimension);
154     const int64_t num_cols = in.dim_size(col_dimension);
155     input_matrix_shapes->emplace_back(
156         std::initializer_list<int64_t>({num_rows, num_cols}));
157     inputs->emplace_back(&in);
158     OP_REQUIRES(
159         context, in.dtype() == DataTypeToEnum<InputScalar>::v(),
160         errors::InvalidArgument("Invalid input dtype ", in.dtype(), " vs ",
161                                 DataTypeToEnum<InputScalar>::v()));
162   }
163   // Have the derived class validate that the inputs are as expected.
164   ValidateInputMatrixShapes(context, *input_matrix_shapes);
165 }
166 
167 template <class InputScalar, class OutputScalar>
PrepareOutputs(OpKernelContext * context,const TensorShapes & input_matrix_shapes,const TensorShape & batch_shape,TensorOutputs * outputs,TensorShapes * output_matrix_shapes)168 void LinearAlgebraOp<InputScalar, OutputScalar>::PrepareOutputs(
169     OpKernelContext* context, const TensorShapes& input_matrix_shapes,
170     const TensorShape& batch_shape, TensorOutputs* outputs,
171     TensorShapes* output_matrix_shapes) {
172   // Get shape for each of the matrix outputs produced by the derived class.
173   *output_matrix_shapes = GetOutputMatrixShapes(input_matrix_shapes);
174   const int num_outputs = output_matrix_shapes->size();
175 
176   // Make sure the number of op outputs is what the derived class expects.
177   OP_REQUIRES(
178       context, num_outputs <= context->num_outputs(),
179       errors::Internal(
180           "Derived class expected more outputs (%d) that the op has (%d).",
181           num_outputs, context->num_outputs()));
182 
183   // Allocate outputs.
184   std::set<int> unused_inputs;
185   for (int input_idx = 0; input_idx < context->num_inputs(); ++input_idx) {
186     unused_inputs.insert(input_idx);
187   }
188   for (int output_idx = 0; output_idx < context->num_outputs(); ++output_idx) {
189     TensorShape output_tensor_shape({});
190     if (output_idx < num_outputs) {
191       // This output is used, set up output shape and allocate it.
192       const TensorShape& output_matrix_shape =
193           output_matrix_shapes->at(output_idx);
194       OP_REQUIRES(context, output_matrix_shape.dims() <= 2,
195                   errors::InvalidArgument(
196                       "Rank of matrix output no. %d must be 0, 1 or 2, got %d.",
197                       output_idx, output_matrix_shape.dims()));
198 
199       // The final output has the shape of the outer batch dimensions
200       // concatenated with the output_matrix_shape (if the output is not
201       // scalar).
202       output_tensor_shape = batch_shape;
203       output_tensor_shape.AppendShape(output_matrix_shape);
204     }
205     Tensor* out = nullptr;
206     // See if there is an input buffer we can reuse for this output.
207     bool reused_input = false;
208     if (EnableInputForwarding()) {
209       for (int input_idx : unused_inputs) {
210         if (context->forward_input_to_output_with_shape(
211                 input_idx, output_idx, output_tensor_shape, &out)) {
212           reused_input = true;
213           unused_inputs.erase(input_idx);
214           break;
215         }
216       }
217     }
218     if (!reused_input) {
219       OP_REQUIRES_OK(context, context->allocate_output(
220                                   output_idx, output_tensor_shape, &out));
221     }
222     OP_REQUIRES(
223         context, out->dtype() == DataTypeToEnum<OutputScalar>::v(),
224         errors::InvalidArgument("Invalid output dtype ", out->dtype(), " vs ",
225                                 DataTypeToEnum<OutputScalar>::v()));
226 
227     outputs->emplace_back(out);
228   }
229 }
230 
231 template <class InputScalar, class OutputScalar>
ComputeTensorSlice(OpKernelContext * context,int64_t matrix_index,const TensorInputs & inputs,const TensorShapes & input_matrix_shapes,const TensorOutputs & outputs,const TensorShapes & output_matrix_shapes)232 void LinearAlgebraOp<InputScalar, OutputScalar>::ComputeTensorSlice(
233     OpKernelContext* context, int64_t matrix_index, const TensorInputs& inputs,
234     const TensorShapes& input_matrix_shapes, const TensorOutputs& outputs,
235     const TensorShapes& output_matrix_shapes) {
236   InputConstMatrixMaps matrix_inputs;
237   for (size_t i = 0; i < inputs.size(); ++i) {
238     // TODO(kalakris): Handle alignment if possible. Eigen::Map is
239     // unaligned by default.
240     matrix_inputs.emplace_back(
241         inputs[i]->flat<InputScalar>().data() +
242             matrix_index * input_matrix_shapes[i].num_elements(),
243         input_matrix_shapes[i].dim_size(0), input_matrix_shapes[i].dim_size(1));
244   }
245 
246   OutputMatrixMaps matrix_outputs;
247   for (size_t i = 0; i < output_matrix_shapes.size(); ++i) {
248     // The output matrix shape may not be a matrix.
249     int num_output_rows = output_matrix_shapes[i].dims() >= 1
250                               ? output_matrix_shapes[i].dim_size(0)
251                               : 1;
252     int num_output_cols = output_matrix_shapes[i].dims() == 2
253                               ? output_matrix_shapes[i].dim_size(1)
254                               : 1;
255     matrix_outputs.emplace_back(
256         outputs[i]->flat<OutputScalar>().data() +
257             matrix_index * output_matrix_shapes[i].num_elements(),
258         num_output_rows, num_output_cols);
259   }
260   ComputeMatrix(context, matrix_inputs, &matrix_outputs);
261 }
262 
263 // Explicitly instantiate LinearAlgebraOp for the scalar types we expect to use.
264 template class LinearAlgebraOp<Eigen::half>;
265 template class LinearAlgebraOp<float>;
266 template class LinearAlgebraOp<double>;
267 template class LinearAlgebraOp<complex64>;
268 template class LinearAlgebraOp<complex128>;
269 template class LinearAlgebraOp<float, complex64>;
270 template class LinearAlgebraOp<double, complex128>;
271 
272 }  // namespace tensorflow
273