xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/backends/backend_detail.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/api/module.h>
4 
5 #include <ATen/core/jit_type.h>
6 
7 #include <functional>
8 
9 namespace torch {
10 namespace jit {
11 
12 using DebugHandleType = int64_t;
13 
14 using NodeToDebugHandle = std::unordered_map<Node*, DebugHandleType>;
15 
16 using BackendDebugHandleGenerator =
17     std::function<NodeToDebugHandle(const std::shared_ptr<Graph>&)>;
18 
19 namespace detail {
20 
21 using BackendPreprocessFunction = std::function<c10::IValue(
22     const Module&,
23     const c10::Dict<IValue, IValue>&,
24     const BackendDebugHandleGenerator& generate_debug_handles)>;
25 
26 TORCH_API void registerBackendPreprocessFunction(
27     const std::string& name,
28     const BackendPreprocessFunction& preprocess);
29 
30 bool hasBackendPreprocessFunction(const std::string& name);
31 
32 BackendPreprocessFunction getBackendPreprocessFunction(const std::string& name);
33 
34 TORCH_API Module codegen_backend_module(
35     const std::string& backend_name,
36     const Module& orig_module,
37     const c10::Dict<IValue, IValue>& method_compile_spec,
38     const c10::DictTypePtr& any_dict_ty);
39 } // namespace detail
40 } // namespace jit
41 } // namespace torch
42