xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/model_tracer/tracer.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <iostream>
2 #include <sstream>
3 #include <string>
4 
5 /**
6  * The tracer.cpp generates a binary that accepts multiple Torch Mobile Model(s)
7  * (with bytecode.pkl), each of which has at least 1 bundled
8  * input. This binary then feeds the bundled input(s) into each corresponding
9  * model and executes it using the lite interpreter.
10  *
11  * Both root operators as well as called operators are recorded and saved
12  * into a YAML file (whose path is provided on the command line).
13  *
14  * Note: Root operators may include primary and other operators that
15  * are not invoked using the dispatcher, and hence they may not show
16  * up in the Traced Operator list.
17  *
18  */
19 
20 #include <ATen/core/dispatch/ObservedOperators.h>
21 #include <torch/csrc/autograd/grad_mode.h>
22 #include <torch/csrc/jit/mobile/import.h>
23 #include <torch/csrc/jit/mobile/model_tracer/KernelDTypeTracer.h>
24 #include <torch/csrc/jit/mobile/model_tracer/MobileModelRunner.h>
25 #include <torch/csrc/jit/mobile/model_tracer/OperatorCallTracer.h>
26 #include <torch/csrc/jit/mobile/model_tracer/TensorUtils.h>
27 #include <torch/csrc/jit/mobile/model_tracer/TracerRunner.h>
28 #include <torch/csrc/jit/mobile/module.h>
29 #include <torch/csrc/jit/mobile/parse_operators.h>
30 #include <torch/script.h>
31 
32 typedef std::map<std::string, std::set<std::string>> kt_type;
33 
34 C10_DEFINE_string(
35     model_input_path,
36     "",
37     "A comma separated list of path(s) to the input model file(s) (.ptl).");
38 
39 C10_DEFINE_string(
40     build_yaml_path,
41     "",
42     "The path of the output YAML file containing traced operator information.");
43 
44 #define REQUIRE_STRING_ARG(name)                            \
45   if (FLAGS_##name.empty()) {                               \
46     std::cerr << "You must specify the flag --" #name "\n"; \
47     return 1;                                               \
48   }
49 
50 #define REQUIRE_INT_ARG(name)                               \
51   if (FLAGS_##name == -1) {                                 \
52     std::cerr << "You must specify the flag --" #name "\n"; \
53     return 1;                                               \
54   }
55 
printOpYAML(std::ostream & out,int indent,const std::string & op_name,bool is_used_for_training,bool is_root_operator,bool include_all_overloads)56 static void printOpYAML(
57     std::ostream& out,
58     int indent,
59     const std::string& op_name,
60     bool is_used_for_training,
61     bool is_root_operator,
62     bool include_all_overloads) {
63   out << std::string(indent, ' ') << op_name << ":" << '\n';
64   out << std::string(indent + 2, ' ')
65       << "is_used_for_training: " << (is_used_for_training ? "true" : "false")
66       << '\n';
67   out << std::string(indent + 2, ' ')
68       << "is_root_operator: " << (is_root_operator ? "true" : "false") << '\n';
69   out << std::string(indent + 2, ' ')
70       << "include_all_overloads: " << (include_all_overloads ? "true" : "false")
71       << '\n';
72 }
73 
printOpsYAML(std::ostream & out,const std::set<std::string> & operator_list,bool is_used_for_training,bool is_root_operator,bool include_all_overloads)74 static void printOpsYAML(
75     std::ostream& out,
76     const std::set<std::string>& operator_list,
77     bool is_used_for_training,
78     bool is_root_operator,
79     bool include_all_overloads) {
80   for (auto& it : operator_list) {
81     printOpYAML(out, 2, it, false, is_root_operator, false);
82   }
83 }
84 
printDTypeYAML(std::ostream & out,int indent,const std::string & kernel_tag_name,const std::set<std::string> & dtypes)85 static void printDTypeYAML(
86     std::ostream& out,
87     int indent,
88     const std::string& kernel_tag_name,
89     const std::set<std::string>& dtypes) {
90   std::string indent_str = std::string(indent, ' ');
91   out << indent_str << kernel_tag_name << ":" << '\n';
92   for (auto& dtype : dtypes) {
93     out << indent_str << "- " << dtype << '\n';
94   }
95 }
96 
printDTypesYAML(std::ostream & out,const torch::jit::mobile::KernelDTypeTracer::kernel_tags_type & kernel_tags)97 static void printDTypesYAML(
98     std::ostream& out,
99     const torch::jit::mobile::KernelDTypeTracer::kernel_tags_type&
100         kernel_tags) {
101   for (auto& it : kernel_tags) {
102     printDTypeYAML(out, 2, it.first, it.second);
103   }
104 }
105 
printCustomClassesYAML(std::ostream & out,const torch::jit::mobile::CustomClassTracer::custom_classes_type & loaded_classes)106 static void printCustomClassesYAML(
107     std::ostream& out,
108     const torch::jit::mobile::CustomClassTracer::custom_classes_type&
109         loaded_classes) {
110   for (auto& class_name : loaded_classes) {
111     out << "- " << class_name << '\n';
112   }
113 }
114 
115 /**
116  * Runs multiple PyTorch lite interpreter models, and additionally writes
117  * out a list of root and called operators, kernel dtypes, and loaded/used
118  * TorchBind custom classes.
119  */
main(int argc,char * argv[])120 int main(int argc, char* argv[]) {
121   if (!c10::ParseCommandLineFlags(&argc, &argv)) {
122     std::cerr << "Failed to parse command line flags!" << '\n';
123     return 1;
124   }
125 
126   REQUIRE_STRING_ARG(model_input_path);
127   REQUIRE_STRING_ARG(build_yaml_path);
128 
129   std::istringstream sin(FLAGS_model_input_path);
130   std::ofstream yaml_out(FLAGS_build_yaml_path);
131 
132   std::cout << "Output: " << FLAGS_build_yaml_path << '\n';
133   torch::jit::mobile::TracerResult tracer_result;
134   std::vector<std::string> model_input_paths;
135 
136   for (std::string model_input_path;
137        std::getline(sin, model_input_path, ',');) {
138     std::cout << "Processing: " << model_input_path << '\n';
139     model_input_paths.push_back(model_input_path);
140   }
141 
142   try {
143     tracer_result = torch::jit::mobile::trace_run(model_input_paths);
144   } catch (std::exception& ex) {
145     std::cerr
146         << "ModelTracer has not been able to load the module for the following reasons:\n"
147         << ex.what()
148         << "\nPlease consider opening an issue at https://github.com/pytorch/pytorch/issues "
149         << "with the detailed error message." << '\n';
150 
151     throw ex;
152   }
153 
154   if (tracer_result.traced_operators.size() <=
155       torch::jit::mobile::always_included_traced_ops.size()) {
156     std::cerr
157         << c10::str(
158                "Error traced_operators size: ",
159                tracer_result.traced_operators.size(),
160                ". Expected the traced operator list to be bigger then the default size ",
161                torch::jit::mobile::always_included_traced_ops.size(),
162                ". Please report a bug in PyTorch.")
163         << '\n';
164   }
165 
166   // If the op exist in both traced_ops and root_ops, leave it in root_ops only
167   for (const auto& root_op : tracer_result.root_ops) {
168     if (tracer_result.traced_operators.find(root_op) !=
169         tracer_result.traced_operators.end()) {
170       tracer_result.traced_operators.erase(root_op);
171     }
172   }
173 
174   yaml_out << "include_all_non_op_selectives: false" << '\n';
175   yaml_out << "build_features: []" << '\n';
176   yaml_out << "operators:" << '\n';
177   printOpsYAML(
178       yaml_out,
179       tracer_result.root_ops,
180       false /* is_used_for_training */,
181       true /* is_root_operator */,
182       false /* include_all_overloads */);
183   printOpsYAML(
184       yaml_out,
185       tracer_result.traced_operators,
186       false /* is_used_for_training */,
187       false /* is_root_operator */,
188       false /* include_all_overloads */);
189 
190   yaml_out << "kernel_metadata:";
191   if (tracer_result.called_kernel_tags.empty()) {
192     yaml_out << " []";
193   }
194   yaml_out << '\n';
195   printDTypesYAML(yaml_out, tracer_result.called_kernel_tags);
196 
197   yaml_out << "custom_classes:";
198   if (tracer_result.loaded_classes.empty()) {
199     yaml_out << " []";
200   }
201   yaml_out << '\n';
202   printCustomClassesYAML(yaml_out, tracer_result.loaded_classes);
203 
204   return 0;
205 }
206