xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/model_tracer/TracerRunner.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <set>
4 #include <string>
5 #include <vector>
6 
7 #include <ATen/core/ivalue.h>
8 #include <torch/csrc/jit/mobile/model_tracer/BuildFeatureTracer.h>
9 #include <torch/csrc/jit/mobile/model_tracer/CustomClassTracer.h>
10 #include <torch/csrc/jit/mobile/model_tracer/KernelDTypeTracer.h>
11 
12 namespace torch::jit::mobile {
13 
14 const std::vector<std::string> always_included_traced_ops = {
15     // The following are called from setup sections.
16     "aten::resize_",
17     "aten::slice.Tensor",
18 };
19 
20 struct TracerResult {
21   std::set<std::string> root_ops;
22   std::set<std::string> traced_operators;
23   KernelDTypeTracer::kernel_tags_type called_kernel_tags;
24   CustomClassTracer::custom_classes_type loaded_classes;
25   BuildFeatureTracer::build_feature_type build_features;
26   std::set<std::string> enabled_backends;
27 };
28 
29 /**
30  * Trace a single model and return the TracerResult.
31  */
32 TracerResult trace_run(const std::string& input_module_path);
33 
34 /**
35  * Trace multiple models and return the TracerResult.
36  */
37 TracerResult trace_run(const std::vector<std::string>& input_module_paths);
38 
39 } // namespace torch::jit::mobile
40