1 #include <torch/csrc/jit/testing/hooks_for_testing.h> 2 3 #include <torch/csrc/jit/api/module.h> 4 5 namespace torch { 6 namespace jit { 7 8 static ModuleHook emit_module_callback; didFinishEmitModule(Module module)9void didFinishEmitModule(Module module) { 10 if (emit_module_callback) { 11 emit_module_callback(module); 12 } 13 } 14 15 static FunctionHook emit_function_callback; didFinishEmitFunction(StrongFunctionPtr fn)16void didFinishEmitFunction(StrongFunctionPtr fn) { 17 if (emit_function_callback) { 18 emit_function_callback(fn); 19 } 20 } 21 setEmitHooks(ModuleHook for_mod,FunctionHook for_fn)22void setEmitHooks(ModuleHook for_mod, FunctionHook for_fn) { 23 emit_module_callback = std::move(for_mod); 24 emit_function_callback = std::move(for_fn); 25 } 26 getEmitHooks()27std::pair<ModuleHook, FunctionHook> getEmitHooks() { 28 return std::make_pair(emit_module_callback, emit_function_callback); 29 } 30 31 } // namespace jit 32 } // namespace torch 33