1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_TFRT_RUNTIME_WORK_QUEUE_INTERFACE_H_
16 #define TENSORFLOW_CORE_TFRT_RUNTIME_WORK_QUEUE_INTERFACE_H_
17
18 #include <cstdint>
19
20 #include "tensorflow/core/platform/context.h"
21 #include "tensorflow/core/platform/statusor.h"
22 #include "tensorflow/core/platform/threadpool_interface.h"
23 #include "tensorflow/core/profiler/lib/connected_traceme.h"
24 #include "tensorflow/core/profiler/lib/traceme_encode.h"
25 #include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime
26 #include "tfrt/support/error_util.h" // from @tf_runtime
27
28 namespace tensorflow {
29 namespace tfrt_stub {
30
31 // This is an intermediate interface in tensorflow for injecting thread pool
32 // implementation into TFRT. We can add savedmodel/tensorflow specific
33 // methods (eg. create an intra op thread pool) without changing TFRT core.
34 class WorkQueueInterface : public tfrt::ConcurrentWorkQueue {
35 public:
WorkQueueInterface(int64_t id)36 explicit WorkQueueInterface(int64_t id) : id_(id) {}
37 WorkQueueInterface() = default;
38 ~WorkQueueInterface() override = 0;
39
40 // Returns per-request work queue if possible. A nullptr should be returned if
41 // the implementation does not implement the per-request work queue.
42 //
43 // TODO(b/198671794): Remove per-request concepts from the work queue
44 // interface so that the interface is more composable. Per-request logic
45 // should be handled separately.
46 ABSL_DEPRECATED("Create the instance directly instead.")
InitializeRequest(tfrt::RequestContextBuilder * request_context_builder,thread::ThreadPoolInterface ** intra_op_threadpool)47 virtual StatusOr<std::unique_ptr<WorkQueueInterface>> InitializeRequest(
48 tfrt::RequestContextBuilder* request_context_builder,
49 thread::ThreadPoolInterface** intra_op_threadpool) const {
50 *intra_op_threadpool = nullptr;
51 return {nullptr};
52 }
53
id()54 int64_t id() const { return id_; }
55
56 private:
57 int64_t id_ = 0;
58 };
59
60 inline WorkQueueInterface::~WorkQueueInterface() = default;
61
62 // Creates a WorkQueueInterface from a ConcurrentWorkQueue. The returned
63 // WorkQueueInterface simply delegates all its public methods to the specified
64 // ConcurrentWorkQueue.
65 std::unique_ptr<WorkQueueInterface> WrapDefaultWorkQueue(
66 std::unique_ptr<tfrt::ConcurrentWorkQueue> work_queue);
67
68 // Creates a WorkQueueInterface from a ConcurrentWorkQueue. The returned
69 // WorkQueueInterface simply delegates all its public methods to the specified
70 // ConcurrentWorkQueue. The `intra_thread_pool` is stored and will be passed out
71 // when `InitializeRequest()` is called.
72 std::unique_ptr<WorkQueueInterface> WrapDefaultWorkQueue(
73 std::unique_ptr<tfrt::ConcurrentWorkQueue> work_queue,
74 thread::ThreadPoolInterface* intra_thread_pool);
75
76 // A helper function that wraps tasks with traceme events.
77 template <typename Callable>
WrapWork(int64_t id,absl::string_view name,Callable && work)78 tfrt::TaskFunction WrapWork(int64_t id, absl::string_view name,
79 Callable&& work) {
80 tensorflow::Context context(tensorflow::ContextKind::kThread);
81 return tfrt::TaskFunction([id, name = std::string(name),
82 context = std::move(context),
83 work = std::forward<Callable>(work)]() mutable {
84 // From TraceMeProducer in the function that launches graph execution, eg.
85 // SavedModelImpl::Run().
86 tensorflow::profiler::TraceMeConsumer activity(
87 [&]() {
88 return tensorflow::profiler::TraceMeEncode(name, {{"id", id}});
89 },
90 tensorflow::profiler::ContextType::kTfrtExecutor, id,
91 tensorflow::profiler::TraceMeLevel::kInfo);
92 tensorflow::WithContext wc(context);
93 std::forward<Callable>(work)();
94 });
95 }
96
97 } // namespace tfrt_stub
98 } // namespace tensorflow
99
100 #endif // TENSORFLOW_CORE_TFRT_RUNTIME_WORK_QUEUE_INTERFACE_H_
101