xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/math_ops.cc.
17 
18 #define EIGEN_USE_THREADS
19 
20 #include "tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h"
21 
22 #include "tensorflow/core/framework/bounds_check.h"
23 #include "tensorflow/core/framework/op.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/kernels/fill_functor.h"
26 #include "tensorflow/core/platform/bfloat16.h"
27 
28 namespace tensorflow {
29 
30 typedef Eigen::ThreadPoolDevice CPUDevice;
31 typedef Eigen::GpuDevice GPUDevice;
32 
33 template <typename Device, typename T, typename Tindices>
34 class SparseTensorDenseMatMulOp : public OpKernel {
35  public:
SparseTensorDenseMatMulOp(OpKernelConstruction * ctx)36   explicit SparseTensorDenseMatMulOp(OpKernelConstruction* ctx)
37       : OpKernel(ctx) {
38     OP_REQUIRES_OK(ctx, ctx->GetAttr("adjoint_a", &adjoint_a_));
39     OP_REQUIRES_OK(ctx, ctx->GetAttr("adjoint_b", &adjoint_b_));
40   }
41 
Compute(OpKernelContext * ctx)42   void Compute(OpKernelContext* ctx) override {
43     const Tensor* a_indices;
44     const Tensor* a_values;
45     const Tensor* a_shape;
46     const Tensor* b;
47     OP_REQUIRES_OK(ctx, ctx->input("a_indices", &a_indices));
48     OP_REQUIRES_OK(ctx, ctx->input("a_values", &a_values));
49     OP_REQUIRES_OK(ctx, ctx->input("a_shape", &a_shape));
50     OP_REQUIRES_OK(ctx, ctx->input("b", &b));
51 
52     // Check that the dimensions of the two matrices are valid.
53     OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b->shape()),
54                 errors::InvalidArgument("Tensor 'b' is not a matrix"));
55 
56     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(a_shape->shape()),
57                 errors::InvalidArgument("Tensor 'a_shape' is not a vector"));
58 
59     OP_REQUIRES(
60         ctx, a_shape->NumElements() == 2,
61         errors::InvalidArgument("Tensor 'a_shape' must have 2 elements"));
62 
63     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(a_values->shape()),
64                 errors::InvalidArgument("Tensor 'a_values' is not a vector"));
65 
66     OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a_indices->shape()),
67                 errors::InvalidArgument("Tensor 'a_indices' is not a matrix"));
68 
69     const int64_t nnz = a_indices->shape().dim_size(0);
70     OP_REQUIRES(ctx, nnz == a_values->NumElements(),
71                 errors::InvalidArgument("Number of rows of a_indices does not "
72                                         "match number of entries in a_values"));
73 
74     OP_REQUIRES(
75         ctx, a_indices->shape().dim_size(1) == a_shape->NumElements(),
76         errors::InvalidArgument("Number of columns of a_indices does not match "
77                                 "number of entries in a_shape"));
78 
79     auto a_shape_t = a_shape->vec<int64_t>();
80     const int64_t outer_left = (adjoint_a_) ? a_shape_t(1) : a_shape_t(0);
81     const int64_t outer_right =
82         (adjoint_b_) ? b->shape().dim_size(0) : b->shape().dim_size(1);
83     const int64_t inner_left = (adjoint_a_) ? a_shape_t(0) : a_shape_t(1);
84     const int64_t inner_right =
85         (adjoint_b_) ? b->shape().dim_size(1) : b->shape().dim_size(0);
86 
87     OP_REQUIRES(
88         ctx, inner_right == inner_left,
89         errors::InvalidArgument(
90             "Cannot multiply A and B because inner dimension does not match: ",
91             inner_left, " vs. ", inner_right,
92             ".  Did you forget a transpose?  "
93             "Dimensions of A: [",
94             a_shape_t(0), ", ", a_shape_t(1),
95             ").  Dimensions of B: ", b->shape().DebugString()));
96 
97     if (std::is_same<Device, GPUDevice>::value) {
98       // The GPU implementation is optimized to use 32 bit indexing, so
99       // give a friendly error to the programmer early on if they
100       // exceed.
101       const int int32max = std::numeric_limits<int>::max();
102       OP_REQUIRES(
103           ctx,
104           (FastBoundsCheck(inner_left, int32max) &&
105            FastBoundsCheck(inner_right, int32max) &&
106            FastBoundsCheck(outer_left, int32max) &&
107            FastBoundsCheck(outer_right, int32max) &&
108            FastBoundsCheck(b->NumElements(), int32max) &&
109            FastBoundsCheck(outer_left * outer_right, int32max) &&
110            FastBoundsCheck(a_values->NumElements(), int32max)),
111           errors::InvalidArgument("Cannot use GPU for > 2^31 entry inputs"));
112       OP_REQUIRES(ctx, FastBoundsCheck(nnz * outer_right, int32max),
113                   errors::InvalidArgument(
114                       "Cannot use GPU when output.shape[1] * nnz(a) > 2^31"));
115     }
116 
117     TensorShape out_shape({outer_left, outer_right});
118     Tensor* out = nullptr;
119     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
120 
121     if (out->NumElements() == 0) {
122       // If a has shape [0, x] or b has shape [x, 0], the output shape
123       // is a 0-element matrix, so there is nothing to do.
124       return;
125     }
126 
127     if (a_values->NumElements() == 0 || b->NumElements() == 0) {
128       // If a has shape [x, 0] and b has shape [0, y], the
129       // output shape is [x, y] where x and y are non-zero, so we fill
130       // the output with zeros.
131       functor::SetZeroFunctor<Device, T> f;
132       f(ctx->eigen_device<Device>(), out->flat<T>());
133       return;
134     }
135 
136 #define MAYBE_ADJOINT(ADJ_A, ADJ_B)                                           \
137   if (adjoint_a_ == ADJ_A && adjoint_b_ == ADJ_B) {                           \
138     Status functor_status = functor::SparseTensorDenseMatMulFunctor<          \
139         Device, T, Tindices, ADJ_A,                                           \
140         ADJ_B>::Compute(ctx, out->matrix<T>(), a_indices->matrix<Tindices>(), \
141                         a_values->vec<T>(), b->matrix<T>());                  \
142     OP_REQUIRES_OK(ctx, functor_status);                                      \
143   }
144 
145     MAYBE_ADJOINT(false, false);
146     MAYBE_ADJOINT(false, true);
147     MAYBE_ADJOINT(true, false);
148     MAYBE_ADJOINT(true, true);
149 
150 #undef MAYBE_ADJOINT
151   }
152 
153  private:
154   bool adjoint_a_;
155   bool adjoint_b_;
156 };
157 
158 #define REGISTER_CPU(TypeT, TypeIndex)           \
159   REGISTER_KERNEL_BUILDER(                       \
160       Name("SparseTensorDenseMatMul")            \
161           .Device(DEVICE_CPU)                    \
162           .TypeConstraint<TypeT>("T")            \
163           .TypeConstraint<TypeIndex>("Tindices") \
164           .HostMemory("a_shape"),                \
165       SparseTensorDenseMatMulOp<CPUDevice, TypeT, TypeIndex>);
166 
167 #define REGISTER_KERNELS_CPU(T) \
168   REGISTER_CPU(T, int64_t);     \
169   REGISTER_CPU(T, int32)
170 
171 REGISTER_KERNELS_CPU(Eigen::half);
172 REGISTER_KERNELS_CPU(float);
173 REGISTER_KERNELS_CPU(double);
174 REGISTER_KERNELS_CPU(int32);
175 REGISTER_KERNELS_CPU(complex64);
176 REGISTER_KERNELS_CPU(complex128);
177 REGISTER_KERNELS_CPU(bfloat16);
178 
179 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
180 
181 namespace functor {
182 #define DECLARE_GPU_SPEC(T, Tindices, ADJ_A, ADJ_B)                         \
183   template <>                                                               \
184   Status SparseTensorDenseMatMulFunctor<                                    \
185       GPUDevice, T, Tindices, ADJ_A,                                        \
186       ADJ_B>::Compute(OpKernelContext* ctx, typename TTypes<T>::Matrix out, \
187                       TTypes<Tindices>::ConstMatrix a_indices,              \
188                       typename TTypes<T>::ConstVec a_values,                \
189                       typename TTypes<T>::ConstMatrix b);                   \
190   extern template struct SparseTensorDenseMatMulFunctor<                    \
191       GPUDevice, T, Tindices, ADJ_A, ADJ_B>;
192 
193 #define REGISTER_GPU_SPEC(T, ADJ_A, ADJ_B)  \
194   DECLARE_GPU_SPEC(T, int32, ADJ_A, ADJ_B); \
195   DECLARE_GPU_SPEC(T, int64_t, ADJ_A, ADJ_B)
196 
197 #define DECLARE_ADJOINT_GPU_SPEC(T)  \
198   REGISTER_GPU_SPEC(T, false, false) \
199   REGISTER_GPU_SPEC(T, false, true)  \
200   REGISTER_GPU_SPEC(T, true, false)  \
201   REGISTER_GPU_SPEC(T, true, true)
202 
203 DECLARE_ADJOINT_GPU_SPEC(Eigen::half);
204 DECLARE_ADJOINT_GPU_SPEC(float);
205 DECLARE_ADJOINT_GPU_SPEC(double);
206 DECLARE_ADJOINT_GPU_SPEC(complex64);
207 DECLARE_ADJOINT_GPU_SPEC(complex128);
208 
209 #undef DECLARE_ADJOINT_GPU_SPEC
210 #undef DECLARE_GPU_SPEC
211 #undef REGISTER_GPU_SPEC
212 
213 }  // namespace functor
214 
215 #define REGISTER_GPU(TypeT, TypeIndex)           \
216   REGISTER_KERNEL_BUILDER(                       \
217       Name("SparseTensorDenseMatMul")            \
218           .Device(DEVICE_GPU)                    \
219           .TypeConstraint<TypeT>("T")            \
220           .TypeConstraint<TypeIndex>("Tindices") \
221           .HostMemory("a_shape"),                \
222       SparseTensorDenseMatMulOp<GPUDevice, TypeT, TypeIndex>);
223 
224 #define REGISTER_KERNELS_GPU(T) \
225   REGISTER_GPU(T, int64_t);     \
226   REGISTER_GPU(T, int32)
227 
228 REGISTER_KERNELS_GPU(Eigen::half);
229 REGISTER_KERNELS_GPU(float);
230 REGISTER_KERNELS_GPU(double);
231 REGISTER_KERNELS_GPU(complex64);
232 REGISTER_KERNELS_GPU(complex128);
233 
234 #undef REGISTER_GPU
235 #undef REGISTER_KERNELS_GPU
236 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
237 
238 namespace functor {
239 
240 namespace {
KOutOfBoundsError(int64_t k,std::size_t i,int rhs_index_a,std::size_t lhs_right)241 Status KOutOfBoundsError(int64_t k, std::size_t i, int rhs_index_a,
242                          std::size_t lhs_right) {
243   return errors::InvalidArgument("k (", k, ") from index[", i, ",", rhs_index_a,
244                                  "] out of bounds (>=", lhs_right, ")");
245 }
246 
MOutOfBoundsError(int64_t m,std::size_t i,int lhs_index_a,int64_t out_dim0)247 Status MOutOfBoundsError(int64_t m, std::size_t i, int lhs_index_a,
248                          int64_t out_dim0) {
249   return errors::InvalidArgument("m (", m, ") from index[", i, ",", lhs_index_a,
250                                  "] out of bounds (>=", out_dim0, ")");
251 }
252 
253 template <typename T, typename Tsum, typename Tindices, bool ADJ_A, bool ADJ_B>
SparseTensorDenseMatMulImpl(typename TTypes<Tsum>::Matrix out,typename TTypes<Tindices>::ConstMatrix a_indices,typename TTypes<T>::ConstVec a_values,typename TTypes<T>::ConstMatrix b)254 Status SparseTensorDenseMatMulImpl(
255     typename TTypes<Tsum>::Matrix out,
256     typename TTypes<Tindices>::ConstMatrix a_indices,
257     typename TTypes<T>::ConstVec a_values, typename TTypes<T>::ConstMatrix b) {
258   // Vectorize certain operations above this size.
259   static constexpr std::size_t kNumVectorize = 32;
260 
261   const std::size_t nnz = a_values.size();
262   const std::size_t rhs_right = (ADJ_B ? b.dimension(0) : b.dimension(1));
263   const std::size_t lhs_right = (ADJ_B ? b.dimension(1) : b.dimension(0));
264   const int lhs_index_a = ADJ_A ? 1 : 0;
265   const int rhs_index_a = ADJ_A ? 0 : 1;
266 
267   // TODO(ebrevdo): After many failed experiments, can't find a multi-threaded
268   // approach that achieves the performance of the single threaded
269   // one.  Perhaps Eigen threadpool implementation is just too slow?
270 
271   if (rhs_right < kNumVectorize) {
272     // Disable vectorization if the RHS of output is too small
273     auto maybe_adjoint_b = MaybeAdjoint<decltype(b), ADJ_B>(b);
274 
275     for (std::size_t i = 0; i < nnz; ++i) {
276       const Tindices m = internal::SubtleMustCopy(a_indices(i, lhs_index_a));
277       const Tindices k = internal::SubtleMustCopy(a_indices(i, rhs_index_a));
278       if (!FastBoundsCheck(k, lhs_right)) {
279         return KOutOfBoundsError(k, i, rhs_index_a, lhs_right);
280       }
281       if (!FastBoundsCheck(m, out.dimension(0))) {
282         return MOutOfBoundsError(m, i, lhs_index_a, out.dimension(0));
283       }
284       const T a_value = ADJ_A ? MaybeConj(a_values(i)) : a_values(i);
285       for (std::size_t n = 0; n < rhs_right; ++n) {
286         const T b_value = maybe_adjoint_b(k, n);
287         out(m, n) += static_cast<Tsum>(a_value) * static_cast<Tsum>(b_value);
288       }
289     }
290   } else {
291     // Vectorization via Eigen.
292     const int b_chip_index = ADJ_B ? 1 : 0;
293 
294 #define LOOP_NNZ(b_passed)                                                  \
295   for (std::size_t i = 0; i < nnz; ++i) {                                   \
296     const Tindices m = internal::SubtleMustCopy(a_indices(i, lhs_index_a)); \
297     const Tindices k = internal::SubtleMustCopy(a_indices(i, rhs_index_a)); \
298     const T a_value = (ADJ_A) ? MaybeConj(a_values(i)) : a_values(i);       \
299     if (!FastBoundsCheck(k, lhs_right)) {                                   \
300       return KOutOfBoundsError(k, i, rhs_index_a, lhs_right);               \
301     }                                                                       \
302     if (!FastBoundsCheck(m, out.dimension(0))) {                            \
303       return MOutOfBoundsError(m, i, lhs_index_a, out.dimension(0));        \
304     }                                                                       \
305     out.template chip<0>(m) +=                                              \
306         b_passed.template chip<b_chip_index>(k).template cast<Tsum>() *     \
307         static_cast<Tsum>(a_value);                                         \
308   }
309 
310     if (ADJ_B) {
311       // Perform transpose and conjugation on B once, since we chip out B's
312       // columns in the nnz loop.
313       Eigen::array<int, 2> shuffle(1, 0);  // preserve dimension order
314       Eigen::Tensor<T, 2, Eigen::ColMajor> col_major_conj_b =
315           b.swap_layout().shuffle(shuffle).conjugate();
316       LOOP_NNZ(col_major_conj_b);
317     } else {
318       LOOP_NNZ(b);
319     }
320 #undef LOOP_NNZ
321   }
322   return OkStatus();
323 }
324 }  // namespace
325 
326 template <typename T, typename Tindices, bool ADJ_A, bool ADJ_B>
327 struct SparseTensorDenseMatMulFunctor<CPUDevice, T, Tindices, ADJ_A, ADJ_B> {
Computetensorflow::functor::SparseTensorDenseMatMulFunctor328   static Status Compute(OpKernelContext* ctx, typename TTypes<T>::Matrix out,
329                         typename TTypes<Tindices>::ConstMatrix a_indices,
330                         typename TTypes<T>::ConstVec a_values,
331                         typename TTypes<T>::ConstMatrix b) {
332     using Tsum = typename SumType<T>::type;
333     Tensor temp_out_t;
334     if (!std::is_same<T, Tsum>::value) {
335       TF_RETURN_IF_ERROR(ctx->allocate_temp(
336           DataTypeToEnum<Tsum>::value,
337           TensorShape({out.dimension(0), out.dimension(1)}), &temp_out_t));
338       auto temp_out = temp_out_t.matrix<Tsum>();
339       temp_out.setZero();
340       TF_RETURN_IF_ERROR(
341           SparseTensorDenseMatMulImpl<T, Tsum, Tindices, ADJ_A, ADJ_B>(
342               temp_out, a_indices, a_values, b));
343       out = temp_out.template cast<T>();
344     } else {
345       out.setZero();
346       // This reinterpret_cast is just to avoid a compilation error. The result
347       // is only used if Tsum == T.
348       auto out_workaround =
349           *reinterpret_cast<typename TTypes<Tsum>::Matrix*>(&out);
350       TF_RETURN_IF_ERROR(
351           SparseTensorDenseMatMulImpl<T, Tsum, Tindices, ADJ_A, ADJ_B>(
352               out_workaround, a_indices, a_values, b));
353     }
354     return OkStatus();
355   }
356 };
357 
358 }  // namespace functor
359 
360 }  // namespace tensorflow
361