1 #include <array>
2 #include <atomic>
3 #include <condition_variable>
4 #include <iostream>
5 #include <memory>
6 #include <random>
7 #include <utility>
8 #include <vector>
9
10 #include <fmt/format.h>
11 #include <gtest/gtest.h>
12
13 #include <ATen/Parallel.h>
14 #include <ATen/record_function.h>
15 #include <c10/util/irange.h>
16
17 // Test that we can add and remove callbacks (both global and thread local.)
TEST(RecordFunctionTest,AddRemove)18 TEST(RecordFunctionTest, AddRemove) {
19 at::clearCallbacks();
20 ASSERT_FALSE(at::hasCallbacks());
21
22 auto start_callback =
23 [](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
24 return nullptr;
25 };
26 auto end_callback = [](const at::RecordFunction& fn, at::ObserverContext*) {};
27
28 auto handle = at::addThreadLocalCallback(
29 at::RecordFunctionCallback(start_callback, end_callback));
30
31 ASSERT_TRUE(at::hasCallbacks());
32 ASSERT_TRUE(at::hasThreadLocalCallbacks());
33 ASSERT_FALSE(at::hasGlobalCallbacks());
34
35 at::removeCallback(handle);
36 ASSERT_FALSE(at::hasCallbacks());
37
38 handle = at::addGlobalCallback(
39 at::RecordFunctionCallback(start_callback, end_callback));
40
41 ASSERT_TRUE(at::hasCallbacks());
42 ASSERT_FALSE(at::hasThreadLocalCallbacks());
43 ASSERT_TRUE(at::hasGlobalCallbacks());
44
45 at::removeCallback(handle);
46 ASSERT_FALSE(at::hasCallbacks());
47 }
48
49 // Test that the callbacks that we register are actually run.
TEST(RecordFunctionTest,ThreadLocalState)50 TEST(RecordFunctionTest, ThreadLocalState) {
51 at::clearCallbacks();
52 ASSERT_FALSE(at::hasCallbacks());
53
54 static int tls_test_start_counter;
55 static int tls_test_end_counter;
56 tls_test_start_counter = 0;
57 tls_test_end_counter = 0;
58
59 auto start_callback =
60 [](const at::RecordFunction&) -> std::unique_ptr<at::ObserverContext> {
61 ++tls_test_start_counter;
62 return nullptr;
63 };
64 auto end_callback = [](const at::RecordFunction&, at::ObserverContext*) {
65 ++tls_test_end_counter;
66 };
67
68 auto handle = at::addThreadLocalCallback(
69 at::RecordFunctionCallback(start_callback, end_callback));
70
71 {
72 at::RecordFunction guard(at::RecordScope::USER_SCOPE);
73 guard.before("Test");
74 EXPECT_EQ(tls_test_start_counter, 1);
75 EXPECT_EQ(tls_test_end_counter, 0);
76 }
77 EXPECT_EQ(tls_test_start_counter, 1);
78 EXPECT_EQ(tls_test_end_counter, 1);
79
80 {
81 tls_test_start_counter = 0;
82 tls_test_end_counter = 0;
83 at::DisableRecordFunctionGuard no_profile_guard;
84 at::RecordFunction guard(at::RecordScope::USER_SCOPE);
85 guard.before("Test");
86 EXPECT_EQ(tls_test_start_counter, 0);
87 EXPECT_EQ(tls_test_end_counter, 0);
88 }
89 EXPECT_EQ(tls_test_start_counter, 0);
90 EXPECT_EQ(tls_test_end_counter, 0);
91
92 {
93 tls_test_start_counter = 0;
94 tls_test_end_counter = 0;
95 RECORD_FUNCTION("Test", {});
96 EXPECT_EQ(tls_test_start_counter, 1);
97 EXPECT_EQ(tls_test_end_counter, 0);
98 }
99 EXPECT_EQ(tls_test_start_counter, 1);
100 EXPECT_EQ(tls_test_end_counter, 1);
101
102 at::removeCallback(handle);
103 ASSERT_FALSE(at::hasCallbacks());
104 }
105
106 // Test that callbacks are run in the order that they are registered.
TEST(RecordFunctionTest,CallOrder)107 TEST(RecordFunctionTest, CallOrder) {
108 at::clearCallbacks();
109 ASSERT_FALSE(at::hasCallbacks());
110
111 static int current_index;
112 current_index = 0;
113
114 static std::array<std::string, 8> expected_order = {
115 "Start Callback 0 Outer",
116 "Start Callback 1 Outer",
117 "Start Callback 0 Inner",
118 "Start Callback 1 Inner",
119 "End Callback 0 Inner",
120 "End Callback 1 Inner",
121 "End Callback 0 Outer",
122 "End Callback 1 Outer",
123 };
124
125 #define REGISTER_CALLBACK(index) \
126 at::addThreadLocalCallback( \
127 at::RecordFunctionCallback( \
128 [](const at::RecordFunction& fn) \
129 -> std::unique_ptr<at::ObserverContext> { \
130 EXPECT_EQ( \
131 fmt::format("Start Callback {} {}", index, fn.name()), \
132 expected_order[current_index++]); \
133 return nullptr; \
134 }, \
135 [](const at::RecordFunction& fn, at::ObserverContext*) { \
136 EXPECT_EQ( \
137 fmt::format("End Callback {} {}", index, fn.name()), \
138 expected_order[current_index++]); \
139 }) \
140 .scopes({at::RecordScope::FUNCTION}))
141
142 REGISTER_CALLBACK(0);
143 REGISTER_CALLBACK(1);
144 #undef REGISTER_CALLBACK
145
146 RECORD_FUNCTION("Outer", {});
147 { RECORD_FUNCTION("Inner", {}); }
148
149 at::clearCallbacks();
150 ASSERT_FALSE(at::hasCallbacks());
151 }
152
153 // Make sure TLS migrates when tasks are launched.
TEST(RecordFunctionTest,ThreadMigration)154 TEST(RecordFunctionTest, ThreadMigration) {
155 at::clearCallbacks();
156 ASSERT_FALSE(at::hasCallbacks());
157
158 static int call_count;
159 call_count = 0;
160
161 auto handle = at::addThreadLocalCallback(
162 at::RecordFunctionCallback(
163 [](const at::RecordFunction&)
164 -> std::unique_ptr<at::ObserverContext> { return nullptr; },
165 [](const at::RecordFunction&, at::ObserverContext*) { ++call_count; })
166 .scopes({at::RecordScope::FUNCTION}));
167
168 EXPECT_EQ(call_count, 0);
169
170 std::condition_variable cv;
171 std::mutex lock;
172 at::launch([&cv]() {
173 RECORD_FUNCTION("Test", {});
174 cv.notify_all();
175 });
176 auto guard = std::unique_lock<std::mutex>(lock);
177 cv.wait(guard, [] { return call_count > 0; });
178
179 EXPECT_EQ(call_count, 1);
180
181 at::removeCallback(handle);
182 ASSERT_FALSE(at::hasCallbacks());
183 }
184
185 // Test sampling logic and validate that callbacks fire at the correct times.
TEST(RecordFunctionTest,Sampling)186 TEST(RecordFunctionTest, Sampling) {
187 at::clearCallbacks();
188 ASSERT_FALSE(at::hasCallbacks());
189
190 static int sample_test_counter;
191 sample_test_counter = 0;
192
193 uint32_t seed = 12345;
194 double p = 0.25;
195
196 at::set_record_function_seed_for_testing(seed);
197 std::mt19937 generator;
198 generator.seed(seed);
199 auto dist = std::geometric_distribution<int>(p);
200
201 // Make sure we know which steps should fire.
202 auto outcomes = std::array<int, 5>{7, 0, 0, 6, 2};
203 for (const auto i : c10::irange(outcomes.size())) {
204 ASSERT_EQ(dist(generator), outcomes[i]);
205 }
206
207 std::vector<int> expected_counts;
208 int running_count = 0;
209 for (const auto i : c10::irange(outcomes.size())) {
210 for (const auto j : c10::irange(outcomes[i])) {
211 expected_counts.push_back(running_count);
212 }
213 expected_counts.push_back(++running_count);
214 }
215
216 auto start_callback =
217 [](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
218 ++sample_test_counter;
219 return nullptr;
220 };
221 auto end_callback = [](const at::RecordFunction& fn, at::ObserverContext*) {};
222
223 auto handle = at::addThreadLocalCallback(
224 at::RecordFunctionCallback(start_callback, end_callback)
225 .samplingProb(p)
226 .scopes({at::RecordScope::FUNCTION}));
227
228 for (const auto i : c10::irange(expected_counts.size())) {
229 RECORD_FUNCTION("Test", {});
230 EXPECT_EQ(sample_test_counter, expected_counts[i]);
231 }
232
233 at::removeCallback(handle);
234 ASSERT_FALSE(at::hasCallbacks());
235 }
236
237 // Validate sampling against a simple reference implementation for a complex set
238 // of registered callbacks.
TEST(RecordFunctionTest,MultipleCallbacks)239 TEST(RecordFunctionTest, MultipleCallbacks) {
240 at::clearCallbacks();
241 ASSERT_FALSE(at::hasCallbacks());
242
243 uint32_t seed = 54321;
244
245 std::mt19937 generator;
246 generator.seed(seed);
247
248 auto sample = [&](double p) {
249 return (p < 1.0 ? std::geometric_distribution<int>(p)(generator) : 0) + 1;
250 };
251
252 std::array<double, 4> probabilities{0.1, 1.0, 1.0, 0.3};
253 std::array<int, 4> next_call;
254 std::array<int, 4> counts;
255 static std::array<int, 4> counts_from_rec_fn;
256 counts_from_rec_fn.fill(0);
257
258 auto start_callback_0 =
259 [](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
260 ++counts_from_rec_fn[0];
261 return nullptr;
262 };
263
264 auto end_callback = [](const at::RecordFunction& fn, at::ObserverContext*) {};
265
266 #define REGISTER_CALLBACK(register_fn, index) \
267 register_fn(at::RecordFunctionCallback( \
268 [](const at::RecordFunction& fn) \
269 -> std::unique_ptr<at::ObserverContext> { \
270 ++counts_from_rec_fn[index]; \
271 return nullptr; \
272 }, \
273 end_callback) \
274 .samplingProb(probabilities[index]) \
275 .scopes({at::RecordScope::FUNCTION}))
276
277 REGISTER_CALLBACK(at::addGlobalCallback, 0);
278 REGISTER_CALLBACK(at::addGlobalCallback, 1);
279 REGISTER_CALLBACK(at::addThreadLocalCallback, 2);
280
281 // The RecordFunction machinery will rebuild callbacks whenever a new observer
282 // is registered, so we need to wait until the last callback to seed the
283 // random number generator.
284 at::set_record_function_seed_for_testing(seed);
285 REGISTER_CALLBACK(at::addThreadLocalCallback, 3);
286 #undef REGISTER_CALLBACK
287
288 for (const auto i : c10::irange(probabilities.size())) {
289 next_call[i] = sample(probabilities[i]);
290 }
291
292 for (const auto i : c10::irange(50)) {
293 RECORD_FUNCTION("Test", {});
294 for (const auto j : c10::irange(next_call.size())) {
295 if (!(--next_call[j])) {
296 ++counts[j];
297 next_call[j] = sample(probabilities[j]);
298 }
299 EXPECT_EQ(counts[j], counts_from_rec_fn[j]);
300 }
301 }
302
303 at::clearCallbacks();
304 ASSERT_FALSE(at::hasCallbacks());
305 }
306