xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tfrt/runtime/runtime.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_TFRT_RUNTIME_RUNTIME_H_
16 #define TENSORFLOW_CORE_TFRT_RUNTIME_RUNTIME_H_
17 
18 #include <memory>
19 
20 #include "tensorflow/core/tfrt/runtime/work_queue_interface.h"
21 #include "tfrt/host_context/resource_context.h"  // from @tf_runtime
22 
23 namespace tfrt {
24 class CoreRuntime;
25 class ConcurrentWorkQueue;
26 }  // namespace tfrt
27 
28 namespace tensorflow {
29 namespace tfrt_stub {
30 
31 // This defines the runtime abstraction in tensorflow for TFRT. It is supposed
32 // to provide tensorflow specific functionalities that are implemented using
33 // TFRT. Currently, the only intended uses for this class are:
34 //  1) Creating the runtime instance with user specified dependencies (eg.
35 //  thread pool).
36 //  2) Creating tensors that can be used by the runtime.
37 //
38 // It is temporary and will be replaced by the official
39 // tensorflow::experimental::cc::Runtime when it lands.
40 class Runtime {
41  public:
42   // Creates a runtime instance with specified threading configuration. Returns
43   // null upon creation error.
44   static std::unique_ptr<Runtime> Create(int num_inter_op_threads,
45                                          int num_intra_op_threads = 0);
46 
47   // Creates a runtime instance with the specified work_queue. Returns null upon
48   // creation error.
49   static std::unique_ptr<Runtime> Create(
50       std::unique_ptr<WorkQueueInterface> work_queue);
51 
52   ~Runtime();
53 
54   Runtime(Runtime&&) = default;
55   Runtime& operator=(Runtime&&) = default;
56 
57   // TODO(tfrt-devs): Add methods for creating TFRT tensors.
58 
59   // TODO(chky): Make this method private as it should be only used by
60   // tfrt::SavedModel. Simply making tfrt::SavedModel a friend class does not
61   // work because the it resides in a different namespace. But we should
62   // consider moving it to the same namespace.
core_runtime()63   tfrt::CoreRuntime* core_runtime() const { return core_runtime_.get(); }
work_queue()64   WorkQueueInterface* work_queue() const { return work_queue_; }
65 
66   // `AddCreateRuntimeResourceFn` allows the client to inject per model
67   // resources that are related to system-wide concepts, such as devices, when
68   // loading a SavedModel.
69   //
70   // A longer term plan is to use a Device concept for this purpose, so that
71   // Runtime contains a vector of Devices. Since it will take some time to
72   // iterate on the Device concept and integrate with the existing
73   // `tfrt::Device` class, we use the callback function as a temporary solution.
74   //
75   // The argument `fn` should be thread-safe.
AddCreateRuntimeResourceFn(std::function<void (tfrt::ResourceContext *)> fn)76   void AddCreateRuntimeResourceFn(
77       std::function<void(tfrt::ResourceContext*)> fn) {
78     runtime_resource_fns_.emplace_back(std::move(fn));
79   }
80 
81   // `CreateRuntimeResources` populates `resource_ctx` with runtime-related
82   // resources.
83   //
84   // This function is thread-safe.
CreateRuntimeResources(tfrt::ResourceContext * resource_ctx)85   void CreateRuntimeResources(tfrt::ResourceContext* resource_ctx) const {
86     for (auto& fn : runtime_resource_fns_) {
87       fn(resource_ctx);
88     }
89   }
90 
91  private:
92   explicit Runtime(std::unique_ptr<tfrt::CoreRuntime> core_runtime,
93                    WorkQueueInterface* work_queue);
94 
95   std::unique_ptr<tfrt::CoreRuntime> core_runtime_;
96   WorkQueueInterface* work_queue_ = nullptr;
97   std::vector<std::function<void(tfrt::ResourceContext*)>>
98       runtime_resource_fns_;
99 };
100 
101 }  // namespace tfrt_stub
102 }  // namespace tensorflow
103 
104 #endif  // TENSORFLOW_CORE_TFRT_RUNTIME_RUNTIME_H_
105