xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/kernel_benchmark_testlib.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_KERNEL_BENCHMARK_TESTLIB_H_
17 #define TENSORFLOW_CORE_COMMON_RUNTIME_KERNEL_BENCHMARK_TESTLIB_H_
18 
19 #include <string>
20 #include <vector>
21 
22 #include "tensorflow/core/common_runtime/executor.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/graph/testlib.h"
25 #include "tensorflow/core/lib/core/threadpool.h"
26 #include "tensorflow/core/platform/macros.h"
27 #include "tensorflow/core/platform/test_benchmark.h"
28 #include "tensorflow/core/platform/types.h"
29 
30 namespace tensorflow {
31 
32 class Device;
33 class FunctionLibraryRuntime;
34 class ProcessFunctionLibraryRuntime;
35 struct SessionOptions;
36 class StaticDeviceMgr;
37 
38 namespace test {
39 
40 class Benchmark {
41  public:
42   // "device" must be either "cpu" or "gpu".  Takes ownership of "g",
43   // "init", and one reference on "rendez" (if not null).
44   //
45   // old_benchmark_api: If true, the benchmark is running with older API
46   //   * In the old API, the timer needs to be stopped/restarted
47   //     by users.
48   //   * In the new API, the timer starts automatically at the first
49   //     iteration of the loop and stops after the last iteration.
50   // TODO(vyng) Remove this once we have migrated all code to newer API.
51   Benchmark(const string& device, Graph* g,
52             const SessionOptions* options = nullptr, Graph* init = nullptr,
53             Rendezvous* rendez = nullptr, const char* executor_type = "",
54             bool old_benchmark_api = false);
55 
56   Benchmark(const string& device, Graph* g, bool old_benchmark_api);
57 
58   ~Benchmark();
59 
60   void Run(benchmark::State& state);
61 
62   void RunWithRendezvousArgs(
63       const std::vector<std::pair<string, Tensor>>& inputs,
64       const std::vector<string>& outputs, benchmark::State& state);
65 
66  private:
67   thread::ThreadPool* pool_ = nullptr;  // Not owned.
68   Device* device_ = nullptr;            // Not owned.
69   Rendezvous* rendez_ = nullptr;
70   std::unique_ptr<StaticDeviceMgr> device_mgr_;
71   std::unique_ptr<FunctionLibraryDefinition> flib_def_;
72   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
73   FunctionLibraryRuntime* flr_;  // Not owned.
74   std::unique_ptr<Executor> exec_;
75 
76   TF_DISALLOW_COPY_AND_ASSIGN(Benchmark);
77 };
78 
79 // Returns the rendezvous key associated with the given Send/Recv node.
80 string GetRendezvousKey(const Node* node);
81 
82 }  // end namespace test
83 }  // end namespace tensorflow
84 
85 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_KERNEL_BENCHMARK_TESTLIB_H_
86