xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/symbolic_shape_registry.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/Exception.h>
2 #include <torch/csrc/jit/frontend/ir_emitter.h>
3 #include <torch/csrc/jit/ir/ir_views.h>
4 #include <torch/csrc/jit/jit_log.h>
5 #include <torch/csrc/jit/passes/inliner.h>
6 #include <torch/csrc/jit/runtime/graph_iterator.h>
7 #include <torch/csrc/jit/runtime/operator.h>
8 #include <torch/csrc/jit/runtime/serialized_shape_function_registry.h>
9 #include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
10 #include <torch/csrc/jit/runtime/symbolic_shape_registry_util.h>
11 #include <torch/csrc/jit/serialization/import_source.h>
12 #include <unordered_map>
13 
14 namespace torch::jit {
15 namespace {
16 std::mutex lock;
17 
18 // split here to satisfy MSVC++
19 // https://docs.microsoft.com/en-us/cpp/error-messages/compiler-errors-1/compiler-error-c2026?view=msvc-170
20 const std::string _xnnpack_shape_compute_functions =
21 #ifdef USE_XNNPACK
22     R"(def prepacked_conv2d_clamp_run(input: List[int], conv2dOpContext: Any):
23     assert isinstance(conv2dOpContext, __torch__.torch.classes.xnnpack.Conv2dOpContext)
24     (weight, bias, stride, padding, dilation, groups) = unchecked_cast(
25         Tuple[List[int], Optional[List[int]], List[int], List[int], List[int], int],
26         ops.prepacked.unpack_prepacked_sizes_conv2d(conv2dOpContext),
27     )
28     return conv2d(input, weight, bias, stride, padding, dilation, groups)
29 
30 def prepacked_linear_clamp_run(input: List[int], linearOpContext: Any):
31     assert isinstance(linearOpContext, __torch__.torch.classes.xnnpack.LinearOpContext)
32     (weight, bias) = unchecked_cast(
33         Tuple[List[int], Optional[List[int]]],
34         ops.prepacked.unpack_prepacked_sizes_linear(linearOpContext),
35     )
36     return linear(input, weight, bias)
37     )"
38 #else
39     ""
40 #endif
41     ;
42 
43 // mapping function schema to shape compute graphs allows multiple functions to
44 // share the same shape compute graph, which is memory efficient and also will
45 // help speed up shape analysis by caching the result of running consecutive ops
46 // for a particular set of inputs with the same graph, e.g. running a series
47 // of pointwise ops
48 // we need a map from schema to shape compute graph, because the aten schema
49 // is not recoverable from the shape compute graph, since the shape compute
50 // graph replaces Tensor inputs with List[int] and there are operators like Conv
51 // which natively have List[int] inputs
52 // TODO: consider storing shape compute graph directly on operator,
53 // and merge into native_functions.yaml
54 
55 // wrapped in function so that operators get registered before map is
56 // initialized
57 // Conditionally defined ops not yet supported in python serialized
58 // operators
conditionally_defined_ops()59 static const OperatorMap<std::string>& conditionally_defined_ops() {
60   // clang-format off
61   static const OperatorMap<std::string> schema_to_function_graph{
62 #ifdef USE_XNNPACK
63       {"prepacked::conv2d_clamp_run(Tensor X, __torch__.torch.classes.xnnpack.Conv2dOpContext W_prepack) -> Tensor Y", "prepacked_conv2d_clamp_run"},
64       {"prepacked::linear_clamp_run(Tensor X, __torch__.torch.classes.xnnpack.LinearOpContext W_prepack) -> Tensor Y", "prepacked_linear_clamp_run"},
65 #endif
66   };
67   // clang-format on
68   return schema_to_function_graph;
69 }
70 
71 std::unordered_map<const FunctionSchema*, std::shared_ptr<Graph>>
72     cached_schema_to_graph;
73 
74 std::unordered_map<const FunctionSchema*, BoundedShapeGraphs>
75     cached_bounded_schema_to_graph;
76 
77 // CompilationUnit that holds all these Functions and keeps them alive.
78 auto compilation_unit = std::make_shared<CompilationUnit>();
79 
getInplaceVariant(const FunctionSchema & base_schema)80 const std::optional<const FunctionSchema*> getInplaceVariant(
81     const FunctionSchema& base_schema) {
82   auto& inplace_variants =
83       getAllOperatorsFor(c10::Symbol::fromQualString(base_schema.name() + "_"));
84 
85   for (const auto& variant : inplace_variants) {
86     // Need to check that all args are the same except for the first, which
87     // is almost the same except for the Alias info
88     const FunctionSchema* schema = &variant->schema();
89     if (!schema->isSubtypeOf(base_schema, false)) {
90       continue;
91     }
92 
93     Argument self_arg = schema->arguments()[0];
94     if (!self_arg.alias_info()->isWrite()) {
95       continue;
96     }
97 
98     Argument ret_arg = schema->returns()[0];
99     if (!ret_arg.alias_info()->isWrite()) {
100       continue;
101     }
102 
103     return schema;
104   }
105   return std::nullopt;
106 }
107 
mapTensorToListOfInts(TypePtr type)108 TypePtr mapTensorToListOfInts(TypePtr type) {
109   if (type->cast<TensorType>()) {
110     return ListType::ofInts();
111   }
112   at::ArrayRef<TypePtr> contained = type->containedTypes();
113   if (contained.empty()) {
114     return type;
115   }
116   return type->withContained(
117       fmap(type->containedTypes(), mapTensorToListOfInts));
118 }
119 
checkForWhileLoop(const FunctionSchema * schema,std::shared_ptr<Graph> graph)120 void checkForWhileLoop(
121     const FunctionSchema* schema,
122     std::shared_ptr<Graph> graph) {
123   DepthFirstGraphNodeIterator graph_it(graph);
124   for (auto* node = graph_it.next(); node != nullptr; node = graph_it.next()) {
125     if (node->kind() != prim::Loop) {
126       continue;
127     }
128     LoopView loop(node);
129     if (loop.loopType() != LoopView::For) {
130       TORCH_WARN(
131           "While loops are not yet implemented in unrolling which may make this shape function difficult to partially evaluate: ",
132           *node,
133           " for schema ",
134           *schema);
135     }
136   }
137 }
138 
checkInputReturnedAsOutput(const FunctionSchema * schema,const std::shared_ptr<Graph> & graph)139 void checkInputReturnedAsOutput(
140     const FunctionSchema* schema,
141     const std::shared_ptr<Graph>& graph) {
142   // Could use alias db here as well but would have to warn because it's
143   // imprecise
144   for (size_t i : c10::irange(graph->inputs().size())) {
145     Value* input = graph->inputs().at(i);
146     for (size_t j : c10::irange(graph->outputs().size())) {
147       Value* output = graph->outputs().at(j);
148       TORCH_CHECK(
149           input != output,
150           "For schema: ",
151           *schema,
152           " input index ",
153           i,
154           " is returned as output index ",
155           j,
156           ". Shape functions must return new unaliased lists");
157     }
158   }
159 }
160 
checkInputAndOutputTypes(const FunctionSchema * schema,const std::shared_ptr<Graph> & graph)161 void checkInputAndOutputTypes(
162     const FunctionSchema* schema,
163     const std::shared_ptr<Graph>& graph) {
164   // allow extra unused arguments to map multiple functions to e.g. unary
165   TORCH_CHECK(
166       graph->inputs().size() <= schema->arguments().size(),
167       "Shape function must have fewer arguments than schema. Got ",
168       graph->inputs().size(),
169       " graph arguments and ",
170       schema->arguments().size(),
171       " schema arguments of schema: ",
172       *schema);
173 
174   for (auto i : c10::irange(graph->inputs().size())) {
175     auto inp_type = schema->arguments().at(i).type();
176     auto mapped_type = mapTensorToListOfInts(inp_type);
177     auto graph_type = graph->inputs().at(i)->type();
178     TORCH_INTERNAL_ASSERT(
179         mapped_type->isSubtypeOf(graph->inputs().at(i)->type()),
180         "For schema type: ",
181         inp_type->str(),
182         " Expected supertype of ",
183         mapped_type->str(),
184         " but got graph_type ",
185         graph_type->str(),
186         " at index ",
187         i,
188         " of schema: ",
189         *schema);
190   }
191 
192   TORCH_CHECK(
193       graph->outputs().size() == schema->returns().size(),
194       "Shape function equal number of outputs as schema. Got ",
195       graph->outputs().size(),
196       " graph outputs and ",
197       schema->returns().size(),
198       " schema returns of schema: ",
199       *schema);
200 
201   for (auto i : c10::irange(schema->returns().size())) {
202     auto out_type = schema->returns().at(i).type();
203     auto mapped_type = mapTensorToListOfInts(out_type);
204     auto graph_type = graph->outputs().at(i)->type();
205     TORCH_INTERNAL_ASSERT(
206         mapped_type->isSubtypeOf(graph->outputs().at(i)->type()),
207         "For schema type: ",
208         out_type->str(),
209         " Expected supertype of ",
210         mapped_type->str(),
211         " but got graph_type ",
212         graph_type->str(),
213         " at output index ",
214         i,
215         " of schema: ",
216         *schema);
217   }
218 }
219 
transformShapeFunction(const FunctionSchema * schema_string,const std::shared_ptr<Graph> & graph)220 void transformShapeFunction(
221     const FunctionSchema* schema_string,
222     const std::shared_ptr<Graph>& graph) {
223   Inline(*graph);
224 
225   // ATEN operators can return multiple unboxed values, this in contrast to
226   // functions defined in TorchScript or User-Registered Operators
227   // Which must use a Tuple
228   // Here, modify the shape graph of aten operators with multiple outputs
229   // so that they correspond to each other
230   if (schema_string->returns().size() > 1) {
231     TORCH_INTERNAL_ASSERT(
232         graph->outputs().size() == 1 &&
233         graph->outputs().at(0)->type()->cast<TupleType>());
234     auto tuple_node = graph->outputs().at(0)->node();
235     WithInsertPoint guard(graph->return_node());
236     auto tuple_unpack_values = createTupleUnpack(tuple_node->output());
237     graph->eraseOutput(0);
238     for (Value* v : tuple_unpack_values) {
239       graph->registerOutput(v);
240     }
241     GRAPH_DUMP("After Output Tuple Unpacking", graph);
242   }
243 }
244 
genShapeComputeFn(const FunctionSchema * schema_string,const std::string & shape_compute_function_name,std::unordered_map<std::string,std::shared_ptr<Graph>> & reused_functions,const CompilationUnit & module)245 std::shared_ptr<Graph> genShapeComputeFn(
246     const FunctionSchema* schema_string,
247     const std::string& shape_compute_function_name,
248     std::unordered_map<std::string, std::shared_ptr<Graph>>& reused_functions,
249     const CompilationUnit& module) {
250   std::shared_ptr<Graph> graph;
251   GRAPH_DEBUG(
252       "Registering schema: ",
253       *schema_string,
254       " with shape compute func: ",
255       shape_compute_function_name);
256   if (reused_functions.count(shape_compute_function_name)) {
257     GRAPH_DEBUG("Registering reused schema");
258     graph = reused_functions[shape_compute_function_name];
259   } else {
260     Function& shape_compute_function =
261         module.get_function(shape_compute_function_name);
262     graph = toGraphFunction(shape_compute_function).graph();
263 
264     transformShapeFunction(schema_string, graph);
265     // NB: we lint the shape functions registered in source
266     // in a test file
267     // LintShapeComputeGraph(schema_string, graph);
268 
269     reused_functions[shape_compute_function_name] = graph;
270   }
271   // allow extra unused arguments to map multiple functions to e.g. unary
272   TORCH_INTERNAL_ASSERT(
273       graph->inputs().size() <= schema_string->arguments().size());
274   return graph;
275 }
276 
registerSchema(const FunctionSchema * schema_string,const std::string & shape_compute_function_name,std::unordered_map<std::string,std::shared_ptr<Graph>> & reused_functions,const CompilationUnit & module)277 void registerSchema(
278     const FunctionSchema* schema_string,
279     const std::string& shape_compute_function_name,
280     std::unordered_map<std::string, std::shared_ptr<Graph>>& reused_functions,
281     const CompilationUnit& module) {
282   auto graph = genShapeComputeFn(
283       schema_string, shape_compute_function_name, reused_functions, module);
284 
285   cached_schema_to_graph[schema_string] = graph;
286 }
287 
registerBoundedSchema(const FunctionSchema * schema_string,const std::string & lower_bound_function_name,const std::string & upper_bound_function_name,std::unordered_map<std::string,std::shared_ptr<Graph>> & reused_functions,const CompilationUnit & module)288 void registerBoundedSchema(
289     const FunctionSchema* schema_string,
290     const std::string& lower_bound_function_name,
291     const std::string& upper_bound_function_name,
292     std::unordered_map<std::string, std::shared_ptr<Graph>>& reused_functions,
293     const CompilationUnit& module) {
294   auto lower_graph = genShapeComputeFn(
295       schema_string, lower_bound_function_name, reused_functions, module);
296   auto upper_graph = genShapeComputeFn(
297       schema_string, upper_bound_function_name, reused_functions, module);
298   cached_bounded_schema_to_graph[schema_string] = {lower_graph, upper_graph};
299 }
300 
loadModule(const CompilationUnit & module)301 void loadModule(const CompilationUnit& module) {
302   std::unordered_map<std::string, std::shared_ptr<Graph>> reused_functions;
303 
304   std::vector<std::pair<std::shared_ptr<Operator>, std::string>>
305       operator_pairs = conditionally_defined_ops().getAllKeysAndValues();
306   auto te_ops = get_tensorexpr_elementwise_set().getAllKeysAndValues();
307   operator_pairs.insert(operator_pairs.end(), te_ops.begin(), te_ops.end());
308   auto more_mappings = GetShapeFunctionMappings().getAllKeysAndValues();
309   operator_pairs.insert(
310       operator_pairs.end(), more_mappings.begin(), more_mappings.end());
311 
312   for (const auto& pair : operator_pairs) {
313     const FunctionSchema* schema_string = &pair.first->schema();
314     const std::string& shape_compute_function_name = pair.second;
315 
316     registerSchema(
317         schema_string, shape_compute_function_name, reused_functions, module);
318 
319     // Register the inplace variant if any for functions with common shape forms
320     if (shape_compute_function_name == "unary") {
321       auto inplace_schema = getInplaceVariant(*schema_string);
322       if (inplace_schema.has_value()) {
323         registerSchema(
324             inplace_schema.value(), "unary", reused_functions, module);
325       }
326     }
327     if (shape_compute_function_name == "broadcast") {
328       auto inplace_schema = getInplaceVariant(*schema_string);
329       if (inplace_schema.has_value()) {
330         registerSchema(
331             inplace_schema.value(),
332             "broadcast_inplace",
333             reused_functions,
334             module);
335       }
336     }
337   }
338 
339   // Now register the bounded schemas
340   for (const auto& pair : GetBoundedShapeMappings().getAllKeysAndValues()) {
341     const FunctionSchema* schema_string = &pair.first->schema();
342     const std::string& lower_bound_function_name = pair.second.first;
343     const std::string& upper_bound_function_name = pair.second.second;
344 
345     registerBoundedSchema(
346         schema_string,
347         lower_bound_function_name,
348         upper_bound_function_name,
349         reused_functions,
350         module);
351   }
352 }
353 
loadFunctions()354 void loadFunctions() {
355   try {
356     auto shape_compute_functions =
357         GetSerializedShapeFunctions() + _xnnpack_shape_compute_functions;
358 
359     auto src = std::make_shared<Source>(shape_compute_functions);
360     std::stringstream ss;
361     std::vector<at::IValue> constantTable;
362     auto resolver = std::make_shared<SourceImporterImpl>(
363         compilation_unit,
364         &constantTable,
365         [&](const std::string& name) -> std::shared_ptr<Source> { return src; },
366         1);
367     compilation_unit->define(
368         std::nullopt, shape_compute_functions, resolver, nullptr);
369     loadModule(*compilation_unit);
370   } catch (...) {
371     // Reset the cache and compilation unit so that we don't get weird errors
372     // in later tests when one of the shape functions is invalid.
373     compilation_unit = std::make_shared<CompilationUnit>();
374     cached_schema_to_graph.clear();
375     throw;
376   }
377 }
378 } // anonymous namespace
379 
shapeComputeGraphForSchema(const FunctionSchema & schema)380 std::optional<std::shared_ptr<Graph>> shapeComputeGraphForSchema(
381     const FunctionSchema& schema) {
382   std::lock_guard<std::mutex> guard(lock);
383   if (cached_schema_to_graph.empty()) {
384     loadFunctions();
385   }
386 
387   GRAPH_DEBUG("Trying to find schema: ", schema);
388   auto cache_it = cached_schema_to_graph.find(&schema);
389   if (cache_it != cached_schema_to_graph.end()) {
390     return cache_it->second;
391   }
392   GRAPH_DEBUG("Could not find schema: ", schema);
393 
394   return std::nullopt;
395 }
396 
boundedGraphsForSchema(const FunctionSchema & schema)397 TORCH_API std::optional<BoundedShapeGraphs> boundedGraphsForSchema(
398     const FunctionSchema& schema) {
399   std::lock_guard<std::mutex> guard(lock);
400   if (cached_bounded_schema_to_graph.empty()) {
401     loadFunctions();
402   }
403   GRAPH_DEBUG("Trying to find schema in bounded graphs: ", schema);
404   auto cache_it = cached_bounded_schema_to_graph.find(&schema);
405   if (cache_it != cached_bounded_schema_to_graph.end()) {
406     return cache_it->second;
407   }
408 
409   return std::nullopt;
410 }
411 
RegisterShapeComputeGraphForSchema(const FunctionSchema & schema,const std::shared_ptr<Graph> & g)412 void RegisterShapeComputeGraphForSchema(
413     const FunctionSchema& schema,
414     const std::shared_ptr<Graph>& g) {
415   std::lock_guard<std::mutex> guard(lock);
416   if (cached_schema_to_graph.empty()) {
417     loadFunctions();
418   }
419   transformShapeFunction(&schema, g);
420   LintShapeComputeGraph(&schema, g);
421 
422   cached_schema_to_graph[&schema] = g;
423 }
424 
RegisteredShapeComputeSchemas()425 std::vector<const FunctionSchema*> RegisteredShapeComputeSchemas() {
426   std::lock_guard<std::mutex> guard(lock);
427   if (cached_schema_to_graph.empty()) {
428     loadFunctions();
429   }
430 
431   std::vector<const FunctionSchema*> schemas;
432   schemas.reserve(cached_schema_to_graph.size());
433   for (const auto& pair : cached_schema_to_graph) {
434     schemas.push_back(pair.first);
435   }
436   return schemas;
437 }
438 
LintShapeComputeGraph(const FunctionSchema * schema,const std::shared_ptr<Graph> & graph)439 void LintShapeComputeGraph(
440     const FunctionSchema* schema,
441     const std::shared_ptr<Graph>& graph) {
442   checkInputAndOutputTypes(schema, graph);
443   checkForWhileLoop(schema, graph);
444   checkInputReturnedAsOutput(schema, graph);
445   // TODO: other checks ? list ops which we don't symbolically optimize, etc ?
446 }
447 
448 } // namespace torch::jit
449