1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
17
18 #include <initializer_list>
19 #include <utility>
20
21 #include "third_party/eigen3/Eigen/Core"
22 #include "tensorflow/core/framework/device_base.h"
23 #include "tensorflow/core/framework/kernel_def_builder.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/tensor_shape.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/platform/errors.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/types.h"
31
32 namespace tensorflow {
33
34 // static
35 template <class InputScalar, class OutputScalar>
ValidateSingleMatrix(OpKernelContext * context,const TensorShapes & input_matrix_shapes)36 void LinearAlgebraOp<InputScalar, OutputScalar>::ValidateSingleMatrix(
37 OpKernelContext* context, const TensorShapes& input_matrix_shapes) {
38 OP_REQUIRES(context, input_matrix_shapes.size() == 1,
39 errors::InvalidArgument("Expected a single input matrix, got %d.",
40 input_matrix_shapes.size()));
41 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_matrix_shapes[0]),
42 errors::InvalidArgument("Input must be a matrix."));
43 }
44
45 // static
46 template <class InputScalar, class OutputScalar>
ValidateSingleSquareMatrix(OpKernelContext * context,const TensorShapes & input_matrix_shapes)47 void LinearAlgebraOp<InputScalar, OutputScalar>::ValidateSingleSquareMatrix(
48 OpKernelContext* context, const TensorShapes& input_matrix_shapes) {
49 OP_REQUIRES(context, input_matrix_shapes.size() == 1,
50 errors::InvalidArgument("Expected a single input matrix, got %d.",
51 input_matrix_shapes.size()));
52 OP_REQUIRES(context, TensorShapeUtils::IsSquareMatrix(input_matrix_shapes[0]),
53 errors::InvalidArgument("Input matrix must be square."));
54 }
55
56 // static
57 template <class InputScalar, class OutputScalar>
ValidateSolver(OpKernelContext * context,const TensorShapes & input_matrix_shapes)58 void LinearAlgebraOp<InputScalar, OutputScalar>::ValidateSolver(
59 OpKernelContext* context, const TensorShapes& input_matrix_shapes) {
60 OP_REQUIRES(context, input_matrix_shapes.size() == 2,
61 errors::InvalidArgument("Expected two input matrices, got %d.",
62 input_matrix_shapes.size()));
63 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_matrix_shapes[0]),
64 errors::InvalidArgument("First input (lhs) must be a matrix."));
65 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_matrix_shapes[1]),
66 errors::InvalidArgument("Second input (rhs) must be a matrix."));
67 OP_REQUIRES(
68 context,
69 input_matrix_shapes[0].dim_size(0) == input_matrix_shapes[1].dim_size(0),
70 errors::InvalidArgument("Input matrix and rhs are incompatible."));
71 }
72
73 // static
74 template <class InputScalar, class OutputScalar>
ValidateSquareSolver(OpKernelContext * context,const TensorShapes & input_matrix_shapes)75 void LinearAlgebraOp<InputScalar, OutputScalar>::ValidateSquareSolver(
76 OpKernelContext* context, const TensorShapes& input_matrix_shapes) {
77 OP_REQUIRES(context, input_matrix_shapes.size() == 2,
78 errors::InvalidArgument("Expected two input matrices, got %d.",
79 input_matrix_shapes.size()));
80 OP_REQUIRES(
81 context, TensorShapeUtils::IsSquareMatrix(input_matrix_shapes[0]),
82 errors::InvalidArgument("First input (lhs) must be a square matrix."));
83 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_matrix_shapes[1]),
84 errors::InvalidArgument("Second input (rhs) must be a matrix."));
85 OP_REQUIRES(
86 context,
87 input_matrix_shapes[0].dim_size(0) == input_matrix_shapes[1].dim_size(0),
88 errors::InvalidArgument("Input matrix and rhs are incompatible."));
89 }
90
91 template <class InputScalar, class OutputScalar>
Compute(OpKernelContext * context)92 void LinearAlgebraOp<InputScalar, OutputScalar>::Compute(
93 OpKernelContext* context) {
94 TensorInputs inputs;
95 TensorShapes input_matrix_shapes;
96 TensorShape batch_shape;
97 AnalyzeInputs(context, &inputs, &input_matrix_shapes, &batch_shape);
98 if (!context->status().ok()) return;
99
100 TensorShapes output_matrix_shapes;
101 TensorOutputs outputs;
102 PrepareOutputs(context, input_matrix_shapes, batch_shape, &outputs,
103 &output_matrix_shapes);
104 if (!context->status().ok()) return;
105
106 // Process the individual matrix problems in parallel using a threadpool.
107 auto shard = [this, &inputs, &input_matrix_shapes, &outputs,
108 &output_matrix_shapes, context](int64_t begin, int64_t end) {
109 for (int64_t i = begin; i < end; ++i) {
110 ComputeTensorSlice(context, i, inputs, input_matrix_shapes, outputs,
111 output_matrix_shapes);
112 }
113 };
114 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
115 Shard(worker_threads.num_threads, worker_threads.workers,
116 batch_shape.num_elements(), GetCostPerUnit(input_matrix_shapes), shard);
117 }
118
119 template <class InputScalar, class OutputScalar>
AnalyzeInputs(OpKernelContext * context,TensorInputs * inputs,TensorShapes * input_matrix_shapes,TensorShape * batch_shape)120 void LinearAlgebraOp<InputScalar, OutputScalar>::AnalyzeInputs(
121 OpKernelContext* context, TensorInputs* inputs,
122 TensorShapes* input_matrix_shapes, TensorShape* batch_shape) {
123 int input_rank = -1;
124 for (int i = 0; i < NumMatrixInputs(context); ++i) {
125 const Tensor& in = context->input(i);
126 if (i == 0) {
127 input_rank = in.dims();
128 OP_REQUIRES(
129 context, input_rank >= 2,
130 errors::InvalidArgument("Input tensor ", i,
131 " must have rank >= 2, got ", input_rank));
132 // If the tensor rank is greater than 2, we consider the inner-most
133 // dimensions as matrices, and loop over all the other outer ("batch")
134 // dimensions to compute the results.
135 for (int dim = 0; dim < input_rank - 2; ++dim) {
136 batch_shape->AddDim(in.dim_size(dim));
137 }
138 } else {
139 // Make sure that all inputs have the same rank and outer dimensions.
140 OP_REQUIRES(context, input_rank == in.dims(),
141 errors::InvalidArgument(
142 "All input tensors must have the same rank."));
143 for (int dim = 0; dim < input_rank - 2; ++dim) {
144 OP_REQUIRES(
145 context, in.dim_size(dim) == batch_shape->dim_size(dim),
146 errors::InvalidArgument(
147 "All input tensors must have the same outer dimensions."));
148 }
149 }
150
151 const int row_dimension = input_rank - 2;
152 const int col_dimension = input_rank - 1;
153 const int64_t num_rows = in.dim_size(row_dimension);
154 const int64_t num_cols = in.dim_size(col_dimension);
155 input_matrix_shapes->emplace_back(
156 std::initializer_list<int64_t>({num_rows, num_cols}));
157 inputs->emplace_back(&in);
158 OP_REQUIRES(
159 context, in.dtype() == DataTypeToEnum<InputScalar>::v(),
160 errors::InvalidArgument("Invalid input dtype ", in.dtype(), " vs ",
161 DataTypeToEnum<InputScalar>::v()));
162 }
163 // Have the derived class validate that the inputs are as expected.
164 ValidateInputMatrixShapes(context, *input_matrix_shapes);
165 }
166
167 template <class InputScalar, class OutputScalar>
PrepareOutputs(OpKernelContext * context,const TensorShapes & input_matrix_shapes,const TensorShape & batch_shape,TensorOutputs * outputs,TensorShapes * output_matrix_shapes)168 void LinearAlgebraOp<InputScalar, OutputScalar>::PrepareOutputs(
169 OpKernelContext* context, const TensorShapes& input_matrix_shapes,
170 const TensorShape& batch_shape, TensorOutputs* outputs,
171 TensorShapes* output_matrix_shapes) {
172 // Get shape for each of the matrix outputs produced by the derived class.
173 *output_matrix_shapes = GetOutputMatrixShapes(input_matrix_shapes);
174 const int num_outputs = output_matrix_shapes->size();
175
176 // Make sure the number of op outputs is what the derived class expects.
177 OP_REQUIRES(
178 context, num_outputs <= context->num_outputs(),
179 errors::Internal(
180 "Derived class expected more outputs (%d) that the op has (%d).",
181 num_outputs, context->num_outputs()));
182
183 // Allocate outputs.
184 std::set<int> unused_inputs;
185 for (int input_idx = 0; input_idx < context->num_inputs(); ++input_idx) {
186 unused_inputs.insert(input_idx);
187 }
188 for (int output_idx = 0; output_idx < context->num_outputs(); ++output_idx) {
189 TensorShape output_tensor_shape({});
190 if (output_idx < num_outputs) {
191 // This output is used, set up output shape and allocate it.
192 const TensorShape& output_matrix_shape =
193 output_matrix_shapes->at(output_idx);
194 OP_REQUIRES(context, output_matrix_shape.dims() <= 2,
195 errors::InvalidArgument(
196 "Rank of matrix output no. %d must be 0, 1 or 2, got %d.",
197 output_idx, output_matrix_shape.dims()));
198
199 // The final output has the shape of the outer batch dimensions
200 // concatenated with the output_matrix_shape (if the output is not
201 // scalar).
202 output_tensor_shape = batch_shape;
203 output_tensor_shape.AppendShape(output_matrix_shape);
204 }
205 Tensor* out = nullptr;
206 // See if there is an input buffer we can reuse for this output.
207 bool reused_input = false;
208 if (EnableInputForwarding()) {
209 for (int input_idx : unused_inputs) {
210 if (context->forward_input_to_output_with_shape(
211 input_idx, output_idx, output_tensor_shape, &out)) {
212 reused_input = true;
213 unused_inputs.erase(input_idx);
214 break;
215 }
216 }
217 }
218 if (!reused_input) {
219 OP_REQUIRES_OK(context, context->allocate_output(
220 output_idx, output_tensor_shape, &out));
221 }
222 OP_REQUIRES(
223 context, out->dtype() == DataTypeToEnum<OutputScalar>::v(),
224 errors::InvalidArgument("Invalid output dtype ", out->dtype(), " vs ",
225 DataTypeToEnum<OutputScalar>::v()));
226
227 outputs->emplace_back(out);
228 }
229 }
230
231 template <class InputScalar, class OutputScalar>
ComputeTensorSlice(OpKernelContext * context,int64_t matrix_index,const TensorInputs & inputs,const TensorShapes & input_matrix_shapes,const TensorOutputs & outputs,const TensorShapes & output_matrix_shapes)232 void LinearAlgebraOp<InputScalar, OutputScalar>::ComputeTensorSlice(
233 OpKernelContext* context, int64_t matrix_index, const TensorInputs& inputs,
234 const TensorShapes& input_matrix_shapes, const TensorOutputs& outputs,
235 const TensorShapes& output_matrix_shapes) {
236 InputConstMatrixMaps matrix_inputs;
237 for (size_t i = 0; i < inputs.size(); ++i) {
238 // TODO(kalakris): Handle alignment if possible. Eigen::Map is
239 // unaligned by default.
240 matrix_inputs.emplace_back(
241 inputs[i]->flat<InputScalar>().data() +
242 matrix_index * input_matrix_shapes[i].num_elements(),
243 input_matrix_shapes[i].dim_size(0), input_matrix_shapes[i].dim_size(1));
244 }
245
246 OutputMatrixMaps matrix_outputs;
247 for (size_t i = 0; i < output_matrix_shapes.size(); ++i) {
248 // The output matrix shape may not be a matrix.
249 int num_output_rows = output_matrix_shapes[i].dims() >= 1
250 ? output_matrix_shapes[i].dim_size(0)
251 : 1;
252 int num_output_cols = output_matrix_shapes[i].dims() == 2
253 ? output_matrix_shapes[i].dim_size(1)
254 : 1;
255 matrix_outputs.emplace_back(
256 outputs[i]->flat<OutputScalar>().data() +
257 matrix_index * output_matrix_shapes[i].num_elements(),
258 num_output_rows, num_output_cols);
259 }
260 ComputeMatrix(context, matrix_inputs, &matrix_outputs);
261 }
262
263 // Explicitly instantiate LinearAlgebraOp for the scalar types we expect to use.
264 template class LinearAlgebraOp<Eigen::half>;
265 template class LinearAlgebraOp<float>;
266 template class LinearAlgebraOp<double>;
267 template class LinearAlgebraOp<complex64>;
268 template class LinearAlgebraOp<complex128>;
269 template class LinearAlgebraOp<float, complex64>;
270 template class LinearAlgebraOp<double, complex128>;
271
272 } // namespace tensorflow
273