1 #include <c10/util/Exception.h>
2 #include <torch/csrc/jit/frontend/ir_emitter.h>
3 #include <torch/csrc/jit/ir/ir_views.h>
4 #include <torch/csrc/jit/jit_log.h>
5 #include <torch/csrc/jit/passes/inliner.h>
6 #include <torch/csrc/jit/runtime/graph_iterator.h>
7 #include <torch/csrc/jit/runtime/operator.h>
8 #include <torch/csrc/jit/runtime/serialized_shape_function_registry.h>
9 #include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
10 #include <torch/csrc/jit/runtime/symbolic_shape_registry_util.h>
11 #include <torch/csrc/jit/serialization/import_source.h>
12 #include <unordered_map>
13
14 namespace torch::jit {
15 namespace {
16 std::mutex lock;
17
18 // split here to satisfy MSVC++
19 // https://docs.microsoft.com/en-us/cpp/error-messages/compiler-errors-1/compiler-error-c2026?view=msvc-170
20 const std::string _xnnpack_shape_compute_functions =
21 #ifdef USE_XNNPACK
22 R"(def prepacked_conv2d_clamp_run(input: List[int], conv2dOpContext: Any):
23 assert isinstance(conv2dOpContext, __torch__.torch.classes.xnnpack.Conv2dOpContext)
24 (weight, bias, stride, padding, dilation, groups) = unchecked_cast(
25 Tuple[List[int], Optional[List[int]], List[int], List[int], List[int], int],
26 ops.prepacked.unpack_prepacked_sizes_conv2d(conv2dOpContext),
27 )
28 return conv2d(input, weight, bias, stride, padding, dilation, groups)
29
30 def prepacked_linear_clamp_run(input: List[int], linearOpContext: Any):
31 assert isinstance(linearOpContext, __torch__.torch.classes.xnnpack.LinearOpContext)
32 (weight, bias) = unchecked_cast(
33 Tuple[List[int], Optional[List[int]]],
34 ops.prepacked.unpack_prepacked_sizes_linear(linearOpContext),
35 )
36 return linear(input, weight, bias)
37 )"
38 #else
39 ""
40 #endif
41 ;
42
43 // mapping function schema to shape compute graphs allows multiple functions to
44 // share the same shape compute graph, which is memory efficient and also will
45 // help speed up shape analysis by caching the result of running consecutive ops
46 // for a particular set of inputs with the same graph, e.g. running a series
47 // of pointwise ops
48 // we need a map from schema to shape compute graph, because the aten schema
49 // is not recoverable from the shape compute graph, since the shape compute
50 // graph replaces Tensor inputs with List[int] and there are operators like Conv
51 // which natively have List[int] inputs
52 // TODO: consider storing shape compute graph directly on operator,
53 // and merge into native_functions.yaml
54
55 // wrapped in function so that operators get registered before map is
56 // initialized
57 // Conditionally defined ops not yet supported in python serialized
58 // operators
conditionally_defined_ops()59 static const OperatorMap<std::string>& conditionally_defined_ops() {
60 // clang-format off
61 static const OperatorMap<std::string> schema_to_function_graph{
62 #ifdef USE_XNNPACK
63 {"prepacked::conv2d_clamp_run(Tensor X, __torch__.torch.classes.xnnpack.Conv2dOpContext W_prepack) -> Tensor Y", "prepacked_conv2d_clamp_run"},
64 {"prepacked::linear_clamp_run(Tensor X, __torch__.torch.classes.xnnpack.LinearOpContext W_prepack) -> Tensor Y", "prepacked_linear_clamp_run"},
65 #endif
66 };
67 // clang-format on
68 return schema_to_function_graph;
69 }
70
71 std::unordered_map<const FunctionSchema*, std::shared_ptr<Graph>>
72 cached_schema_to_graph;
73
74 std::unordered_map<const FunctionSchema*, BoundedShapeGraphs>
75 cached_bounded_schema_to_graph;
76
77 // CompilationUnit that holds all these Functions and keeps them alive.
78 auto compilation_unit = std::make_shared<CompilationUnit>();
79
getInplaceVariant(const FunctionSchema & base_schema)80 const std::optional<const FunctionSchema*> getInplaceVariant(
81 const FunctionSchema& base_schema) {
82 auto& inplace_variants =
83 getAllOperatorsFor(c10::Symbol::fromQualString(base_schema.name() + "_"));
84
85 for (const auto& variant : inplace_variants) {
86 // Need to check that all args are the same except for the first, which
87 // is almost the same except for the Alias info
88 const FunctionSchema* schema = &variant->schema();
89 if (!schema->isSubtypeOf(base_schema, false)) {
90 continue;
91 }
92
93 Argument self_arg = schema->arguments()[0];
94 if (!self_arg.alias_info()->isWrite()) {
95 continue;
96 }
97
98 Argument ret_arg = schema->returns()[0];
99 if (!ret_arg.alias_info()->isWrite()) {
100 continue;
101 }
102
103 return schema;
104 }
105 return std::nullopt;
106 }
107
mapTensorToListOfInts(TypePtr type)108 TypePtr mapTensorToListOfInts(TypePtr type) {
109 if (type->cast<TensorType>()) {
110 return ListType::ofInts();
111 }
112 at::ArrayRef<TypePtr> contained = type->containedTypes();
113 if (contained.empty()) {
114 return type;
115 }
116 return type->withContained(
117 fmap(type->containedTypes(), mapTensorToListOfInts));
118 }
119
checkForWhileLoop(const FunctionSchema * schema,std::shared_ptr<Graph> graph)120 void checkForWhileLoop(
121 const FunctionSchema* schema,
122 std::shared_ptr<Graph> graph) {
123 DepthFirstGraphNodeIterator graph_it(graph);
124 for (auto* node = graph_it.next(); node != nullptr; node = graph_it.next()) {
125 if (node->kind() != prim::Loop) {
126 continue;
127 }
128 LoopView loop(node);
129 if (loop.loopType() != LoopView::For) {
130 TORCH_WARN(
131 "While loops are not yet implemented in unrolling which may make this shape function difficult to partially evaluate: ",
132 *node,
133 " for schema ",
134 *schema);
135 }
136 }
137 }
138
checkInputReturnedAsOutput(const FunctionSchema * schema,const std::shared_ptr<Graph> & graph)139 void checkInputReturnedAsOutput(
140 const FunctionSchema* schema,
141 const std::shared_ptr<Graph>& graph) {
142 // Could use alias db here as well but would have to warn because it's
143 // imprecise
144 for (size_t i : c10::irange(graph->inputs().size())) {
145 Value* input = graph->inputs().at(i);
146 for (size_t j : c10::irange(graph->outputs().size())) {
147 Value* output = graph->outputs().at(j);
148 TORCH_CHECK(
149 input != output,
150 "For schema: ",
151 *schema,
152 " input index ",
153 i,
154 " is returned as output index ",
155 j,
156 ". Shape functions must return new unaliased lists");
157 }
158 }
159 }
160
checkInputAndOutputTypes(const FunctionSchema * schema,const std::shared_ptr<Graph> & graph)161 void checkInputAndOutputTypes(
162 const FunctionSchema* schema,
163 const std::shared_ptr<Graph>& graph) {
164 // allow extra unused arguments to map multiple functions to e.g. unary
165 TORCH_CHECK(
166 graph->inputs().size() <= schema->arguments().size(),
167 "Shape function must have fewer arguments than schema. Got ",
168 graph->inputs().size(),
169 " graph arguments and ",
170 schema->arguments().size(),
171 " schema arguments of schema: ",
172 *schema);
173
174 for (auto i : c10::irange(graph->inputs().size())) {
175 auto inp_type = schema->arguments().at(i).type();
176 auto mapped_type = mapTensorToListOfInts(inp_type);
177 auto graph_type = graph->inputs().at(i)->type();
178 TORCH_INTERNAL_ASSERT(
179 mapped_type->isSubtypeOf(graph->inputs().at(i)->type()),
180 "For schema type: ",
181 inp_type->str(),
182 " Expected supertype of ",
183 mapped_type->str(),
184 " but got graph_type ",
185 graph_type->str(),
186 " at index ",
187 i,
188 " of schema: ",
189 *schema);
190 }
191
192 TORCH_CHECK(
193 graph->outputs().size() == schema->returns().size(),
194 "Shape function equal number of outputs as schema. Got ",
195 graph->outputs().size(),
196 " graph outputs and ",
197 schema->returns().size(),
198 " schema returns of schema: ",
199 *schema);
200
201 for (auto i : c10::irange(schema->returns().size())) {
202 auto out_type = schema->returns().at(i).type();
203 auto mapped_type = mapTensorToListOfInts(out_type);
204 auto graph_type = graph->outputs().at(i)->type();
205 TORCH_INTERNAL_ASSERT(
206 mapped_type->isSubtypeOf(graph->outputs().at(i)->type()),
207 "For schema type: ",
208 out_type->str(),
209 " Expected supertype of ",
210 mapped_type->str(),
211 " but got graph_type ",
212 graph_type->str(),
213 " at output index ",
214 i,
215 " of schema: ",
216 *schema);
217 }
218 }
219
transformShapeFunction(const FunctionSchema * schema_string,const std::shared_ptr<Graph> & graph)220 void transformShapeFunction(
221 const FunctionSchema* schema_string,
222 const std::shared_ptr<Graph>& graph) {
223 Inline(*graph);
224
225 // ATEN operators can return multiple unboxed values, this in contrast to
226 // functions defined in TorchScript or User-Registered Operators
227 // Which must use a Tuple
228 // Here, modify the shape graph of aten operators with multiple outputs
229 // so that they correspond to each other
230 if (schema_string->returns().size() > 1) {
231 TORCH_INTERNAL_ASSERT(
232 graph->outputs().size() == 1 &&
233 graph->outputs().at(0)->type()->cast<TupleType>());
234 auto tuple_node = graph->outputs().at(0)->node();
235 WithInsertPoint guard(graph->return_node());
236 auto tuple_unpack_values = createTupleUnpack(tuple_node->output());
237 graph->eraseOutput(0);
238 for (Value* v : tuple_unpack_values) {
239 graph->registerOutput(v);
240 }
241 GRAPH_DUMP("After Output Tuple Unpacking", graph);
242 }
243 }
244
genShapeComputeFn(const FunctionSchema * schema_string,const std::string & shape_compute_function_name,std::unordered_map<std::string,std::shared_ptr<Graph>> & reused_functions,const CompilationUnit & module)245 std::shared_ptr<Graph> genShapeComputeFn(
246 const FunctionSchema* schema_string,
247 const std::string& shape_compute_function_name,
248 std::unordered_map<std::string, std::shared_ptr<Graph>>& reused_functions,
249 const CompilationUnit& module) {
250 std::shared_ptr<Graph> graph;
251 GRAPH_DEBUG(
252 "Registering schema: ",
253 *schema_string,
254 " with shape compute func: ",
255 shape_compute_function_name);
256 if (reused_functions.count(shape_compute_function_name)) {
257 GRAPH_DEBUG("Registering reused schema");
258 graph = reused_functions[shape_compute_function_name];
259 } else {
260 Function& shape_compute_function =
261 module.get_function(shape_compute_function_name);
262 graph = toGraphFunction(shape_compute_function).graph();
263
264 transformShapeFunction(schema_string, graph);
265 // NB: we lint the shape functions registered in source
266 // in a test file
267 // LintShapeComputeGraph(schema_string, graph);
268
269 reused_functions[shape_compute_function_name] = graph;
270 }
271 // allow extra unused arguments to map multiple functions to e.g. unary
272 TORCH_INTERNAL_ASSERT(
273 graph->inputs().size() <= schema_string->arguments().size());
274 return graph;
275 }
276
registerSchema(const FunctionSchema * schema_string,const std::string & shape_compute_function_name,std::unordered_map<std::string,std::shared_ptr<Graph>> & reused_functions,const CompilationUnit & module)277 void registerSchema(
278 const FunctionSchema* schema_string,
279 const std::string& shape_compute_function_name,
280 std::unordered_map<std::string, std::shared_ptr<Graph>>& reused_functions,
281 const CompilationUnit& module) {
282 auto graph = genShapeComputeFn(
283 schema_string, shape_compute_function_name, reused_functions, module);
284
285 cached_schema_to_graph[schema_string] = graph;
286 }
287
registerBoundedSchema(const FunctionSchema * schema_string,const std::string & lower_bound_function_name,const std::string & upper_bound_function_name,std::unordered_map<std::string,std::shared_ptr<Graph>> & reused_functions,const CompilationUnit & module)288 void registerBoundedSchema(
289 const FunctionSchema* schema_string,
290 const std::string& lower_bound_function_name,
291 const std::string& upper_bound_function_name,
292 std::unordered_map<std::string, std::shared_ptr<Graph>>& reused_functions,
293 const CompilationUnit& module) {
294 auto lower_graph = genShapeComputeFn(
295 schema_string, lower_bound_function_name, reused_functions, module);
296 auto upper_graph = genShapeComputeFn(
297 schema_string, upper_bound_function_name, reused_functions, module);
298 cached_bounded_schema_to_graph[schema_string] = {lower_graph, upper_graph};
299 }
300
loadModule(const CompilationUnit & module)301 void loadModule(const CompilationUnit& module) {
302 std::unordered_map<std::string, std::shared_ptr<Graph>> reused_functions;
303
304 std::vector<std::pair<std::shared_ptr<Operator>, std::string>>
305 operator_pairs = conditionally_defined_ops().getAllKeysAndValues();
306 auto te_ops = get_tensorexpr_elementwise_set().getAllKeysAndValues();
307 operator_pairs.insert(operator_pairs.end(), te_ops.begin(), te_ops.end());
308 auto more_mappings = GetShapeFunctionMappings().getAllKeysAndValues();
309 operator_pairs.insert(
310 operator_pairs.end(), more_mappings.begin(), more_mappings.end());
311
312 for (const auto& pair : operator_pairs) {
313 const FunctionSchema* schema_string = &pair.first->schema();
314 const std::string& shape_compute_function_name = pair.second;
315
316 registerSchema(
317 schema_string, shape_compute_function_name, reused_functions, module);
318
319 // Register the inplace variant if any for functions with common shape forms
320 if (shape_compute_function_name == "unary") {
321 auto inplace_schema = getInplaceVariant(*schema_string);
322 if (inplace_schema.has_value()) {
323 registerSchema(
324 inplace_schema.value(), "unary", reused_functions, module);
325 }
326 }
327 if (shape_compute_function_name == "broadcast") {
328 auto inplace_schema = getInplaceVariant(*schema_string);
329 if (inplace_schema.has_value()) {
330 registerSchema(
331 inplace_schema.value(),
332 "broadcast_inplace",
333 reused_functions,
334 module);
335 }
336 }
337 }
338
339 // Now register the bounded schemas
340 for (const auto& pair : GetBoundedShapeMappings().getAllKeysAndValues()) {
341 const FunctionSchema* schema_string = &pair.first->schema();
342 const std::string& lower_bound_function_name = pair.second.first;
343 const std::string& upper_bound_function_name = pair.second.second;
344
345 registerBoundedSchema(
346 schema_string,
347 lower_bound_function_name,
348 upper_bound_function_name,
349 reused_functions,
350 module);
351 }
352 }
353
loadFunctions()354 void loadFunctions() {
355 try {
356 auto shape_compute_functions =
357 GetSerializedShapeFunctions() + _xnnpack_shape_compute_functions;
358
359 auto src = std::make_shared<Source>(shape_compute_functions);
360 std::stringstream ss;
361 std::vector<at::IValue> constantTable;
362 auto resolver = std::make_shared<SourceImporterImpl>(
363 compilation_unit,
364 &constantTable,
365 [&](const std::string& name) -> std::shared_ptr<Source> { return src; },
366 1);
367 compilation_unit->define(
368 std::nullopt, shape_compute_functions, resolver, nullptr);
369 loadModule(*compilation_unit);
370 } catch (...) {
371 // Reset the cache and compilation unit so that we don't get weird errors
372 // in later tests when one of the shape functions is invalid.
373 compilation_unit = std::make_shared<CompilationUnit>();
374 cached_schema_to_graph.clear();
375 throw;
376 }
377 }
378 } // anonymous namespace
379
shapeComputeGraphForSchema(const FunctionSchema & schema)380 std::optional<std::shared_ptr<Graph>> shapeComputeGraphForSchema(
381 const FunctionSchema& schema) {
382 std::lock_guard<std::mutex> guard(lock);
383 if (cached_schema_to_graph.empty()) {
384 loadFunctions();
385 }
386
387 GRAPH_DEBUG("Trying to find schema: ", schema);
388 auto cache_it = cached_schema_to_graph.find(&schema);
389 if (cache_it != cached_schema_to_graph.end()) {
390 return cache_it->second;
391 }
392 GRAPH_DEBUG("Could not find schema: ", schema);
393
394 return std::nullopt;
395 }
396
boundedGraphsForSchema(const FunctionSchema & schema)397 TORCH_API std::optional<BoundedShapeGraphs> boundedGraphsForSchema(
398 const FunctionSchema& schema) {
399 std::lock_guard<std::mutex> guard(lock);
400 if (cached_bounded_schema_to_graph.empty()) {
401 loadFunctions();
402 }
403 GRAPH_DEBUG("Trying to find schema in bounded graphs: ", schema);
404 auto cache_it = cached_bounded_schema_to_graph.find(&schema);
405 if (cache_it != cached_bounded_schema_to_graph.end()) {
406 return cache_it->second;
407 }
408
409 return std::nullopt;
410 }
411
RegisterShapeComputeGraphForSchema(const FunctionSchema & schema,const std::shared_ptr<Graph> & g)412 void RegisterShapeComputeGraphForSchema(
413 const FunctionSchema& schema,
414 const std::shared_ptr<Graph>& g) {
415 std::lock_guard<std::mutex> guard(lock);
416 if (cached_schema_to_graph.empty()) {
417 loadFunctions();
418 }
419 transformShapeFunction(&schema, g);
420 LintShapeComputeGraph(&schema, g);
421
422 cached_schema_to_graph[&schema] = g;
423 }
424
RegisteredShapeComputeSchemas()425 std::vector<const FunctionSchema*> RegisteredShapeComputeSchemas() {
426 std::lock_guard<std::mutex> guard(lock);
427 if (cached_schema_to_graph.empty()) {
428 loadFunctions();
429 }
430
431 std::vector<const FunctionSchema*> schemas;
432 schemas.reserve(cached_schema_to_graph.size());
433 for (const auto& pair : cached_schema_to_graph) {
434 schemas.push_back(pair.first);
435 }
436 return schemas;
437 }
438
LintShapeComputeGraph(const FunctionSchema * schema,const std::shared_ptr<Graph> & graph)439 void LintShapeComputeGraph(
440 const FunctionSchema* schema,
441 const std::shared_ptr<Graph>& graph) {
442 checkInputAndOutputTypes(schema, graph);
443 checkForWhileLoop(schema, graph);
444 checkInputReturnedAsOutput(schema, graph);
445 // TODO: other checks ? list ops which we don't symbolically optimize, etc ?
446 }
447
448 } // namespace torch::jit
449