1 #include <ATen/record_function.h>
2 #include <torch/csrc/dynamo/cpp_shim.h>
3
4 struct _PytorchRecordFunctionState {
5 at::RecordFunction guard;
6
_PytorchRecordFunctionState_PytorchRecordFunctionState7 _PytorchRecordFunctionState() : guard(at::RecordScope::FUNCTION) {}
8 };
9
_pytorch_record_function_enter(const char * name)10 _PytorchRecordFunctionState* _pytorch_record_function_enter(const char* name) {
11 _PytorchRecordFunctionState* state = new _PytorchRecordFunctionState();
12 state->guard.before(name);
13 return state;
14 }
15
16 static inline _PytorchRecordFunctionState*
_pytorch_record_function_enter_with_kwinputs(const char * name,const std::unordered_map<std::string,c10::IValue> * kwargs)17 _pytorch_record_function_enter_with_kwinputs(
18 const char* name,
19 const std::unordered_map<std::string, c10::IValue>* kwargs) {
20 _PytorchRecordFunctionState* state = new _PytorchRecordFunctionState();
21 std::vector<c10::IValue> args;
22 state->guard.before(name, &args, kwargs);
23 return state;
24 }
25
_pytorch_record_function_enter_with_context(const char * name,const char * context)26 _PytorchRecordFunctionState* _pytorch_record_function_enter_with_context(
27 const char* name,
28 const char* context) {
29 auto map = std::unordered_map<std::string, c10::IValue>();
30 map.insert({"context", c10::IValue(context)});
31 return _pytorch_record_function_enter_with_kwinputs(name, &map);
32 }
33
_pytorch_record_function_exit(_PytorchRecordFunctionState * state)34 void _pytorch_record_function_exit(_PytorchRecordFunctionState* state) {
35 if (state == nullptr) {
36 return;
37 }
38 delete state;
39 }
40