1 #include <torch/csrc/autograd/profiler_python.h>
2
3 #include <atomic>
4 #include <cstdint>
5 #include <deque>
6 #include <limits>
7 #include <memory>
8 #include <queue>
9 #include <string>
10 #include <utility>
11 #include <vector>
12
13 #include <Python.h>
14 #include <frameobject.h>
15
16 #include <ATen/core/TensorBase.h>
17 #include <c10/macros/Macros.h>
18 #include <c10/util/ApproximateClock.h>
19 #include <c10/util/Exception.h>
20 #include <c10/util/Logging.h>
21 #include <c10/util/flat_hash_map.h>
22 #include <c10/util/irange.h>
23 #include <torch/csrc/autograd/python_variable.h>
24 #include <torch/csrc/profiler/collection.h>
25 #include <torch/csrc/profiler/containers.h>
26 #include <torch/csrc/profiler/orchestration/python_tracer.h>
27 #include <torch/csrc/profiler/util.h>
28 #include <torch/csrc/utils/pybind.h>
29 #include <torch/csrc/utils/python_compat.h>
30 #include <torch/csrc/utils/python_strings.h>
31 #include <optional>
32
33 namespace py = pybind11;
34
35 namespace torch::profiler::impl {
36 namespace {
37 enum CallType { PyCall = 0, PyModuleCall, PyCCall, PyOptimizerCall };
38 static constexpr size_t CallTypeSize = 4;
39 using no_ephemeral_t = std::tuple<>;
40 static constexpr uint64_t NoTID = std::numeric_limits<uint64_t>::max();
41
42 // ============================================================================
43 // == Miscellaneous structs and utils =========================================
44 // ============================================================================
45 struct CodeLocation {
46 CodeLocation() = default;
CodeLocationtorch::profiler::impl::__anonaba953630111::CodeLocation47 explicit CodeLocation(PyFrameObject* frame)
48 : line_number_{PyFrame_GetLineNumber(frame)} {
49 auto code = THPCodeObjectPtr(PyFrame_GetCode(frame));
50 filename_ = THPUtils_unpackStringView(code->co_filename).data();
51 name_ = THPUtils_unpackStringView(code->co_name).data();
52 }
53
operator ==torch::profiler::impl::__anonaba953630111::CodeLocation54 bool operator==(const CodeLocation& other) const {
55 return filename_ == other.filename_ && name_ == other.name_ &&
56 line_number_ == other.line_number_;
57 }
58
59 const char* filename_{nullptr};
60 const char* name_{nullptr};
61 int line_number_{0};
62 };
63
64 template <CallType C>
65 PyCodeObject* getCode();
66
67 template <>
getCode()68 PyCodeObject* getCode<CallType::PyModuleCall>() {
69 static auto module_call_code = []() {
70 pybind11::gil_scoped_acquire gil;
71 auto res = py::module::import("torch.nn")
72 .attr("Module")
73 .attr("__call__")
74 .attr("__code__")
75 .ptr();
76 TORCH_INTERNAL_ASSERT(PyCode_Check(res));
77 return (PyCodeObject*)res;
78 }();
79 return module_call_code;
80 };
81
82 template <>
getCode()83 PyCodeObject* getCode<CallType::PyOptimizerCall>() {
84 static auto optimizer_step_code = []() {
85 pybind11::gil_scoped_acquire gil;
86 auto res = py::module::import("torch.optim")
87 .attr("Optimizer")
88 .attr("_optimizer_step_code")
89 .attr("__code__")
90 .ptr();
91 TORCH_INTERNAL_ASSERT(PyCode_Check(res));
92 return (PyCodeObject*)res;
93 }();
94 return optimizer_step_code;
95 };
96
97 } // namespace
98 } // namespace torch::profiler::impl
99
100 template <>
101 struct std::hash<torch::profiler::impl::CodeLocation> {
operator ()std::hash102 size_t operator()(const torch::profiler::impl::CodeLocation& x) {
103 return c10::get_hash(x.filename_, x.name_, x.line_number_);
104 }
105 };
106
107 namespace torch::profiler::impl {
108 namespace {
109 // ============================================================================
110 // == CallTypeHelper: Tools for generic programming on specializations. =======
111 // ============================================================================
112 template <template <CallType> class ClassT>
113 class CallTypeHelper final {
114 private:
115 static_assert(
116 CallType::PyCall == 0,
117 "CallTypeHelper uses integer math which depends on a zero start.");
118 static constexpr size_t End = CallTypeSize;
119
120 template <size_t... I>
121 static constexpr std::tuple<ClassT<(CallType)I>...> make_tuple_impl(
122 std::index_sequence<I...>);
123
124 template <size_t C, typename T, typename FunctorT, typename... Args>
map(T & t,FunctorT & f,Args &&...args)125 static void map(T& t, FunctorT& f, Args&&... args) {
126 f(std::get<C>(t), args...);
127 if constexpr (C + 1 < End) {
128 map<C + 1>(t, f, std::forward<Args>(args)...);
129 }
130 }
131
132 public:
133 using tuple_type = decltype(make_tuple_impl(std::make_index_sequence<End>{}));
134
135 template <typename FunctorT, typename... Args>
map(tuple_type & t,FunctorT & f,Args &&...args)136 static void map(tuple_type& t, FunctorT& f, Args&&... args) {
137 map<0>(t, f, std::forward<Args>(args)...);
138 }
139 };
140
141 // ============================================================================
142 // == Event type definitions. =================================================
143 // ============================================================================
144 // When we are tracing a Python program, the general procedure is to record
145 // every time we enter or exit a function and later replay these events during
146 // post processing. Thus, during the profiling phase we want to do the MINIMAL
147 // amount of work to capture all of the information that we need; otherwise we
148 // will distort the profile. (While we don't wish to be terribly inefficient
149 // during post processing, we are willing to do extra fixup work in post if it
150 // reduces overhead in the profiling phase.)
151 //
152 // When the tracer first enters a frame, it constructs a CallKey for that
153 // location. The contents of the key vary by context. For a python function
154 // the key is the (PyCodeObject*, int) pair that defines the bytecode of the
155 // function. For an `nn.Module` the key is a (non-owning) pointer to `self`.
156 // For a bound C function it is a (non-owning) pointer to the bound function.
157 // A CallKey should be small, inexpensive, and POD.
158 //
159 // We then collect a CallKey<CallType::PyCall> for the calling frame for better
160 // source tracking. This pair is a `Callsite`, and serves as a first level key
161 // during tracing. We lookup the Callsite in a thread local cache which maps
162 // Callsite to a unique integer `TraceKey`. On a cache hit, we simply store the
163 // TraceKey and return. On a cache miss, we use a global value cache to store
164 // whatever fields we need from the two CallKeys, generate a new TraceKey, and
165 // update the local cache.
166 //
167 // During post processing we:
168 // 1) Determine the type represented by a TraceKey by checking which
169 // sub-cache it appears in in the thread local cache.
170 // 2) Look up the pair of CallKeys from the thread local cache.
171 // 3) Look up the expanded values of each CallKey from the global value cache.
172 //
173 // To add a new event type to the cache:
174 // 1) Add an entry to the `CallType` enum.
175 // 2) Add a specialization of Config which defined key_t, ephemeral_t and
176 // cache_t.
177 // 3) Add a specialization of ValueCache::store and ValueCache::load.
178 //
179 // -------------------------
180 // -- Ephemeral arguments --
181 // -------------------------
182 // The value cache mechanism assumes that `key_t` is enough to specify the
183 // correct value. However it may not be possible to materialize a value using
184 // only an instance of `key_t`. As a result, the cache also accepts "ephemeral"
185 // inputs which can be used to populate the value cache. Ephemeral inputs come
186 // with two caveats:
187 // 1) They are NOT safe to save, and cannot be used after `ValueCache::store`.
188 // 2) They should be used to access data that is not expect to change from
189 // call to call, such as the name of a function.
190
191 template <CallType>
192 struct Config;
193
194 template <>
195 struct Config<CallType::PyCall> {
196 using key_t = CodeLocation;
197 using ephemeral_t = no_ephemeral_t;
198 using cache_t = ska::flat_hash_map<key_t, PyFrameState>;
199 static constexpr EventType event_type = EventType::PyCall;
200 };
201
202 template <typename Key, typename Cls, typename ParameterInfo>
203 struct ExtendedPyCallConfig {
204 using key_t = Key;
205 using cls_t = Cls;
206 using ephemeral_t = PyFrameObject*;
207
208 struct ClsAndParameters {
209 cls_t cls_;
210 std::vector<ParameterInfo> parameters_;
211 };
212
213 struct Cache {
214 // `nn.Module.forward` or `optim.Optimizer._optimizer_step_code`
215 std::optional<CodeLocation> location_;
216 ska::flat_hash_map<key_t, ClsAndParameters> cls_and_parameters_;
217 ska::flat_hash_map<cls_t, at::StringView> cls_names_;
218 };
219 using cache_t = Cache;
220
221 static constexpr EventType event_type = EventType::PyCall;
222 };
223
224 template <>
225 struct Config<CallType::PyModuleCall> : ExtendedPyCallConfig<
226 PyModuleSelf,
227 PyModuleCls,
228 NNModuleInfo::ParameterInfo> {};
229
230 template <>
231 struct Config<CallType::PyOptimizerCall> : ExtendedPyCallConfig<
232 PyOptimizerSelf,
233 PyOptimizerCls,
234 OptimizerInfo::ParameterInfo> {};
235
236 template <>
237 struct Config<CallType::PyCCall> {
238 using key_t = PyMethod;
239 using ephemeral_t = PyObject*;
240 using cache_t = ska::flat_hash_map<key_t, at::StringView>;
241 static constexpr EventType event_type = EventType::PyCCall;
242 };
243
244 // ============================================================================
245 // == Callsite & ValueCache: Storage during profiling =========================
246 // ============================================================================
247 template <CallType C>
248 class Callsite {
249 public:
250 static constexpr CallType call_type = C;
251 using key_t = typename Config<C>::key_t;
252
253 static_assert(
254 std::is_trivially_copyable_v<key_t>,
255 "Key should be trivial, as it is passed by value.");
256
257 template <typename U>
Callsite(U value,PyFrameObject * f_back)258 Callsite(U value, PyFrameObject* f_back) : value_(value), caller_(f_back) {}
259
operator ==(const Callsite<C> & other) const260 bool operator==(const Callsite<C>& other) const {
261 return value_ == other.value_ && caller_ == other.caller_;
262 }
263
264 key_t value_;
265 Config<CallType::PyCall>::key_t caller_;
266 };
267
268 // ============================================================================
269 // == Type specific store and load implementations. ===========================
270 // ============================================================================
271 using PyCallKey = Config<CallType::PyCall>::key_t;
272 using PyModuleCallKey = Config<CallType::PyModuleCall>::key_t;
273 using PyCCallKey = Config<CallType::PyCCall>::key_t;
274 using PyOptimizerCallKey = Config<CallType::PyOptimizerCall>::key_t;
275
276 class ValueCache {
277 public:
278 ValueCache() = default;
279 ValueCache(const ValueCache&) = delete;
280
281 template <CallType C>
282 void store(const typename Config<C>::key_t&, typename Config<C>::ephemeral_t);
283
284 template <CallType C>
load(const Callsite<C> & callsite,size_t python_tid) const285 auto load(const Callsite<C>& callsite, size_t python_tid) const {
286 auto caller = load<CallType::PyCall>(callsite.caller_);
287 TORCH_INTERNAL_ASSERT(!caller.module_info_.has_value());
288 return ExtraFields<Config<C>::event_type>{
289 /*end_time_ns=*/std::numeric_limits<c10::time_t>::min(),
290 python_tid,
291 caller.frame_state_,
292 load<C>(callsite.value_)};
293 }
294
295 std::optional<TensorMetadata> recordIfTensor(py::handle p);
296 std::vector<std::pair<std::string, TensorMetadata>> unpackTensorMap(
297 const py::dict& tensor_map);
298 void trimPrefixes();
299
300 private:
301 template <CallType C>
302 typename ExtraFields<Config<C>::event_type>::args_t load(
303 const typename Config<C>::key_t&) const;
304
305 template <CallType C>
306 using State = typename Config<C>::cache_t;
307
308 CallTypeHelper<State>::tuple_type state_;
309 };
310
311 template <CallType C>
set_class(ValueCache * value_cache,typename Config<C>::cache_t & cache,const typename Config<C>::key_t & key,const typename Config<C>::ephemeral_t & frame)312 typename Config<C>::cls_t set_class(
313 ValueCache* value_cache,
314 typename Config<C>::cache_t& cache,
315 const typename Config<C>::key_t& key,
316 const typename Config<C>::ephemeral_t& frame) {
317 if (C10_UNLIKELY(!cache.location_.has_value())) {
318 auto code = THPCodeObjectPtr(PyFrame_GetCode(frame));
319 TORCH_INTERNAL_ASSERT(code.get() == getCode<C>());
320 cache.location_ = PyCallKey(frame);
321 value_cache->store<CallType::PyCall>(*cache.location_, no_ephemeral_t());
322 }
323
324 auto cls_handle = py::handle((PyObject*)key).attr("__class__");
325 auto cls = typename Config<C>::cls_t(cls_handle.ptr());
326 if (cache.cls_names_.find(cls) == cache.cls_names_.end()) {
327 cache.cls_names_[cls] =
328 at::StringView(py::str(cls_handle.attr("__name__")));
329 }
330 return cls;
331 }
332
toTensorMetadata(PyObject * self)333 TensorMetadata toTensorMetadata(PyObject* self) {
334 TORCH_INTERNAL_ASSERT(THPVariable_CheckExact(self));
335 const auto& t = THPVariable_Unpack(self);
336 RawTensorMetadata m{t};
337 return TensorMetadata{
338 m,
339 t.sizes().vec(),
340 m.layout_ == at::kStrided ? t.strides().vec() : std::vector<int64_t>()};
341 }
342
recordIfTensor(py::handle p)343 std::optional<TensorMetadata> ValueCache::recordIfTensor(py::handle p) {
344 return THPVariable_CheckExact(p.ptr())
345 ? std::optional<TensorMetadata>{toTensorMetadata(p.ptr())}
346 : std::nullopt;
347 }
348
unpackTensorMap(const py::dict & tensor_map)349 std::vector<std::pair<std::string, TensorMetadata>> ValueCache::unpackTensorMap(
350 const py::dict& tensor_map) {
351 std::vector<std::pair<std::string, TensorMetadata>> out;
352 for (auto& it : tensor_map) {
353 auto* value = it.second.ptr();
354 if (py::isinstance<py::str>(it.first) && THPVariable_CheckExact(value)) {
355 out.emplace_back(
356 py::cast<std::string>(it.first), toTensorMetadata(value));
357 }
358 }
359 return out;
360 }
361
362 template <>
store(const PyCallKey & key,no_ephemeral_t)363 void ValueCache::store<CallType::PyCall>(const PyCallKey& key, no_ephemeral_t) {
364 auto& locations = std::get<CallType::PyCall>(state_);
365 if (C10_UNLIKELY(locations.find(key) == locations.end())) {
366 locations[key] = {
367 key.line_number_,
368 at::StringView(key.filename_),
369 at::StringView(key.name_)};
370 }
371 }
372
373 template <>
load(const PyCallKey & key) const374 ExtraFields<EventType::PyCall>::args_t ValueCache::load<CallType::PyCall>(
375 const PyCallKey& key) const {
376 return {std::get<CallType::PyCall>(state_).at(key), std::nullopt};
377 }
378
379 template <>
store(const PyModuleCallKey & key,Config<CallType::PyModuleCall>::ephemeral_t frame)380 void ValueCache::store<CallType::PyModuleCall>(
381 const PyModuleCallKey& key,
382 Config<CallType::PyModuleCall>::ephemeral_t frame) {
383 auto& cache = std::get<CallType::PyModuleCall>(state_);
384 if (C10_UNLIKELY(
385 cache.cls_and_parameters_.find(key) ==
386 cache.cls_and_parameters_.end())) {
387 auto cls = set_class<CallType::PyModuleCall>(this, cache, key, frame);
388
389 py::dict params = py::handle((PyObject*)key).attr("_parameters");
390 std::vector<NNModuleInfo::ParameterInfo> params_;
391 for (auto& it : params) {
392 auto* p = it.second.ptr();
393 if (py::isinstance<py::str>(it.first) && THPVariable_CheckExact(p)) {
394 params_.push_back(
395 {it.first.cast<std::string>(),
396 toTensorMetadata(p),
397 recordIfTensor(py::getattr(it.second, "grad", py::none()))});
398 }
399 }
400 cache.cls_and_parameters_[key] = {cls, std::move(params_)};
401 }
402 }
403
404 template <>
load(const PyModuleCallKey & key) const405 ExtraFields<EventType::PyCall>::args_t ValueCache::load<CallType::PyModuleCall>(
406 const PyModuleCallKey& key) const {
407 auto& cache = std::get<CallType::PyModuleCall>(state_);
408 TORCH_INTERNAL_ASSERT(cache.location_.has_value());
409 const auto& cls_and_parameters = cache.cls_and_parameters_.at(key);
410 const auto& cls = cls_and_parameters.cls_;
411 NNModuleInfo info{
412 key, cls, cache.cls_names_.at(cls), cls_and_parameters.parameters_};
413 return {
414 /*frame_state_=*/std::get<CallType::PyCall>(state_).at(*cache.location_),
415 /*module_info_=*/std::move(info),
416 /*optimizer_info_=*/std::nullopt};
417 }
418
419 template <>
store(const PyOptimizerCallKey & key,Config<CallType::PyOptimizerCall>::ephemeral_t frame)420 void ValueCache::store<CallType::PyOptimizerCall>(
421 const PyOptimizerCallKey& key,
422 Config<CallType::PyOptimizerCall>::ephemeral_t frame) {
423 auto& cache = std::get<CallType::PyOptimizerCall>(state_);
424 if (C10_UNLIKELY(
425 cache.cls_and_parameters_.find(key) ==
426 cache.cls_and_parameters_.end())) {
427 auto cls = set_class<CallType::PyOptimizerCall>(this, cache, key, frame);
428 const py::handle self{(PyObject*)key};
429 std::vector<OptimizerInfo::ParameterInfo> params;
430
431 for (const auto& i : (py::list)self.attr("param_groups")) {
432 for (auto& param : py::cast<py::dict>(i).attr("get")("params")) {
433 if (THPVariable_CheckExact(param.ptr())) {
434 // While `self.state` is permitted to store data in an arbitrary way,
435 // all generic optimizers (SGD, Adam, etc) use param as the key since
436 // the state in question is tied to particular parameters. We can
437 // relax this assumption if the need arises.
438 params.push_back(
439 {toTensorMetadata(param.ptr()),
440 recordIfTensor(py::getattr(param, "grad", py::none())),
441 unpackTensorMap(py::cast<py::dict>(self.attr("state"))
442 .attr("get")(param, py::dict()))});
443 }
444 }
445 }
446
447 cache.cls_and_parameters_[key] = {cls, std::move(params)};
448 }
449 }
450
451 template <>
load(const PyOptimizerCallKey & key) const452 ExtraFields<EventType::PyCall>::args_t ValueCache::load<
453 CallType::PyOptimizerCall>(const PyOptimizerCallKey& key) const {
454 auto& cache = std::get<CallType::PyOptimizerCall>(state_);
455 const auto& cls_and_parameters = cache.cls_and_parameters_.at(key);
456 auto cls = cls_and_parameters.cls_;
457 OptimizerInfo info{
458 key, cls, cache.cls_names_.at(cls), cls_and_parameters.parameters_};
459 return {
460 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
461 /*frame_state_=*/std::get<CallType::PyCall>(state_).at(*cache.location_),
462 /*module_info_=*/std::nullopt,
463 /*optimizer_info_=*/std::move(info)};
464 }
465
466 template <>
store(const PyCCallKey & key,Config<CallType::PyCCall>::ephemeral_t arg)467 void ValueCache::store<CallType::PyCCall>(
468 const PyCCallKey& key,
469 Config<CallType::PyCCall>::ephemeral_t arg) {
470 auto& names = std::get<CallType::PyCCall>(state_);
471 if (C10_UNLIKELY(names.find(key) == names.end())) {
472 names[key] = at::StringView(py::repr(arg));
473 }
474 }
475
476 template <>
load(const PyCCallKey & key) const477 ExtraFields<EventType::PyCCall>::args_t ValueCache::load<CallType::PyCCall>(
478 const PyCCallKey& key) const {
479 return std::get<CallType::PyCCall>(state_).at(key);
480 }
481
482 // TODO: Use re2.
trimPrefixes()483 void ValueCache::trimPrefixes() {
484 static const auto prefixes = []() {
485 pybind11::gil_scoped_acquire gil;
486 return py::module::import("torch.profiler.python_tracer")
487 .attr("_prefix_regex")()
488 .cast<std::vector<std::string>>();
489 }();
490
491 for (auto& it : std::get<CallType::PyCall>(state_)) {
492 std::string filename = it.second.filename_.str();
493 for (const auto& p : prefixes) {
494 if (filename.compare(0, p.size(), p) == 0) {
495 filename.erase(0, p.size());
496 it.second.filename_ = at::StringView(filename);
497 break;
498 }
499 }
500 }
501 }
502
503 // ============================================================================
504 // == TraceKey cache ==========================================================
505 // ============================================================================
506 using python_tracer::TraceKey;
507
nextKey()508 TraceKey nextKey() {
509 static std::atomic<uint64_t> key{0};
510 return TraceKey{++key};
511 }
512
513 template <CallType C>
514 struct TraceKeyCacheState {
515 struct Hash {
operator ()torch::profiler::impl::__anonaba953630411::TraceKeyCacheState::Hash516 size_t operator()(const Callsite<C>& key) {
517 return c10::get_hash(key.value_, key.caller_);
518 }
519 };
520
interntorch::profiler::impl::__anonaba953630411::TraceKeyCacheState521 TraceKey intern(
522 Callsite<C> callsite,
523 typename Config<C>::ephemeral_t ephemeral,
524 ValueCache& value_cache) {
525 auto it = state_.find(callsite);
526 if (C10_UNLIKELY(it == state_.end())) {
527 value_cache.store<C>(callsite.value_, ephemeral);
528 value_cache.store<CallType::PyCall>(callsite.caller_, no_ephemeral_t());
529 it = state_.insert({callsite, nextKey()}).first;
530 }
531 return it->second;
532 }
533
lookuptorch::profiler::impl::__anonaba953630411::TraceKeyCacheState534 auto lookup(Callsite<C>& callsite, ValueCache& value_cache) const {
535 return std::make_pair(
536 value_cache.load<C>(callsite.value_),
537 value_cache.load<CallType::PyCall>(callsite.caller_));
538 }
539
540 ska::flat_hash_map<Callsite<C>, TraceKey, Hash> state_;
541 };
542
543 // ============================================================================
544 // == Core CPython data types =================================================
545 // ============================================================================
546 // PyObject that allows different threads to record events without colliding.
547 // It is passed as the second argument when enabling tracing via
548 // `PyEval_SetProfile`.
549 struct ThreadLocalResults;
550 struct TraceContext {
551 PyObject_HEAD;
552 ThreadLocalResults* thread_local_results_;
553 };
554
555 // CPython boilerplate to define `TraceContext` as a proper python object.
556 static PyTypeObject TraceContextType = {
557 PyVarObject_HEAD_INIT(nullptr, 0) "TraceContext", /* tp_name */
558 sizeof(TraceContext), /* tp_basicsize */
559 0, /* tp_itemsize */
560 nullptr, /* tp_dealloc */
561 0,
562 /* tp_vectorcall_offset */
563 nullptr, /* tp_getattr */
564 nullptr, /* tp_setattr */
565 nullptr, /* tp_reserved */
566 nullptr, /* tp_repr */
567 nullptr, /* tp_as_number */
568 nullptr, /* tp_as_sequence */
569 nullptr, /* tp_as_mapping */
570 nullptr, /* tp_hash */
571 nullptr, /* tp_call */
572 nullptr, /* tp_str */
573 nullptr, /* tp_getattro */
574 nullptr, /* tp_setattro */
575 nullptr, /* tp_as_buffer */
576 Py_TPFLAGS_DEFAULT, /* tp_flags */
577 "Python tracer TLS", /* tp_doc */
578 nullptr, /* tp_traverse */
579 nullptr, /* tp_clear */
580 nullptr, /* tp_richcompare */
581 0, /* tp_weaklistoffset */
582 nullptr, /* tp_iter */
583 nullptr, /* tp_iternext */
584 nullptr, /* tp_methods */
585 nullptr, /* tp_members */
586 nullptr, /* tp_getset */
587 nullptr, /* tp_base */
588 nullptr, /* tp_dict */
589 nullptr, /* tp_descr_get */
590 nullptr, /* tp_descr_set */
591 0, /* tp_dictoffset */
592 nullptr, /* tp_init */
593 nullptr, /* tp_alloc */
594 PyType_GenericNew, /* tp_new */
595 nullptr /* tp_free */
596 };
597
598 class gil_and_restore_thread {
599 public:
gil_and_restore_thread()600 gil_and_restore_thread()
601 : gil_(), initial_thread_state_{PyThreadState_Get()} {}
~gil_and_restore_thread()602 ~gil_and_restore_thread() {
603 PyThreadState_Swap(initial_thread_state_);
604
605 // `gil_scoped_acquire` is a bit fragile in on-demand mode:
606 // https://github.com/pytorch/pytorch/pull/91684#issuecomment-1413154458
607 if (!Py_IsInitialized()) {
608 gil_.disarm();
609 }
610 }
611
initial_thread_state() const612 PyThreadState* initial_thread_state() const {
613 return initial_thread_state_;
614 }
615
616 private:
617 pybind11::gil_scoped_acquire gil_;
618 PyThreadState* initial_thread_state_;
619 };
620
621 // ============================================================================
622 // == Thread local cache ======================================================
623 // ============================================================================
624 class PythonTracer;
625 struct ThreadLocalResults {
ThreadLocalResultstorch::profiler::impl::__anonaba953630411::ThreadLocalResults626 ThreadLocalResults(
627 PyThreadState* thread_state,
628 ValueCache* value_cache,
629 PythonTracer* active_tracer)
630 : thread_state_{thread_state},
631 ctx_{(TraceContext*)TraceContextType.tp_alloc(&TraceContextType, 0)},
632 value_cache_{value_cache},
633 active_tracer_{active_tracer} {
634 ctx_->thread_local_results_ = this;
635 }
636
637 ThreadLocalResults() = delete;
638 ThreadLocalResults(const ThreadLocalResults&) = delete;
639 ThreadLocalResults(ThreadLocalResults&&) = delete;
640 ThreadLocalResults& operator=(const ThreadLocalResults&) = delete;
641 ThreadLocalResults& operator=(const ThreadLocalResults&&) = delete;
642
~ThreadLocalResultstorch::profiler::impl::__anonaba953630411::ThreadLocalResults643 ~ThreadLocalResults() {
644 Py_DECREF((PyObject*)ctx_);
645 }
646
647 template <CallType C, EventType E, typename Ephemeral, typename... Args>
interntorch::profiler::impl::__anonaba953630411::ThreadLocalResults648 TraceKey intern(Ephemeral ephemeral, Args... args) {
649 static_assert(
650 Config<C>::event_type == E,
651 "ThreadLocalResults.intern called from the wrong typed context.");
652 auto callsite = Callsite<C>(std::forward<Args>(args)...);
653 return std::get<C>(trace_keys_).intern(callsite, ephemeral, *value_cache_);
654 }
655
656 static constexpr size_t BLOCK_SIZE = 1024;
657
658 PyThreadState* thread_state_;
659 TraceContext* ctx_;
660 ValueCache* value_cache_;
661 PythonTracer* active_tracer_;
662 CallTypeHelper<TraceKeyCacheState>::tuple_type trace_keys_;
663 AppendOnlyList<c10::approx_time_t, BLOCK_SIZE> exit_times_;
664 AppendOnlyList<c10::approx_time_t, BLOCK_SIZE> c_exit_times_;
665 };
666
667 // ============================================================================
668 // == Tracing implementation ==================================================
669 // ============================================================================
670 class PythonTracer final : public python_tracer::PythonTracerBase {
671 public:
672 PythonTracer(torch::profiler::impl::RecordQueue* queue);
673 // NOLINTNEXTLINE(bugprone-exception-escape)
674 ~PythonTracer() override;
675
676 static int pyProfileFn(
677 PyObject* obj,
678 PyFrameObject* frame,
679 int what,
680 PyObject* arg);
681
682 void stop() override;
683 void restart() override;
684 std::vector<std::shared_ptr<Result>> getEvents(
685 std::function<c10::time_t(c10::approx_time_t)> time_converter,
686 std::vector<python_tracer::CompressedEvent>& enters,
687 c10::time_t end_time_ns) override;
688
689 struct StartFrame {
690 TraceKey trace_key_;
691 c10::approx_time_t start_time{};
692 };
693
694 private:
695 void recordPyCall(
696 ThreadLocalResults& tls,
697 PyFrameObject* frame,
698 bool is_startup_frame);
699
700 void recordCCall(
701 ThreadLocalResults& tls,
702 PyFrameObject* frame,
703 PyObject* arg);
704
705 const std::vector<PyThreadState*> interpreterThreads() const;
706
707 std::atomic<bool> active_lock_{false};
708 bool active_{false};
709
710 torch::profiler::impl::RecordQueue* queue_;
711 PyInterpreterState* interpreter_{nullptr};
712 PyCodeObject* module_call_code_;
713 PyCodeObject* optimizer_hook_;
714
715 std::vector<StartFrame> start_frames_;
716 std::deque<ThreadLocalResults> thread_local_results_;
717 ValueCache value_cache_;
718 };
719
interpreterThreads() const720 const std::vector<PyThreadState*> PythonTracer::interpreterThreads() const {
721 pybind11::gil_scoped_acquire gil;
722 std::vector<PyThreadState*> out;
723 if (SOFT_ASSERT(interpreter_)) {
724 auto* thread_state = PyInterpreterState_ThreadHead(interpreter_);
725 while (thread_state != nullptr) {
726 out.push_back(thread_state);
727 thread_state = PyThreadState_Next(thread_state);
728 }
729 }
730 return out;
731 }
732
PythonTracer(torch::profiler::impl::RecordQueue * queue)733 PythonTracer::PythonTracer(torch::profiler::impl::RecordQueue* queue)
734 : queue_(queue),
735
736 module_call_code_(getCode<CallType::PyModuleCall>()),
737 optimizer_hook_(getCode<CallType::PyOptimizerCall>()) {
738 TORCH_CHECK(queue_ != nullptr);
739
740 bool expected{false};
741 active_ = active_lock_.compare_exchange_strong(expected, true);
742 if (!active_) {
743 TORCH_WARN(
744 "There is already an active Python tracer. "
745 "Refusing to register profile functions.");
746 return;
747 }
748
749 gil_and_restore_thread gil;
750 interpreter_ = PyInterpreterState_Get();
751
752 if (!gil.initial_thread_state()) {
753 TORCH_WARN("PyThreadState_Get returned NULL");
754 return;
755 }
756
757 // Register the tracer in each thread.
758 for (const auto thread_state : interpreterThreads()) {
759 PyThreadState_Swap(thread_state);
760
761 thread_local_results_.emplace_back(thread_state, &value_cache_, this);
762 auto* ctx = thread_local_results_.back().ctx_;
763
764 // When we begin profiling there are already frames on the Python
765 // interpreter stack. To ensure a complete trace, we must push calls
766 // to all the prior frames onto our event stack. (We stop at depth=128)
767
768 std::vector<THPFrameObjectPtr> current_stack;
769 auto frame = PyEval_GetFrame();
770 Py_XINCREF(frame);
771
772 size_t depth = 0; // Make sure we can't infinite loop.
773 while (frame != nullptr) {
774 current_stack.emplace_back(frame);
775 if (++depth == 128) {
776 break;
777 }
778
779 // NB: `PyFrame_GetBack` returns a strong reference.
780 frame = PyFrame_GetBack(frame);
781 }
782
783 for (auto it = current_stack.rbegin(); it != current_stack.rend(); it++) {
784 recordPyCall(thread_local_results_.back(), it->get(), true);
785 auto frame_refcount = Py_REFCNT(it->get());
786
787 // We hold one reference in `current_stack`, and the interpreter holds
788 // another.
789 TORCH_INTERNAL_ASSERT(frame_refcount >= 2, frame_refcount);
790 }
791
792 // Note:
793 // This profile will not compose with other CPython profilers, and
794 // cannot be round tripped via `sys.settrace(sys.gettrace())`
795 PyEval_SetProfile(PythonTracer::pyProfileFn, (PyObject*)ctx);
796 }
797 };
798
stop()799 void PythonTracer::stop() {
800 gil_and_restore_thread gil;
801 if (active_) {
802 for (const auto thread_state : interpreterThreads()) {
803 if (thread_state->c_profilefunc == &PythonTracer::pyProfileFn) {
804 PyThreadState_Swap(thread_state);
805 PyEval_SetProfile(nullptr, nullptr);
806 }
807 }
808
809 auto lock_returned = active_lock_.compare_exchange_strong(active_, false);
810 active_ = false;
811 SOFT_ASSERT(lock_returned, "Failed to return python tracer lock.");
812 }
813 }
814
restart()815 void PythonTracer::restart() {
816 gil_and_restore_thread gil;
817 active_ = active_lock_.compare_exchange_strong(active_, true);
818 if (!active_) {
819 TORCH_WARN(
820 "There is already an active Python tracer. "
821 "Refusing to register profile functions.");
822 return;
823 }
824 int cur_thread = 0;
825 for (const auto thread_state : interpreterThreads()) {
826 if (thread_state->c_profilefunc == nullptr) {
827 auto* ctx = thread_local_results_[cur_thread].ctx_;
828 PyThreadState_Swap(thread_state);
829 PyEval_SetProfile(PythonTracer::pyProfileFn, (PyObject*)ctx);
830 }
831 }
832 }
833
834 // NOLINTNEXTLINE(bugprone-exception-escape)
~PythonTracer()835 PythonTracer::~PythonTracer() {
836 if (active_) {
837 TORCH_WARN("`PythonTracer::stop()` was not called.");
838 stop();
839 }
840 }
841
recordPyCall(ThreadLocalResults & tls,PyFrameObject * frame,bool is_startup_frame)842 void PythonTracer::recordPyCall(
843 ThreadLocalResults& tls,
844 PyFrameObject* frame,
845 bool is_startup_frame) {
846 static constexpr auto E = EventType::PyCall;
847 const auto key = [&]() -> TraceKey {
848 auto code = THPCodeObjectPtr(PyFrame_GetCode(frame));
849 if (code.get() == module_call_code_) {
850 // By default, CPython stores locals in a "fast" format, with an array
851 // of names and an array of values. Consequently, frame->f_locals is
852 // NULL since the interpreter has no need to populate it.
853 //
854 // If these arrays were part of the public API then we could very
855 // quickly access `self`. Unfortunately they are not, and moreover are
856 // not stable across versions. As a result, we are forced to call
857 // `PyFrame_FastToLocals` which forces the interpreter to materialize
858 // the full dict of locals.
859 auto locals = THPObjectPtr(PyFrame_GetLocals(frame));
860 auto self = THPObjectPtr(PyDict_GetItemString(locals, "self"));
861 Py_INCREF(self.get());
862 auto back = THPFrameObjectPtr(PyFrame_GetBack(frame));
863 TORCH_INTERNAL_ASSERT(back != nullptr);
864 return tls.intern<CallType::PyModuleCall, E>(
865 frame, self.get(), back.get());
866 } else if (code.get() == optimizer_hook_) {
867 auto locals = THPObjectPtr(PyFrame_GetLocals(frame));
868 auto self = THPObjectPtr(PyDict_GetItemString(locals, "self"));
869 Py_INCREF(self.get());
870 auto back = THPFrameObjectPtr(PyFrame_GetBack(frame));
871 TORCH_INTERNAL_ASSERT(back != nullptr);
872 return tls.intern<CallType::PyOptimizerCall, E>(
873 frame, self.get(), back.get());
874 } else {
875 auto back = THPFrameObjectPtr(PyFrame_GetBack(frame));
876 auto f_back = (back.get() != nullptr) ? back.get() : frame;
877 return tls.intern<CallType::PyCall, E>(no_ephemeral_t(), frame, f_back);
878 }
879 }();
880 const auto time = c10::getApproximateTime();
881 is_startup_frame ? start_frames_.push_back({key, time})
882 : queue_->getSubqueue()->emplace_py_call(key, time);
883 }
884
recordCCall(ThreadLocalResults & tls,PyFrameObject * frame,PyObject * arg)885 void PythonTracer::recordCCall(
886 ThreadLocalResults& tls,
887 PyFrameObject* frame,
888 PyObject* arg) {
889 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(PyCFunction_Check(arg));
890 auto fn = reinterpret_cast<PyCFunctionObject*>(arg);
891
892 // NB: For C calls a new frame is not created, so we use `frame` rather than
893 // `frame->f_back`.
894 auto key = tls.intern<CallType::PyCCall, EventType::PyCCall>(
895 arg, (void*)(fn->m_ml), frame);
896 queue_->getSubqueue()->emplace_py_call(key, c10::getApproximateTime());
897 }
898
899 // ============================================================================
900 // == Post processing =========================================================
901 // ============================================================================
902 struct Exit {
operator >torch::profiler::impl::__anonaba953630411::Exit903 bool operator>(const Exit& other) const {
904 return t_ > other.t_;
905 }
906
907 c10::time_t t_;
908 size_t python_tid_;
909 };
910
911 class PostProcess {
912 public:
PostProcess(std::function<c10::time_t (c10::approx_time_t)> time_converter,std::deque<ThreadLocalResults> & tls,const ValueCache & value_cache,c10::time_t end_time_ns)913 PostProcess(
914 std::function<c10::time_t(c10::approx_time_t)> time_converter,
915 std::deque<ThreadLocalResults>& tls,
916 const ValueCache& value_cache,
917 c10::time_t end_time_ns)
918 : end_time_{end_time_ns}, time_converter_{std::move(time_converter)} {
919 for (size_t python_tid : c10::irange(tls.size())) {
920 CallTypeHelper<TraceKeyCacheState>::map(
921 tls[python_tid].trace_keys_, *this, value_cache, python_tid);
922
923 addExits<EventType::PyCall>(tls[python_tid].exit_times_, python_tid);
924 addExits<EventType::PyCCall>(tls[python_tid].c_exit_times_, python_tid);
925 }
926 }
927
set_start_frames(const std::vector<PythonTracer::StartFrame> & start_frames,std::vector<python_tracer::CompressedEvent> & enters)928 void set_start_frames(
929 const std::vector<PythonTracer::StartFrame>& start_frames,
930 std::vector<python_tracer::CompressedEvent>& enters) {
931 for (const auto& frame : start_frames) {
932 enters.push_back(
933 {frame.trace_key_,
934 NoTID, // Allows us to detect unhandled start frames
935 {},
936 time_converter_(frame.start_time)});
937 }
938 }
939
940 template <CallType C>
operator ()(const TraceKeyCacheState<C> & trace_cache,const ValueCache & value_cache,size_t python_tid)941 void operator()(
942 const TraceKeyCacheState<C>& trace_cache,
943 const ValueCache& value_cache,
944 size_t python_tid) {
945 for (const auto& it : trace_cache.state_) {
946 const auto inserted = get_state<Config<C>::event_type>().fields_.insert(
947 {it.second, value_cache.load(it.first, python_tid)});
948 TORCH_INTERNAL_ASSERT(inserted.second, "Duplicate key: ", it.second);
949 }
950 }
951
952 template <EventType E, size_t N>
addExits(AppendOnlyList<c10::approx_time_t,N> & exits,size_t python_tid)953 void addExits(
954 AppendOnlyList<c10::approx_time_t, N>& exits,
955 size_t python_tid) {
956 for (const auto i : exits) {
957 get_state<E>().exits_.push({time_converter_(i), python_tid});
958 }
959 }
960
run(std::vector<python_tracer::CompressedEvent> & enters)961 std::vector<std::shared_ptr<Result>> run(
962 std::vector<python_tracer::CompressedEvent>& enters) {
963 std::stable_sort(
964 enters.begin(), enters.end(), [](const auto a, const auto b) {
965 return a.enter_t_ < b.enter_t_;
966 });
967 std::vector<std::shared_ptr<Result>> out;
968 populate<EventType::PyCall>(enters, out);
969 populate<EventType::PyCCall>(enters, out);
970 return out;
971 }
972
973 private:
974 template <EventType E>
populate(std::vector<python_tracer::CompressedEvent> & enters,std::vector<std::shared_ptr<Result>> & out)975 void populate(
976 std::vector<python_tracer::CompressedEvent>& enters,
977 std::vector<std::shared_ptr<Result>>& out) {
978 using stack_t = std::vector<std::shared_ptr<Result>>;
979 const auto initial_size = out.size();
980 auto pop = [](stack_t& stack, c10::time_t t) {
981 TORCH_INTERNAL_ASSERT(!stack.empty(), "Python replay stack is empty.");
982 std::get<ExtraFields<E>>(stack.back()->extra_fields_).end_time_ns_ = t;
983 stack.pop_back();
984 };
985
986 ska::flat_hash_map<size_t, stack_t> stacks;
987 auto& state = get_state<E>();
988 for (const auto& enter : enters) {
989 auto fields_it = state.fields_.find(enter.key_);
990 if (fields_it != state.fields_.end()) {
991 while (!state.exits_.empty() &&
992 state.exits_.top().t_ < enter.enter_t_) {
993 auto& exit = state.exits_.top();
994 pop(stacks[exit.python_tid_], exit.t_);
995 state.exits_.pop();
996 }
997 out.push_back(Result::create(
998 enter.enter_t_,
999 enter.system_tid_,
1000 enter.kineto_info_,
1001 fields_it->second));
1002
1003 stacks[fields_it->second.python_tid_].push_back(out.back());
1004 }
1005 }
1006
1007 // Handle events which were still running when profiling ended.
1008 for (auto& i : stacks) {
1009 while (!i.second.empty()) {
1010 pop(i.second, end_time_);
1011 }
1012 }
1013
1014 // Assign system TIDs to start events based on the system TID of the next
1015 // observed event with the same Python TID.
1016 ska::flat_hash_map<size_t, std::pair<size_t, kineto::DeviceAndResource>>
1017 tid_map;
1018 auto it = out.rbegin();
1019 for (C10_UNUSED auto _ : c10::irange(initial_size, out.size())) {
1020 const auto python_tid =
1021 std::get<ExtraFields<E>>((*it)->extra_fields_).python_tid_;
1022 if ((*it)->start_tid_ == NoTID && SOFT_ASSERT(E == EventType::PyCall)) {
1023 const auto& tid_info =
1024 tid_map.insert({python_tid, {NoTID, kineto::DeviceAndResource()}})
1025 .first->second;
1026 (*it)->start_tid_ = tid_info.first;
1027 (*it)->kineto_info_ = tid_info.second;
1028 }
1029 tid_map[python_tid] = {(*it)->start_tid_, (*it)->kineto_info_};
1030 ++it;
1031 }
1032 }
1033
1034 template <EventType E>
1035 struct State {
1036 ska::flat_hash_map<TraceKey, ExtraFields<E>> fields_;
1037 std::priority_queue<Exit, std::vector<Exit>, std::greater<>> exits_;
1038 };
1039
1040 template <EventType E>
get_state()1041 auto& get_state() {
1042 return std::get < E == EventType::PyCall ? 0 : 1 > (state_);
1043 }
1044
1045 c10::time_t end_time_;
1046 std::function<c10::time_t(c10::approx_time_t)> time_converter_;
1047 std::tuple<State<EventType::PyCall>, State<EventType::PyCCall>> state_;
1048 };
1049
1050 struct PythonIDVisitor {
operator ()torch::profiler::impl::__anonaba953630411::PythonIDVisitor1051 void operator()(ExtraFields<EventType::PyCall>& py_call) {
1052 py_call.id_ = ++current_python_id_;
1053 if (py_call.module_.has_value()) {
1054 auto& m = py_call.module_;
1055 auto& module_ids = module_ids_[m->cls_];
1056 m->id_ = module_ids.insert({m->self_, module_ids.size()}).first->second;
1057 }
1058 }
1059
operator ()torch::profiler::impl::__anonaba953630411::PythonIDVisitor1060 void operator()(ExtraFields<EventType::PyCCall>& py_call) {
1061 py_call.id_ = ++current_python_id_;
1062 }
1063
1064 template <typename T>
operator ()torch::profiler::impl::__anonaba953630411::PythonIDVisitor1065 void operator()(T&) {}
1066
1067 size_t current_python_id_{0};
1068 ska::flat_hash_map<PyModuleCls, ska::flat_hash_map<PyModuleSelf, size_t>>
1069 module_ids_;
1070 };
1071
getEvents(std::function<c10::time_t (c10::approx_time_t)> time_converter,std::vector<python_tracer::CompressedEvent> & enters,c10::time_t end_time_ns)1072 std::vector<std::shared_ptr<Result>> PythonTracer::getEvents(
1073 std::function<c10::time_t(c10::approx_time_t)> time_converter,
1074 std::vector<python_tracer::CompressedEvent>& enters,
1075 c10::time_t end_time_ns) {
1076 value_cache_.trimPrefixes();
1077 PostProcess post_process(
1078 std::move(time_converter),
1079 thread_local_results_,
1080 value_cache_,
1081 end_time_ns);
1082 post_process.set_start_frames(start_frames_, enters);
1083 auto out = post_process.run(enters);
1084
1085 std::stable_sort(out.begin(), out.end(), [](const auto& a, const auto& b) {
1086 return a->start_time_ns_ < b->start_time_ns_;
1087 });
1088
1089 PythonIDVisitor id_visitor;
1090 for (auto& i : out) {
1091 std::visit(id_visitor, i->extra_fields_);
1092 }
1093
1094 return out;
1095 }
1096
1097 // ============================================================================
1098 // == API =====================================================================
1099 // ============================================================================
pyProfileFn(PyObject * obj,PyFrameObject * frame,int what,PyObject * arg)1100 int PythonTracer::pyProfileFn(
1101 PyObject* obj,
1102 PyFrameObject* frame,
1103 int what,
1104 PyObject* arg) {
1105 auto& local_results =
1106 *reinterpret_cast<TraceContext*>(obj)->thread_local_results_;
1107 switch (what) {
1108 case PyTrace_CALL:
1109 local_results.active_tracer_->recordPyCall(local_results, frame, false);
1110 break;
1111
1112 case PyTrace_C_CALL:
1113 local_results.active_tracer_->recordCCall(local_results, frame, arg);
1114 break;
1115
1116 case PyTrace_EXCEPTION:
1117 case PyTrace_RETURN:
1118 local_results.exit_times_.emplace_back(c10::getApproximateTime());
1119 break;
1120
1121 case PyTrace_C_EXCEPTION:
1122 case PyTrace_C_RETURN:
1123 local_results.c_exit_times_.emplace_back(c10::getApproximateTime());
1124 break;
1125 }
1126 return 0;
1127 }
1128
getTracer(torch::profiler::impl::RecordQueue * queue)1129 std::unique_ptr<python_tracer::PythonTracerBase> getTracer(
1130 torch::profiler::impl::RecordQueue* queue) {
1131 return std::make_unique<PythonTracer>(queue);
1132 }
1133 } // namespace
1134 } // namespace torch::profiler::impl
1135
1136 namespace torch::autograd::profiler::python_tracer {
1137
init()1138 void init() {
1139 pybind11::gil_scoped_acquire gil;
1140 TORCH_CHECK(PyType_Ready(&torch::profiler::impl::TraceContextType) == 0);
1141 torch::profiler::impl::python_tracer::registerTracer(
1142 &torch::profiler::impl::getTracer);
1143 }
1144 } // namespace torch::autograd::profiler::python_tracer
1145