xref: /aosp_15_r20/external/pytorch/torch/csrc/dynamo/cpp_shim.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/record_function.h>
2 #include <torch/csrc/dynamo/cpp_shim.h>
3 
4 struct _PytorchRecordFunctionState {
5   at::RecordFunction guard;
6 
_PytorchRecordFunctionState_PytorchRecordFunctionState7   _PytorchRecordFunctionState() : guard(at::RecordScope::FUNCTION) {}
8 };
9 
_pytorch_record_function_enter(const char * name)10 _PytorchRecordFunctionState* _pytorch_record_function_enter(const char* name) {
11   _PytorchRecordFunctionState* state = new _PytorchRecordFunctionState();
12   state->guard.before(name);
13   return state;
14 }
15 
16 static inline _PytorchRecordFunctionState*
_pytorch_record_function_enter_with_kwinputs(const char * name,const std::unordered_map<std::string,c10::IValue> * kwargs)17 _pytorch_record_function_enter_with_kwinputs(
18     const char* name,
19     const std::unordered_map<std::string, c10::IValue>* kwargs) {
20   _PytorchRecordFunctionState* state = new _PytorchRecordFunctionState();
21   std::vector<c10::IValue> args;
22   state->guard.before(name, &args, kwargs);
23   return state;
24 }
25 
_pytorch_record_function_enter_with_context(const char * name,const char * context)26 _PytorchRecordFunctionState* _pytorch_record_function_enter_with_context(
27     const char* name,
28     const char* context) {
29   auto map = std::unordered_map<std::string, c10::IValue>();
30   map.insert({"context", c10::IValue(context)});
31   return _pytorch_record_function_enter_with_kwinputs(name, &map);
32 }
33 
_pytorch_record_function_exit(_PytorchRecordFunctionState * state)34 void _pytorch_record_function_exit(_PytorchRecordFunctionState* state) {
35   if (state == nullptr) {
36     return;
37   }
38   delete state;
39 }
40