xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/script_profile.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <chrono>
4 #include <map>
5 #include <string>
6 
7 #include <ATen/core/ivalue.h>
8 #include <c10/macros/Macros.h>
9 #include <torch/csrc/jit/frontend/source_ref.h>
10 #include <torch/csrc/jit/ir/ir.h>
11 
12 namespace torch::jit {
13 namespace profiling {
14 
15 struct Datapoint {
16   using Timepoint = std::chrono::time_point<std::chrono::steady_clock>;
17   SourceRange sourceRange;
18   Timepoint start;
19   Timepoint end;
20 
DatapointDatapoint21   explicit Datapoint(SourceRange sr)
22       : sourceRange(std::move(sr)), start(std::chrono::steady_clock::now()) {}
23 };
24 
25 class TORCH_API InstructionSpan {
26  public:
27   explicit InstructionSpan(Node&);
28   ~InstructionSpan();
29   InstructionSpan(InstructionSpan&&) = delete;
30   InstructionSpan& operator=(InstructionSpan&&) = delete;
31 
32  private:
33   std::unique_ptr<Datapoint> datapoint_;
34 };
35 
36 bool TORCH_API isProfilingOngoing();
37 
38 } // namespace profiling
39 
40 struct TORCH_API InstructionStats : public CustomClassHolder {
41   int64_t count{0};
42   std::chrono::nanoseconds duration{0};
43 };
44 
45 class TORCH_API SourceStats : public CustomClassHolder {
46  public:
47   using LineMap = c10::Dict<int64_t, c10::intrusive_ptr<InstructionStats>>;
48 
SourceStats(SourceRef source,LineMap lineMap)49   SourceStats(SourceRef source, LineMap lineMap)
50       : source_(std::move(source)), lineMap_(std::move(lineMap)) {}
51 
getSourceRef()52   const SourceRef& getSourceRef() const {
53     return source_;
54   }
55 
getLineMap()56   const LineMap& getLineMap() const {
57     return lineMap_;
58   }
59 
60  private:
61   SourceRef source_;
62   LineMap lineMap_;
63 };
64 
65 /**
66  * ScriptProfile is an underlying C++ implementation for TorchScript profiling.
67  * The profiling section is specified by calling enable() and disable():
68  *
69  * ...
70  * scriptProfile.enable();
71  * ...
72  * (scripts)
73  * ...
74  * scriptProfile.disable();
75  * ...
76  *
77  * NOTE: you cannot attach the profiler while the script is running.
78  *
79  * To retrieve collected runtime data, users may call dumpStats() and do
80  * arbitrary filtering on the data they want. Note that dumpStats() should
81  * not be called inside a profiling section.
82  * In general, stats are aggregated per source function body, and then by line
83  * number.
84  */
85 class TORCH_API ScriptProfile : public CustomClassHolder {
86   // Aggregates datapoints by function source id, then by line number.
87   using LineMap = std::map<int64_t, InstructionStats>;
88   using SourceMap = std::map<SourceRef, LineMap, std::less<>>;
89 
90  public:
91   void enable();
92   void disable();
93   const SourceMap& dumpStats();
94   void addDatapoint(std::shared_ptr<profiling::Datapoint>);
95   ~ScriptProfile() override;
96 
97  private:
98   bool enabled_{false};
99   std::vector<std::shared_ptr<profiling::Datapoint>> datapoints_;
100   SourceMap sourceMap_;
101 };
102 
103 } // namespace torch::jit
104