1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/batching_util/bounded_executor.h"
17
18 #include "absl/functional/bind_front.h"
19 #include "absl/time/time.h"
20 #include "tensorflow/core/lib/core/notification.h"
21 #include "tensorflow/core/platform/status_matchers.h"
22 #include "tensorflow/core/platform/statusor.h"
23 #include "tensorflow/core/platform/test.h"
24 #include "tensorflow/core/platform/threadpool_interface.h"
25 #include "tensorflow/core/protobuf/error_codes.pb.h"
26
27 namespace tensorflow {
28 namespace serving {
29
30 namespace {
31 // Tracks the number of concurrently running tasks.
32 class TaskTracker {
33 public:
34 // Creates a functor that invokes Run() with the given arguments.
MakeTask(int task_id,absl::Duration sleep_duration)35 std::function<void()> MakeTask(int task_id, absl::Duration sleep_duration) {
36 return absl::bind_front(&TaskTracker::Run, this, task_id, sleep_duration);
37 }
38
39 // Updates run counts, sleeps for a short time and then returns.
40 // Exits early if fiber is cancelled.
Run(int task_id,absl::Duration sleep_duration)41 void Run(int task_id, absl::Duration sleep_duration) {
42 LOG(INFO) << "Entering task " << task_id;
43 // Update run counters.
44 {
45 mutex_lock l(mutex_);
46 ++task_count_;
47 ++running_count_;
48 if (running_count_ > max_running_count_) {
49 max_running_count_ = running_count_;
50 }
51 }
52
53 // Use a sleep loop so we can quickly detect cancellation even when the
54 // total sleep time is very large.
55
56 Env::Default()->SleepForMicroseconds(
57 absl::ToInt64Microseconds(sleep_duration));
58 // Update run counters.
59 {
60 mutex_lock l(mutex_);
61 --running_count_;
62 }
63 LOG(INFO) << "Task " << task_id << " exiting.";
64 }
65
66 // Returns number of tasks that have been run.
task_count()67 int task_count() {
68 mutex_lock l(mutex_);
69 return task_count_;
70 }
71
72 // Returns number of tasks that are currently running.
running_count()73 int running_count() {
74 mutex_lock l(mutex_);
75 return running_count_;
76 }
77
78 // Returns the max number of tasks that have run concurrently.
max_running_count()79 int max_running_count() {
80 mutex_lock l(mutex_);
81 return max_running_count_;
82 }
83
84 private:
85 mutex mutex_;
86 int task_count_ = 0;
87 int running_count_ = 0;
88 int max_running_count_ = 0;
89 };
90
TEST(BoundedExecutorTest,InvalidEmptyEnv)91 TEST(BoundedExecutorTest, InvalidEmptyEnv) {
92 BoundedExecutor::Options options;
93 options.num_threads = 2;
94 options.env = nullptr;
95 EXPECT_THAT(BoundedExecutor::Create(options),
96 ::tensorflow::testing::StatusIs(
97 error::INVALID_ARGUMENT, "options.env must not be nullptr"));
98 }
99
TEST(BoundedExecutorTest,InvalidNumThreads)100 TEST(BoundedExecutorTest, InvalidNumThreads) {
101 {
102 BoundedExecutor::Options options;
103 options.num_threads = 0;
104 EXPECT_THAT(
105 BoundedExecutor::Create(options),
106 ::tensorflow::testing::StatusIs(
107 error::INVALID_ARGUMENT, "options.num_threads must be positive"));
108 }
109
110 {
111 BoundedExecutor::Options options;
112 options.num_threads = -1;
113 EXPECT_THAT(
114 BoundedExecutor::Create(options),
115 ::tensorflow::testing::StatusIs(
116 error::INVALID_ARGUMENT, "options.num_threads must be positive"));
117 }
118 }
119
TEST(BoundedExecutorTest,AddRunsFunctionsEventually)120 TEST(BoundedExecutorTest, AddRunsFunctionsEventually) {
121 BoundedExecutor::Options options;
122 options.num_threads = 2;
123 TF_ASSERT_OK_AND_ASSIGN(auto executor, BoundedExecutor::Create(options));
124
125 Notification done0;
126 executor->Schedule([&done0] { done0.Notify(); });
127 Notification done1;
128 executor->Schedule([&done1] { done1.Notify(); });
129 done0.WaitForNotification();
130 done1.WaitForNotification();
131
132 executor.reset();
133 }
134
TEST(BoundedExecutorTest,MaxInflightLimit)135 TEST(BoundedExecutorTest, MaxInflightLimit) {
136 BoundedExecutor::Options options;
137 options.num_threads = 5;
138 TF_ASSERT_OK_AND_ASSIGN(auto executor, BoundedExecutor::Create(options));
139
140 const int num_tasks = 100;
141 TaskTracker task_tracker;
142 for (int i = 0; i < num_tasks; i++) {
143 executor->Schedule(task_tracker.MakeTask(i, absl::Seconds(1)));
144 }
145 executor.reset();
146
147 EXPECT_EQ(task_tracker.task_count(), num_tasks);
148 EXPECT_EQ(task_tracker.max_running_count(), options.num_threads);
149 EXPECT_EQ(task_tracker.running_count(), 0);
150 }
151
152 } // namespace
153 } // namespace serving
154 } // namespace tensorflow
155