xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/sdca_ops_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
17 #include "tensorflow/core/framework/tensor.h"
18 #include "tensorflow/core/framework/tensor_testutil.h"
19 #include "tensorflow/core/graph/node_builder.h"
20 #include "tensorflow/core/lib/random/random.h"
21 #include "tensorflow/core/platform/test.h"
22 #include "tensorflow/core/platform/test_benchmark.h"
23 #include "tensorflow/core/public/session_options.h"
24 
25 namespace tensorflow {
26 
27 namespace {
28 
GetSingleThreadedOptions()29 const SessionOptions* GetSingleThreadedOptions() {
30   static const SessionOptions* const kSessionOptions = []() {
31     SessionOptions* const result = new SessionOptions();
32     result->config.set_intra_op_parallelism_threads(1);
33     result->config.set_inter_op_parallelism_threads(1);
34     result->config.add_session_inter_op_thread_pool()->set_num_threads(1);
35     return result;
36   }();
37   return kSessionOptions;
38 }
39 
GetMultiThreadedOptions()40 const SessionOptions* GetMultiThreadedOptions() {
41   static const SessionOptions* const kSessionOptions = []() {
42     SessionOptions* const result = new SessionOptions();
43     result->config.set_intra_op_parallelism_threads(0);  // Auto-configured.
44     result->config.set_inter_op_parallelism_threads(0);  // Auto-configured.
45     result->config.add_session_inter_op_thread_pool()->set_num_threads(
46         0);  // Auto-configured.
47     return result;
48   }();
49   return kSessionOptions;
50 }
51 
Var(Graph * const g,const int n)52 Node* Var(Graph* const g, const int n) {
53   return test::graph::Var(g, DT_FLOAT, TensorShape({n}));
54 }
55 
56 // Returns a vector of size 'nodes' with each node being of size 'node_size'.
VarVector(Graph * const g,const int nodes,const int node_size)57 std::vector<Node*> VarVector(Graph* const g, const int nodes,
58                              const int node_size) {
59   std::vector<Node*> result;
60   result.reserve(nodes);
61   for (int i = 0; i < nodes; ++i) {
62     result.push_back(Var(g, node_size));
63   }
64   return result;
65 }
66 
Zeros(Graph * const g,const TensorShape & shape)67 Node* Zeros(Graph* const g, const TensorShape& shape) {
68   Tensor data(DT_FLOAT, shape);
69   data.flat<float>().setZero();
70   return test::graph::Constant(g, data);
71 }
72 
Zeros(Graph * const g,const int n)73 Node* Zeros(Graph* const g, const int n) { return Zeros(g, TensorShape({n})); }
74 
Ones(Graph * const g,const int n)75 Node* Ones(Graph* const g, const int n) {
76   Tensor data(DT_FLOAT, TensorShape({n}));
77   test::FillFn<float>(&data, [](const int i) { return 1.0f; });
78   return test::graph::Constant(g, data);
79 }
80 
SparseIndices(Graph * const g,const int sparse_features_per_group)81 Node* SparseIndices(Graph* const g, const int sparse_features_per_group) {
82   Tensor data(DT_INT64, TensorShape({sparse_features_per_group}));
83   test::FillFn<int64_t>(&data, [&](const int i) { return i; });
84   return test::graph::Constant(g, data);
85 }
86 
SparseExampleIndices(Graph * const g,const int sparse_features_per_group,const int num_examples)87 Node* SparseExampleIndices(Graph* const g, const int sparse_features_per_group,
88                            const int num_examples) {
89   const int x_size = num_examples * 4;
90   Tensor data(DT_INT64, TensorShape({x_size}));
91   test::FillFn<int64_t>(&data, [&](const int i) { return i / 4; });
92   return test::graph::Constant(g, data);
93 }
94 
SparseFeatureIndices(Graph * const g,const int sparse_features_per_group,const int num_examples)95 Node* SparseFeatureIndices(Graph* const g, const int sparse_features_per_group,
96                            const int num_examples) {
97   const int x_size = num_examples * 4;
98   Tensor data(DT_INT64, TensorShape({x_size}));
99   test::FillFn<int64_t>(
100       &data, [&](const int i) { return i % sparse_features_per_group; });
101   return test::graph::Constant(g, data);
102 }
103 
RandomZeroOrOne(Graph * const g,const int n)104 Node* RandomZeroOrOne(Graph* const g, const int n) {
105   Tensor data(DT_FLOAT, TensorShape({n}));
106   test::FillFn<float>(&data, [](const int i) {
107     // Fill with 0.0 or 1.0 at random.
108     return (random::New64() % 2) == 0 ? 0.0f : 1.0f;
109   });
110   return test::graph::Constant(g, data);
111 }
112 
RandomZeroOrOneMatrix(Graph * const g,const int n,int d)113 Node* RandomZeroOrOneMatrix(Graph* const g, const int n, int d) {
114   Tensor data(DT_FLOAT, TensorShape({n, d}));
115   test::FillFn<float>(&data, [](const int i) {
116     // Fill with 0.0 or 1.0 at random.
117     return (random::New64() % 2) == 0 ? 0.0f : 1.0f;
118   });
119   return test::graph::Constant(g, data);
120 }
121 
GetGraphs(const int32_t num_examples,const int32_t num_sparse_feature_groups,const int32_t sparse_features_per_group,const int32_t num_dense_feature_groups,const int32_t dense_features_per_group,Graph ** const init_g,Graph ** train_g)122 void GetGraphs(const int32_t num_examples,
123                const int32_t num_sparse_feature_groups,
124                const int32_t sparse_features_per_group,
125                const int32_t num_dense_feature_groups,
126                const int32_t dense_features_per_group, Graph** const init_g,
127                Graph** train_g) {
128   {
129     // Build initialization graph
130     Graph* g = new Graph(OpRegistry::Global());
131 
132     // These nodes have to be created first, and in the same way as the
133     // nodes in the graph below.
134     std::vector<Node*> sparse_weight_nodes =
135         VarVector(g, num_sparse_feature_groups, sparse_features_per_group);
136     std::vector<Node*> dense_weight_nodes =
137         VarVector(g, num_dense_feature_groups, dense_features_per_group);
138     Node* const multi_zero = Zeros(g, sparse_features_per_group);
139     for (Node* n : sparse_weight_nodes) {
140       test::graph::Assign(g, n, multi_zero);
141     }
142     Node* const zero = Zeros(g, dense_features_per_group);
143     for (Node* n : dense_weight_nodes) {
144       test::graph::Assign(g, n, zero);
145     }
146 
147     *init_g = g;
148   }
149 
150   {
151     // Build execution graph
152     Graph* g = new Graph(OpRegistry::Global());
153 
154     // These nodes have to be created first, and in the same way as the
155     // nodes in the graph above.
156     std::vector<Node*> sparse_weight_nodes =
157         VarVector(g, num_sparse_feature_groups, sparse_features_per_group);
158     std::vector<Node*> dense_weight_nodes =
159         VarVector(g, num_dense_feature_groups, dense_features_per_group);
160 
161     std::vector<NodeBuilder::NodeOut> sparse_indices;
162     std::vector<NodeBuilder::NodeOut> sparse_weights;
163     for (Node* n : sparse_weight_nodes) {
164       sparse_indices.push_back(
165           NodeBuilder::NodeOut(SparseIndices(g, sparse_features_per_group)));
166       sparse_weights.push_back(NodeBuilder::NodeOut(n));
167     }
168     std::vector<NodeBuilder::NodeOut> dense_weights;
169     dense_weights.reserve(dense_weight_nodes.size());
170     for (Node* n : dense_weight_nodes) {
171       dense_weights.push_back(NodeBuilder::NodeOut(n));
172     }
173 
174     std::vector<NodeBuilder::NodeOut> sparse_example_indices;
175     std::vector<NodeBuilder::NodeOut> sparse_feature_indices;
176     std::vector<NodeBuilder::NodeOut> sparse_values;
177     sparse_example_indices.reserve(num_sparse_feature_groups);
178     for (int i = 0; i < num_sparse_feature_groups; ++i) {
179       sparse_example_indices.push_back(NodeBuilder::NodeOut(
180           SparseExampleIndices(g, sparse_features_per_group, num_examples)));
181     }
182     sparse_feature_indices.reserve(num_sparse_feature_groups);
183     for (int i = 0; i < num_sparse_feature_groups; ++i) {
184       sparse_feature_indices.push_back(NodeBuilder::NodeOut(
185           SparseFeatureIndices(g, sparse_features_per_group, num_examples)));
186     }
187     sparse_values.reserve(num_sparse_feature_groups);
188     for (int i = 0; i < num_sparse_feature_groups; ++i) {
189       sparse_values.push_back(
190           NodeBuilder::NodeOut(RandomZeroOrOne(g, num_examples * 4)));
191     }
192 
193     std::vector<NodeBuilder::NodeOut> dense_features;
194     dense_features.reserve(num_dense_feature_groups);
195     for (int i = 0; i < num_dense_feature_groups; ++i) {
196       dense_features.push_back(NodeBuilder::NodeOut(
197           RandomZeroOrOneMatrix(g, num_examples, dense_features_per_group)));
198     }
199 
200     Node* const weights = Ones(g, num_examples);
201     Node* const labels = RandomZeroOrOne(g, num_examples);
202     Node* const example_state_data = Zeros(g, TensorShape({num_examples, 4}));
203 
204     Node* sdca = nullptr;
205     TF_CHECK_OK(
206         NodeBuilder(g->NewName("sdca"), "SdcaOptimizer")
207             .Attr("loss_type", "logistic_loss")
208             .Attr("num_sparse_features", num_sparse_feature_groups)
209             .Attr("num_sparse_features_with_values", num_sparse_feature_groups)
210             .Attr("num_dense_features", num_dense_feature_groups)
211             .Attr("l1", 0.0)
212             .Attr("l2", 1.0)
213             .Attr("num_loss_partitions", 1)
214             .Attr("num_inner_iterations", 2)
215             .Input(sparse_example_indices)
216             .Input(sparse_feature_indices)
217             .Input(sparse_values)
218             .Input(dense_features)
219             .Input(weights)
220             .Input(labels)
221             .Input(sparse_indices)
222             .Input(sparse_weights)
223             .Input(dense_weights)
224             .Input(example_state_data)
225             .Finalize(g, &sdca));
226 
227     *train_g = g;
228   }
229 }
230 
BM_SDCA(::testing::benchmark::State & state)231 void BM_SDCA(::testing::benchmark::State& state) {
232   const int num_examples = state.range(0);
233   Graph* init = nullptr;
234   Graph* train = nullptr;
235   GetGraphs(num_examples, 20 /* sparse feature groups */,
236             5 /* sparse features per group */, 1 /* dense feature groups*/,
237             20 /* dense features per group */, &init, &train);
238   test::Benchmark("cpu", train, GetSingleThreadedOptions(), init, nullptr, "",
239                   /*old_benchmark_api*/ false)
240       .Run(state);
241 }
242 
BM_SDCA_LARGE_DENSE(::testing::benchmark::State & state)243 void BM_SDCA_LARGE_DENSE(::testing::benchmark::State& state) {
244   const int num_examples = state.range(0);
245 
246   Graph* init = nullptr;
247   Graph* train = nullptr;
248   GetGraphs(num_examples, 0 /* sparse feature groups */,
249             0 /* sparse features per group */, 5 /* dense feature groups*/,
250             200000 /* dense features per group */, &init, &train);
251   test::Benchmark("cpu", train, GetSingleThreadedOptions(), init, nullptr, "",
252                   /*old_benchmark_api*/ false)
253       .Run(state);
254 }
255 
BM_SDCA_LARGE_SPARSE(::testing::benchmark::State & state)256 void BM_SDCA_LARGE_SPARSE(::testing::benchmark::State& state) {
257   const int num_examples = state.range(0);
258 
259   Graph* init = nullptr;
260   Graph* train = nullptr;
261   GetGraphs(num_examples, 65 /* sparse feature groups */,
262             1e6 /* sparse features per group */, 0 /* dense feature groups*/,
263             0 /* dense features per group */, &init, &train);
264   test::Benchmark("cpu", train, GetMultiThreadedOptions(), init, nullptr, "",
265                   /*old_benchmark_api*/ false)
266       .Run(state);
267 }
268 }  // namespace
269 
270 BENCHMARK(BM_SDCA)->Arg(128)->Arg(256)->Arg(512)->Arg(1024);
271 BENCHMARK(BM_SDCA_LARGE_DENSE)->Arg(128)->Arg(256)->Arg(512)->Arg(1024);
272 BENCHMARK(BM_SDCA_LARGE_SPARSE)->Arg(128)->Arg(256)->Arg(512)->Arg(1024);
273 
274 }  // namespace tensorflow
275