xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/linalg/banded_triangular_solve_op_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/tensor.h"
19 #include "tensorflow/core/framework/tensor_shape.h"
20 #include "tensorflow/core/framework/types.pb.h"
21 #include "tensorflow/core/graph/graph.h"
22 #include "tensorflow/core/graph/node_builder.h"
23 #include "tensorflow/core/graph/testlib.h"
24 #include "tensorflow/core/kernels/linalg/matrix_set_diag_op.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/platform/test.h"
27 #include "tensorflow/core/platform/test_benchmark.h"
28 
29 namespace tensorflow {
30 namespace {
31 
SetDiag(int num_bands,Graph * g,Node * bands,Node * triangular)32 Node* SetDiag(int num_bands, Graph* g, Node* bands, Node* triangular) {
33   Node* ret;
34   Tensor bandwidth(DT_INT32, TensorShape({2}));
35   bandwidth.flat<int32>()(0) = -(num_bands - 1);
36   bandwidth.flat<int32>()(1) = 0;
37   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "MatrixSetDiagV3")
38                   .Input(triangular)
39                   .Input(bands)
40                   .Input(test::graph::Constant(g, bandwidth))
41                   .Attr("align", "RIGHT_LEFT")
42                   .Finalize(g, &ret));
43   return ret;
44 }
45 
BandedTriangularSolve(Graph * g,Node * in0,Node * in1)46 Node* BandedTriangularSolve(Graph* g, Node* in0, Node* in1) {
47   Node* ret;
48   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BandedTriangularSolve")
49                   .Input(in0)
50                   .Input(in1)
51                   .Attr("lower", true)
52                   .Attr("adjoint", false)
53                   .Finalize(g, &ret));
54   return ret;
55 }
56 
MatrixTriangularSolve(Graph * g,Node * in0,Node * in1)57 Node* MatrixTriangularSolve(Graph* g, Node* in0, Node* in1) {
58   Node* ret;
59   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "MatrixTriangularSolve")
60                   .Input(in0)
61                   .Input(in1)
62                   .Attr("lower", true)
63                   .Attr("adjoint", false)
64                   .Finalize(g, &ret));
65   return ret;
66 }
67 
68 template <typename T>
BandedTriangularSolve(int64_t num_bands,int64_t n,int64_t m,bool use_banded_solver,DataType type)69 static Graph* BandedTriangularSolve(int64_t num_bands, int64_t n, int64_t m,
70                                     bool use_banded_solver, DataType type) {
71   Graph* g = new Graph(OpRegistry::Global());
72   Tensor in0(type, TensorShape({num_bands, n}));
73   // Set diagonal to nonzero to guarantee invertibility.
74   in0.flat<T>().setRandom();
75   in0.flat<T>() =
76       in0.flat<T>().abs() + in0.flat<T>().constant(static_cast<T>(0.5));
77   Tensor in1(type, TensorShape({n, m}));
78   in1.flat<T>().setRandom();
79   if (use_banded_solver) {
80     BandedTriangularSolve(g, test::graph::Constant(g, in0),
81                           test::graph::Constant(g, in1));
82   } else {
83     // Create a zero tensor.
84     Tensor in2(type, TensorShape({n, n}));
85     in2.flat<T>().setZero();
86     Node* triangular_matrix =
87         SetDiag(num_bands, g, test::graph::Constant(g, in0),
88                 test::graph::Constant(g, in2));
89     MatrixTriangularSolve(g, triangular_matrix, test::graph::Constant(g, in1));
90   }
91   return g;
92 }
93 
94 // Macro arguments names: --------------------------------------------------- //
95 //   K: Number of bands
96 //   N: Inner dimension of LHS, Inner dimension of RHS.
97 //   M: Outer dimensions of RHS
98 //   BS: boolean indicating whether to use the banded solver
99 //    T: C++ type of scalars (e.g. float, std::complex)
100 //   TT: TensorFlow type of scalars (e.g. DT_FLOAT, DT_COMPLEX128
101 #define BM_BandedTriangularSolveDev(K, N, M, BS, T, TT, D)              \
102   static void BM_BandedTriangularSolve##_##K##_##N##_##M##_##BS##_##TT( \
103       ::testing::benchmark::State& state) {                             \
104     test::Benchmark(#D, BandedTriangularSolve<T>(K, N, M, BS, TT),      \
105                     /*old_benchmark_api*/ false)                        \
106         .Run(state);                                                    \
107     state.SetItemsProcessed(state.iterations() * K * N + N * M);        \
108   }                                                                     \
109   BENCHMARK(BM_BandedTriangularSolve##_##K##_##N##_##M##_##BS##_##TT)   \
110       ->UseRealTime();
111 
112 #define BM_BandedTriangularSolve(K, N, M, BS, D)                \
113   BM_BandedTriangularSolveDev(K, N, M, BS, float, DT_FLOAT, D); \
114   BM_BandedTriangularSolveDev(K, N, M, BS, double, DT_DOUBLE, D);
115 
116 // Small number of bands, few rhs
117 BM_BandedTriangularSolve(2, 32, 1, true, cpu);
118 BM_BandedTriangularSolve(2, 32, 1, false, cpu);
119 BM_BandedTriangularSolve(4, 32, 1, true, cpu);
120 BM_BandedTriangularSolve(4, 32, 1, false, cpu);
121 BM_BandedTriangularSolve(8, 32, 1, true, cpu);
122 BM_BandedTriangularSolve(8, 32, 1, false, cpu);
123 BM_BandedTriangularSolve(16, 32, 1, true, cpu);
124 BM_BandedTriangularSolve(16, 32, 1, false, cpu);
125 BM_BandedTriangularSolve(2, 128, 1, true, cpu);
126 BM_BandedTriangularSolve(2, 128, 1, false, cpu);
127 BM_BandedTriangularSolve(4, 128, 1, true, cpu);
128 BM_BandedTriangularSolve(4, 128, 1, false, cpu);
129 BM_BandedTriangularSolve(8, 128, 1, true, cpu);
130 BM_BandedTriangularSolve(8, 128, 1, false, cpu);
131 BM_BandedTriangularSolve(16, 128, 1, true, cpu);
132 BM_BandedTriangularSolve(16, 128, 1, false, cpu);
133 BM_BandedTriangularSolve(2, 512, 1, true, cpu);
134 BM_BandedTriangularSolve(2, 512, 1, false, cpu);
135 BM_BandedTriangularSolve(4, 512, 1, true, cpu);
136 BM_BandedTriangularSolve(4, 512, 1, false, cpu);
137 BM_BandedTriangularSolve(8, 512, 1, true, cpu);
138 BM_BandedTriangularSolve(8, 512, 1, false, cpu);
139 BM_BandedTriangularSolve(16, 512, 1, true, cpu);
140 BM_BandedTriangularSolve(16, 512, 1, false, cpu);
141 
142 // Larger # rhs
143 BM_BandedTriangularSolve(2, 32, 32, true, cpu);
144 BM_BandedTriangularSolve(2, 32, 32, false, cpu);
145 BM_BandedTriangularSolve(4, 32, 32, true, cpu);
146 BM_BandedTriangularSolve(4, 32, 32, false, cpu);
147 BM_BandedTriangularSolve(8, 32, 32, true, cpu);
148 BM_BandedTriangularSolve(8, 32, 32, false, cpu);
149 BM_BandedTriangularSolve(16, 32, 32, true, cpu);
150 BM_BandedTriangularSolve(16, 32, 32, false, cpu);
151 BM_BandedTriangularSolve(2, 128, 128, true, cpu);
152 BM_BandedTriangularSolve(2, 128, 128, false, cpu);
153 BM_BandedTriangularSolve(4, 128, 128, true, cpu);
154 BM_BandedTriangularSolve(4, 128, 128, false, cpu);
155 BM_BandedTriangularSolve(8, 128, 128, true, cpu);
156 BM_BandedTriangularSolve(8, 128, 128, false, cpu);
157 BM_BandedTriangularSolve(16, 128, 128, true, cpu);
158 BM_BandedTriangularSolve(16, 128, 128, false, cpu);
159 BM_BandedTriangularSolve(2, 512, 512, true, cpu);
160 BM_BandedTriangularSolve(2, 512, 512, false, cpu);
161 BM_BandedTriangularSolve(4, 512, 512, true, cpu);
162 BM_BandedTriangularSolve(4, 512, 512, false, cpu);
163 BM_BandedTriangularSolve(8, 512, 512, true, cpu);
164 BM_BandedTriangularSolve(8, 512, 512, false, cpu);
165 BM_BandedTriangularSolve(16, 512, 512, true, cpu);
166 BM_BandedTriangularSolve(16, 512, 512, false, cpu);
167 
168 BM_BandedTriangularSolve(2, 2048, 2048, true, cpu);
169 BM_BandedTriangularSolve(2, 2048, 2048, false, cpu);
170 BM_BandedTriangularSolve(4, 2048, 2048, true, cpu);
171 BM_BandedTriangularSolve(4, 2048, 2048, false, cpu);
172 BM_BandedTriangularSolve(8, 2048, 2048, true, cpu);
173 BM_BandedTriangularSolve(8, 2048, 2048, false, cpu);
174 BM_BandedTriangularSolve(16, 2048, 2048, true, cpu);
175 BM_BandedTriangularSolve(16, 2048, 2048, false, cpu);
176 BM_BandedTriangularSolve(32, 2048, 2048, true, cpu);
177 BM_BandedTriangularSolve(32, 2048, 2048, false, cpu);
178 BM_BandedTriangularSolve(64, 2048, 2048, true, cpu);
179 BM_BandedTriangularSolve(64, 2048, 2048, false, cpu);
180 
181 }  // namespace
182 }  // namespace tensorflow
183