1 #pragma once 2 3 #include <chrono> 4 #include <map> 5 #include <string> 6 7 #include <ATen/core/ivalue.h> 8 #include <c10/macros/Macros.h> 9 #include <torch/csrc/jit/frontend/source_ref.h> 10 #include <torch/csrc/jit/ir/ir.h> 11 12 namespace torch::jit { 13 namespace profiling { 14 15 struct Datapoint { 16 using Timepoint = std::chrono::time_point<std::chrono::steady_clock>; 17 SourceRange sourceRange; 18 Timepoint start; 19 Timepoint end; 20 DatapointDatapoint21 explicit Datapoint(SourceRange sr) 22 : sourceRange(std::move(sr)), start(std::chrono::steady_clock::now()) {} 23 }; 24 25 class TORCH_API InstructionSpan { 26 public: 27 explicit InstructionSpan(Node&); 28 ~InstructionSpan(); 29 InstructionSpan(InstructionSpan&&) = delete; 30 InstructionSpan& operator=(InstructionSpan&&) = delete; 31 32 private: 33 std::unique_ptr<Datapoint> datapoint_; 34 }; 35 36 bool TORCH_API isProfilingOngoing(); 37 38 } // namespace profiling 39 40 struct TORCH_API InstructionStats : public CustomClassHolder { 41 int64_t count{0}; 42 std::chrono::nanoseconds duration{0}; 43 }; 44 45 class TORCH_API SourceStats : public CustomClassHolder { 46 public: 47 using LineMap = c10::Dict<int64_t, c10::intrusive_ptr<InstructionStats>>; 48 SourceStats(SourceRef source,LineMap lineMap)49 SourceStats(SourceRef source, LineMap lineMap) 50 : source_(std::move(source)), lineMap_(std::move(lineMap)) {} 51 getSourceRef()52 const SourceRef& getSourceRef() const { 53 return source_; 54 } 55 getLineMap()56 const LineMap& getLineMap() const { 57 return lineMap_; 58 } 59 60 private: 61 SourceRef source_; 62 LineMap lineMap_; 63 }; 64 65 /** 66 * ScriptProfile is an underlying C++ implementation for TorchScript profiling. 67 * The profiling section is specified by calling enable() and disable(): 68 * 69 * ... 70 * scriptProfile.enable(); 71 * ... 72 * (scripts) 73 * ... 74 * scriptProfile.disable(); 75 * ... 76 * 77 * NOTE: you cannot attach the profiler while the script is running. 78 * 79 * To retrieve collected runtime data, users may call dumpStats() and do 80 * arbitrary filtering on the data they want. Note that dumpStats() should 81 * not be called inside a profiling section. 82 * In general, stats are aggregated per source function body, and then by line 83 * number. 84 */ 85 class TORCH_API ScriptProfile : public CustomClassHolder { 86 // Aggregates datapoints by function source id, then by line number. 87 using LineMap = std::map<int64_t, InstructionStats>; 88 using SourceMap = std::map<SourceRef, LineMap, std::less<>>; 89 90 public: 91 void enable(); 92 void disable(); 93 const SourceMap& dumpStats(); 94 void addDatapoint(std::shared_ptr<profiling::Datapoint>); 95 ~ScriptProfile() override; 96 97 private: 98 bool enabled_{false}; 99 std::vector<std::shared_ptr<profiling::Datapoint>> datapoints_; 100 SourceMap sourceMap_; 101 }; 102 103 } // namespace torch::jit 104