xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/decomposition_registry.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/frontend/ir_emitter.h>
2 #include <torch/csrc/jit/jit_log.h>
3 #include <torch/csrc/jit/passes/constant_propagation.h>
4 #include <torch/csrc/jit/passes/peephole.h>
5 #include <torch/csrc/jit/runtime/decomposition_registry.h>
6 #include <torch/csrc/jit/runtime/decomposition_registry_util.h>
7 #include <torch/csrc/jit/runtime/operator.h>
8 #include <torch/csrc/jit/serialization/import_source.h>
9 
10 #include <c10/util/Exception.h>
11 #include <torch/csrc/autograd/jit_decomp_interface.h>
12 #include <torch/csrc/jit/ir/ir.h>
13 #include <torch/csrc/jit/passes/constant_propagation.h>
14 #include <torch/csrc/jit/passes/inliner.h>
15 #include <torch/csrc/jit/passes/peephole.h>
16 #include <torch/csrc/jit/runtime/graph_executor.h>
17 #include <memory>
18 #include <unordered_map>
19 
20 namespace torch::jit {
21 namespace {
22 std::mutex lock;
23 
24 // CompilationUnit that holds all these Functions and keeps them alive.
25 auto compilation_unit = std::make_shared<CompilationUnit>();
26 std::unordered_map<const FunctionSchema*, std::shared_ptr<Graph>>
27     schema_to_decomposition;
28 
29 // Holds User-Registered Functions and keeps them alive
30 std::unordered_map<const FunctionSchema*, std::unique_ptr<Function>>
31     user_registered_funcs;
32 
33 std::unordered_map<const FunctionSchema*, Function*> schema_to_function;
34 
loadModule(const CompilationUnit & module)35 void loadModule(const CompilationUnit& module) {
36   const auto& mappings = GetDecompositionMapping().getAllKeysAndValues();
37   for (const auto& pair : mappings) {
38     const FunctionSchema* schema = &pair.first->schema();
39     const std::string& decomposition_function_name = pair.second;
40 
41     Function& decomposition_function =
42         module.get_function(decomposition_function_name);
43     std::shared_ptr<Graph> graph =
44         toGraphFunction(decomposition_function).graph();
45 
46     schema_to_function[schema] = &decomposition_function;
47     schema_to_decomposition[schema] = graph;
48   }
49 }
50 
loadDecompositionFunctions()51 void loadDecompositionFunctions() {
52   std::lock_guard<std::mutex> guard(lock);
53   if (!schema_to_decomposition.empty()) {
54     return;
55   }
56 
57   auto src = std::make_shared<Source>(GetSerializedDecompositions());
58   std::stringstream ss;
59   std::vector<at::IValue> constantTable;
60   auto resolver = std::make_shared<SourceImporterImpl>(
61       compilation_unit,
62       &constantTable,
63       [&](const std::string& name) -> std::shared_ptr<Source> { return src; },
64       1);
65   compilation_unit->define(
66       std::nullopt, GetSerializedDecompositions(), resolver, nullptr);
67   loadModule(*compilation_unit);
68 }
69 
70 } // anonymous namespace
71 
DecomposeOp(Node * n)72 static void DecomposeOp(Node* n) {
73   auto schema = n->maybeSchema();
74   if (!schema) {
75     return;
76   }
77   auto decomposition = GetDecomposition(n->schema());
78   if (!decomposition) {
79     return;
80   }
81   WithInsertPoint guard(n);
82   auto outputs =
83       insertGraph(*n->owningGraph(), *decomposition->get(), n->inputs());
84   TORCH_INTERNAL_ASSERT(outputs.size() == n->outputs().size());
85   for (size_t i : c10::irange(outputs.size())) {
86     n->outputs().at(i)->replaceAllUsesWith(outputs[i]);
87   }
88   n->destroy();
89 }
90 
RunDecompositions(Block * block)91 static void RunDecompositions(Block* block) {
92   for (auto it = block->nodes().begin(); it != block->nodes().end();) {
93     Node* n = *it;
94     it++; // advance iterator bc the current node may be destroyed
95     for (Block* b : n->blocks()) {
96       RunDecompositions(b);
97     }
98     DecomposeOp(n);
99   }
100 }
101 
RunDecompositions(std::shared_ptr<Graph> g)102 void RunDecompositions(std::shared_ptr<Graph> g) {
103   RunDecompositions(g->block());
104   for (C10_UNUSED const auto _ : c10::irange(2)) {
105     PeepholeOptimize(g, /*disable_shape_peephole*/ true);
106     ConstantPropagation(g);
107   }
108 }
109 
GetDecomposition(const FunctionSchema & schema)110 std::optional<std::shared_ptr<Graph>> GetDecomposition(
111     const FunctionSchema& schema) {
112   loadDecompositionFunctions();
113   GRAPH_DEBUG("Trying to find schema: ", schema);
114   auto cache_it = schema_to_decomposition.find(&schema);
115   if (cache_it != schema_to_decomposition.end()) {
116     return cache_it->second;
117   }
118   GRAPH_DEBUG("Could not find schema: ", schema);
119 
120   return std::nullopt;
121 }
122 
GetDecompositionFunction(const FunctionSchema & schema)123 std::optional<GraphFunction*> GetDecompositionFunction(
124     const FunctionSchema& schema) {
125   loadDecompositionFunctions();
126   auto cache_it = schema_to_function.find(&schema);
127   GRAPH_DEBUG("Trying to find schema: ", schema);
128   if (cache_it == schema_to_function.end()) {
129     GRAPH_DEBUG("Could not find schema: ", schema);
130     return std::nullopt;
131   }
132   auto& func = toGraphFunction(*cache_it->second);
133   // Simple Executor:
134   // To allow decomposition to run on tensor subclasses such as batched tensors,
135   // we set decomposition execution to use the simple executor so that
136   // optimizations that do not compose with arbitrary subclasses (such as
137   // fusion) do not run
138   func._set_initial_executor_execution_mode(ExecutorExecutionMode::SIMPLE);
139   return &func;
140 }
141 
142 // Decomposition registers a Graph so that we can initialize a GraphFunction
143 // that will run with Simple Executor
RegisterDecomposition(const FunctionSchema & schema,std::shared_ptr<Graph> g)144 void RegisterDecomposition(
145     const FunctionSchema& schema,
146     std::shared_ptr<Graph> g) {
147   loadDecompositionFunctions();
148   std::lock_guard<std::mutex> guard(lock);
149   Inline(*g);
150   for (const auto i : c10::irange(2)) {
151     (void)i; // Suppress unused variable warning
152     PeepholeOptimize(g);
153     ConstantPropagationImmutableTypes(g);
154   }
155 
156   auto new_func = std::make_unique<GraphFunction>(
157       schema.name(), g, nullptr, ExecutorExecutionMode::SIMPLE);
158   user_registered_funcs.emplace(&schema, std::move(new_func));
159   schema_to_function[&schema] = user_registered_funcs[&schema].get();
160   schema_to_decomposition[&schema] = g;
161 }
162 
163 // see NOTE: [Jit Decomposition Interface]
164 struct JitDecomp final : torch::autograd::impl::JitDecompInterface {
165   bool has_jit_decomposition(const c10::FunctionSchema& schema) const override;
166   void run_jit_decomposition(
167       const c10::OperatorHandle& op,
168       torch::jit::Stack* stack) const override;
169 };
170 
171 JitDecomp jitDecomp;
172 torch::autograd::impl::JitDecompRegisterer registerJitDecomp(&jitDecomp);
173 
run_jit_decomposition(const c10::OperatorHandle & op,torch::jit::Stack * stack) const174 void JitDecomp::run_jit_decomposition(
175     const c10::OperatorHandle& op,
176     torch::jit::Stack* stack) const {
177   ::torch::jit::run_jit_decomposition(op, stack);
178 }
179 
has_jit_decomposition(const FunctionSchema & schema) const180 bool JitDecomp::has_jit_decomposition(const FunctionSchema& schema) const {
181   return ::torch::jit::has_jit_decomposition(schema);
182 }
183 
run_jit_decomposition(const c10::OperatorHandle & op,torch::jit::Stack * stack)184 void run_jit_decomposition(
185     const c10::OperatorHandle& op,
186     torch::jit::Stack* stack) {
187   const auto& schema = op.schema();
188   // TODO: templatize based on op and keep static trace_exec
189   auto* trace_exec = torch::jit::GetDecompositionExecutor(schema);
190   trace_exec->run((*stack));
191   if (stack->back().isTuple()) {
192     at::IValue tup = stack->back();
193     stack->pop_back();
194     for (const auto& elem : tup.toTuple()->elements()) {
195       stack->push_back(elem);
196     }
197   }
198 }
199 
has_jit_decomposition(const FunctionSchema & schema)200 bool has_jit_decomposition(const FunctionSchema& schema) {
201   return GetDecompositionFunction(schema).has_value();
202 }
203 
GetDecompositionExecutor(const FunctionSchema & schema)204 Function* GetDecompositionExecutor(const FunctionSchema& schema) {
205   auto maybe_func = GetDecompositionFunction(schema);
206   TORCH_INTERNAL_ASSERT(maybe_func);
207   return *maybe_func;
208 }
209 
GetDecompositionExecutor(const char * schema_literal)210 Function* GetDecompositionExecutor(const char* schema_literal) {
211   auto& schema = getOperatorForLiteral(schema_literal)->schema();
212   return GetDecompositionExecutor(schema);
213 }
214 
215 } // namespace torch::jit
216