xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/record_function_ops.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/record_function.h>
3 #include <torch/custom_class.h>
4 #include <optional>
5 
6 namespace torch::autograd::profiler {
7 
8 struct PythonRecordFunction : public torch::CustomClassHolder {
9   at::RecordFunction record;
10 
11   explicit PythonRecordFunction(
12       at::RecordScope scope = at::RecordScope::FUNCTION)
recordPythonRecordFunction13       : record(scope) {}
14 };
15 
16 // Creates a new profiling scope using RecordFunction and invokes its starting
17 // callbacks.
18 TORCH_API c10::intrusive_ptr<PythonRecordFunction> record_function_enter_new(
19     const std::string& name,
20     const std::optional<std::string>& args = std::nullopt);
21 
22 // Schedules RecordFunction's end callbacks to be run on completion of a future.
23 TORCH_API c10::intrusive_ptr<c10::ivalue::Future> _call_end_callbacks_on_fut_new(
24     const c10::intrusive_ptr<PythonRecordFunction>& record,
25     const c10::intrusive_ptr<c10::ivalue::Future>& fut);
26 
27 } // namespace torch::autograd::profiler
28