1 #include <torch/csrc/jit/frontend/ir_emitter.h>
2 #include <torch/csrc/jit/jit_log.h>
3 #include <torch/csrc/jit/passes/constant_propagation.h>
4 #include <torch/csrc/jit/passes/peephole.h>
5 #include <torch/csrc/jit/runtime/decomposition_registry.h>
6 #include <torch/csrc/jit/runtime/decomposition_registry_util.h>
7 #include <torch/csrc/jit/runtime/operator.h>
8 #include <torch/csrc/jit/serialization/import_source.h>
9
10 #include <c10/util/Exception.h>
11 #include <torch/csrc/autograd/jit_decomp_interface.h>
12 #include <torch/csrc/jit/ir/ir.h>
13 #include <torch/csrc/jit/passes/constant_propagation.h>
14 #include <torch/csrc/jit/passes/inliner.h>
15 #include <torch/csrc/jit/passes/peephole.h>
16 #include <torch/csrc/jit/runtime/graph_executor.h>
17 #include <memory>
18 #include <unordered_map>
19
20 namespace torch::jit {
21 namespace {
22 std::mutex lock;
23
24 // CompilationUnit that holds all these Functions and keeps them alive.
25 auto compilation_unit = std::make_shared<CompilationUnit>();
26 std::unordered_map<const FunctionSchema*, std::shared_ptr<Graph>>
27 schema_to_decomposition;
28
29 // Holds User-Registered Functions and keeps them alive
30 std::unordered_map<const FunctionSchema*, std::unique_ptr<Function>>
31 user_registered_funcs;
32
33 std::unordered_map<const FunctionSchema*, Function*> schema_to_function;
34
loadModule(const CompilationUnit & module)35 void loadModule(const CompilationUnit& module) {
36 const auto& mappings = GetDecompositionMapping().getAllKeysAndValues();
37 for (const auto& pair : mappings) {
38 const FunctionSchema* schema = &pair.first->schema();
39 const std::string& decomposition_function_name = pair.second;
40
41 Function& decomposition_function =
42 module.get_function(decomposition_function_name);
43 std::shared_ptr<Graph> graph =
44 toGraphFunction(decomposition_function).graph();
45
46 schema_to_function[schema] = &decomposition_function;
47 schema_to_decomposition[schema] = graph;
48 }
49 }
50
loadDecompositionFunctions()51 void loadDecompositionFunctions() {
52 std::lock_guard<std::mutex> guard(lock);
53 if (!schema_to_decomposition.empty()) {
54 return;
55 }
56
57 auto src = std::make_shared<Source>(GetSerializedDecompositions());
58 std::stringstream ss;
59 std::vector<at::IValue> constantTable;
60 auto resolver = std::make_shared<SourceImporterImpl>(
61 compilation_unit,
62 &constantTable,
63 [&](const std::string& name) -> std::shared_ptr<Source> { return src; },
64 1);
65 compilation_unit->define(
66 std::nullopt, GetSerializedDecompositions(), resolver, nullptr);
67 loadModule(*compilation_unit);
68 }
69
70 } // anonymous namespace
71
DecomposeOp(Node * n)72 static void DecomposeOp(Node* n) {
73 auto schema = n->maybeSchema();
74 if (!schema) {
75 return;
76 }
77 auto decomposition = GetDecomposition(n->schema());
78 if (!decomposition) {
79 return;
80 }
81 WithInsertPoint guard(n);
82 auto outputs =
83 insertGraph(*n->owningGraph(), *decomposition->get(), n->inputs());
84 TORCH_INTERNAL_ASSERT(outputs.size() == n->outputs().size());
85 for (size_t i : c10::irange(outputs.size())) {
86 n->outputs().at(i)->replaceAllUsesWith(outputs[i]);
87 }
88 n->destroy();
89 }
90
RunDecompositions(Block * block)91 static void RunDecompositions(Block* block) {
92 for (auto it = block->nodes().begin(); it != block->nodes().end();) {
93 Node* n = *it;
94 it++; // advance iterator bc the current node may be destroyed
95 for (Block* b : n->blocks()) {
96 RunDecompositions(b);
97 }
98 DecomposeOp(n);
99 }
100 }
101
RunDecompositions(std::shared_ptr<Graph> g)102 void RunDecompositions(std::shared_ptr<Graph> g) {
103 RunDecompositions(g->block());
104 for (C10_UNUSED const auto _ : c10::irange(2)) {
105 PeepholeOptimize(g, /*disable_shape_peephole*/ true);
106 ConstantPropagation(g);
107 }
108 }
109
GetDecomposition(const FunctionSchema & schema)110 std::optional<std::shared_ptr<Graph>> GetDecomposition(
111 const FunctionSchema& schema) {
112 loadDecompositionFunctions();
113 GRAPH_DEBUG("Trying to find schema: ", schema);
114 auto cache_it = schema_to_decomposition.find(&schema);
115 if (cache_it != schema_to_decomposition.end()) {
116 return cache_it->second;
117 }
118 GRAPH_DEBUG("Could not find schema: ", schema);
119
120 return std::nullopt;
121 }
122
GetDecompositionFunction(const FunctionSchema & schema)123 std::optional<GraphFunction*> GetDecompositionFunction(
124 const FunctionSchema& schema) {
125 loadDecompositionFunctions();
126 auto cache_it = schema_to_function.find(&schema);
127 GRAPH_DEBUG("Trying to find schema: ", schema);
128 if (cache_it == schema_to_function.end()) {
129 GRAPH_DEBUG("Could not find schema: ", schema);
130 return std::nullopt;
131 }
132 auto& func = toGraphFunction(*cache_it->second);
133 // Simple Executor:
134 // To allow decomposition to run on tensor subclasses such as batched tensors,
135 // we set decomposition execution to use the simple executor so that
136 // optimizations that do not compose with arbitrary subclasses (such as
137 // fusion) do not run
138 func._set_initial_executor_execution_mode(ExecutorExecutionMode::SIMPLE);
139 return &func;
140 }
141
142 // Decomposition registers a Graph so that we can initialize a GraphFunction
143 // that will run with Simple Executor
RegisterDecomposition(const FunctionSchema & schema,std::shared_ptr<Graph> g)144 void RegisterDecomposition(
145 const FunctionSchema& schema,
146 std::shared_ptr<Graph> g) {
147 loadDecompositionFunctions();
148 std::lock_guard<std::mutex> guard(lock);
149 Inline(*g);
150 for (const auto i : c10::irange(2)) {
151 (void)i; // Suppress unused variable warning
152 PeepholeOptimize(g);
153 ConstantPropagationImmutableTypes(g);
154 }
155
156 auto new_func = std::make_unique<GraphFunction>(
157 schema.name(), g, nullptr, ExecutorExecutionMode::SIMPLE);
158 user_registered_funcs.emplace(&schema, std::move(new_func));
159 schema_to_function[&schema] = user_registered_funcs[&schema].get();
160 schema_to_decomposition[&schema] = g;
161 }
162
163 // see NOTE: [Jit Decomposition Interface]
164 struct JitDecomp final : torch::autograd::impl::JitDecompInterface {
165 bool has_jit_decomposition(const c10::FunctionSchema& schema) const override;
166 void run_jit_decomposition(
167 const c10::OperatorHandle& op,
168 torch::jit::Stack* stack) const override;
169 };
170
171 JitDecomp jitDecomp;
172 torch::autograd::impl::JitDecompRegisterer registerJitDecomp(&jitDecomp);
173
run_jit_decomposition(const c10::OperatorHandle & op,torch::jit::Stack * stack) const174 void JitDecomp::run_jit_decomposition(
175 const c10::OperatorHandle& op,
176 torch::jit::Stack* stack) const {
177 ::torch::jit::run_jit_decomposition(op, stack);
178 }
179
has_jit_decomposition(const FunctionSchema & schema) const180 bool JitDecomp::has_jit_decomposition(const FunctionSchema& schema) const {
181 return ::torch::jit::has_jit_decomposition(schema);
182 }
183
run_jit_decomposition(const c10::OperatorHandle & op,torch::jit::Stack * stack)184 void run_jit_decomposition(
185 const c10::OperatorHandle& op,
186 torch::jit::Stack* stack) {
187 const auto& schema = op.schema();
188 // TODO: templatize based on op and keep static trace_exec
189 auto* trace_exec = torch::jit::GetDecompositionExecutor(schema);
190 trace_exec->run((*stack));
191 if (stack->back().isTuple()) {
192 at::IValue tup = stack->back();
193 stack->pop_back();
194 for (const auto& elem : tup.toTuple()->elements()) {
195 stack->push_back(elem);
196 }
197 }
198 }
199
has_jit_decomposition(const FunctionSchema & schema)200 bool has_jit_decomposition(const FunctionSchema& schema) {
201 return GetDecompositionFunction(schema).has_value();
202 }
203
GetDecompositionExecutor(const FunctionSchema & schema)204 Function* GetDecompositionExecutor(const FunctionSchema& schema) {
205 auto maybe_func = GetDecompositionFunction(schema);
206 TORCH_INTERNAL_ASSERT(maybe_func);
207 return *maybe_func;
208 }
209
GetDecompositionExecutor(const char * schema_literal)210 Function* GetDecompositionExecutor(const char* schema_literal) {
211 auto& schema = getOperatorForLiteral(schema_literal)->schema();
212 return GetDecompositionExecutor(schema);
213 }
214
215 } // namespace torch::jit
216