1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/extension/threadpool/threadpool.h>
10
11 #include <mutex>
12 #include <numeric>
13 #include <random>
14
15 #include <executorch/extension/threadpool/threadpool_guard.h>
16
17 #include <gtest/gtest.h>
18
19 using namespace ::testing;
20
21 namespace {
22
div_round_up(const size_t divident,const size_t divisor)23 size_t div_round_up(const size_t divident, const size_t divisor) {
24 return (divident + divisor - 1) / divisor;
25 }
26
resize_and_fill_vector(std::vector<int32_t> & a,const size_t size)27 void resize_and_fill_vector(std::vector<int32_t>& a, const size_t size) {
28 std::random_device rd;
29 std::mt19937 gen(rd());
30 std::uniform_int_distribution<> distrib(1, size * 2);
31 a.resize(size);
32 auto generator = [&distrib, &gen]() { return distrib(gen); };
33 std::generate(a.begin(), a.end(), generator);
34 }
35
generate_add_test_inputs(std::vector<int32_t> & a,std::vector<int32_t> & b,std::vector<int32_t> & c_ref,std::vector<int32_t> & c,size_t vector_size)36 void generate_add_test_inputs(
37 std::vector<int32_t>& a,
38 std::vector<int32_t>& b,
39 std::vector<int32_t>& c_ref,
40 std::vector<int32_t>& c,
41 size_t vector_size) {
42 resize_and_fill_vector(a, vector_size);
43 resize_and_fill_vector(b, vector_size);
44 resize_and_fill_vector(c, vector_size);
45 resize_and_fill_vector(c_ref, vector_size);
46 for (size_t i = 0, size = a.size(); i < size; ++i) {
47 c_ref[i] = a[i] + b[i];
48 }
49 }
50
generate_reduce_test_inputs(std::vector<int32_t> & a,int32_t & c_ref,size_t vector_size)51 void generate_reduce_test_inputs(
52 std::vector<int32_t>& a,
53 int32_t& c_ref,
54 size_t vector_size) {
55 resize_and_fill_vector(a, vector_size);
56 c_ref = 0;
57 for (size_t i = 0, size = a.size(); i < size; ++i) {
58 c_ref += a[i];
59 }
60 }
61
run_lambda_with_size(std::function<void (size_t)> f,size_t range,size_t grain_size)62 void run_lambda_with_size(
63 std::function<void(size_t)> f,
64 size_t range,
65 size_t grain_size) {
66 size_t num_grains = div_round_up(range, grain_size);
67
68 auto threadpool = ::executorch::extension::threadpool::get_threadpool();
69 threadpool->run(f, range);
70 }
71 } // namespace
72
TEST(ThreadPoolTest,ParallelAdd)73 TEST(ThreadPoolTest, ParallelAdd) {
74 std::vector<int32_t> a, b, c, c_ref;
75 size_t vector_size = 100;
76 size_t grain_size = 10;
77
78 auto add_lambda = [&](size_t i) {
79 size_t start_index = i * grain_size;
80 size_t end_index = start_index + grain_size;
81 end_index = std::min(end_index, vector_size);
82 for (size_t j = start_index; j < end_index; ++j) {
83 c[j] = a[j] + b[j];
84 }
85 };
86
87 auto threadpool = ::executorch::extension::threadpool::get_threadpool();
88 EXPECT_GT(threadpool->get_thread_count(), 1);
89
90 generate_add_test_inputs(a, b, c_ref, c, vector_size);
91 run_lambda_with_size(add_lambda, vector_size, grain_size);
92 EXPECT_EQ(c, c_ref);
93
94 // Try smaller grain size
95 grain_size = 5;
96 generate_add_test_inputs(a, b, c_ref, c, vector_size);
97 run_lambda_with_size(add_lambda, vector_size, grain_size);
98 EXPECT_EQ(c, c_ref);
99
100 vector_size = 7;
101 generate_add_test_inputs(a, b, c_ref, c, vector_size);
102 run_lambda_with_size(add_lambda, vector_size, grain_size);
103 EXPECT_EQ(c, c_ref);
104
105 vector_size = 7;
106 grain_size = 5;
107 generate_add_test_inputs(a, b, c_ref, c, vector_size);
108 run_lambda_with_size(add_lambda, vector_size, grain_size);
109 EXPECT_EQ(c, c_ref);
110 }
111
112 // Test parallel reduction where we acquire lock within lambda
TEST(ThreadPoolTest,ParallelReduce)113 TEST(ThreadPoolTest, ParallelReduce) {
114 std::vector<int32_t> a;
115 int32_t c = 0, c_ref = 0;
116 size_t vector_size = 100;
117 size_t grain_size = 11;
118 std::mutex m;
119
120 auto reduce_lambda = [&](size_t i) {
121 size_t start_index = i * grain_size;
122 size_t end_index = start_index + grain_size;
123 end_index = std::min(end_index, vector_size);
124 std::lock_guard<std::mutex> lock(m);
125 for (size_t j = start_index; j < end_index; ++j) {
126 c += a[j];
127 }
128 };
129
130 auto threadpool = ::executorch::extension::threadpool::get_threadpool();
131 EXPECT_GT(threadpool->get_thread_count(), 1);
132
133 generate_reduce_test_inputs(a, c_ref, vector_size);
134 run_lambda_with_size(reduce_lambda, vector_size, grain_size);
135 EXPECT_EQ(c, c_ref);
136
137 vector_size = 7;
138 c = c_ref = 0;
139 generate_reduce_test_inputs(a, c_ref, vector_size);
140 run_lambda_with_size(reduce_lambda, vector_size, grain_size);
141 EXPECT_EQ(c, c_ref);
142 }
143
144 // Copied from
145 // caffe2/aten/src/ATen/test/test_thread_pool_guard.cp
TEST(TestNoThreadPoolGuard,TestThreadPoolGuard)146 TEST(TestNoThreadPoolGuard, TestThreadPoolGuard) {
147 auto threadpool_ptr = ::executorch::extension::threadpool::get_pthreadpool();
148
149 ASSERT_NE(threadpool_ptr, nullptr);
150 {
151 ::executorch::extension::threadpool::NoThreadPoolGuard g1;
152 auto threadpool_ptr1 =
153 ::executorch::extension::threadpool::get_pthreadpool();
154 ASSERT_EQ(threadpool_ptr1, nullptr);
155
156 {
157 ::executorch::extension::threadpool::NoThreadPoolGuard g2;
158 auto threadpool_ptr2 =
159 ::executorch::extension::threadpool::get_pthreadpool();
160 ASSERT_EQ(threadpool_ptr2, nullptr);
161 }
162
163 // Guard should restore prev value (nullptr)
164 auto threadpool_ptr3 =
165 ::executorch::extension::threadpool::get_pthreadpool();
166 ASSERT_EQ(threadpool_ptr3, nullptr);
167 }
168
169 // Guard should restore prev value (pthreadpool_)
170 auto threadpool_ptr4 = ::executorch::extension::threadpool::get_pthreadpool();
171 ASSERT_NE(threadpool_ptr4, nullptr);
172 ASSERT_EQ(threadpool_ptr4, threadpool_ptr);
173 }
174
TEST(TestNoThreadPoolGuard,TestRunWithGuard)175 TEST(TestNoThreadPoolGuard, TestRunWithGuard) {
176 const std::vector<int64_t> array = {1, 2, 3};
177
178 auto pool = ::executorch::extension::threadpool::get_threadpool();
179 int64_t inner = 0;
180 {
181 // Run on same thread
182 ::executorch::extension::threadpool::NoThreadPoolGuard g1;
183 auto fn = [&array, &inner](const size_t task_id) {
184 inner += array[task_id];
185 };
186 pool->run(fn, 3);
187
188 // confirm the guard is on
189 auto threadpool_ptr =
190 ::executorch::extension::threadpool::get_pthreadpool();
191 ASSERT_EQ(threadpool_ptr, nullptr);
192 }
193 ASSERT_EQ(inner, 6);
194 }
195