1 #include <torch/csrc/jit/mobile/nnc/aot_compiler.h> 2 3 #include <ATen/Functions.h> 4 #include <ATen/NativeFunctions.h> 5 #include <torch/csrc/jit/backends/backend.h> 6 #include <torch/csrc/jit/backends/backend_detail.h> 7 #include <torch/csrc/jit/backends/backend_preprocess.h> 8 #include <torch/csrc/jit/ir/ir.h> 9 #include <torch/csrc/jit/jit_log.h> 10 #include <torch/csrc/jit/passes/constant_propagation.h> 11 #include <torch/csrc/jit/passes/dead_code_elimination.h> 12 #include <torch/csrc/jit/passes/frozen_graph_optimizations.h> 13 #include <torch/csrc/jit/passes/lower_tuples.h> 14 #include <torch/csrc/jit/passes/peephole.h> 15 #include <torch/csrc/jit/passes/remove_mutation.h> 16 #include <torch/csrc/jit/passes/shape_analysis.h> 17 #include <torch/csrc/jit/passes/symbolic_shape_analysis.h> 18 #include <torch/csrc/jit/runtime/jit_trace.h> 19 #include <torch/csrc/jit/tensorexpr/graph_opt.h> 20 #include <torch/csrc/jit/tensorexpr/ir.h> 21 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h> 22 #include <torch/csrc/jit/tensorexpr/kernel.h> 23 #include <fstream> 24 25 using namespace torch::jit; 26 using namespace torch::jit::tensorexpr; 27 28 namespace torch { 29 namespace jit { 30 namespace mobile { 31 namespace nnc { 32 33 // TODO(mvz): temporarily disable NNC backend in mobile builds. 34 /* 35 static std::vector<int64_t> getConstSizes(const BufPtr b) { 36 std::vector<int64_t> r; 37 for (const auto& dim : b->dims()) { 38 LongImmPtr imm_dim = to<LongImm>(dim); 39 // TODO: assert it's actually immediate 40 int64_t s = imm_dim->value(); 41 r.push_back(s); 42 } 43 return r; 44 } 45 46 // Construct input-specs vector from the inputs of the original graph 47 static std::vector<mobile::nnc::InputSpec> toInputSpecs( 48 const std::shared_ptr<tensorexpr::TensorExprKernel>& kernel) { 49 const std::shared_ptr<Graph>& g = kernel->graph(); 50 std::vector<mobile::nnc::InputSpec> specs; 51 52 // Graph inputs include scalar values for symbolic shapes, for which we 53 // don't need input specs. These scalar values come last among the graph 54 // inputs 55 auto num_inputs = 56 g->inputs().size() - kernel->getSymbolicShapeInputs().size(); 57 58 for (const auto i : c10::irange(num_inputs)) { 59 auto v = g->inputs()[i]; 60 const auto& t = v->type(); 61 mobile::nnc::InputSpec spec; 62 TORCH_CHECK(t->kind() == TypeKind::TensorType, "Unsupported input type"); 63 const auto& tt = t->cast<TensorType>(); 64 spec.sizes_ = {}; 65 auto sizes_vec = *tt->sizes().sizes(); 66 for (auto s : sizes_vec) { 67 spec.sizes_.push_back(s ? *s : 0); 68 } 69 spec.dtype_ = *tt->scalarType(); 70 specs.emplace_back(std::move(spec)); 71 } 72 return specs; 73 } 74 75 // Locate symbolic shapes in shapes of the inputs. 76 // 77 // For each symbolic shape we're trying to find the input from which it can be 78 // extracted and the dimension index in that input. 79 // For instance, if we have 80 // graph(%x : Float(SS(-1), 10), %y : Long(20, SS(-2), %ss_1 : int, %ss_2 : int) 81 // then we would need to find locations of two symbolic shapes: SS(-1) and 82 // SS(-2). The first one corresponds to the first dimension of the first input, 83 // the second one corresponds to the second dimension of the second input, 84 // so we will return {{0, 0}, {1, 1}}. 85 // 86 // If a symbolic shape cannot be found among dimensions of inputs, we 87 // will throw an error (this situation is possible when symbolic shape 88 // corresponds to the size of an intermediate - we don't support this 89 // case here yet). 90 // 91 // If a symbolic shape can be found in several different positions, we 92 // return the first one we find (TODO: maybe we should return all and 93 // verify that they all match at runtime). 94 static std::vector<SymbolicShapePosition> findSymbolicShapePositions( 95 std::shared_ptr<tensorexpr::TensorExprKernel> kernel) { 96 std::vector<SymbolicShapePosition> res; 97 for (int64_t sym_idx : kernel->getSymbolicShapeInputs()) { 98 bool found = false; 99 for (int64_t input_idx : c10::irange(kernel->graph()->inputs().size())) { 100 auto input = kernel->graph()->inputs()[input_idx]; 101 102 if (!input->type()->cast<TensorType>()) { 103 continue; 104 } 105 auto tt = input->type()->expect<TensorType>(); 106 if (!tt->symbolic_sizes().sizes()) { 107 continue; 108 } 109 std::vector<at::ShapeSymbol> shape_vec = *tt->symbolic_sizes().sizes(); 110 for (int64_t dim_idx : c10::irange(shape_vec.size())) { 111 if (shape_vec[dim_idx].value() == sym_idx) { 112 res.emplace_back(input_idx, dim_idx); 113 found = true; 114 break; 115 } 116 } 117 if (found) { 118 break; 119 } 120 } 121 TORCH_CHECK( 122 found, "Could not locate a symbolic shape among input tensor shapes"); 123 } 124 return res; 125 } 126 127 static std::unique_ptr<Function> compileMethod( 128 std::shared_ptr<tensorexpr::TensorExprKernel> kernel, 129 const std::string& method_name, 130 const std::vector<std::vector<int64_t>>& sizes, 131 const std::vector<at::ScalarType>& types) { 132 auto func = std::make_unique<Function>(); 133 func->set_name(method_name); 134 func->set_input_specs(toInputSpecs(kernel)); 135 136 auto params = c10::impl::GenericList(c10::AnyType::get()); 137 auto const_descriptors = kernel->getConstantDescriptors(); 138 for (const auto& cd : const_descriptors) { 139 auto sizes = getConstSizes(cd.buf); 140 if (!cd.node) { 141 // sizes.empty() needs to be handled as sizes can be empty for Scalar 142 // Tensors 143 at::Tensor const_tensor = !sizes.empty() 144 ? at::from_blob(cd.ptr, sizes).clone() 145 : at::native::wrapped_scalar_tensor(*static_cast<double*>(cd.ptr)); 146 params.push_back(const_tensor); 147 } else { 148 params.emplace_back(toIValue(cd.node->output())); 149 } 150 } 151 func->set_parameters(params); 152 153 MemoryPlan plan; 154 plan.buffer_sizes_ = {}; // temp_sizes_; 155 // TODO: implement prealloc optimization and fill in temp_sizes 156 func->set_memory_plan(plan); 157 158 int64_t n_inputs = kernel->graph()->inputs().size(); 159 int64_t n_outputs = kernel->graph()->outputs().size(); 160 std::vector<OutputSpec> out_spec; 161 for (int64_t idx = n_inputs; idx < n_inputs + n_outputs; idx++) { 162 const auto& ba = kernel->getBufferArgs()[idx]; 163 OutputSpec output; 164 output.sizes_ = getConstSizes(ba.buf()); 165 // TODO: assert the output is a buffer and not a scalar 166 output.dtype_ = ba.buf()->dtype().scalar_type(); 167 if (isQIntType(output.dtype_)) { 168 // Supporting only static qscale/qzero 169 output.qscale_ = 170 to<DoubleImm>(torch::jit::tensorexpr::IRSimplifier::simplify( 171 ba.buf()->qscale())) 172 ->value(); 173 output.qzero_ = 174 to<LongImm>( 175 torch::jit::tensorexpr::IRSimplifier::simplify(ba.buf()->qzero())) 176 ->value(); 177 } 178 out_spec.push_back(output); 179 } 180 func->set_output_specs(out_spec); 181 func->set_sym_shape_positions(findSymbolicShapePositions(kernel)); 182 183 return func; 184 } 185 186 static std::pair<std::unique_ptr<Function>, const std::string> aotCompile( 187 const std::string& method_name, 188 std::shared_ptr<Graph>& g, 189 const std::vector<std::vector<int64_t>>& sizes, 190 const std::vector<at::ScalarType>& types, 191 const std::string& kernel_func_name, 192 const std::vector<int64_t>& symbolic_ind) { 193 GRAPH_DEBUG("Input sizes ", sizes); 194 GRAPH_DEBUG("Input types ", types); 195 GRAPH_DEBUG("Method name ", method_name); 196 GRAPH_DEBUG("Kernel func name ", kernel_func_name); 197 GRAPH_DEBUG("Symbolic indices ", symbolic_ind); 198 199 std::shared_ptr<tensorexpr::TensorExprKernel> kernel; 200 std::vector<torch::jit::StrideInput> stride_desc = { 201 torch::jit::StrideInput::TENSOR_CONT}; 202 std::unordered_map< 203 const torch::jit::Value*, 204 std::vector<torch::jit::StrideInput>> 205 symbolic_strides; 206 if (!symbolic_ind.empty()) { 207 for (auto i : g->inputs()) { 208 symbolic_strides[i] = stride_desc; 209 } 210 for (auto o : g->outputs()) { 211 symbolic_strides[o] = stride_desc; 212 } 213 } 214 kernel = std::make_shared<tensorexpr::TensorExprKernel>(TensorExprKernel( 215 g, kernel_func_name, {}, symbolic_ind, false, symbolic_strides)); 216 217 const std::string compiled_assembly = kernel->getCodeText(); 218 auto func = compileMethod(kernel, method_name, sizes, types); 219 return std::make_pair(std::move(func), compiled_assembly); 220 } 221 222 static void writeOutputLlvmAssembly( 223 const std::string& asm_code, 224 const std::string& output_llvm_file_name) { 225 std::ofstream output(output_llvm_file_name); 226 output << asm_code; 227 GRAPH_DEBUG( 228 "The compiled llvm assembly code was saved to ", output_llvm_file_name); 229 } 230 231 static std::vector<std::string> split( 232 char separator, 233 const std::string& string, 234 bool ignore_empty = true) { 235 std::vector<std::string> pieces; 236 std::stringstream ss(string); 237 std::string item; 238 while (getline(ss, item, separator)) { 239 if (!ignore_empty || !item.empty()) { 240 pieces.push_back(std::move(item)); 241 } 242 } 243 return pieces; 244 } 245 246 static std::vector<std::vector<int64_t>> parseInputShapes( 247 const std::string& input_dims_s) { 248 std::vector<std::string> input_dims_list = split(';', input_dims_s); 249 std::vector<std::vector<int64_t>> inputs; 250 for (const auto& input_dims_item : input_dims_list) { 251 auto input_dims_str = split(',', input_dims_item); 252 std::vector<int64_t> input_dims; 253 input_dims.reserve(input_dims_str.size()); 254 for (const auto& s : input_dims_str) { 255 input_dims.push_back(std::stoi(s)); 256 } 257 inputs.push_back(input_dims); 258 } 259 return inputs; 260 } 261 262 static std::vector<at::ScalarType> parseInputTypes( 263 const std::string& input_types_str) { 264 std::vector<std::string> inputTypes = split(';', input_types_str); 265 std::vector<at::ScalarType> scalarTypes; 266 for (const auto& inputType : inputTypes) { 267 at::ScalarType scalarType; 268 if (inputType == "float") { 269 scalarType = at::ScalarType::Float; 270 } else if (inputType == "uint8") { 271 scalarType = at::ScalarType::Byte; 272 } else if (inputType == "int64") { 273 scalarType = at::ScalarType::Long; 274 } else { 275 CAFFE_THROW("Unsupported input type: ", inputType); 276 } 277 scalarTypes.push_back(scalarType); 278 } 279 return scalarTypes; 280 } 281 282 static std::vector<at::MemoryFormat> parseInputMemoryFormats( 283 const std::string& input_memory_format_str) { 284 std::vector<std::string> memFormatsStr = split(';', input_memory_format_str); 285 std::vector<at::MemoryFormat> memFormats; 286 for (const auto& memFormatStr : memFormatsStr) { 287 at::MemoryFormat memFormat; 288 if (memFormatStr == "contiguous") { 289 memFormat = at::MemoryFormat::Contiguous; 290 } else if (memFormatStr == "channels_last") { 291 memFormat = at::MemoryFormat::ChannelsLast; 292 } else { 293 CAFFE_THROW("Unsupported memory format: ", memFormatStr); 294 } 295 memFormats.push_back(memFormat); 296 } 297 return memFormats; 298 } 299 300 static std::vector<int64_t> parseInputDynamicShapes( 301 const std::string& dynamic_dims_s) { 302 std::vector<std::string> dynamic_dims_list = split(',', dynamic_dims_s); 303 std::vector<int64_t> dynamic_dims; 304 dynamic_dims.reserve(dynamic_dims_list.size()); 305 for (const auto& dim : dynamic_dims_list) { 306 dynamic_dims.push_back(std::stoi(dim)); 307 } 308 return dynamic_dims; 309 } 310 311 static std::string getNncKernelId( 312 const std::string& model_name, 313 const std::string& model_version, 314 const std::string& method_name) { 315 // TODO: calculate the version_token. 316 const std::string version_token = "VERTOKEN"; 317 return model_name + ":" + model_version + ":" + method_name + ":" + 318 version_token; 319 } 320 321 static std::string getNncKernelFuncName( 322 const std::string& model_name, 323 const std::string& model_version, 324 const std::string& method_name) { 325 return "nnc_" + model_name + "_" + model_version + "_" + method_name; 326 } 327 328 // Preprocess the graph and returns the processed graph and 329 // symbolic values if dynamic input shapes are specified 330 static std::pair<std::shared_ptr<Graph>, std::vector<int64_t>> 331 preprocessGraphPasses( 332 std::shared_ptr<Graph>& graph, 333 const std::vector<std::optional<at::Tensor>>& example_inputs, 334 const std::vector<int64_t>& dynamic_sizes) { 335 GRAPH_DEBUG("Before preprocessing graph passes: ", *graph); 336 torch::jit::RemoveTensorMutation(graph); 337 torch::jit::EliminateDeadCode(graph->block()); 338 graph = torch::jit::tensorexpr::removeUnusedSelfArgument(graph); 339 340 torch::jit::tensorexpr::annotateInputShapes(graph, example_inputs); 341 torch::jit::OptimizeFrozenGraph(graph, true); 342 torch::jit::PropagateShapesOnGraph(graph); 343 torch::jit::PeepholeOptimize(graph, false); 344 torch::jit::ConstantPropagation(graph); 345 torch::jit::PropagateShapesOnGraph(graph); 346 torch::jit::PeepholeOptimize(graph, false); 347 torch::jit::ConstantPropagation(graph); 348 349 tensorexpr::removeUnusedSelfArgument(graph); 350 351 std::vector<at::IValue> example_values; 352 example_values.reserve(example_inputs.size()); 353 for (auto example_input : example_inputs) { 354 example_values.emplace_back(*example_input); 355 } 356 graph = TraceGraph(graph, example_values); 357 // TODO: Remove annotateInputShapes pass when TraceGraph can also capture 358 // input shapes 359 tensorexpr::annotateInputShapes(graph, example_inputs); 360 361 RemoveListMutation(graph); 362 RemoveTensorMutation(graph); 363 EliminateDeadCode(graph); 364 LowerAllTuples(graph); 365 366 auto sym_val = 367 torch::jit::tensorexpr::makeShapesSymbolic(graph, dynamic_sizes); 368 369 GRAPH_DEBUG("After preprocessing graph passes: ", *graph); 370 return std::make_pair(graph, sym_val); 371 } 372 373 static std::vector<std::optional<at::Tensor>> generateExampleInputs( 374 const std::vector<std::vector<int64_t>>& inputShapes, 375 const std::vector<at::ScalarType>& inputTypes, 376 const std::vector<at::MemoryFormat>& inputMemoryFormats) { 377 std::vector<std::optional<at::Tensor>> example_inputs; 378 example_inputs.reserve(inputShapes.size()); 379 for (const auto i : c10::irange(inputShapes.size())) { 380 const auto dtype = at::dtype(inputTypes[i]); 381 const auto memory_format = inputMemoryFormats[i]; 382 example_inputs.emplace_back( 383 at::rand(inputShapes[i]).to(dtype).contiguous(memory_format)); 384 } 385 return example_inputs; 386 } 387 388 static c10::IValue preprocess( 389 const torch::jit::Module& mod, 390 const c10::Dict<c10::IValue, c10::IValue>& compile_spec, 391 const torch::jit::BackendDebugHandleGenerator& generate_debug_handles) { 392 torch::jit::mobile::nnc::CompilationUnit cu; 393 for (const auto& kv : compile_spec) { 394 GRAPH_DEBUG("Key: ", kv.key()); 395 GRAPH_DEBUG("Value: ", kv.value()); 396 std::string method_name = *(kv.key().toString()); 397 GRAPH_DEBUG("Method name: ", method_name); 398 auto method_spec = kv.value().toGenericDict(); 399 std::string model_name = *method_spec.at("model_name").toString(); 400 std::string model_version = *method_spec.at("model_version").toString(); 401 std::string asmfile_name = *method_spec.at("asmfile").toString(); 402 GRAPH_DEBUG("Model name: ", model_name); 403 GRAPH_DEBUG("Model version: ", model_version); 404 GRAPH_DEBUG("Asm file name: ", asmfile_name); 405 406 auto method = mod.get_method(method_name); 407 auto graph = toGraphFunction(method.function()).graph()->copy(); 408 409 auto sizes = parseInputShapes(*method_spec.at("sizes").toString()); 410 auto types = parseInputTypes(*method_spec.at("types").toString()); 411 auto dynamic_sizes = 412 parseInputDynamicShapes(*method_spec.at("dynamic_sizes").toString()); 413 414 std::string memory_formats_str = method_spec.contains("memory_formats") 415 ? (*method_spec.at("memory_formats").toString()).string() 416 : ""; 417 auto memory_formats = memory_formats_str.empty() 418 ? std::vector<at::MemoryFormat>( 419 sizes.size(), at::MemoryFormat::Contiguous) 420 : parseInputMemoryFormats(memory_formats_str); 421 422 auto example_inputs = generateExampleInputs(sizes, types, memory_formats); 423 auto preprocessed = 424 preprocessGraphPasses(graph, example_inputs, dynamic_sizes); 425 426 auto kernel_func_name = 427 getNncKernelFuncName(model_name, model_version, method_name); 428 auto processed_graph = preprocessed.first; 429 auto sym_values = preprocessed.second; 430 auto compiled = torch::jit::mobile::nnc::aotCompile( 431 method_name, 432 processed_graph, 433 sizes, 434 types, 435 kernel_func_name, 436 sym_values); 437 writeOutputLlvmAssembly(compiled.second, asmfile_name); 438 auto func = std::move(compiled.first); 439 func->set_nnc_kernel_id( 440 getNncKernelId(model_name, model_version, method_name)); 441 cu.register_function(std::move(func)); 442 } 443 return cu.serialize(); 444 } 445 */ 446 447 // static auto reg = torch::jit::backend_preprocess_register("nnc", preprocess); 448 449 } // namespace nnc 450 } // namespace mobile 451 } // namespace jit 452 } // namespace torch 453