xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/profiler_python.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/autograd/profiler_python.h>
2 
3 #include <atomic>
4 #include <cstdint>
5 #include <deque>
6 #include <limits>
7 #include <memory>
8 #include <queue>
9 #include <string>
10 #include <utility>
11 #include <vector>
12 
13 #include <Python.h>
14 #include <frameobject.h>
15 
16 #include <ATen/core/TensorBase.h>
17 #include <c10/macros/Macros.h>
18 #include <c10/util/ApproximateClock.h>
19 #include <c10/util/Exception.h>
20 #include <c10/util/Logging.h>
21 #include <c10/util/flat_hash_map.h>
22 #include <c10/util/irange.h>
23 #include <torch/csrc/autograd/python_variable.h>
24 #include <torch/csrc/profiler/collection.h>
25 #include <torch/csrc/profiler/containers.h>
26 #include <torch/csrc/profiler/orchestration/python_tracer.h>
27 #include <torch/csrc/profiler/util.h>
28 #include <torch/csrc/utils/pybind.h>
29 #include <torch/csrc/utils/python_compat.h>
30 #include <torch/csrc/utils/python_strings.h>
31 #include <optional>
32 
33 namespace py = pybind11;
34 
35 namespace torch::profiler::impl {
36 namespace {
37 enum CallType { PyCall = 0, PyModuleCall, PyCCall, PyOptimizerCall };
38 static constexpr size_t CallTypeSize = 4;
39 using no_ephemeral_t = std::tuple<>;
40 static constexpr uint64_t NoTID = std::numeric_limits<uint64_t>::max();
41 
42 // ============================================================================
43 // == Miscellaneous structs and utils =========================================
44 // ============================================================================
45 struct CodeLocation {
46   CodeLocation() = default;
CodeLocationtorch::profiler::impl::__anonaba953630111::CodeLocation47   explicit CodeLocation(PyFrameObject* frame)
48       : line_number_{PyFrame_GetLineNumber(frame)} {
49     auto code = THPCodeObjectPtr(PyFrame_GetCode(frame));
50     filename_ = THPUtils_unpackStringView(code->co_filename).data();
51     name_ = THPUtils_unpackStringView(code->co_name).data();
52   }
53 
operator ==torch::profiler::impl::__anonaba953630111::CodeLocation54   bool operator==(const CodeLocation& other) const {
55     return filename_ == other.filename_ && name_ == other.name_ &&
56         line_number_ == other.line_number_;
57   }
58 
59   const char* filename_{nullptr};
60   const char* name_{nullptr};
61   int line_number_{0};
62 };
63 
64 template <CallType C>
65 PyCodeObject* getCode();
66 
67 template <>
getCode()68 PyCodeObject* getCode<CallType::PyModuleCall>() {
69   static auto module_call_code = []() {
70     pybind11::gil_scoped_acquire gil;
71     auto res = py::module::import("torch.nn")
72                    .attr("Module")
73                    .attr("__call__")
74                    .attr("__code__")
75                    .ptr();
76     TORCH_INTERNAL_ASSERT(PyCode_Check(res));
77     return (PyCodeObject*)res;
78   }();
79   return module_call_code;
80 };
81 
82 template <>
getCode()83 PyCodeObject* getCode<CallType::PyOptimizerCall>() {
84   static auto optimizer_step_code = []() {
85     pybind11::gil_scoped_acquire gil;
86     auto res = py::module::import("torch.optim")
87                    .attr("Optimizer")
88                    .attr("_optimizer_step_code")
89                    .attr("__code__")
90                    .ptr();
91     TORCH_INTERNAL_ASSERT(PyCode_Check(res));
92     return (PyCodeObject*)res;
93   }();
94   return optimizer_step_code;
95 };
96 
97 } // namespace
98 } // namespace torch::profiler::impl
99 
100 template <>
101 struct std::hash<torch::profiler::impl::CodeLocation> {
operator ()std::hash102   size_t operator()(const torch::profiler::impl::CodeLocation& x) {
103     return c10::get_hash(x.filename_, x.name_, x.line_number_);
104   }
105 };
106 
107 namespace torch::profiler::impl {
108 namespace {
109 // ============================================================================
110 // == CallTypeHelper: Tools for generic programming on specializations. =======
111 // ============================================================================
112 template <template <CallType> class ClassT>
113 class CallTypeHelper final {
114  private:
115   static_assert(
116       CallType::PyCall == 0,
117       "CallTypeHelper uses integer math which depends on a zero start.");
118   static constexpr size_t End = CallTypeSize;
119 
120   template <size_t... I>
121   static constexpr std::tuple<ClassT<(CallType)I>...> make_tuple_impl(
122       std::index_sequence<I...>);
123 
124   template <size_t C, typename T, typename FunctorT, typename... Args>
map(T & t,FunctorT & f,Args &&...args)125   static void map(T& t, FunctorT& f, Args&&... args) {
126     f(std::get<C>(t), args...);
127     if constexpr (C + 1 < End) {
128       map<C + 1>(t, f, std::forward<Args>(args)...);
129     }
130   }
131 
132  public:
133   using tuple_type = decltype(make_tuple_impl(std::make_index_sequence<End>{}));
134 
135   template <typename FunctorT, typename... Args>
map(tuple_type & t,FunctorT & f,Args &&...args)136   static void map(tuple_type& t, FunctorT& f, Args&&... args) {
137     map<0>(t, f, std::forward<Args>(args)...);
138   }
139 };
140 
141 // ============================================================================
142 // == Event type definitions. =================================================
143 // ============================================================================
144 // When we are tracing a Python program, the general procedure is to record
145 // every time we enter or exit a function and later replay these events during
146 // post processing. Thus, during the profiling phase we want to do the MINIMAL
147 // amount of work to capture all of the information that we need; otherwise we
148 // will distort the profile. (While we don't wish to be terribly inefficient
149 // during post processing, we are willing to do extra fixup work in post if it
150 // reduces overhead in the profiling phase.)
151 //
152 // When the tracer first enters a frame, it constructs a CallKey for that
153 // location. The contents of the key vary by context. For a python function
154 // the key is the (PyCodeObject*, int) pair that defines the bytecode of the
155 // function. For an `nn.Module` the key is a (non-owning) pointer to `self`.
156 // For a bound C function it is a (non-owning) pointer to the bound function.
157 // A CallKey should be small, inexpensive, and POD.
158 //
159 // We then collect a CallKey<CallType::PyCall> for the calling frame for better
160 // source tracking. This pair is a `Callsite`, and serves as a first level key
161 // during tracing. We lookup the Callsite in a thread local cache which maps
162 // Callsite to a unique integer `TraceKey`. On a cache hit, we simply store the
163 // TraceKey and return. On a cache miss, we use a global value cache to store
164 // whatever fields we need from the two CallKeys, generate a new TraceKey, and
165 // update the local cache.
166 //
167 // During post processing we:
168 //   1) Determine the type represented by a TraceKey by checking which
169 //      sub-cache it appears in in the thread local cache.
170 //   2) Look up the pair of CallKeys from the thread local cache.
171 //   3) Look up the expanded values of each CallKey from the global value cache.
172 //
173 // To add a new event type to the cache:
174 //   1) Add an entry to the `CallType` enum.
175 //   2) Add a specialization of Config which defined key_t, ephemeral_t and
176 //      cache_t.
177 //   3) Add a specialization of ValueCache::store and ValueCache::load.
178 //
179 // -------------------------
180 // -- Ephemeral arguments --
181 // -------------------------
182 // The value cache mechanism assumes that `key_t` is enough to specify the
183 // correct value. However it may not be possible to materialize a value using
184 // only an instance of `key_t`. As a result, the cache also accepts "ephemeral"
185 // inputs which can be used to populate the value cache. Ephemeral inputs come
186 // with two caveats:
187 //  1) They are NOT safe to save, and cannot be used after `ValueCache::store`.
188 //  2) They should be used to access data that is not expect to change from
189 //     call to call, such as the name of a function.
190 
191 template <CallType>
192 struct Config;
193 
194 template <>
195 struct Config<CallType::PyCall> {
196   using key_t = CodeLocation;
197   using ephemeral_t = no_ephemeral_t;
198   using cache_t = ska::flat_hash_map<key_t, PyFrameState>;
199   static constexpr EventType event_type = EventType::PyCall;
200 };
201 
202 template <typename Key, typename Cls, typename ParameterInfo>
203 struct ExtendedPyCallConfig {
204   using key_t = Key;
205   using cls_t = Cls;
206   using ephemeral_t = PyFrameObject*;
207 
208   struct ClsAndParameters {
209     cls_t cls_;
210     std::vector<ParameterInfo> parameters_;
211   };
212 
213   struct Cache {
214     // `nn.Module.forward` or `optim.Optimizer._optimizer_step_code`
215     std::optional<CodeLocation> location_;
216     ska::flat_hash_map<key_t, ClsAndParameters> cls_and_parameters_;
217     ska::flat_hash_map<cls_t, at::StringView> cls_names_;
218   };
219   using cache_t = Cache;
220 
221   static constexpr EventType event_type = EventType::PyCall;
222 };
223 
224 template <>
225 struct Config<CallType::PyModuleCall> : ExtendedPyCallConfig<
226                                             PyModuleSelf,
227                                             PyModuleCls,
228                                             NNModuleInfo::ParameterInfo> {};
229 
230 template <>
231 struct Config<CallType::PyOptimizerCall> : ExtendedPyCallConfig<
232                                                PyOptimizerSelf,
233                                                PyOptimizerCls,
234                                                OptimizerInfo::ParameterInfo> {};
235 
236 template <>
237 struct Config<CallType::PyCCall> {
238   using key_t = PyMethod;
239   using ephemeral_t = PyObject*;
240   using cache_t = ska::flat_hash_map<key_t, at::StringView>;
241   static constexpr EventType event_type = EventType::PyCCall;
242 };
243 
244 // ============================================================================
245 // == Callsite & ValueCache: Storage during profiling =========================
246 // ============================================================================
247 template <CallType C>
248 class Callsite {
249  public:
250   static constexpr CallType call_type = C;
251   using key_t = typename Config<C>::key_t;
252 
253   static_assert(
254       std::is_trivially_copyable_v<key_t>,
255       "Key should be trivial, as it is passed by value.");
256 
257   template <typename U>
Callsite(U value,PyFrameObject * f_back)258   Callsite(U value, PyFrameObject* f_back) : value_(value), caller_(f_back) {}
259 
operator ==(const Callsite<C> & other) const260   bool operator==(const Callsite<C>& other) const {
261     return value_ == other.value_ && caller_ == other.caller_;
262   }
263 
264   key_t value_;
265   Config<CallType::PyCall>::key_t caller_;
266 };
267 
268 // ============================================================================
269 // == Type specific store and load implementations. ===========================
270 // ============================================================================
271 using PyCallKey = Config<CallType::PyCall>::key_t;
272 using PyModuleCallKey = Config<CallType::PyModuleCall>::key_t;
273 using PyCCallKey = Config<CallType::PyCCall>::key_t;
274 using PyOptimizerCallKey = Config<CallType::PyOptimizerCall>::key_t;
275 
276 class ValueCache {
277  public:
278   ValueCache() = default;
279   ValueCache(const ValueCache&) = delete;
280 
281   template <CallType C>
282   void store(const typename Config<C>::key_t&, typename Config<C>::ephemeral_t);
283 
284   template <CallType C>
load(const Callsite<C> & callsite,size_t python_tid) const285   auto load(const Callsite<C>& callsite, size_t python_tid) const {
286     auto caller = load<CallType::PyCall>(callsite.caller_);
287     TORCH_INTERNAL_ASSERT(!caller.module_info_.has_value());
288     return ExtraFields<Config<C>::event_type>{
289         /*end_time_ns=*/std::numeric_limits<c10::time_t>::min(),
290         python_tid,
291         caller.frame_state_,
292         load<C>(callsite.value_)};
293   }
294 
295   std::optional<TensorMetadata> recordIfTensor(py::handle p);
296   std::vector<std::pair<std::string, TensorMetadata>> unpackTensorMap(
297       const py::dict& tensor_map);
298   void trimPrefixes();
299 
300  private:
301   template <CallType C>
302   typename ExtraFields<Config<C>::event_type>::args_t load(
303       const typename Config<C>::key_t&) const;
304 
305   template <CallType C>
306   using State = typename Config<C>::cache_t;
307 
308   CallTypeHelper<State>::tuple_type state_;
309 };
310 
311 template <CallType C>
set_class(ValueCache * value_cache,typename Config<C>::cache_t & cache,const typename Config<C>::key_t & key,const typename Config<C>::ephemeral_t & frame)312 typename Config<C>::cls_t set_class(
313     ValueCache* value_cache,
314     typename Config<C>::cache_t& cache,
315     const typename Config<C>::key_t& key,
316     const typename Config<C>::ephemeral_t& frame) {
317   if (C10_UNLIKELY(!cache.location_.has_value())) {
318     auto code = THPCodeObjectPtr(PyFrame_GetCode(frame));
319     TORCH_INTERNAL_ASSERT(code.get() == getCode<C>());
320     cache.location_ = PyCallKey(frame);
321     value_cache->store<CallType::PyCall>(*cache.location_, no_ephemeral_t());
322   }
323 
324   auto cls_handle = py::handle((PyObject*)key).attr("__class__");
325   auto cls = typename Config<C>::cls_t(cls_handle.ptr());
326   if (cache.cls_names_.find(cls) == cache.cls_names_.end()) {
327     cache.cls_names_[cls] =
328         at::StringView(py::str(cls_handle.attr("__name__")));
329   }
330   return cls;
331 }
332 
toTensorMetadata(PyObject * self)333 TensorMetadata toTensorMetadata(PyObject* self) {
334   TORCH_INTERNAL_ASSERT(THPVariable_CheckExact(self));
335   const auto& t = THPVariable_Unpack(self);
336   RawTensorMetadata m{t};
337   return TensorMetadata{
338       m,
339       t.sizes().vec(),
340       m.layout_ == at::kStrided ? t.strides().vec() : std::vector<int64_t>()};
341 }
342 
recordIfTensor(py::handle p)343 std::optional<TensorMetadata> ValueCache::recordIfTensor(py::handle p) {
344   return THPVariable_CheckExact(p.ptr())
345       ? std::optional<TensorMetadata>{toTensorMetadata(p.ptr())}
346       : std::nullopt;
347 }
348 
unpackTensorMap(const py::dict & tensor_map)349 std::vector<std::pair<std::string, TensorMetadata>> ValueCache::unpackTensorMap(
350     const py::dict& tensor_map) {
351   std::vector<std::pair<std::string, TensorMetadata>> out;
352   for (auto& it : tensor_map) {
353     auto* value = it.second.ptr();
354     if (py::isinstance<py::str>(it.first) && THPVariable_CheckExact(value)) {
355       out.emplace_back(
356           py::cast<std::string>(it.first), toTensorMetadata(value));
357     }
358   }
359   return out;
360 }
361 
362 template <>
store(const PyCallKey & key,no_ephemeral_t)363 void ValueCache::store<CallType::PyCall>(const PyCallKey& key, no_ephemeral_t) {
364   auto& locations = std::get<CallType::PyCall>(state_);
365   if (C10_UNLIKELY(locations.find(key) == locations.end())) {
366     locations[key] = {
367         key.line_number_,
368         at::StringView(key.filename_),
369         at::StringView(key.name_)};
370   }
371 }
372 
373 template <>
load(const PyCallKey & key) const374 ExtraFields<EventType::PyCall>::args_t ValueCache::load<CallType::PyCall>(
375     const PyCallKey& key) const {
376   return {std::get<CallType::PyCall>(state_).at(key), std::nullopt};
377 }
378 
379 template <>
store(const PyModuleCallKey & key,Config<CallType::PyModuleCall>::ephemeral_t frame)380 void ValueCache::store<CallType::PyModuleCall>(
381     const PyModuleCallKey& key,
382     Config<CallType::PyModuleCall>::ephemeral_t frame) {
383   auto& cache = std::get<CallType::PyModuleCall>(state_);
384   if (C10_UNLIKELY(
385           cache.cls_and_parameters_.find(key) ==
386           cache.cls_and_parameters_.end())) {
387     auto cls = set_class<CallType::PyModuleCall>(this, cache, key, frame);
388 
389     py::dict params = py::handle((PyObject*)key).attr("_parameters");
390     std::vector<NNModuleInfo::ParameterInfo> params_;
391     for (auto& it : params) {
392       auto* p = it.second.ptr();
393       if (py::isinstance<py::str>(it.first) && THPVariable_CheckExact(p)) {
394         params_.push_back(
395             {it.first.cast<std::string>(),
396              toTensorMetadata(p),
397              recordIfTensor(py::getattr(it.second, "grad", py::none()))});
398       }
399     }
400     cache.cls_and_parameters_[key] = {cls, std::move(params_)};
401   }
402 }
403 
404 template <>
load(const PyModuleCallKey & key) const405 ExtraFields<EventType::PyCall>::args_t ValueCache::load<CallType::PyModuleCall>(
406     const PyModuleCallKey& key) const {
407   auto& cache = std::get<CallType::PyModuleCall>(state_);
408   TORCH_INTERNAL_ASSERT(cache.location_.has_value());
409   const auto& cls_and_parameters = cache.cls_and_parameters_.at(key);
410   const auto& cls = cls_and_parameters.cls_;
411   NNModuleInfo info{
412       key, cls, cache.cls_names_.at(cls), cls_and_parameters.parameters_};
413   return {
414       /*frame_state_=*/std::get<CallType::PyCall>(state_).at(*cache.location_),
415       /*module_info_=*/std::move(info),
416       /*optimizer_info_=*/std::nullopt};
417 }
418 
419 template <>
store(const PyOptimizerCallKey & key,Config<CallType::PyOptimizerCall>::ephemeral_t frame)420 void ValueCache::store<CallType::PyOptimizerCall>(
421     const PyOptimizerCallKey& key,
422     Config<CallType::PyOptimizerCall>::ephemeral_t frame) {
423   auto& cache = std::get<CallType::PyOptimizerCall>(state_);
424   if (C10_UNLIKELY(
425           cache.cls_and_parameters_.find(key) ==
426           cache.cls_and_parameters_.end())) {
427     auto cls = set_class<CallType::PyOptimizerCall>(this, cache, key, frame);
428     const py::handle self{(PyObject*)key};
429     std::vector<OptimizerInfo::ParameterInfo> params;
430 
431     for (const auto& i : (py::list)self.attr("param_groups")) {
432       for (auto& param : py::cast<py::dict>(i).attr("get")("params")) {
433         if (THPVariable_CheckExact(param.ptr())) {
434           // While `self.state` is permitted to store data in an arbitrary way,
435           // all generic optimizers (SGD, Adam, etc) use param as the key since
436           // the state in question is tied to particular parameters. We can
437           // relax this assumption if the need arises.
438           params.push_back(
439               {toTensorMetadata(param.ptr()),
440                recordIfTensor(py::getattr(param, "grad", py::none())),
441                unpackTensorMap(py::cast<py::dict>(self.attr("state"))
442                                    .attr("get")(param, py::dict()))});
443         }
444       }
445     }
446 
447     cache.cls_and_parameters_[key] = {cls, std::move(params)};
448   }
449 }
450 
451 template <>
load(const PyOptimizerCallKey & key) const452 ExtraFields<EventType::PyCall>::args_t ValueCache::load<
453     CallType::PyOptimizerCall>(const PyOptimizerCallKey& key) const {
454   auto& cache = std::get<CallType::PyOptimizerCall>(state_);
455   const auto& cls_and_parameters = cache.cls_and_parameters_.at(key);
456   auto cls = cls_and_parameters.cls_;
457   OptimizerInfo info{
458       key, cls, cache.cls_names_.at(cls), cls_and_parameters.parameters_};
459   return {
460       // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
461       /*frame_state_=*/std::get<CallType::PyCall>(state_).at(*cache.location_),
462       /*module_info_=*/std::nullopt,
463       /*optimizer_info_=*/std::move(info)};
464 }
465 
466 template <>
store(const PyCCallKey & key,Config<CallType::PyCCall>::ephemeral_t arg)467 void ValueCache::store<CallType::PyCCall>(
468     const PyCCallKey& key,
469     Config<CallType::PyCCall>::ephemeral_t arg) {
470   auto& names = std::get<CallType::PyCCall>(state_);
471   if (C10_UNLIKELY(names.find(key) == names.end())) {
472     names[key] = at::StringView(py::repr(arg));
473   }
474 }
475 
476 template <>
load(const PyCCallKey & key) const477 ExtraFields<EventType::PyCCall>::args_t ValueCache::load<CallType::PyCCall>(
478     const PyCCallKey& key) const {
479   return std::get<CallType::PyCCall>(state_).at(key);
480 }
481 
482 // TODO: Use re2.
trimPrefixes()483 void ValueCache::trimPrefixes() {
484   static const auto prefixes = []() {
485     pybind11::gil_scoped_acquire gil;
486     return py::module::import("torch.profiler.python_tracer")
487         .attr("_prefix_regex")()
488         .cast<std::vector<std::string>>();
489   }();
490 
491   for (auto& it : std::get<CallType::PyCall>(state_)) {
492     std::string filename = it.second.filename_.str();
493     for (const auto& p : prefixes) {
494       if (filename.compare(0, p.size(), p) == 0) {
495         filename.erase(0, p.size());
496         it.second.filename_ = at::StringView(filename);
497         break;
498       }
499     }
500   }
501 }
502 
503 // ============================================================================
504 // == TraceKey cache ==========================================================
505 // ============================================================================
506 using python_tracer::TraceKey;
507 
nextKey()508 TraceKey nextKey() {
509   static std::atomic<uint64_t> key{0};
510   return TraceKey{++key};
511 }
512 
513 template <CallType C>
514 struct TraceKeyCacheState {
515   struct Hash {
operator ()torch::profiler::impl::__anonaba953630411::TraceKeyCacheState::Hash516     size_t operator()(const Callsite<C>& key) {
517       return c10::get_hash(key.value_, key.caller_);
518     }
519   };
520 
interntorch::profiler::impl::__anonaba953630411::TraceKeyCacheState521   TraceKey intern(
522       Callsite<C> callsite,
523       typename Config<C>::ephemeral_t ephemeral,
524       ValueCache& value_cache) {
525     auto it = state_.find(callsite);
526     if (C10_UNLIKELY(it == state_.end())) {
527       value_cache.store<C>(callsite.value_, ephemeral);
528       value_cache.store<CallType::PyCall>(callsite.caller_, no_ephemeral_t());
529       it = state_.insert({callsite, nextKey()}).first;
530     }
531     return it->second;
532   }
533 
lookuptorch::profiler::impl::__anonaba953630411::TraceKeyCacheState534   auto lookup(Callsite<C>& callsite, ValueCache& value_cache) const {
535     return std::make_pair(
536         value_cache.load<C>(callsite.value_),
537         value_cache.load<CallType::PyCall>(callsite.caller_));
538   }
539 
540   ska::flat_hash_map<Callsite<C>, TraceKey, Hash> state_;
541 };
542 
543 // ============================================================================
544 // == Core CPython data types =================================================
545 // ============================================================================
546 // PyObject that allows different threads to record events without colliding.
547 // It is passed as the second argument when enabling tracing via
548 // `PyEval_SetProfile`.
549 struct ThreadLocalResults;
550 struct TraceContext {
551   PyObject_HEAD;
552   ThreadLocalResults* thread_local_results_;
553 };
554 
555 // CPython boilerplate to define `TraceContext` as a proper python object.
556 static PyTypeObject TraceContextType = {
557     PyVarObject_HEAD_INIT(nullptr, 0) "TraceContext", /* tp_name */
558     sizeof(TraceContext), /* tp_basicsize */
559     0, /* tp_itemsize */
560     nullptr, /* tp_dealloc */
561     0,
562     /* tp_vectorcall_offset */
563     nullptr, /* tp_getattr */
564     nullptr, /* tp_setattr */
565     nullptr, /* tp_reserved */
566     nullptr, /* tp_repr */
567     nullptr, /* tp_as_number */
568     nullptr, /* tp_as_sequence */
569     nullptr, /* tp_as_mapping */
570     nullptr, /* tp_hash  */
571     nullptr, /* tp_call */
572     nullptr, /* tp_str */
573     nullptr, /* tp_getattro */
574     nullptr, /* tp_setattro */
575     nullptr, /* tp_as_buffer */
576     Py_TPFLAGS_DEFAULT, /* tp_flags */
577     "Python tracer TLS", /* tp_doc */
578     nullptr, /* tp_traverse */
579     nullptr, /* tp_clear */
580     nullptr, /* tp_richcompare */
581     0, /* tp_weaklistoffset */
582     nullptr, /* tp_iter */
583     nullptr, /* tp_iternext */
584     nullptr, /* tp_methods */
585     nullptr, /* tp_members */
586     nullptr, /* tp_getset */
587     nullptr, /* tp_base */
588     nullptr, /* tp_dict */
589     nullptr, /* tp_descr_get */
590     nullptr, /* tp_descr_set */
591     0, /* tp_dictoffset */
592     nullptr, /* tp_init */
593     nullptr, /* tp_alloc */
594     PyType_GenericNew, /* tp_new */
595     nullptr /* tp_free */
596 };
597 
598 class gil_and_restore_thread {
599  public:
gil_and_restore_thread()600   gil_and_restore_thread()
601       : gil_(), initial_thread_state_{PyThreadState_Get()} {}
~gil_and_restore_thread()602   ~gil_and_restore_thread() {
603     PyThreadState_Swap(initial_thread_state_);
604 
605     // `gil_scoped_acquire` is a bit fragile in on-demand mode:
606     // https://github.com/pytorch/pytorch/pull/91684#issuecomment-1413154458
607     if (!Py_IsInitialized()) {
608       gil_.disarm();
609     }
610   }
611 
initial_thread_state() const612   PyThreadState* initial_thread_state() const {
613     return initial_thread_state_;
614   }
615 
616  private:
617   pybind11::gil_scoped_acquire gil_;
618   PyThreadState* initial_thread_state_;
619 };
620 
621 // ============================================================================
622 // == Thread local cache ======================================================
623 // ============================================================================
624 class PythonTracer;
625 struct ThreadLocalResults {
ThreadLocalResultstorch::profiler::impl::__anonaba953630411::ThreadLocalResults626   ThreadLocalResults(
627       PyThreadState* thread_state,
628       ValueCache* value_cache,
629       PythonTracer* active_tracer)
630       : thread_state_{thread_state},
631         ctx_{(TraceContext*)TraceContextType.tp_alloc(&TraceContextType, 0)},
632         value_cache_{value_cache},
633         active_tracer_{active_tracer} {
634     ctx_->thread_local_results_ = this;
635   }
636 
637   ThreadLocalResults() = delete;
638   ThreadLocalResults(const ThreadLocalResults&) = delete;
639   ThreadLocalResults(ThreadLocalResults&&) = delete;
640   ThreadLocalResults& operator=(const ThreadLocalResults&) = delete;
641   ThreadLocalResults& operator=(const ThreadLocalResults&&) = delete;
642 
~ThreadLocalResultstorch::profiler::impl::__anonaba953630411::ThreadLocalResults643   ~ThreadLocalResults() {
644     Py_DECREF((PyObject*)ctx_);
645   }
646 
647   template <CallType C, EventType E, typename Ephemeral, typename... Args>
interntorch::profiler::impl::__anonaba953630411::ThreadLocalResults648   TraceKey intern(Ephemeral ephemeral, Args... args) {
649     static_assert(
650         Config<C>::event_type == E,
651         "ThreadLocalResults.intern called from the wrong typed context.");
652     auto callsite = Callsite<C>(std::forward<Args>(args)...);
653     return std::get<C>(trace_keys_).intern(callsite, ephemeral, *value_cache_);
654   }
655 
656   static constexpr size_t BLOCK_SIZE = 1024;
657 
658   PyThreadState* thread_state_;
659   TraceContext* ctx_;
660   ValueCache* value_cache_;
661   PythonTracer* active_tracer_;
662   CallTypeHelper<TraceKeyCacheState>::tuple_type trace_keys_;
663   AppendOnlyList<c10::approx_time_t, BLOCK_SIZE> exit_times_;
664   AppendOnlyList<c10::approx_time_t, BLOCK_SIZE> c_exit_times_;
665 };
666 
667 // ============================================================================
668 // == Tracing implementation ==================================================
669 // ============================================================================
670 class PythonTracer final : public python_tracer::PythonTracerBase {
671  public:
672   PythonTracer(torch::profiler::impl::RecordQueue* queue);
673   // NOLINTNEXTLINE(bugprone-exception-escape)
674   ~PythonTracer() override;
675 
676   static int pyProfileFn(
677       PyObject* obj,
678       PyFrameObject* frame,
679       int what,
680       PyObject* arg);
681 
682   void stop() override;
683   void restart() override;
684   std::vector<std::shared_ptr<Result>> getEvents(
685       std::function<c10::time_t(c10::approx_time_t)> time_converter,
686       std::vector<python_tracer::CompressedEvent>& enters,
687       c10::time_t end_time_ns) override;
688 
689   struct StartFrame {
690     TraceKey trace_key_;
691     c10::approx_time_t start_time{};
692   };
693 
694  private:
695   void recordPyCall(
696       ThreadLocalResults& tls,
697       PyFrameObject* frame,
698       bool is_startup_frame);
699 
700   void recordCCall(
701       ThreadLocalResults& tls,
702       PyFrameObject* frame,
703       PyObject* arg);
704 
705   const std::vector<PyThreadState*> interpreterThreads() const;
706 
707   std::atomic<bool> active_lock_{false};
708   bool active_{false};
709 
710   torch::profiler::impl::RecordQueue* queue_;
711   PyInterpreterState* interpreter_{nullptr};
712   PyCodeObject* module_call_code_;
713   PyCodeObject* optimizer_hook_;
714 
715   std::vector<StartFrame> start_frames_;
716   std::deque<ThreadLocalResults> thread_local_results_;
717   ValueCache value_cache_;
718 };
719 
interpreterThreads() const720 const std::vector<PyThreadState*> PythonTracer::interpreterThreads() const {
721   pybind11::gil_scoped_acquire gil;
722   std::vector<PyThreadState*> out;
723   if (SOFT_ASSERT(interpreter_)) {
724     auto* thread_state = PyInterpreterState_ThreadHead(interpreter_);
725     while (thread_state != nullptr) {
726       out.push_back(thread_state);
727       thread_state = PyThreadState_Next(thread_state);
728     }
729   }
730   return out;
731 }
732 
PythonTracer(torch::profiler::impl::RecordQueue * queue)733 PythonTracer::PythonTracer(torch::profiler::impl::RecordQueue* queue)
734     : queue_(queue),
735 
736       module_call_code_(getCode<CallType::PyModuleCall>()),
737       optimizer_hook_(getCode<CallType::PyOptimizerCall>()) {
738   TORCH_CHECK(queue_ != nullptr);
739 
740   bool expected{false};
741   active_ = active_lock_.compare_exchange_strong(expected, true);
742   if (!active_) {
743     TORCH_WARN(
744         "There is already an active Python tracer. "
745         "Refusing to register profile functions.");
746     return;
747   }
748 
749   gil_and_restore_thread gil;
750   interpreter_ = PyInterpreterState_Get();
751 
752   if (!gil.initial_thread_state()) {
753     TORCH_WARN("PyThreadState_Get returned NULL");
754     return;
755   }
756 
757   // Register the tracer in each thread.
758   for (const auto thread_state : interpreterThreads()) {
759     PyThreadState_Swap(thread_state);
760 
761     thread_local_results_.emplace_back(thread_state, &value_cache_, this);
762     auto* ctx = thread_local_results_.back().ctx_;
763 
764     // When we begin profiling there are already frames on the Python
765     // interpreter stack. To ensure a complete trace, we must push calls
766     // to all the prior frames onto our event stack. (We stop at depth=128)
767 
768     std::vector<THPFrameObjectPtr> current_stack;
769     auto frame = PyEval_GetFrame();
770     Py_XINCREF(frame);
771 
772     size_t depth = 0; // Make sure we can't infinite loop.
773     while (frame != nullptr) {
774       current_stack.emplace_back(frame);
775       if (++depth == 128) {
776         break;
777       }
778 
779       // NB: `PyFrame_GetBack` returns a strong reference.
780       frame = PyFrame_GetBack(frame);
781     }
782 
783     for (auto it = current_stack.rbegin(); it != current_stack.rend(); it++) {
784       recordPyCall(thread_local_results_.back(), it->get(), true);
785       auto frame_refcount = Py_REFCNT(it->get());
786 
787       // We hold one reference in `current_stack`, and the interpreter holds
788       // another.
789       TORCH_INTERNAL_ASSERT(frame_refcount >= 2, frame_refcount);
790     }
791 
792     // Note:
793     //   This profile will not compose with other CPython profilers, and
794     //   cannot be round tripped via `sys.settrace(sys.gettrace())`
795     PyEval_SetProfile(PythonTracer::pyProfileFn, (PyObject*)ctx);
796   }
797 };
798 
stop()799 void PythonTracer::stop() {
800   gil_and_restore_thread gil;
801   if (active_) {
802     for (const auto thread_state : interpreterThreads()) {
803       if (thread_state->c_profilefunc == &PythonTracer::pyProfileFn) {
804         PyThreadState_Swap(thread_state);
805         PyEval_SetProfile(nullptr, nullptr);
806       }
807     }
808 
809     auto lock_returned = active_lock_.compare_exchange_strong(active_, false);
810     active_ = false;
811     SOFT_ASSERT(lock_returned, "Failed to return python tracer lock.");
812   }
813 }
814 
restart()815 void PythonTracer::restart() {
816   gil_and_restore_thread gil;
817   active_ = active_lock_.compare_exchange_strong(active_, true);
818   if (!active_) {
819     TORCH_WARN(
820         "There is already an active Python tracer. "
821         "Refusing to register profile functions.");
822     return;
823   }
824   int cur_thread = 0;
825   for (const auto thread_state : interpreterThreads()) {
826     if (thread_state->c_profilefunc == nullptr) {
827       auto* ctx = thread_local_results_[cur_thread].ctx_;
828       PyThreadState_Swap(thread_state);
829       PyEval_SetProfile(PythonTracer::pyProfileFn, (PyObject*)ctx);
830     }
831   }
832 }
833 
834 // NOLINTNEXTLINE(bugprone-exception-escape)
~PythonTracer()835 PythonTracer::~PythonTracer() {
836   if (active_) {
837     TORCH_WARN("`PythonTracer::stop()` was not called.");
838     stop();
839   }
840 }
841 
recordPyCall(ThreadLocalResults & tls,PyFrameObject * frame,bool is_startup_frame)842 void PythonTracer::recordPyCall(
843     ThreadLocalResults& tls,
844     PyFrameObject* frame,
845     bool is_startup_frame) {
846   static constexpr auto E = EventType::PyCall;
847   const auto key = [&]() -> TraceKey {
848     auto code = THPCodeObjectPtr(PyFrame_GetCode(frame));
849     if (code.get() == module_call_code_) {
850       // By default, CPython stores locals in a "fast" format, with an array
851       // of names and an array of values. Consequently, frame->f_locals is
852       // NULL since the interpreter has no need to populate it.
853       //
854       // If these arrays were part of the public API then we could very
855       // quickly access `self`. Unfortunately they are not, and moreover are
856       // not stable across versions. As a result, we are forced to call
857       // `PyFrame_FastToLocals` which forces the interpreter to materialize
858       // the full dict of locals.
859       auto locals = THPObjectPtr(PyFrame_GetLocals(frame));
860       auto self = THPObjectPtr(PyDict_GetItemString(locals, "self"));
861       Py_INCREF(self.get());
862       auto back = THPFrameObjectPtr(PyFrame_GetBack(frame));
863       TORCH_INTERNAL_ASSERT(back != nullptr);
864       return tls.intern<CallType::PyModuleCall, E>(
865           frame, self.get(), back.get());
866     } else if (code.get() == optimizer_hook_) {
867       auto locals = THPObjectPtr(PyFrame_GetLocals(frame));
868       auto self = THPObjectPtr(PyDict_GetItemString(locals, "self"));
869       Py_INCREF(self.get());
870       auto back = THPFrameObjectPtr(PyFrame_GetBack(frame));
871       TORCH_INTERNAL_ASSERT(back != nullptr);
872       return tls.intern<CallType::PyOptimizerCall, E>(
873           frame, self.get(), back.get());
874     } else {
875       auto back = THPFrameObjectPtr(PyFrame_GetBack(frame));
876       auto f_back = (back.get() != nullptr) ? back.get() : frame;
877       return tls.intern<CallType::PyCall, E>(no_ephemeral_t(), frame, f_back);
878     }
879   }();
880   const auto time = c10::getApproximateTime();
881   is_startup_frame ? start_frames_.push_back({key, time})
882                    : queue_->getSubqueue()->emplace_py_call(key, time);
883 }
884 
recordCCall(ThreadLocalResults & tls,PyFrameObject * frame,PyObject * arg)885 void PythonTracer::recordCCall(
886     ThreadLocalResults& tls,
887     PyFrameObject* frame,
888     PyObject* arg) {
889   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(PyCFunction_Check(arg));
890   auto fn = reinterpret_cast<PyCFunctionObject*>(arg);
891 
892   // NB: For C calls a new frame is not created, so we use `frame` rather than
893   //     `frame->f_back`.
894   auto key = tls.intern<CallType::PyCCall, EventType::PyCCall>(
895       arg, (void*)(fn->m_ml), frame);
896   queue_->getSubqueue()->emplace_py_call(key, c10::getApproximateTime());
897 }
898 
899 // ============================================================================
900 // == Post processing =========================================================
901 // ============================================================================
902 struct Exit {
operator >torch::profiler::impl::__anonaba953630411::Exit903   bool operator>(const Exit& other) const {
904     return t_ > other.t_;
905   }
906 
907   c10::time_t t_;
908   size_t python_tid_;
909 };
910 
911 class PostProcess {
912  public:
PostProcess(std::function<c10::time_t (c10::approx_time_t)> time_converter,std::deque<ThreadLocalResults> & tls,const ValueCache & value_cache,c10::time_t end_time_ns)913   PostProcess(
914       std::function<c10::time_t(c10::approx_time_t)> time_converter,
915       std::deque<ThreadLocalResults>& tls,
916       const ValueCache& value_cache,
917       c10::time_t end_time_ns)
918       : end_time_{end_time_ns}, time_converter_{std::move(time_converter)} {
919     for (size_t python_tid : c10::irange(tls.size())) {
920       CallTypeHelper<TraceKeyCacheState>::map(
921           tls[python_tid].trace_keys_, *this, value_cache, python_tid);
922 
923       addExits<EventType::PyCall>(tls[python_tid].exit_times_, python_tid);
924       addExits<EventType::PyCCall>(tls[python_tid].c_exit_times_, python_tid);
925     }
926   }
927 
set_start_frames(const std::vector<PythonTracer::StartFrame> & start_frames,std::vector<python_tracer::CompressedEvent> & enters)928   void set_start_frames(
929       const std::vector<PythonTracer::StartFrame>& start_frames,
930       std::vector<python_tracer::CompressedEvent>& enters) {
931     for (const auto& frame : start_frames) {
932       enters.push_back(
933           {frame.trace_key_,
934            NoTID, // Allows us to detect unhandled start frames
935            {},
936            time_converter_(frame.start_time)});
937     }
938   }
939 
940   template <CallType C>
operator ()(const TraceKeyCacheState<C> & trace_cache,const ValueCache & value_cache,size_t python_tid)941   void operator()(
942       const TraceKeyCacheState<C>& trace_cache,
943       const ValueCache& value_cache,
944       size_t python_tid) {
945     for (const auto& it : trace_cache.state_) {
946       const auto inserted = get_state<Config<C>::event_type>().fields_.insert(
947           {it.second, value_cache.load(it.first, python_tid)});
948       TORCH_INTERNAL_ASSERT(inserted.second, "Duplicate key: ", it.second);
949     }
950   }
951 
952   template <EventType E, size_t N>
addExits(AppendOnlyList<c10::approx_time_t,N> & exits,size_t python_tid)953   void addExits(
954       AppendOnlyList<c10::approx_time_t, N>& exits,
955       size_t python_tid) {
956     for (const auto i : exits) {
957       get_state<E>().exits_.push({time_converter_(i), python_tid});
958     }
959   }
960 
run(std::vector<python_tracer::CompressedEvent> & enters)961   std::vector<std::shared_ptr<Result>> run(
962       std::vector<python_tracer::CompressedEvent>& enters) {
963     std::stable_sort(
964         enters.begin(), enters.end(), [](const auto a, const auto b) {
965           return a.enter_t_ < b.enter_t_;
966         });
967     std::vector<std::shared_ptr<Result>> out;
968     populate<EventType::PyCall>(enters, out);
969     populate<EventType::PyCCall>(enters, out);
970     return out;
971   }
972 
973  private:
974   template <EventType E>
populate(std::vector<python_tracer::CompressedEvent> & enters,std::vector<std::shared_ptr<Result>> & out)975   void populate(
976       std::vector<python_tracer::CompressedEvent>& enters,
977       std::vector<std::shared_ptr<Result>>& out) {
978     using stack_t = std::vector<std::shared_ptr<Result>>;
979     const auto initial_size = out.size();
980     auto pop = [](stack_t& stack, c10::time_t t) {
981       TORCH_INTERNAL_ASSERT(!stack.empty(), "Python replay stack is empty.");
982       std::get<ExtraFields<E>>(stack.back()->extra_fields_).end_time_ns_ = t;
983       stack.pop_back();
984     };
985 
986     ska::flat_hash_map<size_t, stack_t> stacks;
987     auto& state = get_state<E>();
988     for (const auto& enter : enters) {
989       auto fields_it = state.fields_.find(enter.key_);
990       if (fields_it != state.fields_.end()) {
991         while (!state.exits_.empty() &&
992                state.exits_.top().t_ < enter.enter_t_) {
993           auto& exit = state.exits_.top();
994           pop(stacks[exit.python_tid_], exit.t_);
995           state.exits_.pop();
996         }
997         out.push_back(Result::create(
998             enter.enter_t_,
999             enter.system_tid_,
1000             enter.kineto_info_,
1001             fields_it->second));
1002 
1003         stacks[fields_it->second.python_tid_].push_back(out.back());
1004       }
1005     }
1006 
1007     // Handle events which were still running when profiling ended.
1008     for (auto& i : stacks) {
1009       while (!i.second.empty()) {
1010         pop(i.second, end_time_);
1011       }
1012     }
1013 
1014     // Assign system TIDs to start events based on the system TID of the next
1015     // observed event with the same Python TID.
1016     ska::flat_hash_map<size_t, std::pair<size_t, kineto::DeviceAndResource>>
1017         tid_map;
1018     auto it = out.rbegin();
1019     for (C10_UNUSED auto _ : c10::irange(initial_size, out.size())) {
1020       const auto python_tid =
1021           std::get<ExtraFields<E>>((*it)->extra_fields_).python_tid_;
1022       if ((*it)->start_tid_ == NoTID && SOFT_ASSERT(E == EventType::PyCall)) {
1023         const auto& tid_info =
1024             tid_map.insert({python_tid, {NoTID, kineto::DeviceAndResource()}})
1025                 .first->second;
1026         (*it)->start_tid_ = tid_info.first;
1027         (*it)->kineto_info_ = tid_info.second;
1028       }
1029       tid_map[python_tid] = {(*it)->start_tid_, (*it)->kineto_info_};
1030       ++it;
1031     }
1032   }
1033 
1034   template <EventType E>
1035   struct State {
1036     ska::flat_hash_map<TraceKey, ExtraFields<E>> fields_;
1037     std::priority_queue<Exit, std::vector<Exit>, std::greater<>> exits_;
1038   };
1039 
1040   template <EventType E>
get_state()1041   auto& get_state() {
1042     return std::get < E == EventType::PyCall ? 0 : 1 > (state_);
1043   }
1044 
1045   c10::time_t end_time_;
1046   std::function<c10::time_t(c10::approx_time_t)> time_converter_;
1047   std::tuple<State<EventType::PyCall>, State<EventType::PyCCall>> state_;
1048 };
1049 
1050 struct PythonIDVisitor {
operator ()torch::profiler::impl::__anonaba953630411::PythonIDVisitor1051   void operator()(ExtraFields<EventType::PyCall>& py_call) {
1052     py_call.id_ = ++current_python_id_;
1053     if (py_call.module_.has_value()) {
1054       auto& m = py_call.module_;
1055       auto& module_ids = module_ids_[m->cls_];
1056       m->id_ = module_ids.insert({m->self_, module_ids.size()}).first->second;
1057     }
1058   }
1059 
operator ()torch::profiler::impl::__anonaba953630411::PythonIDVisitor1060   void operator()(ExtraFields<EventType::PyCCall>& py_call) {
1061     py_call.id_ = ++current_python_id_;
1062   }
1063 
1064   template <typename T>
operator ()torch::profiler::impl::__anonaba953630411::PythonIDVisitor1065   void operator()(T&) {}
1066 
1067   size_t current_python_id_{0};
1068   ska::flat_hash_map<PyModuleCls, ska::flat_hash_map<PyModuleSelf, size_t>>
1069       module_ids_;
1070 };
1071 
getEvents(std::function<c10::time_t (c10::approx_time_t)> time_converter,std::vector<python_tracer::CompressedEvent> & enters,c10::time_t end_time_ns)1072 std::vector<std::shared_ptr<Result>> PythonTracer::getEvents(
1073     std::function<c10::time_t(c10::approx_time_t)> time_converter,
1074     std::vector<python_tracer::CompressedEvent>& enters,
1075     c10::time_t end_time_ns) {
1076   value_cache_.trimPrefixes();
1077   PostProcess post_process(
1078       std::move(time_converter),
1079       thread_local_results_,
1080       value_cache_,
1081       end_time_ns);
1082   post_process.set_start_frames(start_frames_, enters);
1083   auto out = post_process.run(enters);
1084 
1085   std::stable_sort(out.begin(), out.end(), [](const auto& a, const auto& b) {
1086     return a->start_time_ns_ < b->start_time_ns_;
1087   });
1088 
1089   PythonIDVisitor id_visitor;
1090   for (auto& i : out) {
1091     std::visit(id_visitor, i->extra_fields_);
1092   }
1093 
1094   return out;
1095 }
1096 
1097 // ============================================================================
1098 // == API =====================================================================
1099 // ============================================================================
pyProfileFn(PyObject * obj,PyFrameObject * frame,int what,PyObject * arg)1100 int PythonTracer::pyProfileFn(
1101     PyObject* obj,
1102     PyFrameObject* frame,
1103     int what,
1104     PyObject* arg) {
1105   auto& local_results =
1106       *reinterpret_cast<TraceContext*>(obj)->thread_local_results_;
1107   switch (what) {
1108     case PyTrace_CALL:
1109       local_results.active_tracer_->recordPyCall(local_results, frame, false);
1110       break;
1111 
1112     case PyTrace_C_CALL:
1113       local_results.active_tracer_->recordCCall(local_results, frame, arg);
1114       break;
1115 
1116     case PyTrace_EXCEPTION:
1117     case PyTrace_RETURN:
1118       local_results.exit_times_.emplace_back(c10::getApproximateTime());
1119       break;
1120 
1121     case PyTrace_C_EXCEPTION:
1122     case PyTrace_C_RETURN:
1123       local_results.c_exit_times_.emplace_back(c10::getApproximateTime());
1124       break;
1125   }
1126   return 0;
1127 }
1128 
getTracer(torch::profiler::impl::RecordQueue * queue)1129 std::unique_ptr<python_tracer::PythonTracerBase> getTracer(
1130     torch::profiler::impl::RecordQueue* queue) {
1131   return std::make_unique<PythonTracer>(queue);
1132 }
1133 } // namespace
1134 } // namespace torch::profiler::impl
1135 
1136 namespace torch::autograd::profiler::python_tracer {
1137 
init()1138 void init() {
1139   pybind11::gil_scoped_acquire gil;
1140   TORCH_CHECK(PyType_Ready(&torch::profiler::impl::TraceContextType) == 0);
1141   torch::profiler::impl::python_tracer::registerTracer(
1142       &torch::profiler::impl::getTracer);
1143 }
1144 } // namespace torch::autograd::profiler::python_tracer
1145