1 #include <iostream>
2 #include <sstream>
3 #include <string>
4
5 /**
6 * The tracer.cpp generates a binary that accepts multiple Torch Mobile Model(s)
7 * (with bytecode.pkl), each of which has at least 1 bundled
8 * input. This binary then feeds the bundled input(s) into each corresponding
9 * model and executes it using the lite interpreter.
10 *
11 * Both root operators as well as called operators are recorded and saved
12 * into a YAML file (whose path is provided on the command line).
13 *
14 * Note: Root operators may include primary and other operators that
15 * are not invoked using the dispatcher, and hence they may not show
16 * up in the Traced Operator list.
17 *
18 */
19
20 #include <ATen/core/dispatch/ObservedOperators.h>
21 #include <torch/csrc/autograd/grad_mode.h>
22 #include <torch/csrc/jit/mobile/import.h>
23 #include <torch/csrc/jit/mobile/model_tracer/KernelDTypeTracer.h>
24 #include <torch/csrc/jit/mobile/model_tracer/MobileModelRunner.h>
25 #include <torch/csrc/jit/mobile/model_tracer/OperatorCallTracer.h>
26 #include <torch/csrc/jit/mobile/model_tracer/TensorUtils.h>
27 #include <torch/csrc/jit/mobile/model_tracer/TracerRunner.h>
28 #include <torch/csrc/jit/mobile/module.h>
29 #include <torch/csrc/jit/mobile/parse_operators.h>
30 #include <torch/script.h>
31
32 typedef std::map<std::string, std::set<std::string>> kt_type;
33
34 C10_DEFINE_string(
35 model_input_path,
36 "",
37 "A comma separated list of path(s) to the input model file(s) (.ptl).");
38
39 C10_DEFINE_string(
40 build_yaml_path,
41 "",
42 "The path of the output YAML file containing traced operator information.");
43
44 #define REQUIRE_STRING_ARG(name) \
45 if (FLAGS_##name.empty()) { \
46 std::cerr << "You must specify the flag --" #name "\n"; \
47 return 1; \
48 }
49
50 #define REQUIRE_INT_ARG(name) \
51 if (FLAGS_##name == -1) { \
52 std::cerr << "You must specify the flag --" #name "\n"; \
53 return 1; \
54 }
55
printOpYAML(std::ostream & out,int indent,const std::string & op_name,bool is_used_for_training,bool is_root_operator,bool include_all_overloads)56 static void printOpYAML(
57 std::ostream& out,
58 int indent,
59 const std::string& op_name,
60 bool is_used_for_training,
61 bool is_root_operator,
62 bool include_all_overloads) {
63 out << std::string(indent, ' ') << op_name << ":" << '\n';
64 out << std::string(indent + 2, ' ')
65 << "is_used_for_training: " << (is_used_for_training ? "true" : "false")
66 << '\n';
67 out << std::string(indent + 2, ' ')
68 << "is_root_operator: " << (is_root_operator ? "true" : "false") << '\n';
69 out << std::string(indent + 2, ' ')
70 << "include_all_overloads: " << (include_all_overloads ? "true" : "false")
71 << '\n';
72 }
73
printOpsYAML(std::ostream & out,const std::set<std::string> & operator_list,bool is_used_for_training,bool is_root_operator,bool include_all_overloads)74 static void printOpsYAML(
75 std::ostream& out,
76 const std::set<std::string>& operator_list,
77 bool is_used_for_training,
78 bool is_root_operator,
79 bool include_all_overloads) {
80 for (auto& it : operator_list) {
81 printOpYAML(out, 2, it, false, is_root_operator, false);
82 }
83 }
84
printDTypeYAML(std::ostream & out,int indent,const std::string & kernel_tag_name,const std::set<std::string> & dtypes)85 static void printDTypeYAML(
86 std::ostream& out,
87 int indent,
88 const std::string& kernel_tag_name,
89 const std::set<std::string>& dtypes) {
90 std::string indent_str = std::string(indent, ' ');
91 out << indent_str << kernel_tag_name << ":" << '\n';
92 for (auto& dtype : dtypes) {
93 out << indent_str << "- " << dtype << '\n';
94 }
95 }
96
printDTypesYAML(std::ostream & out,const torch::jit::mobile::KernelDTypeTracer::kernel_tags_type & kernel_tags)97 static void printDTypesYAML(
98 std::ostream& out,
99 const torch::jit::mobile::KernelDTypeTracer::kernel_tags_type&
100 kernel_tags) {
101 for (auto& it : kernel_tags) {
102 printDTypeYAML(out, 2, it.first, it.second);
103 }
104 }
105
printCustomClassesYAML(std::ostream & out,const torch::jit::mobile::CustomClassTracer::custom_classes_type & loaded_classes)106 static void printCustomClassesYAML(
107 std::ostream& out,
108 const torch::jit::mobile::CustomClassTracer::custom_classes_type&
109 loaded_classes) {
110 for (auto& class_name : loaded_classes) {
111 out << "- " << class_name << '\n';
112 }
113 }
114
115 /**
116 * Runs multiple PyTorch lite interpreter models, and additionally writes
117 * out a list of root and called operators, kernel dtypes, and loaded/used
118 * TorchBind custom classes.
119 */
main(int argc,char * argv[])120 int main(int argc, char* argv[]) {
121 if (!c10::ParseCommandLineFlags(&argc, &argv)) {
122 std::cerr << "Failed to parse command line flags!" << '\n';
123 return 1;
124 }
125
126 REQUIRE_STRING_ARG(model_input_path);
127 REQUIRE_STRING_ARG(build_yaml_path);
128
129 std::istringstream sin(FLAGS_model_input_path);
130 std::ofstream yaml_out(FLAGS_build_yaml_path);
131
132 std::cout << "Output: " << FLAGS_build_yaml_path << '\n';
133 torch::jit::mobile::TracerResult tracer_result;
134 std::vector<std::string> model_input_paths;
135
136 for (std::string model_input_path;
137 std::getline(sin, model_input_path, ',');) {
138 std::cout << "Processing: " << model_input_path << '\n';
139 model_input_paths.push_back(model_input_path);
140 }
141
142 try {
143 tracer_result = torch::jit::mobile::trace_run(model_input_paths);
144 } catch (std::exception& ex) {
145 std::cerr
146 << "ModelTracer has not been able to load the module for the following reasons:\n"
147 << ex.what()
148 << "\nPlease consider opening an issue at https://github.com/pytorch/pytorch/issues "
149 << "with the detailed error message." << '\n';
150
151 throw ex;
152 }
153
154 if (tracer_result.traced_operators.size() <=
155 torch::jit::mobile::always_included_traced_ops.size()) {
156 std::cerr
157 << c10::str(
158 "Error traced_operators size: ",
159 tracer_result.traced_operators.size(),
160 ". Expected the traced operator list to be bigger then the default size ",
161 torch::jit::mobile::always_included_traced_ops.size(),
162 ". Please report a bug in PyTorch.")
163 << '\n';
164 }
165
166 // If the op exist in both traced_ops and root_ops, leave it in root_ops only
167 for (const auto& root_op : tracer_result.root_ops) {
168 if (tracer_result.traced_operators.find(root_op) !=
169 tracer_result.traced_operators.end()) {
170 tracer_result.traced_operators.erase(root_op);
171 }
172 }
173
174 yaml_out << "include_all_non_op_selectives: false" << '\n';
175 yaml_out << "build_features: []" << '\n';
176 yaml_out << "operators:" << '\n';
177 printOpsYAML(
178 yaml_out,
179 tracer_result.root_ops,
180 false /* is_used_for_training */,
181 true /* is_root_operator */,
182 false /* include_all_overloads */);
183 printOpsYAML(
184 yaml_out,
185 tracer_result.traced_operators,
186 false /* is_used_for_training */,
187 false /* is_root_operator */,
188 false /* include_all_overloads */);
189
190 yaml_out << "kernel_metadata:";
191 if (tracer_result.called_kernel_tags.empty()) {
192 yaml_out << " []";
193 }
194 yaml_out << '\n';
195 printDTypesYAML(yaml_out, tracer_result.called_kernel_tags);
196
197 yaml_out << "custom_classes:";
198 if (tracer_result.loaded_classes.empty()) {
199 yaml_out << " []";
200 }
201 yaml_out << '\n';
202 printCustomClassesYAML(yaml_out, tracer_result.loaded_classes);
203
204 return 0;
205 }
206