1 #include <ATen/ThreadLocalState.h>
2 #include <ATen/cpp_custom_type_hack.h>
3 #include <ATen/record_function.h>
4 #include <torch/csrc/autograd/record_function_ops.h>
5
6 #include <torch/csrc/jit/runtime/operator.h>
7 #include <torch/library.h>
8
9 namespace caffe2 {
10 // Required for cpp_custom_type_hack to work
11 // NOLINTNEXTLINE(bugprone-exception-escape)
12 CAFFE_KNOWN_TYPE(at::RecordFunction);
13 } // namespace caffe2
14
15 namespace torch::autograd::profiler {
16
17 // Creates a new profiling scope using RecordFunction and invokes its starting
18 // callbacks.
record_function_enter(const std::string & name,const std::optional<std::string> & args,at::RecordFunction & rec)19 static void record_function_enter(
20 const std::string& name,
21 const std::optional<std::string>& args,
22 at::RecordFunction& rec) {
23 if (rec.isActive()) {
24 if (rec.needsInputs() && args.has_value()) {
25 rec.before(
26 name, c10::ArrayRef<const c10::IValue>{c10::IValue{args.value()}});
27 } else {
28 rec.before(name);
29 }
30 }
31 }
32
33 // Legacy signature using cpp_custom_type_hack
record_function_enter_legacy(const std::string & name,const std::optional<std::string> & args)34 static at::Tensor record_function_enter_legacy(
35 const std::string& name,
36 const std::optional<std::string>& args) {
37 auto rec = std::make_unique<at::RecordFunction>(at::RecordScope::USER_SCOPE);
38 record_function_enter(name, args, *rec);
39 return at::cpp_custom_type_hack::create(std::move(rec), at::TensorOptions());
40 }
41
42 // New signature using custom_class
record_function_enter_new(const std::string & name,const std::optional<std::string> & args)43 c10::intrusive_ptr<PythonRecordFunction> record_function_enter_new(
44 const std::string& name,
45 const std::optional<std::string>& args) {
46 auto rec =
47 c10::make_intrusive<PythonRecordFunction>(at::RecordScope::USER_SCOPE);
48 record_function_enter(name, args, rec->record);
49 return rec;
50 }
51
getRecordFunctionFromTensor(const at::Tensor & handle)52 static at::RecordFunction& getRecordFunctionFromTensor(
53 const at::Tensor& handle) {
54 auto& rec = at::cpp_custom_type_hack::cast<at::RecordFunction>(handle);
55 return rec;
56 }
57
58 // Ends the profiling scope created with record_function_enter.
record_function_exit(at::RecordFunction & rec)59 static void record_function_exit(at::RecordFunction& rec) {
60 rec.end();
61 }
62
63 // Legacy signature using cpp_custom_type_hack
record_function_exit_legacy(const at::Tensor & handle)64 static void record_function_exit_legacy(const at::Tensor& handle) {
65 // We don't actually need to do anything with handle just need to persist the
66 // lifetime until now.
67 auto& rec = getRecordFunctionFromTensor(handle);
68 record_function_exit(rec);
69 }
70
71 // New signature using custom_class
record_function_exit_new(const c10::intrusive_ptr<PythonRecordFunction> & record)72 static void record_function_exit_new(
73 const c10::intrusive_ptr<PythonRecordFunction>& record) {
74 record_function_exit(record->record);
75 }
76
77 template <typename Func>
_call_end_callbacks_on_fut(Func get_record,const c10::intrusive_ptr<c10::ivalue::Future> & fut)78 c10::intrusive_ptr<c10::ivalue::Future> _call_end_callbacks_on_fut(
79 Func get_record,
80 const c10::intrusive_ptr<c10::ivalue::Future>& fut) {
81 // Profiling callback that ends the associated record_function
82 // and returns the value of the passed in future.
83 auto futureProfilingFunc =
84 [get_record = std::move(get_record)](c10::ivalue::Future& fut) {
85 auto& rec = get_record();
86 rec.end();
87 // Note: this future is returned to the user to ensure that a call to
88 // wait() ensures that profiling callbacks have ran. To ensure that this
89 // is transparent, we must make this future propagate the value of the
90 // RPC future. Use value() here instead of constValue() to ensure we
91 // propagate errors.
92 return fut.value();
93 };
94 // Define a future that completes after the profiling callbacks are run.
95 auto profiledFut = fut->then(
96 at::wrapPropagateTLSState(std::move(futureProfilingFunc)),
97 fut->elementType());
98 return profiledFut;
99 }
100
101 // Legacy signature using cpp_custom_type_hack
_call_end_callbacks_on_fut_legacy(const at::Tensor & handle,const c10::intrusive_ptr<c10::ivalue::Future> & fut)102 static c10::intrusive_ptr<c10::ivalue::Future> _call_end_callbacks_on_fut_legacy(
103 const at::Tensor& handle,
104 const c10::intrusive_ptr<c10::ivalue::Future>& fut) {
105 return _call_end_callbacks_on_fut(
106 [handle]() -> at::RecordFunction& {
107 TORCH_INTERNAL_ASSERT(
108 handle.defined(),
109 "Undefined RecordFunction handle. This can happen if the handle is "
110 "not correctly persisted and is destroyed before the future is "
111 "realized.");
112
113 return getRecordFunctionFromTensor(handle);
114 },
115 fut);
116 }
117
118 // New signature using custom_class
_call_end_callbacks_on_fut_new(const c10::intrusive_ptr<PythonRecordFunction> & record,const c10::intrusive_ptr<c10::ivalue::Future> & fut)119 c10::intrusive_ptr<c10::ivalue::Future> _call_end_callbacks_on_fut_new(
120 const c10::intrusive_ptr<PythonRecordFunction>& record,
121 const c10::intrusive_ptr<c10::ivalue::Future>& fut) {
122 return _call_end_callbacks_on_fut(
123 [record]() -> at::RecordFunction& { return record->record; }, fut);
124 }
125
126 // Internal only, do not use directly, use Python's record_function()
TORCH_LIBRARY_FRAGMENT(profiler,m)127 TORCH_LIBRARY_FRAGMENT(profiler, m) {
128 m.class_<PythonRecordFunction>("_RecordFunction");
129
130 m.def(
131 "_record_function_enter(str name, str? args=None) -> Tensor",
132 &record_function_enter_legacy);
133 m.def(
134 "_record_function_enter_new(str name, str? args=None) -> "
135 "__torch__.torch.classes.profiler._RecordFunction",
136 &record_function_enter_new);
137 m.def("_record_function_exit", &record_function_exit_legacy);
138 m.def("_record_function_exit._RecordFunction", &record_function_exit_new);
139
140 torch::jit::registerOperator(torch::jit::Operator(
141 "profiler::_call_end_callbacks_on_jit_fut(Tensor x, Future(t) y) -> Future(t)",
142 [](jit::Stack& stack) {
143 // Pop inputs, which should be a future and a tensor
144 auto fut = jit::pop(stack).toFuture();
145 auto tensor = jit::pop(stack).toTensor();
146 auto profiledFut = _call_end_callbacks_on_fut_legacy(tensor, fut);
147 // return future that completes when profiling callbacks have run.
148 jit::push(stack, std::move(profiledFut));
149 },
150 c10::AliasAnalysisKind::FROM_SCHEMA));
151 torch::jit::registerOperator(torch::jit::Operator(
152 "profiler::_call_end_callbacks_on_jit_fut._RecordFunction("
153 "__torch__.torch.classes.profiler._RecordFunction x, Future(t) y) -> Future(t)",
154 [](c10::Stack& stack) {
155 // Pop inputs, which should be a future and a PythonRecordFunction
156 auto fut = torch::jit::pop(stack).toFuture();
157 auto tensor =
158 torch::jit::pop(stack).toCustomClass<PythonRecordFunction>();
159 auto profiledFut = _call_end_callbacks_on_fut_new(tensor, fut);
160 // return future that completes when profiling callbacks have run.
161 torch::jit::push(stack, std::move(profiledFut));
162 },
163 c10::AliasAnalysisKind::FROM_SCHEMA));
164 }
165
166 } // namespace torch::autograd::profiler
167