xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/batching_util/bounded_executor_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/batching_util/bounded_executor.h"
17 
18 #include "absl/functional/bind_front.h"
19 #include "absl/time/time.h"
20 #include "tensorflow/core/lib/core/notification.h"
21 #include "tensorflow/core/platform/status_matchers.h"
22 #include "tensorflow/core/platform/statusor.h"
23 #include "tensorflow/core/platform/test.h"
24 #include "tensorflow/core/platform/threadpool_interface.h"
25 #include "tensorflow/core/protobuf/error_codes.pb.h"
26 
27 namespace tensorflow {
28 namespace serving {
29 
30 namespace {
31 // Tracks the number of concurrently running tasks.
32 class TaskTracker {
33  public:
34   // Creates a functor that invokes Run() with the given arguments.
MakeTask(int task_id,absl::Duration sleep_duration)35   std::function<void()> MakeTask(int task_id, absl::Duration sleep_duration) {
36     return absl::bind_front(&TaskTracker::Run, this, task_id, sleep_duration);
37   }
38 
39   // Updates run counts, sleeps for a short time and then returns.
40   // Exits early if fiber is cancelled.
Run(int task_id,absl::Duration sleep_duration)41   void Run(int task_id, absl::Duration sleep_duration) {
42     LOG(INFO) << "Entering task " << task_id;
43     // Update run counters.
44     {
45       mutex_lock l(mutex_);
46       ++task_count_;
47       ++running_count_;
48       if (running_count_ > max_running_count_) {
49         max_running_count_ = running_count_;
50       }
51     }
52 
53     // Use a sleep loop so we can quickly detect cancellation even when the
54     // total sleep time is very large.
55 
56     Env::Default()->SleepForMicroseconds(
57         absl::ToInt64Microseconds(sleep_duration));
58     // Update run counters.
59     {
60       mutex_lock l(mutex_);
61       --running_count_;
62     }
63     LOG(INFO) << "Task " << task_id << " exiting.";
64   }
65 
66   // Returns number of tasks that have been run.
task_count()67   int task_count() {
68     mutex_lock l(mutex_);
69     return task_count_;
70   }
71 
72   // Returns number of tasks that are currently running.
running_count()73   int running_count() {
74     mutex_lock l(mutex_);
75     return running_count_;
76   }
77 
78   // Returns the max number of tasks that have run concurrently.
max_running_count()79   int max_running_count() {
80     mutex_lock l(mutex_);
81     return max_running_count_;
82   }
83 
84  private:
85   mutex mutex_;
86   int task_count_ = 0;
87   int running_count_ = 0;
88   int max_running_count_ = 0;
89 };
90 
TEST(BoundedExecutorTest,InvalidEmptyEnv)91 TEST(BoundedExecutorTest, InvalidEmptyEnv) {
92   BoundedExecutor::Options options;
93   options.num_threads = 2;
94   options.env = nullptr;
95   EXPECT_THAT(BoundedExecutor::Create(options),
96               ::tensorflow::testing::StatusIs(
97                   error::INVALID_ARGUMENT, "options.env must not be nullptr"));
98 }
99 
TEST(BoundedExecutorTest,InvalidNumThreads)100 TEST(BoundedExecutorTest, InvalidNumThreads) {
101   {
102     BoundedExecutor::Options options;
103     options.num_threads = 0;
104     EXPECT_THAT(
105         BoundedExecutor::Create(options),
106         ::tensorflow::testing::StatusIs(
107             error::INVALID_ARGUMENT, "options.num_threads must be positive"));
108   }
109 
110   {
111     BoundedExecutor::Options options;
112     options.num_threads = -1;
113     EXPECT_THAT(
114         BoundedExecutor::Create(options),
115         ::tensorflow::testing::StatusIs(
116             error::INVALID_ARGUMENT, "options.num_threads must be positive"));
117   }
118 }
119 
TEST(BoundedExecutorTest,AddRunsFunctionsEventually)120 TEST(BoundedExecutorTest, AddRunsFunctionsEventually) {
121   BoundedExecutor::Options options;
122   options.num_threads = 2;
123   TF_ASSERT_OK_AND_ASSIGN(auto executor, BoundedExecutor::Create(options));
124 
125   Notification done0;
126   executor->Schedule([&done0] { done0.Notify(); });
127   Notification done1;
128   executor->Schedule([&done1] { done1.Notify(); });
129   done0.WaitForNotification();
130   done1.WaitForNotification();
131 
132   executor.reset();
133 }
134 
TEST(BoundedExecutorTest,MaxInflightLimit)135 TEST(BoundedExecutorTest, MaxInflightLimit) {
136   BoundedExecutor::Options options;
137   options.num_threads = 5;
138   TF_ASSERT_OK_AND_ASSIGN(auto executor, BoundedExecutor::Create(options));
139 
140   const int num_tasks = 100;
141   TaskTracker task_tracker;
142   for (int i = 0; i < num_tasks; i++) {
143     executor->Schedule(task_tracker.MakeTask(i, absl::Seconds(1)));
144   }
145   executor.reset();
146 
147   EXPECT_EQ(task_tracker.task_count(), num_tasks);
148   EXPECT_EQ(task_tracker.max_running_count(), options.num_threads);
149   EXPECT_EQ(task_tracker.running_count(), 0);
150 }
151 
152 }  // namespace
153 }  // namespace serving
154 }  // namespace tensorflow
155