1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_KERNEL_BENCHMARK_TESTLIB_H_ 17 #define TENSORFLOW_CORE_COMMON_RUNTIME_KERNEL_BENCHMARK_TESTLIB_H_ 18 19 #include <string> 20 #include <vector> 21 22 #include "tensorflow/core/common_runtime/executor.h" 23 #include "tensorflow/core/framework/tensor.h" 24 #include "tensorflow/core/graph/testlib.h" 25 #include "tensorflow/core/lib/core/threadpool.h" 26 #include "tensorflow/core/platform/macros.h" 27 #include "tensorflow/core/platform/test_benchmark.h" 28 #include "tensorflow/core/platform/types.h" 29 30 namespace tensorflow { 31 32 class Device; 33 class FunctionLibraryRuntime; 34 class ProcessFunctionLibraryRuntime; 35 struct SessionOptions; 36 class StaticDeviceMgr; 37 38 namespace test { 39 40 class Benchmark { 41 public: 42 // "device" must be either "cpu" or "gpu". Takes ownership of "g", 43 // "init", and one reference on "rendez" (if not null). 44 // 45 // old_benchmark_api: If true, the benchmark is running with older API 46 // * In the old API, the timer needs to be stopped/restarted 47 // by users. 48 // * In the new API, the timer starts automatically at the first 49 // iteration of the loop and stops after the last iteration. 50 // TODO(vyng) Remove this once we have migrated all code to newer API. 51 Benchmark(const string& device, Graph* g, 52 const SessionOptions* options = nullptr, Graph* init = nullptr, 53 Rendezvous* rendez = nullptr, const char* executor_type = "", 54 bool old_benchmark_api = false); 55 56 Benchmark(const string& device, Graph* g, bool old_benchmark_api); 57 58 ~Benchmark(); 59 60 void Run(benchmark::State& state); 61 62 void RunWithRendezvousArgs( 63 const std::vector<std::pair<string, Tensor>>& inputs, 64 const std::vector<string>& outputs, benchmark::State& state); 65 66 private: 67 thread::ThreadPool* pool_ = nullptr; // Not owned. 68 Device* device_ = nullptr; // Not owned. 69 Rendezvous* rendez_ = nullptr; 70 std::unique_ptr<StaticDeviceMgr> device_mgr_; 71 std::unique_ptr<FunctionLibraryDefinition> flib_def_; 72 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_; 73 FunctionLibraryRuntime* flr_; // Not owned. 74 std::unique_ptr<Executor> exec_; 75 76 TF_DISALLOW_COPY_AND_ASSIGN(Benchmark); 77 }; 78 79 // Returns the rendezvous key associated with the given Send/Recv node. 80 string GetRendezvousKey(const Node* node); 81 82 } // end namespace test 83 } // end namespace tensorflow 84 85 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_KERNEL_BENCHMARK_TESTLIB_H_ 86