xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/nnc/aot_compiler.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/mobile/nnc/aot_compiler.h>
2 
3 #include <ATen/Functions.h>
4 #include <ATen/NativeFunctions.h>
5 #include <torch/csrc/jit/backends/backend.h>
6 #include <torch/csrc/jit/backends/backend_detail.h>
7 #include <torch/csrc/jit/backends/backend_preprocess.h>
8 #include <torch/csrc/jit/ir/ir.h>
9 #include <torch/csrc/jit/jit_log.h>
10 #include <torch/csrc/jit/passes/constant_propagation.h>
11 #include <torch/csrc/jit/passes/dead_code_elimination.h>
12 #include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
13 #include <torch/csrc/jit/passes/lower_tuples.h>
14 #include <torch/csrc/jit/passes/peephole.h>
15 #include <torch/csrc/jit/passes/remove_mutation.h>
16 #include <torch/csrc/jit/passes/shape_analysis.h>
17 #include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
18 #include <torch/csrc/jit/runtime/jit_trace.h>
19 #include <torch/csrc/jit/tensorexpr/graph_opt.h>
20 #include <torch/csrc/jit/tensorexpr/ir.h>
21 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
22 #include <torch/csrc/jit/tensorexpr/kernel.h>
23 #include <fstream>
24 
25 using namespace torch::jit;
26 using namespace torch::jit::tensorexpr;
27 
28 namespace torch {
29 namespace jit {
30 namespace mobile {
31 namespace nnc {
32 
33 // TODO(mvz): temporarily disable NNC backend in mobile builds.
34 /*
35 static std::vector<int64_t> getConstSizes(const BufPtr b) {
36   std::vector<int64_t> r;
37   for (const auto& dim : b->dims()) {
38     LongImmPtr imm_dim = to<LongImm>(dim);
39     // TODO: assert it's actually immediate
40     int64_t s = imm_dim->value();
41     r.push_back(s);
42   }
43   return r;
44 }
45 
46 // Construct input-specs vector from the inputs of the original graph
47 static std::vector<mobile::nnc::InputSpec> toInputSpecs(
48     const std::shared_ptr<tensorexpr::TensorExprKernel>& kernel) {
49   const std::shared_ptr<Graph>& g = kernel->graph();
50   std::vector<mobile::nnc::InputSpec> specs;
51 
52   // Graph inputs include scalar values for symbolic shapes, for which we
53   // don't need input specs. These scalar values come last among the graph
54   // inputs
55   auto num_inputs =
56       g->inputs().size() - kernel->getSymbolicShapeInputs().size();
57 
58   for (const auto i : c10::irange(num_inputs)) {
59     auto v = g->inputs()[i];
60     const auto& t = v->type();
61     mobile::nnc::InputSpec spec;
62     TORCH_CHECK(t->kind() == TypeKind::TensorType, "Unsupported input type");
63     const auto& tt = t->cast<TensorType>();
64     spec.sizes_ = {};
65     auto sizes_vec = *tt->sizes().sizes();
66     for (auto s : sizes_vec) {
67       spec.sizes_.push_back(s ? *s : 0);
68     }
69     spec.dtype_ = *tt->scalarType();
70     specs.emplace_back(std::move(spec));
71   }
72   return specs;
73 }
74 
75 // Locate symbolic shapes in shapes of the inputs.
76 //
77 // For each symbolic shape we're trying to find the input from which it can be
78 // extracted and the dimension index in that input.
79 // For instance, if we have
80 // graph(%x : Float(SS(-1), 10), %y : Long(20, SS(-2), %ss_1 : int, %ss_2 : int)
81 // then we would need to find locations of two symbolic shapes: SS(-1) and
82 // SS(-2). The first one corresponds to the first dimension of the first input,
83 // the second one corresponds to the second dimension of the second input,
84 // so we will return {{0, 0}, {1, 1}}.
85 //
86 // If a symbolic shape cannot be found among dimensions of inputs, we
87 // will throw an error (this situation is possible when symbolic shape
88 // corresponds to the size of an intermediate - we don't support this
89 // case here yet).
90 //
91 // If a symbolic shape can be found in several different positions, we
92 // return the first one we find (TODO: maybe we should return all and
93 // verify that they all match at runtime).
94 static std::vector<SymbolicShapePosition> findSymbolicShapePositions(
95     std::shared_ptr<tensorexpr::TensorExprKernel> kernel) {
96   std::vector<SymbolicShapePosition> res;
97   for (int64_t sym_idx : kernel->getSymbolicShapeInputs()) {
98     bool found = false;
99     for (int64_t input_idx : c10::irange(kernel->graph()->inputs().size())) {
100       auto input = kernel->graph()->inputs()[input_idx];
101 
102       if (!input->type()->cast<TensorType>()) {
103         continue;
104       }
105       auto tt = input->type()->expect<TensorType>();
106       if (!tt->symbolic_sizes().sizes()) {
107         continue;
108       }
109       std::vector<at::ShapeSymbol> shape_vec = *tt->symbolic_sizes().sizes();
110       for (int64_t dim_idx : c10::irange(shape_vec.size())) {
111         if (shape_vec[dim_idx].value() == sym_idx) {
112           res.emplace_back(input_idx, dim_idx);
113           found = true;
114           break;
115         }
116       }
117       if (found) {
118         break;
119       }
120     }
121     TORCH_CHECK(
122         found, "Could not locate a symbolic shape among input tensor shapes");
123   }
124   return res;
125 }
126 
127 static std::unique_ptr<Function> compileMethod(
128     std::shared_ptr<tensorexpr::TensorExprKernel> kernel,
129     const std::string& method_name,
130     const std::vector<std::vector<int64_t>>& sizes,
131     const std::vector<at::ScalarType>& types) {
132   auto func = std::make_unique<Function>();
133   func->set_name(method_name);
134   func->set_input_specs(toInputSpecs(kernel));
135 
136   auto params = c10::impl::GenericList(c10::AnyType::get());
137   auto const_descriptors = kernel->getConstantDescriptors();
138   for (const auto& cd : const_descriptors) {
139     auto sizes = getConstSizes(cd.buf);
140     if (!cd.node) {
141       // sizes.empty() needs to be handled as sizes can be empty for Scalar
142       // Tensors
143       at::Tensor const_tensor = !sizes.empty()
144           ? at::from_blob(cd.ptr, sizes).clone()
145           : at::native::wrapped_scalar_tensor(*static_cast<double*>(cd.ptr));
146       params.push_back(const_tensor);
147     } else {
148       params.emplace_back(toIValue(cd.node->output()));
149     }
150   }
151   func->set_parameters(params);
152 
153   MemoryPlan plan;
154   plan.buffer_sizes_ = {}; // temp_sizes_;
155   // TODO: implement prealloc optimization and fill in temp_sizes
156   func->set_memory_plan(plan);
157 
158   int64_t n_inputs = kernel->graph()->inputs().size();
159   int64_t n_outputs = kernel->graph()->outputs().size();
160   std::vector<OutputSpec> out_spec;
161   for (int64_t idx = n_inputs; idx < n_inputs + n_outputs; idx++) {
162     const auto& ba = kernel->getBufferArgs()[idx];
163     OutputSpec output;
164     output.sizes_ = getConstSizes(ba.buf());
165     // TODO: assert the output is a buffer and not a scalar
166     output.dtype_ = ba.buf()->dtype().scalar_type();
167     if (isQIntType(output.dtype_)) {
168       // Supporting only static qscale/qzero
169       output.qscale_ =
170           to<DoubleImm>(torch::jit::tensorexpr::IRSimplifier::simplify(
171                             ba.buf()->qscale()))
172               ->value();
173       output.qzero_ =
174           to<LongImm>(
175               torch::jit::tensorexpr::IRSimplifier::simplify(ba.buf()->qzero()))
176               ->value();
177     }
178     out_spec.push_back(output);
179   }
180   func->set_output_specs(out_spec);
181   func->set_sym_shape_positions(findSymbolicShapePositions(kernel));
182 
183   return func;
184 }
185 
186 static std::pair<std::unique_ptr<Function>, const std::string> aotCompile(
187     const std::string& method_name,
188     std::shared_ptr<Graph>& g,
189     const std::vector<std::vector<int64_t>>& sizes,
190     const std::vector<at::ScalarType>& types,
191     const std::string& kernel_func_name,
192     const std::vector<int64_t>& symbolic_ind) {
193   GRAPH_DEBUG("Input sizes ", sizes);
194   GRAPH_DEBUG("Input types ", types);
195   GRAPH_DEBUG("Method name ", method_name);
196   GRAPH_DEBUG("Kernel func name ", kernel_func_name);
197   GRAPH_DEBUG("Symbolic indices ", symbolic_ind);
198 
199   std::shared_ptr<tensorexpr::TensorExprKernel> kernel;
200   std::vector<torch::jit::StrideInput> stride_desc = {
201       torch::jit::StrideInput::TENSOR_CONT};
202   std::unordered_map<
203       const torch::jit::Value*,
204       std::vector<torch::jit::StrideInput>>
205       symbolic_strides;
206   if (!symbolic_ind.empty()) {
207     for (auto i : g->inputs()) {
208       symbolic_strides[i] = stride_desc;
209     }
210     for (auto o : g->outputs()) {
211       symbolic_strides[o] = stride_desc;
212     }
213   }
214   kernel = std::make_shared<tensorexpr::TensorExprKernel>(TensorExprKernel(
215       g, kernel_func_name, {}, symbolic_ind, false, symbolic_strides));
216 
217   const std::string compiled_assembly = kernel->getCodeText();
218   auto func = compileMethod(kernel, method_name, sizes, types);
219   return std::make_pair(std::move(func), compiled_assembly);
220 }
221 
222 static void writeOutputLlvmAssembly(
223     const std::string& asm_code,
224     const std::string& output_llvm_file_name) {
225   std::ofstream output(output_llvm_file_name);
226   output << asm_code;
227   GRAPH_DEBUG(
228       "The compiled llvm assembly code was saved to ", output_llvm_file_name);
229 }
230 
231 static std::vector<std::string> split(
232     char separator,
233     const std::string& string,
234     bool ignore_empty = true) {
235   std::vector<std::string> pieces;
236   std::stringstream ss(string);
237   std::string item;
238   while (getline(ss, item, separator)) {
239     if (!ignore_empty || !item.empty()) {
240       pieces.push_back(std::move(item));
241     }
242   }
243   return pieces;
244 }
245 
246 static std::vector<std::vector<int64_t>> parseInputShapes(
247     const std::string& input_dims_s) {
248   std::vector<std::string> input_dims_list = split(';', input_dims_s);
249   std::vector<std::vector<int64_t>> inputs;
250   for (const auto& input_dims_item : input_dims_list) {
251     auto input_dims_str = split(',', input_dims_item);
252     std::vector<int64_t> input_dims;
253     input_dims.reserve(input_dims_str.size());
254     for (const auto& s : input_dims_str) {
255       input_dims.push_back(std::stoi(s));
256     }
257     inputs.push_back(input_dims);
258   }
259   return inputs;
260 }
261 
262 static std::vector<at::ScalarType> parseInputTypes(
263     const std::string& input_types_str) {
264   std::vector<std::string> inputTypes = split(';', input_types_str);
265   std::vector<at::ScalarType> scalarTypes;
266   for (const auto& inputType : inputTypes) {
267     at::ScalarType scalarType;
268     if (inputType == "float") {
269       scalarType = at::ScalarType::Float;
270     } else if (inputType == "uint8") {
271       scalarType = at::ScalarType::Byte;
272     } else if (inputType == "int64") {
273       scalarType = at::ScalarType::Long;
274     } else {
275       CAFFE_THROW("Unsupported input type: ", inputType);
276     }
277     scalarTypes.push_back(scalarType);
278   }
279   return scalarTypes;
280 }
281 
282 static std::vector<at::MemoryFormat> parseInputMemoryFormats(
283     const std::string& input_memory_format_str) {
284   std::vector<std::string> memFormatsStr = split(';', input_memory_format_str);
285   std::vector<at::MemoryFormat> memFormats;
286   for (const auto& memFormatStr : memFormatsStr) {
287     at::MemoryFormat memFormat;
288     if (memFormatStr == "contiguous") {
289       memFormat = at::MemoryFormat::Contiguous;
290     } else if (memFormatStr == "channels_last") {
291       memFormat = at::MemoryFormat::ChannelsLast;
292     } else {
293       CAFFE_THROW("Unsupported memory format: ", memFormatStr);
294     }
295     memFormats.push_back(memFormat);
296   }
297   return memFormats;
298 }
299 
300 static std::vector<int64_t> parseInputDynamicShapes(
301     const std::string& dynamic_dims_s) {
302   std::vector<std::string> dynamic_dims_list = split(',', dynamic_dims_s);
303   std::vector<int64_t> dynamic_dims;
304   dynamic_dims.reserve(dynamic_dims_list.size());
305   for (const auto& dim : dynamic_dims_list) {
306     dynamic_dims.push_back(std::stoi(dim));
307   }
308   return dynamic_dims;
309 }
310 
311 static std::string getNncKernelId(
312     const std::string& model_name,
313     const std::string& model_version,
314     const std::string& method_name) {
315   // TODO: calculate the version_token.
316   const std::string version_token = "VERTOKEN";
317   return model_name + ":" + model_version + ":" + method_name + ":" +
318       version_token;
319 }
320 
321 static std::string getNncKernelFuncName(
322     const std::string& model_name,
323     const std::string& model_version,
324     const std::string& method_name) {
325   return "nnc_" + model_name + "_" + model_version + "_" + method_name;
326 }
327 
328 // Preprocess the graph and returns the processed graph and
329 // symbolic values if dynamic input shapes are specified
330 static std::pair<std::shared_ptr<Graph>, std::vector<int64_t>>
331 preprocessGraphPasses(
332     std::shared_ptr<Graph>& graph,
333     const std::vector<std::optional<at::Tensor>>& example_inputs,
334     const std::vector<int64_t>& dynamic_sizes) {
335   GRAPH_DEBUG("Before preprocessing graph passes: ", *graph);
336   torch::jit::RemoveTensorMutation(graph);
337   torch::jit::EliminateDeadCode(graph->block());
338   graph = torch::jit::tensorexpr::removeUnusedSelfArgument(graph);
339 
340   torch::jit::tensorexpr::annotateInputShapes(graph, example_inputs);
341   torch::jit::OptimizeFrozenGraph(graph, true);
342   torch::jit::PropagateShapesOnGraph(graph);
343   torch::jit::PeepholeOptimize(graph, false);
344   torch::jit::ConstantPropagation(graph);
345   torch::jit::PropagateShapesOnGraph(graph);
346   torch::jit::PeepholeOptimize(graph, false);
347   torch::jit::ConstantPropagation(graph);
348 
349   tensorexpr::removeUnusedSelfArgument(graph);
350 
351   std::vector<at::IValue> example_values;
352   example_values.reserve(example_inputs.size());
353   for (auto example_input : example_inputs) {
354     example_values.emplace_back(*example_input);
355   }
356   graph = TraceGraph(graph, example_values);
357   // TODO: Remove annotateInputShapes pass when TraceGraph can also capture
358   // input shapes
359   tensorexpr::annotateInputShapes(graph, example_inputs);
360 
361   RemoveListMutation(graph);
362   RemoveTensorMutation(graph);
363   EliminateDeadCode(graph);
364   LowerAllTuples(graph);
365 
366   auto sym_val =
367       torch::jit::tensorexpr::makeShapesSymbolic(graph, dynamic_sizes);
368 
369   GRAPH_DEBUG("After preprocessing graph passes: ", *graph);
370   return std::make_pair(graph, sym_val);
371 }
372 
373 static std::vector<std::optional<at::Tensor>> generateExampleInputs(
374     const std::vector<std::vector<int64_t>>& inputShapes,
375     const std::vector<at::ScalarType>& inputTypes,
376     const std::vector<at::MemoryFormat>& inputMemoryFormats) {
377   std::vector<std::optional<at::Tensor>> example_inputs;
378   example_inputs.reserve(inputShapes.size());
379   for (const auto i : c10::irange(inputShapes.size())) {
380     const auto dtype = at::dtype(inputTypes[i]);
381     const auto memory_format = inputMemoryFormats[i];
382     example_inputs.emplace_back(
383         at::rand(inputShapes[i]).to(dtype).contiguous(memory_format));
384   }
385   return example_inputs;
386 }
387 
388 static c10::IValue preprocess(
389     const torch::jit::Module& mod,
390     const c10::Dict<c10::IValue, c10::IValue>& compile_spec,
391     const torch::jit::BackendDebugHandleGenerator& generate_debug_handles) {
392   torch::jit::mobile::nnc::CompilationUnit cu;
393   for (const auto& kv : compile_spec) {
394     GRAPH_DEBUG("Key: ", kv.key());
395     GRAPH_DEBUG("Value: ", kv.value());
396     std::string method_name = *(kv.key().toString());
397     GRAPH_DEBUG("Method name: ", method_name);
398     auto method_spec = kv.value().toGenericDict();
399     std::string model_name = *method_spec.at("model_name").toString();
400     std::string model_version = *method_spec.at("model_version").toString();
401     std::string asmfile_name = *method_spec.at("asmfile").toString();
402     GRAPH_DEBUG("Model name: ", model_name);
403     GRAPH_DEBUG("Model version: ", model_version);
404     GRAPH_DEBUG("Asm file name: ", asmfile_name);
405 
406     auto method = mod.get_method(method_name);
407     auto graph = toGraphFunction(method.function()).graph()->copy();
408 
409     auto sizes = parseInputShapes(*method_spec.at("sizes").toString());
410     auto types = parseInputTypes(*method_spec.at("types").toString());
411     auto dynamic_sizes =
412         parseInputDynamicShapes(*method_spec.at("dynamic_sizes").toString());
413 
414     std::string memory_formats_str = method_spec.contains("memory_formats")
415         ? (*method_spec.at("memory_formats").toString()).string()
416         : "";
417     auto memory_formats = memory_formats_str.empty()
418         ? std::vector<at::MemoryFormat>(
419               sizes.size(), at::MemoryFormat::Contiguous)
420         : parseInputMemoryFormats(memory_formats_str);
421 
422     auto example_inputs = generateExampleInputs(sizes, types, memory_formats);
423     auto preprocessed =
424         preprocessGraphPasses(graph, example_inputs, dynamic_sizes);
425 
426     auto kernel_func_name =
427         getNncKernelFuncName(model_name, model_version, method_name);
428     auto processed_graph = preprocessed.first;
429     auto sym_values = preprocessed.second;
430     auto compiled = torch::jit::mobile::nnc::aotCompile(
431         method_name,
432         processed_graph,
433         sizes,
434         types,
435         kernel_func_name,
436         sym_values);
437     writeOutputLlvmAssembly(compiled.second, asmfile_name);
438     auto func = std::move(compiled.first);
439     func->set_nnc_kernel_id(
440         getNncKernelId(model_name, model_version, method_name));
441     cu.register_function(std::move(func));
442   }
443   return cu.serialize();
444 }
445 */
446 
447 // static auto reg = torch::jit::backend_preprocess_register("nnc", preprocess);
448 
449 } // namespace nnc
450 } // namespace mobile
451 } // namespace jit
452 } // namespace torch
453