xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/testing/hooks_for_testing.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/testing/hooks_for_testing.h>
2 
3 #include <torch/csrc/jit/api/module.h>
4 
5 namespace torch {
6 namespace jit {
7 
8 static ModuleHook emit_module_callback;
didFinishEmitModule(Module module)9 void didFinishEmitModule(Module module) {
10   if (emit_module_callback) {
11     emit_module_callback(module);
12   }
13 }
14 
15 static FunctionHook emit_function_callback;
didFinishEmitFunction(StrongFunctionPtr fn)16 void didFinishEmitFunction(StrongFunctionPtr fn) {
17   if (emit_function_callback) {
18     emit_function_callback(fn);
19   }
20 }
21 
setEmitHooks(ModuleHook for_mod,FunctionHook for_fn)22 void setEmitHooks(ModuleHook for_mod, FunctionHook for_fn) {
23   emit_module_callback = std::move(for_mod);
24   emit_function_callback = std::move(for_fn);
25 }
26 
getEmitHooks()27 std::pair<ModuleHook, FunctionHook> getEmitHooks() {
28   return std::make_pair(emit_module_callback, emit_function_callback);
29 }
30 
31 } // namespace jit
32 } // namespace torch
33