1 #include <ATen/Functions.h>
2 #include <ATen/core/dispatch/Dispatcher.h>
3 #include <ATen/core/dispatch/ObservedOperators.h>
4 #include <c10/core/ScalarType.h>
5 #include <c10/util/Exception.h>
6 #include <torch/csrc/autograd/grad_mode.h>
7 #include <torch/csrc/jit/mobile/compatibility/runtime_compatibility.h>
8 #include <torch/csrc/jit/mobile/model_tracer/KernelDTypeTracer.h>
9 #include <torch/csrc/jit/mobile/model_tracer/MobileModelRunner.h>
10 #include <torch/csrc/jit/mobile/model_tracer/OperatorCallTracer.h>
11 #include <torch/csrc/jit/mobile/model_tracer/TensorUtils.h>
12 #include <torch/csrc/jit/mobile/model_tracer/TracerRunner.h>
13 #include <torch/csrc/jit/mobile/parse_operators.h>
14 #include <torch/csrc/jit/runtime/operator.h>
15 #include <torch/script.h>
16
17 namespace torch::jit::mobile {
18
19 // Fetched from caffe2/aten/src/ATen/native/metal/MetalAten.mm
20 // Diffusion Link: https://fburl.com/diffusion/atwwmax2
21 const std::vector<std::string> gpu_metal_operators = {
22 "aten::conv2d",
23 "aten::add.Tensor",
24 "aten::add_.Tensor",
25 "aten::addmm",
26 "aten::empty.memory_format",
27 "aten::empty_strided",
28 "aten::log_softmax.int",
29 "aten::max_pool2d",
30 "aten::mul.Tensor",
31 "aten::relu",
32 "aten::relu_",
33 "aten::sigmoid",
34 "aten::sub.Tensor",
35 "aten::upsample_nearest2d.vec",
36 "aten::view",
37 "aten::adaptive_avg_pool2d",
38 "aten::hardtanh_",
39 "aten::reshape",
40 "aten::flatten.using_ints",
41 };
42
43 /**
44 * These are a collection of some common ATen methods that are usually
45 * called outside of the Model's forward() run, and they need to be
46 * traced to ensure that the used operators are included in the build.
47 * If/When this list becomes too long, we can consider making it a
48 * per-model list.
49 */
call_setup_methods()50 static void call_setup_methods() {
51 at::zeros({2, 2});
52 at::ones({2, 2});
53 at::Tensor t1 = at::empty({7, 7});
54 at::Tensor t2 = t1.fill_(3);
55 at::Tensor t3 = t1.new_empty_strided(
56 {2, 3},
57 {3,
58 1}); // TODO investigate how this is different from normal empty_strided
59 at::narrow(t2, 1, 0, 1);
60 at::eq(t1, t2);
61 const volatile bool nz = at::native::is_nonzero(at::zeros({1}));
62 (void)nz;
63
64 // Create a byte tensor and copy it
65 auto zb = at::zeros({10}, at::kByte);
66 auto zf = at::zeros({10}, at::kFloat);
67 zb.copy_(zf);
68 t2.div(1);
69
70 // Typically, failures show up in CopyKernel.cpp, so enumerating
71 // common dtypes that may show up.
72 const auto all_dtypes_for_copy = {
73 at::kBool,
74 at::kByte,
75 at::kFloat,
76 at::kInt,
77 at::kChar,
78 at::kDouble,
79 at::kShort,
80 at::kLong};
81 for (const auto dtype : all_dtypes_for_copy) {
82 auto tensor1 = at::empty({10}, dtype);
83 tensor1.copy_(at::zeros({10}, at::kBool));
84 tensor1.copy_(at::zeros({10}, at::kFloat));
85 tensor1.copy_(at::zeros({10}, at::kInt));
86 }
87
88 torch::zeros({0, 0}, torch::ScalarType::Float);
89 std::vector<float> storage(20, 1.0);
90 std::vector<int64_t> sizes({2, 10});
91 torch::from_blob(storage.data(), at::IntArrayRef(sizes), at::kFloat);
92 }
93
94 /**
95 * Similar to setup methods there are a suite a functions that often appear
96 * under certain conditions but may avoid getting called in the trace due to the
97 * narrow nature of bundled inputs
98 */
call_dependent_methods(std::set<std::string> & root_ops)99 static void call_dependent_methods(std::set<std::string>& root_ops) {
100 bool is_training = false;
101 bool has_batchnorm = false;
102 bool has_dropout = false;
103 for (const std::string& op : root_ops) {
104 if (op.find("backward") != std::string::npos ||
105 op.find("requires_grad_") != std::string::npos) {
106 is_training = true;
107 }
108 if (op.find("batch_norm") != std::string::npos) {
109 has_batchnorm = true;
110 }
111 if (op.find("dropout") != std::string::npos) {
112 has_dropout = true;
113 }
114 }
115 if (is_training && has_batchnorm) {
116 at::batch_norm(
117 at::ones({2, 2}),
118 std::nullopt,
119 std::nullopt,
120 std::nullopt,
121 std::nullopt,
122 true,
123 0.1,
124 0.1,
125 false);
126 }
127 if (is_training && has_dropout) {
128 at::dropout(at::ones({20, 20, 20}), 0.2, true);
129 }
130 }
131
132 /**
133 * Call methods on the Tensor object that we expect to be called
134 * in production on this Tensor.
135 */
consume_tensor(const at::Tensor & t)136 static void consume_tensor(const at::Tensor& t) {
137 const at::Tensor& c = t;
138 c.copy_(t.cpu());
139 }
140
141 static std::unordered_map<std::string, c10::FunctionSchema>
_get_runtime_ops_and_schema()142 _get_runtime_ops_and_schema() {
143 std::unordered_map<std::string, c10::FunctionSchema> result;
144
145 // Grab the jit operators
146 auto nonDispatcherOperators = torch::jit::getAllOperators();
147 for (const auto& full_op : nonDispatcherOperators) {
148 auto op = full_op->schema();
149 auto op_name = op.name();
150 if (!op.overload_name().empty()) {
151 op_name += ("." + op.overload_name());
152 }
153 result.emplace(op_name, op);
154 }
155
156 // Grab the dispatcher operators
157 auto dispatcherOperators = c10::Dispatcher::singleton().getAllOpNames();
158 for (auto& op : dispatcherOperators) {
159 // grab schema
160 const auto op_handle = c10::Dispatcher::singleton().findOp(op);
161 if (op_handle->hasSchema()) {
162 auto op_name = op.name;
163 if (!op.overload_name.empty()) {
164 op_name += ("." + op.overload_name);
165 }
166 result.emplace(op_name, op_handle->schema());
167 }
168 }
169
170 return result;
171 }
172
173 /**
174 * For the vast majority of usecases the instrumentation in getCustomClass will
175 * catch any custom classes referenced by a model. There are however, niche
176 * situations that avoid the getCustomClass instrumentation due to some nuances
177 * of mobile model deserialization. To get around that we can search through all
178 * the used ops, and inspect their schemas to search for any referenced classes.
179 * Example schema: prepacked::linear_clamp_prepack(Tensor W, Tensor? B=None,
180 * Scalar? output_min=None, Scalar? output_max=None) ->
181 * __torch__.torch.classes.xnnpack.LinearOpContext"
182 */
recordCustomClassesFromOpSchemas(std::set<std::string> & root_ops,std::set<std::string> & traced_ops,std::set<std::string> & loaded_classes)183 static void recordCustomClassesFromOpSchemas(
184 std::set<std::string>& root_ops,
185 std::set<std::string>& traced_ops,
186 std::set<std::string>& loaded_classes) {
187 std::set<std::string> ops;
188 ops.insert(root_ops.begin(), root_ops.end());
189 ops.insert(traced_ops.begin(), traced_ops.end());
190 auto ops_and_schemas = _get_runtime_ops_and_schema();
191
192 auto record_if_class = [&](const std::string& type_name) {
193 // All custom class types start with __torch__ not sure if this is by
194 // chance or guaranteed
195 if (type_name.find("__torch__") != std::string::npos) {
196 // The name of a customClassType here is its fully qualified name, but
197 // in registration only the class name is used so only record that
198 auto class_name = type_name.substr(type_name.find_last_of('.') + 1);
199 // Function schemas can include other type indicators such as [] so we
200 // need to trim to just alphanumeric + '_' characters as well
201 class_name = class_name.substr(
202 0,
203 class_name.find_first_not_of(
204 "aAbBcCdDeEfFgGhHiIjJkKlLmMnNoOpPqQrRsStTuUvVwWxXyYzZ_1234567890"));
205 loaded_classes.insert(class_name);
206 }
207 };
208
209 for (auto& op_name : ops) {
210 // This check is only necessary because of GPU models.
211 // Certain models can only run on a specific backend say metal.
212 // Those ops will be present in the models root ops, but likely
213 // not the tracer on linux
214 if (ops_and_schemas.find(op_name) != ops_and_schemas.end()) {
215 auto& schema = ops_and_schemas.at(op_name);
216 for (auto& arg : schema.arguments()) {
217 record_if_class(arg.type()->annotation_str());
218 }
219 for (auto& ret : schema.returns()) {
220 record_if_class(ret.type()->annotation_str());
221 }
222 }
223 }
224 }
225
run_model(const std::string & input_module_path,std::set<std::string> & root_ops,std::set<std::string> & enabled_backends,KernelDTypeTracer::kernel_tags_type & called_kernel_tags)226 static void run_model(
227 const std::string& input_module_path,
228 std::set<std::string>& root_ops,
229 std::set<std::string>& enabled_backends,
230 KernelDTypeTracer::kernel_tags_type& called_kernel_tags) {
231 // Load the module on CPU with the flag to skip the operator exists check.
232 // This is needed so that we can load any TorchBind objects (custom classes)
233 // that this model refers to so that any operators being called from those
234 // TorchBind objects can be traced by the model tracer.
235 torch::jit::mobile::MobileModelRunner module_runner(input_module_path, 0);
236 root_ops = module_runner.get_root_operators();
237 std::cout << "Got " << root_ops.size() << " Root Operators." << '\n';
238
239 if (torch::jit::mobile::MobileModelRunner::set_has_metal_gpu_operators(
240 root_ops)) {
241 std::cout << "Inferred Metal GPU Model." << '\n';
242 root_ops.insert(gpu_metal_operators.begin(), gpu_metal_operators.end());
243 called_kernel_tags["__unused__"] = {"Float"};
244 enabled_backends.insert("Metal GPU");
245
246 // When we encounter a GPU model, we should call .cpu().copy_() on the
247 // tensors in the bundled inputs, since this is what will happen when
248 // such a model is executed on an iOS device (to copy the Tensor to Metal
249 // memory via a call to .metal()).
250 module_runner.for_each_tensor_in_bundled_inputs(consume_tensor);
251 } else {
252 std::cout << "Inferred CPU Model." << '\n';
253 enabled_backends.insert("CPU");
254 torch::jit::mobile::MobileModelRunner mobile_module_runner(
255 input_module_path);
256
257 // When we encounter a CPU model, we should call .cpu().copy_() on the
258 // tensors in the bundled inputs, since this is what will happen when
259 // such a model is executed on an Android device since the PyTorch JNI
260 // bindings call .cpu() in JIValue::newJIValueFromAtIValue().
261 module_runner.for_each_tensor_in_bundled_inputs(consume_tensor);
262
263 // If a user has bundled inputs since that api was updated to accept
264 // bundled inputs for multiple methods They should go down this route.
265 // Even if they only bundle inputs for forward they will have the new
266 // style bundled inputs. Since at this time in tracer.cpp we do not know
267 // what functions have bundled inputs we must call
268 // get_bundled_inputs_functions_and_info if it exists to get the set.
269 if (mobile_module_runner.has_new_style_bundled_inputs()) {
270 auto bundled_inputs_mapping =
271 mobile_module_runner.get_many_functions_bundled_inputs();
272 for (auto& entry : bundled_inputs_mapping) {
273 std::string function_name = entry.first;
274 std::vector<std::vector<at::IValue>> bundled_inputs = entry.second;
275 std::cout << "Got " << bundled_inputs.size() << " bundled input(s) for "
276 << function_name << "\n\n";
277 std::vector<at::IValue> results =
278 mobile_module_runner.run_with_inputs(function_name, bundled_inputs);
279
280 for (auto& result : results) {
281 // Consume the result Tensor(s) when tracing on CPU since the
282 // Android/Java JNI bindings will do the same.
283 torch::jit::mobile::for_each_tensor_in_ivalue(result, consume_tensor);
284 }
285 }
286 // If get_bundled_inputs_functions_and_info does not exists we default
287 // to assuming they bundled before that change was made. If no bundled
288 // inputs are found here either an error will be thrown
289 } else {
290 std::vector<std::vector<at::IValue>> bundled_inputs =
291 mobile_module_runner.get_all_bundled_inputs();
292 std::cout << "Got " << bundled_inputs.size() << " bundled input(s)\n\n";
293 std::vector<at::IValue> results =
294 mobile_module_runner.run_with_inputs(bundled_inputs);
295
296 for (auto& result : results) {
297 // Consume the result Tensor(s) when tracing on CPU since the
298 // Android/Java JNI bindings will do the same.
299 torch::jit::mobile::for_each_tensor_in_ivalue(result, consume_tensor);
300 }
301 }
302 }
303 }
304
trace_run(const std::string & input_module_path)305 TracerResult trace_run(const std::string& input_module_path) {
306 return trace_run(std::vector<std::string>(1, input_module_path));
307 }
308
trace_run(const std::vector<std::string> & input_module_paths)309 TracerResult trace_run(const std::vector<std::string>& input_module_paths) {
310 at::globalContext().setQEngine(at::QEngine::QNNPACK);
311 c10::ObservedOperators::getUnobservedOperatorList().clear();
312
313 torch::jit::mobile::OperatorCallTracer op_tracer;
314 torch::jit::mobile::KernelDTypeTracer kdtype_tracer;
315 torch::jit::mobile::CustomClassTracer custom_class_tracer;
316 torch::jit::mobile::BuildFeatureTracer build_feature_tracer;
317
318 call_setup_methods();
319
320 std::set<std::string> root_ops, traced_operators, enabled_backends,
321 loaded_classes, build_features;
322 torch::jit::mobile::KernelDTypeTracer::kernel_tags_type called_kernel_tags;
323
324 using torch::jit::MobileModuleLoadOptions;
325
326 for (auto& input_module_path : input_module_paths) {
327 // run with QNNPACK
328 at::globalContext().setQEngine(at::QEngine::QNNPACK);
329
330 run_model(
331 input_module_path, root_ops, enabled_backends, called_kernel_tags);
332 // Not every model can be successfully run with fbgemm,
333 // but for those that can this can help broaden the tracers scope around
334 // hyper optimized QNNPack paths
335 try {
336 at::globalContext().setQEngine(at::QEngine::FBGEMM);
337 run_model(
338 input_module_path, root_ops, enabled_backends, called_kernel_tags);
339 } catch (std::exception& ex) {
340 std::cerr
341 << "ModelTracer encountered an error while attempting to run the model in FBGEMM mode"
342 << ex.what() << "\n Skipping FBGEMM execution" << '\n';
343 }
344 try {
345 at::globalContext().setQEngine(at::QEngine::QNNPACK);
346 c10::InferenceMode guard(true);
347 run_model(
348 input_module_path, root_ops, enabled_backends, called_kernel_tags);
349 } catch (std::exception& ex) {
350 std::cerr
351 << "ModelTracer encountered an error while attempting to run the model under an inference guard"
352 << ex.what() << "\n Skipping inference guard execution" << '\n';
353 }
354 }
355
356 call_dependent_methods(root_ops);
357
358 op_tracer.getCalledOperators().withLock(
359 [&](std::set<std::string>& called_operators) {
360 traced_operators = called_operators;
361 });
362
363 recordCustomClassesFromOpSchemas(root_ops, traced_operators, loaded_classes);
364
365 kdtype_tracer.getCalledKernelTags().withLock(
366 [&](KernelDTypeTracer::kernel_tags_type& kernel_tags) {
367 called_kernel_tags.insert(kernel_tags.begin(), kernel_tags.end());
368 });
369
370 traced_operators.insert(
371 always_included_traced_ops.begin(), always_included_traced_ops.end());
372
373 custom_class_tracer.getLoadedClasses().withLock(
374 [&](CustomClassTracer::custom_classes_type& custom_classes) {
375 loaded_classes.insert(custom_classes.begin(), custom_classes.end());
376 });
377
378 build_feature_tracer.getBuildFeatures().withLock(
379 [&](BuildFeatureTracer::build_feature_type& bf) {
380 build_features.insert(bf.begin(), bf.end());
381 });
382
383 TracerResult tracer_result = {
384 root_ops,
385 traced_operators,
386 called_kernel_tags,
387 loaded_classes,
388 build_features,
389 enabled_backends};
390
391 return tracer_result;
392 }
393
394 } // namespace torch::jit::mobile
395