1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_TFRT_RUNTIME_RUNTIME_H_ 16 #define TENSORFLOW_CORE_TFRT_RUNTIME_RUNTIME_H_ 17 18 #include <memory> 19 20 #include "tensorflow/core/tfrt/runtime/work_queue_interface.h" 21 #include "tfrt/host_context/resource_context.h" // from @tf_runtime 22 23 namespace tfrt { 24 class CoreRuntime; 25 class ConcurrentWorkQueue; 26 } // namespace tfrt 27 28 namespace tensorflow { 29 namespace tfrt_stub { 30 31 // This defines the runtime abstraction in tensorflow for TFRT. It is supposed 32 // to provide tensorflow specific functionalities that are implemented using 33 // TFRT. Currently, the only intended uses for this class are: 34 // 1) Creating the runtime instance with user specified dependencies (eg. 35 // thread pool). 36 // 2) Creating tensors that can be used by the runtime. 37 // 38 // It is temporary and will be replaced by the official 39 // tensorflow::experimental::cc::Runtime when it lands. 40 class Runtime { 41 public: 42 // Creates a runtime instance with specified threading configuration. Returns 43 // null upon creation error. 44 static std::unique_ptr<Runtime> Create(int num_inter_op_threads, 45 int num_intra_op_threads = 0); 46 47 // Creates a runtime instance with the specified work_queue. Returns null upon 48 // creation error. 49 static std::unique_ptr<Runtime> Create( 50 std::unique_ptr<WorkQueueInterface> work_queue); 51 52 ~Runtime(); 53 54 Runtime(Runtime&&) = default; 55 Runtime& operator=(Runtime&&) = default; 56 57 // TODO(tfrt-devs): Add methods for creating TFRT tensors. 58 59 // TODO(chky): Make this method private as it should be only used by 60 // tfrt::SavedModel. Simply making tfrt::SavedModel a friend class does not 61 // work because the it resides in a different namespace. But we should 62 // consider moving it to the same namespace. core_runtime()63 tfrt::CoreRuntime* core_runtime() const { return core_runtime_.get(); } work_queue()64 WorkQueueInterface* work_queue() const { return work_queue_; } 65 66 // `AddCreateRuntimeResourceFn` allows the client to inject per model 67 // resources that are related to system-wide concepts, such as devices, when 68 // loading a SavedModel. 69 // 70 // A longer term plan is to use a Device concept for this purpose, so that 71 // Runtime contains a vector of Devices. Since it will take some time to 72 // iterate on the Device concept and integrate with the existing 73 // `tfrt::Device` class, we use the callback function as a temporary solution. 74 // 75 // The argument `fn` should be thread-safe. AddCreateRuntimeResourceFn(std::function<void (tfrt::ResourceContext *)> fn)76 void AddCreateRuntimeResourceFn( 77 std::function<void(tfrt::ResourceContext*)> fn) { 78 runtime_resource_fns_.emplace_back(std::move(fn)); 79 } 80 81 // `CreateRuntimeResources` populates `resource_ctx` with runtime-related 82 // resources. 83 // 84 // This function is thread-safe. CreateRuntimeResources(tfrt::ResourceContext * resource_ctx)85 void CreateRuntimeResources(tfrt::ResourceContext* resource_ctx) const { 86 for (auto& fn : runtime_resource_fns_) { 87 fn(resource_ctx); 88 } 89 } 90 91 private: 92 explicit Runtime(std::unique_ptr<tfrt::CoreRuntime> core_runtime, 93 WorkQueueInterface* work_queue); 94 95 std::unique_ptr<tfrt::CoreRuntime> core_runtime_; 96 WorkQueueInterface* work_queue_ = nullptr; 97 std::vector<std::function<void(tfrt::ResourceContext*)>> 98 runtime_resource_fns_; 99 }; 100 101 } // namespace tfrt_stub 102 } // namespace tensorflow 103 104 #endif // TENSORFLOW_CORE_TFRT_RUNTIME_RUNTIME_H_ 105