xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/record_function_ops.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ThreadLocalState.h>
2 #include <ATen/cpp_custom_type_hack.h>
3 #include <ATen/record_function.h>
4 #include <torch/csrc/autograd/record_function_ops.h>
5 
6 #include <torch/csrc/jit/runtime/operator.h>
7 #include <torch/library.h>
8 
9 namespace caffe2 {
10 // Required for cpp_custom_type_hack to work
11 // NOLINTNEXTLINE(bugprone-exception-escape)
12 CAFFE_KNOWN_TYPE(at::RecordFunction);
13 } // namespace caffe2
14 
15 namespace torch::autograd::profiler {
16 
17 // Creates a new profiling scope using RecordFunction and invokes its starting
18 // callbacks.
record_function_enter(const std::string & name,const std::optional<std::string> & args,at::RecordFunction & rec)19 static void record_function_enter(
20     const std::string& name,
21     const std::optional<std::string>& args,
22     at::RecordFunction& rec) {
23   if (rec.isActive()) {
24     if (rec.needsInputs() && args.has_value()) {
25       rec.before(
26           name, c10::ArrayRef<const c10::IValue>{c10::IValue{args.value()}});
27     } else {
28       rec.before(name);
29     }
30   }
31 }
32 
33 // Legacy signature using cpp_custom_type_hack
record_function_enter_legacy(const std::string & name,const std::optional<std::string> & args)34 static at::Tensor record_function_enter_legacy(
35     const std::string& name,
36     const std::optional<std::string>& args) {
37   auto rec = std::make_unique<at::RecordFunction>(at::RecordScope::USER_SCOPE);
38   record_function_enter(name, args, *rec);
39   return at::cpp_custom_type_hack::create(std::move(rec), at::TensorOptions());
40 }
41 
42 // New signature using custom_class
record_function_enter_new(const std::string & name,const std::optional<std::string> & args)43 c10::intrusive_ptr<PythonRecordFunction> record_function_enter_new(
44     const std::string& name,
45     const std::optional<std::string>& args) {
46   auto rec =
47       c10::make_intrusive<PythonRecordFunction>(at::RecordScope::USER_SCOPE);
48   record_function_enter(name, args, rec->record);
49   return rec;
50 }
51 
getRecordFunctionFromTensor(const at::Tensor & handle)52 static at::RecordFunction& getRecordFunctionFromTensor(
53     const at::Tensor& handle) {
54   auto& rec = at::cpp_custom_type_hack::cast<at::RecordFunction>(handle);
55   return rec;
56 }
57 
58 // Ends the profiling scope created with record_function_enter.
record_function_exit(at::RecordFunction & rec)59 static void record_function_exit(at::RecordFunction& rec) {
60   rec.end();
61 }
62 
63 // Legacy signature using cpp_custom_type_hack
record_function_exit_legacy(const at::Tensor & handle)64 static void record_function_exit_legacy(const at::Tensor& handle) {
65   // We don't actually need to do anything with handle just need to persist the
66   // lifetime until now.
67   auto& rec = getRecordFunctionFromTensor(handle);
68   record_function_exit(rec);
69 }
70 
71 // New signature using custom_class
record_function_exit_new(const c10::intrusive_ptr<PythonRecordFunction> & record)72 static void record_function_exit_new(
73     const c10::intrusive_ptr<PythonRecordFunction>& record) {
74   record_function_exit(record->record);
75 }
76 
77 template <typename Func>
_call_end_callbacks_on_fut(Func get_record,const c10::intrusive_ptr<c10::ivalue::Future> & fut)78 c10::intrusive_ptr<c10::ivalue::Future> _call_end_callbacks_on_fut(
79     Func get_record,
80     const c10::intrusive_ptr<c10::ivalue::Future>& fut) {
81   // Profiling callback that ends the associated record_function
82   // and returns the value of the passed in future.
83   auto futureProfilingFunc =
84       [get_record = std::move(get_record)](c10::ivalue::Future& fut) {
85         auto& rec = get_record();
86         rec.end();
87         // Note: this future is returned to the user to ensure that a call to
88         // wait() ensures that profiling callbacks have ran. To ensure that this
89         // is transparent, we must make this future propagate the value of the
90         // RPC future. Use value() here instead of constValue() to ensure we
91         // propagate errors.
92         return fut.value();
93       };
94   // Define a future that completes after the profiling callbacks are run.
95   auto profiledFut = fut->then(
96       at::wrapPropagateTLSState(std::move(futureProfilingFunc)),
97       fut->elementType());
98   return profiledFut;
99 }
100 
101 // Legacy signature using cpp_custom_type_hack
_call_end_callbacks_on_fut_legacy(const at::Tensor & handle,const c10::intrusive_ptr<c10::ivalue::Future> & fut)102 static c10::intrusive_ptr<c10::ivalue::Future> _call_end_callbacks_on_fut_legacy(
103     const at::Tensor& handle,
104     const c10::intrusive_ptr<c10::ivalue::Future>& fut) {
105   return _call_end_callbacks_on_fut(
106       [handle]() -> at::RecordFunction& {
107         TORCH_INTERNAL_ASSERT(
108             handle.defined(),
109             "Undefined RecordFunction handle. This can happen if the handle is "
110             "not correctly persisted and is destroyed before the future is "
111             "realized.");
112 
113         return getRecordFunctionFromTensor(handle);
114       },
115       fut);
116 }
117 
118 // New signature using custom_class
_call_end_callbacks_on_fut_new(const c10::intrusive_ptr<PythonRecordFunction> & record,const c10::intrusive_ptr<c10::ivalue::Future> & fut)119 c10::intrusive_ptr<c10::ivalue::Future> _call_end_callbacks_on_fut_new(
120     const c10::intrusive_ptr<PythonRecordFunction>& record,
121     const c10::intrusive_ptr<c10::ivalue::Future>& fut) {
122   return _call_end_callbacks_on_fut(
123       [record]() -> at::RecordFunction& { return record->record; }, fut);
124 }
125 
126 // Internal only, do not use directly, use Python's record_function()
TORCH_LIBRARY_FRAGMENT(profiler,m)127 TORCH_LIBRARY_FRAGMENT(profiler, m) {
128   m.class_<PythonRecordFunction>("_RecordFunction");
129 
130   m.def(
131       "_record_function_enter(str name, str? args=None) -> Tensor",
132       &record_function_enter_legacy);
133   m.def(
134       "_record_function_enter_new(str name, str? args=None) -> "
135       "__torch__.torch.classes.profiler._RecordFunction",
136       &record_function_enter_new);
137   m.def("_record_function_exit", &record_function_exit_legacy);
138   m.def("_record_function_exit._RecordFunction", &record_function_exit_new);
139 
140   torch::jit::registerOperator(torch::jit::Operator(
141       "profiler::_call_end_callbacks_on_jit_fut(Tensor x, Future(t) y) -> Future(t)",
142       [](jit::Stack& stack) {
143         // Pop inputs, which should be a future and a tensor
144         auto fut = jit::pop(stack).toFuture();
145         auto tensor = jit::pop(stack).toTensor();
146         auto profiledFut = _call_end_callbacks_on_fut_legacy(tensor, fut);
147         // return future that completes when profiling callbacks have run.
148         jit::push(stack, std::move(profiledFut));
149       },
150       c10::AliasAnalysisKind::FROM_SCHEMA));
151   torch::jit::registerOperator(torch::jit::Operator(
152       "profiler::_call_end_callbacks_on_jit_fut._RecordFunction("
153       "__torch__.torch.classes.profiler._RecordFunction x, Future(t) y) -> Future(t)",
154       [](c10::Stack& stack) {
155         // Pop inputs, which should be a future and a PythonRecordFunction
156         auto fut = torch::jit::pop(stack).toFuture();
157         auto tensor =
158             torch::jit::pop(stack).toCustomClass<PythonRecordFunction>();
159         auto profiledFut = _call_end_callbacks_on_fut_new(tensor, fut);
160         // return future that completes when profiling callbacks have run.
161         torch::jit::push(stack, std::move(profiledFut));
162       },
163       c10::AliasAnalysisKind::FROM_SCHEMA));
164 }
165 
166 } // namespace torch::autograd::profiler
167