xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/model_tracer/TracerRunner.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/Functions.h>
2 #include <ATen/core/dispatch/Dispatcher.h>
3 #include <ATen/core/dispatch/ObservedOperators.h>
4 #include <c10/core/ScalarType.h>
5 #include <c10/util/Exception.h>
6 #include <torch/csrc/autograd/grad_mode.h>
7 #include <torch/csrc/jit/mobile/compatibility/runtime_compatibility.h>
8 #include <torch/csrc/jit/mobile/model_tracer/KernelDTypeTracer.h>
9 #include <torch/csrc/jit/mobile/model_tracer/MobileModelRunner.h>
10 #include <torch/csrc/jit/mobile/model_tracer/OperatorCallTracer.h>
11 #include <torch/csrc/jit/mobile/model_tracer/TensorUtils.h>
12 #include <torch/csrc/jit/mobile/model_tracer/TracerRunner.h>
13 #include <torch/csrc/jit/mobile/parse_operators.h>
14 #include <torch/csrc/jit/runtime/operator.h>
15 #include <torch/script.h>
16 
17 namespace torch::jit::mobile {
18 
19 // Fetched from caffe2/aten/src/ATen/native/metal/MetalAten.mm
20 // Diffusion Link: https://fburl.com/diffusion/atwwmax2
21 const std::vector<std::string> gpu_metal_operators = {
22     "aten::conv2d",
23     "aten::add.Tensor",
24     "aten::add_.Tensor",
25     "aten::addmm",
26     "aten::empty.memory_format",
27     "aten::empty_strided",
28     "aten::log_softmax.int",
29     "aten::max_pool2d",
30     "aten::mul.Tensor",
31     "aten::relu",
32     "aten::relu_",
33     "aten::sigmoid",
34     "aten::sub.Tensor",
35     "aten::upsample_nearest2d.vec",
36     "aten::view",
37     "aten::adaptive_avg_pool2d",
38     "aten::hardtanh_",
39     "aten::reshape",
40     "aten::flatten.using_ints",
41 };
42 
43 /**
44  * These are a collection of some common ATen methods that are usually
45  * called outside of the Model's forward() run, and they need to be
46  * traced to ensure that the used operators are included in the build.
47  * If/When this list becomes too long, we can consider making it a
48  * per-model list.
49  */
call_setup_methods()50 static void call_setup_methods() {
51   at::zeros({2, 2});
52   at::ones({2, 2});
53   at::Tensor t1 = at::empty({7, 7});
54   at::Tensor t2 = t1.fill_(3);
55   at::Tensor t3 = t1.new_empty_strided(
56       {2, 3},
57       {3,
58        1}); // TODO investigate how this is different from normal empty_strided
59   at::narrow(t2, 1, 0, 1);
60   at::eq(t1, t2);
61   const volatile bool nz = at::native::is_nonzero(at::zeros({1}));
62   (void)nz;
63 
64   // Create a byte tensor and copy it
65   auto zb = at::zeros({10}, at::kByte);
66   auto zf = at::zeros({10}, at::kFloat);
67   zb.copy_(zf);
68   t2.div(1);
69 
70   // Typically, failures show up in CopyKernel.cpp, so enumerating
71   // common dtypes that may show up.
72   const auto all_dtypes_for_copy = {
73       at::kBool,
74       at::kByte,
75       at::kFloat,
76       at::kInt,
77       at::kChar,
78       at::kDouble,
79       at::kShort,
80       at::kLong};
81   for (const auto dtype : all_dtypes_for_copy) {
82     auto tensor1 = at::empty({10}, dtype);
83     tensor1.copy_(at::zeros({10}, at::kBool));
84     tensor1.copy_(at::zeros({10}, at::kFloat));
85     tensor1.copy_(at::zeros({10}, at::kInt));
86   }
87 
88   torch::zeros({0, 0}, torch::ScalarType::Float);
89   std::vector<float> storage(20, 1.0);
90   std::vector<int64_t> sizes({2, 10});
91   torch::from_blob(storage.data(), at::IntArrayRef(sizes), at::kFloat);
92 }
93 
94 /**
95  * Similar to setup methods there are a suite a functions that often appear
96  * under certain conditions but may avoid getting called in the trace due to the
97  * narrow nature of bundled inputs
98  */
call_dependent_methods(std::set<std::string> & root_ops)99 static void call_dependent_methods(std::set<std::string>& root_ops) {
100   bool is_training = false;
101   bool has_batchnorm = false;
102   bool has_dropout = false;
103   for (const std::string& op : root_ops) {
104     if (op.find("backward") != std::string::npos ||
105         op.find("requires_grad_") != std::string::npos) {
106       is_training = true;
107     }
108     if (op.find("batch_norm") != std::string::npos) {
109       has_batchnorm = true;
110     }
111     if (op.find("dropout") != std::string::npos) {
112       has_dropout = true;
113     }
114   }
115   if (is_training && has_batchnorm) {
116     at::batch_norm(
117         at::ones({2, 2}),
118         std::nullopt,
119         std::nullopt,
120         std::nullopt,
121         std::nullopt,
122         true,
123         0.1,
124         0.1,
125         false);
126   }
127   if (is_training && has_dropout) {
128     at::dropout(at::ones({20, 20, 20}), 0.2, true);
129   }
130 }
131 
132 /**
133  * Call methods on the Tensor object that we expect to be called
134  * in production on this Tensor.
135  */
consume_tensor(const at::Tensor & t)136 static void consume_tensor(const at::Tensor& t) {
137   const at::Tensor& c = t;
138   c.copy_(t.cpu());
139 }
140 
141 static std::unordered_map<std::string, c10::FunctionSchema>
_get_runtime_ops_and_schema()142 _get_runtime_ops_and_schema() {
143   std::unordered_map<std::string, c10::FunctionSchema> result;
144 
145   // Grab the jit operators
146   auto nonDispatcherOperators = torch::jit::getAllOperators();
147   for (const auto& full_op : nonDispatcherOperators) {
148     auto op = full_op->schema();
149     auto op_name = op.name();
150     if (!op.overload_name().empty()) {
151       op_name += ("." + op.overload_name());
152     }
153     result.emplace(op_name, op);
154   }
155 
156   // Grab the dispatcher operators
157   auto dispatcherOperators = c10::Dispatcher::singleton().getAllOpNames();
158   for (auto& op : dispatcherOperators) {
159     // grab schema
160     const auto op_handle = c10::Dispatcher::singleton().findOp(op);
161     if (op_handle->hasSchema()) {
162       auto op_name = op.name;
163       if (!op.overload_name.empty()) {
164         op_name += ("." + op.overload_name);
165       }
166       result.emplace(op_name, op_handle->schema());
167     }
168   }
169 
170   return result;
171 }
172 
173 /**
174  * For the vast majority of usecases the instrumentation in getCustomClass will
175  * catch any custom classes referenced by a model. There are however, niche
176  * situations that avoid the getCustomClass instrumentation due to some nuances
177  * of mobile model deserialization. To get around that we can search through all
178  * the used ops, and inspect their schemas to search for any referenced classes.
179  * Example schema: prepacked::linear_clamp_prepack(Tensor W, Tensor? B=None,
180  *   Scalar? output_min=None, Scalar? output_max=None) ->
181  *   __torch__.torch.classes.xnnpack.LinearOpContext"
182  */
recordCustomClassesFromOpSchemas(std::set<std::string> & root_ops,std::set<std::string> & traced_ops,std::set<std::string> & loaded_classes)183 static void recordCustomClassesFromOpSchemas(
184     std::set<std::string>& root_ops,
185     std::set<std::string>& traced_ops,
186     std::set<std::string>& loaded_classes) {
187   std::set<std::string> ops;
188   ops.insert(root_ops.begin(), root_ops.end());
189   ops.insert(traced_ops.begin(), traced_ops.end());
190   auto ops_and_schemas = _get_runtime_ops_and_schema();
191 
192   auto record_if_class = [&](const std::string& type_name) {
193     // All custom class types start with __torch__ not sure if this is by
194     // chance or guaranteed
195     if (type_name.find("__torch__") != std::string::npos) {
196       // The name of a customClassType here is its fully qualified name, but
197       // in registration only the class name is used so only record that
198       auto class_name = type_name.substr(type_name.find_last_of('.') + 1);
199       // Function schemas can include other type indicators such as [] so we
200       // need to trim to just alphanumeric + '_' characters as well
201       class_name = class_name.substr(
202           0,
203           class_name.find_first_not_of(
204               "aAbBcCdDeEfFgGhHiIjJkKlLmMnNoOpPqQrRsStTuUvVwWxXyYzZ_1234567890"));
205       loaded_classes.insert(class_name);
206     }
207   };
208 
209   for (auto& op_name : ops) {
210     // This check is only necessary because of GPU models.
211     // Certain models can only run on a specific backend say metal.
212     // Those ops will be present in the models root ops, but likely
213     // not the tracer on linux
214     if (ops_and_schemas.find(op_name) != ops_and_schemas.end()) {
215       auto& schema = ops_and_schemas.at(op_name);
216       for (auto& arg : schema.arguments()) {
217         record_if_class(arg.type()->annotation_str());
218       }
219       for (auto& ret : schema.returns()) {
220         record_if_class(ret.type()->annotation_str());
221       }
222     }
223   }
224 }
225 
run_model(const std::string & input_module_path,std::set<std::string> & root_ops,std::set<std::string> & enabled_backends,KernelDTypeTracer::kernel_tags_type & called_kernel_tags)226 static void run_model(
227     const std::string& input_module_path,
228     std::set<std::string>& root_ops,
229     std::set<std::string>& enabled_backends,
230     KernelDTypeTracer::kernel_tags_type& called_kernel_tags) {
231   // Load the module on CPU with the flag to skip the operator exists check.
232   // This is needed so that we can load any TorchBind objects (custom classes)
233   // that this model refers to so that any operators being called from those
234   // TorchBind objects can be traced by the model tracer.
235   torch::jit::mobile::MobileModelRunner module_runner(input_module_path, 0);
236   root_ops = module_runner.get_root_operators();
237   std::cout << "Got " << root_ops.size() << " Root Operators." << '\n';
238 
239   if (torch::jit::mobile::MobileModelRunner::set_has_metal_gpu_operators(
240           root_ops)) {
241     std::cout << "Inferred Metal GPU Model." << '\n';
242     root_ops.insert(gpu_metal_operators.begin(), gpu_metal_operators.end());
243     called_kernel_tags["__unused__"] = {"Float"};
244     enabled_backends.insert("Metal GPU");
245 
246     // When we encounter a GPU model, we should call .cpu().copy_() on the
247     // tensors in the bundled inputs, since this is what will happen when
248     // such a model is executed on an iOS device (to copy the Tensor to Metal
249     // memory via a call to .metal()).
250     module_runner.for_each_tensor_in_bundled_inputs(consume_tensor);
251   } else {
252     std::cout << "Inferred CPU Model." << '\n';
253     enabled_backends.insert("CPU");
254     torch::jit::mobile::MobileModelRunner mobile_module_runner(
255         input_module_path);
256 
257     // When we encounter a CPU model, we should call .cpu().copy_() on the
258     // tensors in the bundled inputs, since this is what will happen when
259     // such a model is executed on an Android device since the PyTorch JNI
260     // bindings call .cpu() in JIValue::newJIValueFromAtIValue().
261     module_runner.for_each_tensor_in_bundled_inputs(consume_tensor);
262 
263     // If a user has bundled inputs since that api was updated to accept
264     // bundled inputs for multiple methods They should go down this route.
265     // Even if they only bundle inputs for forward they will have the new
266     // style bundled inputs. Since at this time in tracer.cpp we do not know
267     // what functions have bundled inputs we must call
268     // get_bundled_inputs_functions_and_info if it exists to get the set.
269     if (mobile_module_runner.has_new_style_bundled_inputs()) {
270       auto bundled_inputs_mapping =
271           mobile_module_runner.get_many_functions_bundled_inputs();
272       for (auto& entry : bundled_inputs_mapping) {
273         std::string function_name = entry.first;
274         std::vector<std::vector<at::IValue>> bundled_inputs = entry.second;
275         std::cout << "Got " << bundled_inputs.size() << " bundled input(s) for "
276                   << function_name << "\n\n";
277         std::vector<at::IValue> results =
278             mobile_module_runner.run_with_inputs(function_name, bundled_inputs);
279 
280         for (auto& result : results) {
281           // Consume the result Tensor(s) when tracing on CPU since the
282           // Android/Java JNI bindings will do the same.
283           torch::jit::mobile::for_each_tensor_in_ivalue(result, consume_tensor);
284         }
285       }
286       // If get_bundled_inputs_functions_and_info does not exists we default
287       // to assuming they bundled before that change was made. If no bundled
288       // inputs are found here either an error will be thrown
289     } else {
290       std::vector<std::vector<at::IValue>> bundled_inputs =
291           mobile_module_runner.get_all_bundled_inputs();
292       std::cout << "Got " << bundled_inputs.size() << " bundled input(s)\n\n";
293       std::vector<at::IValue> results =
294           mobile_module_runner.run_with_inputs(bundled_inputs);
295 
296       for (auto& result : results) {
297         // Consume the result Tensor(s) when tracing on CPU since the
298         // Android/Java JNI bindings will do the same.
299         torch::jit::mobile::for_each_tensor_in_ivalue(result, consume_tensor);
300       }
301     }
302   }
303 }
304 
trace_run(const std::string & input_module_path)305 TracerResult trace_run(const std::string& input_module_path) {
306   return trace_run(std::vector<std::string>(1, input_module_path));
307 }
308 
trace_run(const std::vector<std::string> & input_module_paths)309 TracerResult trace_run(const std::vector<std::string>& input_module_paths) {
310   at::globalContext().setQEngine(at::QEngine::QNNPACK);
311   c10::ObservedOperators::getUnobservedOperatorList().clear();
312 
313   torch::jit::mobile::OperatorCallTracer op_tracer;
314   torch::jit::mobile::KernelDTypeTracer kdtype_tracer;
315   torch::jit::mobile::CustomClassTracer custom_class_tracer;
316   torch::jit::mobile::BuildFeatureTracer build_feature_tracer;
317 
318   call_setup_methods();
319 
320   std::set<std::string> root_ops, traced_operators, enabled_backends,
321       loaded_classes, build_features;
322   torch::jit::mobile::KernelDTypeTracer::kernel_tags_type called_kernel_tags;
323 
324   using torch::jit::MobileModuleLoadOptions;
325 
326   for (auto& input_module_path : input_module_paths) {
327     // run with QNNPACK
328     at::globalContext().setQEngine(at::QEngine::QNNPACK);
329 
330     run_model(
331         input_module_path, root_ops, enabled_backends, called_kernel_tags);
332     // Not every model can be successfully run with fbgemm,
333     // but for those that can this can help broaden the tracers scope around
334     // hyper optimized QNNPack paths
335     try {
336       at::globalContext().setQEngine(at::QEngine::FBGEMM);
337       run_model(
338           input_module_path, root_ops, enabled_backends, called_kernel_tags);
339     } catch (std::exception& ex) {
340       std::cerr
341           << "ModelTracer encountered an error while attempting to run the model in FBGEMM mode"
342           << ex.what() << "\n Skipping FBGEMM execution" << '\n';
343     }
344     try {
345       at::globalContext().setQEngine(at::QEngine::QNNPACK);
346       c10::InferenceMode guard(true);
347       run_model(
348           input_module_path, root_ops, enabled_backends, called_kernel_tags);
349     } catch (std::exception& ex) {
350       std::cerr
351           << "ModelTracer encountered an error while attempting to run the model under an inference guard"
352           << ex.what() << "\n Skipping inference guard execution" << '\n';
353     }
354   }
355 
356   call_dependent_methods(root_ops);
357 
358   op_tracer.getCalledOperators().withLock(
359       [&](std::set<std::string>& called_operators) {
360         traced_operators = called_operators;
361       });
362 
363   recordCustomClassesFromOpSchemas(root_ops, traced_operators, loaded_classes);
364 
365   kdtype_tracer.getCalledKernelTags().withLock(
366       [&](KernelDTypeTracer::kernel_tags_type& kernel_tags) {
367         called_kernel_tags.insert(kernel_tags.begin(), kernel_tags.end());
368       });
369 
370   traced_operators.insert(
371       always_included_traced_ops.begin(), always_included_traced_ops.end());
372 
373   custom_class_tracer.getLoadedClasses().withLock(
374       [&](CustomClassTracer::custom_classes_type& custom_classes) {
375         loaded_classes.insert(custom_classes.begin(), custom_classes.end());
376       });
377 
378   build_feature_tracer.getBuildFeatures().withLock(
379       [&](BuildFeatureTracer::build_feature_type& bf) {
380         build_features.insert(bf.begin(), bf.end());
381       });
382 
383   TracerResult tracer_result = {
384       root_ops,
385       traced_operators,
386       called_kernel_tags,
387       loaded_classes,
388       build_features,
389       enabled_backends};
390 
391   return tracer_result;
392 }
393 
394 } // namespace torch::jit::mobile
395