xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/client/lib/matrix.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATRIX_H_
17 #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATRIX_H_
18 
19 #include <array>
20 #include <optional>
21 #include <vector>
22 
23 #include "absl/strings/string_view.h"
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 #include "tensorflow/compiler/xla/types.h"
28 #include "tensorflow/compiler/xla/xla_data.pb.h"
29 
30 namespace xla {
31 
32 // Returns an m x n matrix with 1s on the diagonal elements, zeros everywhere
33 // else.
34 XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64_t m,
35                      int64_t n);
36 
37 // Returns a mask where the 'diagonal'-th diagonal is true and everything else
38 // is false.
39 XlaOp GetDiagonalMask(XlaOp x, int diagonal = 0);
40 
41 // Get the diagonals of the last two dimensions. Use k>0 for diagonals above the
42 // main diagonal, and k<0 for diagonals below the main diagonal.
43 //
44 // If 'x' has shape [..., M, N]
45 //  If k >= 0: then the output has shape [..., min(M, N - k)], containing the
46 //            diagonal elements (i.e., with indices [..., i, i + k]).
47 //  If k < 0: then the output has shape [..., min(M + k, N)], containing the
48 //            diagonal elements (i.e., with indices [..., i - k, i]).
49 XlaOp GetMatrixDiagonal(XlaOp x, int k = 0);
50 XlaOp GetMatrixDiagonalViaGather(XlaOp x, int k = 0);
51 
52 // Places diag along the kth diagonal of target.
53 XlaOp SetMatrixDiagonal(XlaOp matrix, XlaOp diag, int k = 0);
54 
55 // Returns a lower-triangular mask, i.e., true below and including the
56 // `diagonal`-th diagonal and false above that diagonal.
57 XlaOp TriangleMask(XlaOp x, int diagonal);
58 
59 // Get the upper or lower triangle part of the last two dimensions
60 XlaOp Triangle(XlaOp x, bool lower);
61 
62 // Get the upper triangle part of the last two dimensions
63 XlaOp UpperTriangle(XlaOp x);
64 
65 // Get the lower triangle part of the last two dimensions
66 XlaOp LowerTriangle(XlaOp x);
67 
68 // If x is an array of shape [..., n, n], symmetrizes the matrix by replacing
69 // the upper triangle with the transpose of the lower triangle (if lower is
70 // True, vice-versa otherwise). If the type of `x` is complex, makes the matrix
71 // Hermitian by taking the conjugate of the complex part and setting the
72 // complex diagonal to zero.
73 XlaOp Symmetrize(XlaOp x, bool lower);
74 
75 // Multiplies slices of two tensors in batches.
76 
77 // Multiplies all slices of `Tensor` `x` and `y` (each slice can be
78 // viewed as an element of a batch), and arranges the individual results
79 // in a single output tensor of the same batch size.
80 //
81 // The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]`
82 // and `[..., r_y, c_y]`.
83 //
84 // The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where:
85 //
86 //     r_o = c_x if transpose_x else r_x
87 //     c_o = r_y if transpose_y else c_y
88 //
89 // It is computed as:
90 //
91 //     output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :])
92 xla::XlaOp BatchDot(
93     xla::XlaOp x, xla::XlaOp y,
94     xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT,
95     std::optional<PrimitiveType> preferred_element_type = std::nullopt);
96 xla::XlaOp BatchDot(
97     xla::XlaOp x, bool transpose_x, xla::XlaOp y, bool transpose_y,
98     xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT,
99     std::optional<PrimitiveType> preferred_element_type = std::nullopt);
100 
101 // Parse an einsum string into dimension numbers:
102 //   "ab,cb->ac"
103 // becomes:
104 //   {{0, 1},{2, 1},{0, 2}}
105 //
106 // Each occurrence of ellipsis ("...") occurring in the input is replaced with
107 // the same numeric dimensions. The number of such dimensions is inferred from
108 // x_rank and y_rank. For example:
109 //   einsum_config: "...ab,...bcd->...acd"
110 //   x_rank: 4
111 //   y_rank: 5
112 // becomes:
113 //   {{0, 1, 2, 3},{0, 1, 3, 4, 5},{0, 1, 2, 4, 5}}
114 //
115 // NOTE: This function is meant for testing, there is no need to call it
116 // directly.
117 
118 StatusOr<std::array<std::vector<int64_t>, 3>> ParseEinsumString(
119     absl::string_view einsum_config, int64_t x_rank, int64_t y_rank);
120 
121 // If an einsum config does not contain an -> one will be added and the output
122 // config will be the sorted characters with any ellipsis at the beginning.
123 // Returns an empty string if the einsum string already has an ->.
124 std::string NormalizeEinsumString(absl::string_view einsum_config);
125 
126 // Supports two operand einsum notation like "ab,cb->ac".
127 xla::XlaOp Einsum(
128     xla::XlaOp x, xla::XlaOp y, absl::string_view einsum_config,
129     xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT,
130     std::optional<PrimitiveType> preferred_element_type = std::nullopt);
131 xla::XlaOp Einsum(
132     xla::XlaOp x, absl::string_view einsum_config,
133     xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT);
134 
135 
136 // Same as above but supporting numeric labels on dimensions. So "ab,cb->ac"
137 // becomes:
138 //   x_config = {0, 1}
139 //   y_config = {2, 1}
140 //   output_config = {0, 2}
141 xla::XlaOp Einsum(
142     xla::XlaOp x, absl::Span<const int64_t> x_config, xla::XlaOp y,
143     absl::Span<const int64_t> y_config, absl::Span<const int64_t> output_config,
144     xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT,
145     std::optional<PrimitiveType> preferred_element_type = std::nullopt);
146 
147 // Transposes a stack of matrices `x` by swapping the last two dimensions.
148 xla::XlaOp TransposeInMinorDims(xla::XlaOp x);
149 
150 // Transposes `x` in its minor dimensions if `transpose` is true, otherwise
151 // returns `x` unchanged.
152 xla::XlaOp MaybeTransposeInMinorDims(xla::XlaOp x, bool transpose);
153 
154 }  // namespace xla
155 
156 #endif  // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATRIX_H_
157