xref: /aosp_15_r20/external/pytorch/test/cpp/profiler/record_function.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <array>
2 #include <atomic>
3 #include <condition_variable>
4 #include <iostream>
5 #include <memory>
6 #include <random>
7 #include <utility>
8 #include <vector>
9 
10 #include <fmt/format.h>
11 #include <gtest/gtest.h>
12 
13 #include <ATen/Parallel.h>
14 #include <ATen/record_function.h>
15 #include <c10/util/irange.h>
16 
17 // Test that we can add and remove callbacks (both global and thread local.)
TEST(RecordFunctionTest,AddRemove)18 TEST(RecordFunctionTest, AddRemove) {
19   at::clearCallbacks();
20   ASSERT_FALSE(at::hasCallbacks());
21 
22   auto start_callback =
23       [](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
24     return nullptr;
25   };
26   auto end_callback = [](const at::RecordFunction& fn, at::ObserverContext*) {};
27 
28   auto handle = at::addThreadLocalCallback(
29       at::RecordFunctionCallback(start_callback, end_callback));
30 
31   ASSERT_TRUE(at::hasCallbacks());
32   ASSERT_TRUE(at::hasThreadLocalCallbacks());
33   ASSERT_FALSE(at::hasGlobalCallbacks());
34 
35   at::removeCallback(handle);
36   ASSERT_FALSE(at::hasCallbacks());
37 
38   handle = at::addGlobalCallback(
39       at::RecordFunctionCallback(start_callback, end_callback));
40 
41   ASSERT_TRUE(at::hasCallbacks());
42   ASSERT_FALSE(at::hasThreadLocalCallbacks());
43   ASSERT_TRUE(at::hasGlobalCallbacks());
44 
45   at::removeCallback(handle);
46   ASSERT_FALSE(at::hasCallbacks());
47 }
48 
49 // Test that the callbacks that we register are actually run.
TEST(RecordFunctionTest,ThreadLocalState)50 TEST(RecordFunctionTest, ThreadLocalState) {
51   at::clearCallbacks();
52   ASSERT_FALSE(at::hasCallbacks());
53 
54   static int tls_test_start_counter;
55   static int tls_test_end_counter;
56   tls_test_start_counter = 0;
57   tls_test_end_counter = 0;
58 
59   auto start_callback =
60       [](const at::RecordFunction&) -> std::unique_ptr<at::ObserverContext> {
61     ++tls_test_start_counter;
62     return nullptr;
63   };
64   auto end_callback = [](const at::RecordFunction&, at::ObserverContext*) {
65     ++tls_test_end_counter;
66   };
67 
68   auto handle = at::addThreadLocalCallback(
69       at::RecordFunctionCallback(start_callback, end_callback));
70 
71   {
72     at::RecordFunction guard(at::RecordScope::USER_SCOPE);
73     guard.before("Test");
74     EXPECT_EQ(tls_test_start_counter, 1);
75     EXPECT_EQ(tls_test_end_counter, 0);
76   }
77   EXPECT_EQ(tls_test_start_counter, 1);
78   EXPECT_EQ(tls_test_end_counter, 1);
79 
80   {
81     tls_test_start_counter = 0;
82     tls_test_end_counter = 0;
83     at::DisableRecordFunctionGuard no_profile_guard;
84     at::RecordFunction guard(at::RecordScope::USER_SCOPE);
85     guard.before("Test");
86     EXPECT_EQ(tls_test_start_counter, 0);
87     EXPECT_EQ(tls_test_end_counter, 0);
88   }
89   EXPECT_EQ(tls_test_start_counter, 0);
90   EXPECT_EQ(tls_test_end_counter, 0);
91 
92   {
93     tls_test_start_counter = 0;
94     tls_test_end_counter = 0;
95     RECORD_FUNCTION("Test", {});
96     EXPECT_EQ(tls_test_start_counter, 1);
97     EXPECT_EQ(tls_test_end_counter, 0);
98   }
99   EXPECT_EQ(tls_test_start_counter, 1);
100   EXPECT_EQ(tls_test_end_counter, 1);
101 
102   at::removeCallback(handle);
103   ASSERT_FALSE(at::hasCallbacks());
104 }
105 
106 // Test that callbacks are run in the order that they are registered.
TEST(RecordFunctionTest,CallOrder)107 TEST(RecordFunctionTest, CallOrder) {
108   at::clearCallbacks();
109   ASSERT_FALSE(at::hasCallbacks());
110 
111   static int current_index;
112   current_index = 0;
113 
114   static std::array<std::string, 8> expected_order = {
115       "Start Callback 0 Outer",
116       "Start Callback 1 Outer",
117       "Start Callback 0 Inner",
118       "Start Callback 1 Inner",
119       "End Callback 0 Inner",
120       "End Callback 1 Inner",
121       "End Callback 0 Outer",
122       "End Callback 1 Outer",
123   };
124 
125 #define REGISTER_CALLBACK(index)                                       \
126   at::addThreadLocalCallback(                                          \
127       at::RecordFunctionCallback(                                      \
128           [](const at::RecordFunction& fn)                             \
129               -> std::unique_ptr<at::ObserverContext> {                \
130             EXPECT_EQ(                                                 \
131                 fmt::format("Start Callback {} {}", index, fn.name()), \
132                 expected_order[current_index++]);                      \
133             return nullptr;                                            \
134           },                                                           \
135           [](const at::RecordFunction& fn, at::ObserverContext*) {     \
136             EXPECT_EQ(                                                 \
137                 fmt::format("End Callback {} {}", index, fn.name()),   \
138                 expected_order[current_index++]);                      \
139           })                                                           \
140           .scopes({at::RecordScope::FUNCTION}))
141 
142   REGISTER_CALLBACK(0);
143   REGISTER_CALLBACK(1);
144 #undef REGISTER_CALLBACK
145 
146   RECORD_FUNCTION("Outer", {});
147   { RECORD_FUNCTION("Inner", {}); }
148 
149   at::clearCallbacks();
150   ASSERT_FALSE(at::hasCallbacks());
151 }
152 
153 // Make sure TLS migrates when tasks are launched.
TEST(RecordFunctionTest,ThreadMigration)154 TEST(RecordFunctionTest, ThreadMigration) {
155   at::clearCallbacks();
156   ASSERT_FALSE(at::hasCallbacks());
157 
158   static int call_count;
159   call_count = 0;
160 
161   auto handle = at::addThreadLocalCallback(
162       at::RecordFunctionCallback(
163           [](const at::RecordFunction&)
164               -> std::unique_ptr<at::ObserverContext> { return nullptr; },
165           [](const at::RecordFunction&, at::ObserverContext*) { ++call_count; })
166           .scopes({at::RecordScope::FUNCTION}));
167 
168   EXPECT_EQ(call_count, 0);
169 
170   std::condition_variable cv;
171   std::mutex lock;
172   at::launch([&cv]() {
173     RECORD_FUNCTION("Test", {});
174     cv.notify_all();
175   });
176   auto guard = std::unique_lock<std::mutex>(lock);
177   cv.wait(guard, [] { return call_count > 0; });
178 
179   EXPECT_EQ(call_count, 1);
180 
181   at::removeCallback(handle);
182   ASSERT_FALSE(at::hasCallbacks());
183 }
184 
185 // Test sampling logic and validate that callbacks fire at the correct times.
TEST(RecordFunctionTest,Sampling)186 TEST(RecordFunctionTest, Sampling) {
187   at::clearCallbacks();
188   ASSERT_FALSE(at::hasCallbacks());
189 
190   static int sample_test_counter;
191   sample_test_counter = 0;
192 
193   uint32_t seed = 12345;
194   double p = 0.25;
195 
196   at::set_record_function_seed_for_testing(seed);
197   std::mt19937 generator;
198   generator.seed(seed);
199   auto dist = std::geometric_distribution<int>(p);
200 
201   // Make sure we know which steps should fire.
202   auto outcomes = std::array<int, 5>{7, 0, 0, 6, 2};
203   for (const auto i : c10::irange(outcomes.size())) {
204     ASSERT_EQ(dist(generator), outcomes[i]);
205   }
206 
207   std::vector<int> expected_counts;
208   int running_count = 0;
209   for (const auto i : c10::irange(outcomes.size())) {
210     for (const auto j : c10::irange(outcomes[i])) {
211       expected_counts.push_back(running_count);
212     }
213     expected_counts.push_back(++running_count);
214   }
215 
216   auto start_callback =
217       [](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
218     ++sample_test_counter;
219     return nullptr;
220   };
221   auto end_callback = [](const at::RecordFunction& fn, at::ObserverContext*) {};
222 
223   auto handle = at::addThreadLocalCallback(
224       at::RecordFunctionCallback(start_callback, end_callback)
225           .samplingProb(p)
226           .scopes({at::RecordScope::FUNCTION}));
227 
228   for (const auto i : c10::irange(expected_counts.size())) {
229     RECORD_FUNCTION("Test", {});
230     EXPECT_EQ(sample_test_counter, expected_counts[i]);
231   }
232 
233   at::removeCallback(handle);
234   ASSERT_FALSE(at::hasCallbacks());
235 }
236 
237 // Validate sampling against a simple reference implementation for a complex set
238 // of registered callbacks.
TEST(RecordFunctionTest,MultipleCallbacks)239 TEST(RecordFunctionTest, MultipleCallbacks) {
240   at::clearCallbacks();
241   ASSERT_FALSE(at::hasCallbacks());
242 
243   uint32_t seed = 54321;
244 
245   std::mt19937 generator;
246   generator.seed(seed);
247 
248   auto sample = [&](double p) {
249     return (p < 1.0 ? std::geometric_distribution<int>(p)(generator) : 0) + 1;
250   };
251 
252   std::array<double, 4> probabilities{0.1, 1.0, 1.0, 0.3};
253   std::array<int, 4> next_call;
254   std::array<int, 4> counts;
255   static std::array<int, 4> counts_from_rec_fn;
256   counts_from_rec_fn.fill(0);
257 
258   auto start_callback_0 =
259       [](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
260     ++counts_from_rec_fn[0];
261     return nullptr;
262   };
263 
264   auto end_callback = [](const at::RecordFunction& fn, at::ObserverContext*) {};
265 
266 #define REGISTER_CALLBACK(register_fn, index)                   \
267   register_fn(at::RecordFunctionCallback(                       \
268                   [](const at::RecordFunction& fn)              \
269                       -> std::unique_ptr<at::ObserverContext> { \
270                     ++counts_from_rec_fn[index];                \
271                     return nullptr;                             \
272                   },                                            \
273                   end_callback)                                 \
274                   .samplingProb(probabilities[index])           \
275                   .scopes({at::RecordScope::FUNCTION}))
276 
277   REGISTER_CALLBACK(at::addGlobalCallback, 0);
278   REGISTER_CALLBACK(at::addGlobalCallback, 1);
279   REGISTER_CALLBACK(at::addThreadLocalCallback, 2);
280 
281   // The RecordFunction machinery will rebuild callbacks whenever a new observer
282   // is registered, so we need to wait until the last callback to seed the
283   // random number generator.
284   at::set_record_function_seed_for_testing(seed);
285   REGISTER_CALLBACK(at::addThreadLocalCallback, 3);
286 #undef REGISTER_CALLBACK
287 
288   for (const auto i : c10::irange(probabilities.size())) {
289     next_call[i] = sample(probabilities[i]);
290   }
291 
292   for (const auto i : c10::irange(50)) {
293     RECORD_FUNCTION("Test", {});
294     for (const auto j : c10::irange(next_call.size())) {
295       if (!(--next_call[j])) {
296         ++counts[j];
297         next_call[j] = sample(probabilities[j]);
298       }
299       EXPECT_EQ(counts[j], counts_from_rec_fn[j]);
300     }
301   }
302 
303   at::clearCallbacks();
304   ASSERT_FALSE(at::hasCallbacks());
305 }
306