1 #include <torch/csrc/jit/mobile/model_tracer/MobileModelRunner.h>
2 #include <torch/csrc/jit/mobile/model_tracer/TensorUtils.h>
3
4 namespace torch::jit::mobile {
5
6 std::vector<std::vector<at::IValue>> MobileModelRunner::
ivalue_to_bundled_inputs(const c10::IValue & bundled_inputs)7 ivalue_to_bundled_inputs(const c10::IValue& bundled_inputs) {
8 CAFFE_ENFORCE(
9 bundled_inputs.isList(),
10 "Expected get_all_bundled_inputs to ",
11 "return a list but got a ",
12 bundled_inputs.tagKind(),
13 " instead");
14
15 c10::List<at::IValue> all_inputs = bundled_inputs.toList();
16 CAFFE_ENFORCE(
17 !all_inputs.empty(),
18 "Expected at least 1 bundled input, ",
19 "but found none. Please use ",
20 "torch.utils.bundled_inputs.augment_model_with_bundled_inputs to add.");
21
22 std::vector<std::vector<at::IValue>> ret;
23 for (at::IValue input : all_inputs) {
24 CAFFE_ENFORCE(
25 input.isTuple(),
26 "Expected list element to be a tuple ",
27 "but got a ",
28 input.tagKind(),
29 " instead");
30 ret.push_back(input.toTupleRef().elements());
31 }
32
33 return ret;
34 }
35
36 std::unordered_map<std::string, std::string> MobileModelRunner::
ivalue_to_bundled_inputs_map(const c10::IValue & bundled_inputs)37 ivalue_to_bundled_inputs_map(const c10::IValue& bundled_inputs) {
38 CAFFE_ENFORCE(
39 bundled_inputs.isGenericDict(),
40 "Expected get_bundled_inputs_functions_and_info to ",
41 "return a dict but got a ",
42 bundled_inputs.tagKind(),
43 " instead");
44
45 c10::Dict<at::IValue, at::IValue> all_inputs = bundled_inputs.toGenericDict();
46 CAFFE_ENFORCE(
47 !all_inputs.empty(),
48 "Expected at least 1 function with bundled inputs, ",
49 "but found none. Please use ",
50 "torch.utils.bundled_inputs.augment_model_with_bundled_inputs to add.");
51
52 std::unordered_map<std::string, std::string> ret;
53 for (auto& input : all_inputs) {
54 const at::IValue& function_name = input.key();
55 const at::IValue& nested_dict = input.value();
56 CAFFE_ENFORCE(
57 function_name.isString(),
58 "Expected function with inputs to be a string ",
59 "but got a ",
60 function_name.tagKind(),
61 " instead");
62 CAFFE_ENFORCE(
63 nested_dict.isGenericDict(),
64 "Expected function name to map to dictionary ",
65 "but got a ",
66 nested_dict.tagKind(),
67 " instead");
68
69 // Got the nested dict now need to convert that into std types
70 c10::Dict<at::IValue, at::IValue> function_and_info_ival_dict =
71 nested_dict.toGenericDict();
72 std::unordered_map<std::string, std::vector<std::string>>
73 function_and_info_dict;
74 for (auto& entry : function_and_info_ival_dict) {
75 const at::IValue& key = entry.key();
76 const at::IValue& value = entry.value();
77 CAFFE_ENFORCE(
78 key.isString(),
79 "Expected extra information key to be a string ",
80 "but got a ",
81 value.tagKind(),
82 " instead");
83 CAFFE_ENFORCE(
84 value.isList(),
85 "Expected extra information values to be a list ",
86 "but got a ",
87 value.tagKind(),
88 " instead");
89
90 // Got the value of the nested dict entry now need to convert it to std
91 // types
92 std::vector<std::string> data_list;
93 c10::List<at::IValue> ival_data = value.toList();
94 for (at::IValue data : ival_data) {
95 CAFFE_ENFORCE(
96 data.isString(),
97 "Expected list element of nested dict entries to be a string ",
98 "but got a ",
99 data.tagKind(),
100 " instead");
101 data_list.push_back(data.toStringRef());
102 }
103
104 // Add entry into std type mapping
105 function_and_info_dict[key.toStringRef()] = data_list;
106 }
107
108 // Could store the full mapping of std types, but the 'info' section isnt
109 // needed here
110 std::string input_function =
111 function_and_info_dict["get_inputs_function_name"][0];
112 ret[function_name.toStringRef()] = input_function;
113 }
114
115 return ret;
116 }
117
118 std::vector<std::vector<at::IValue>> MobileModelRunner::
get_all_bundled_inputs()119 get_all_bundled_inputs() {
120 auto has_bundled_input = module_->find_method("get_all_bundled_inputs");
121 CAFFE_ENFORCE(
122 has_bundled_input,
123 "Model does not have bundled inputs. ",
124 "Use torch.utils.bundled_inputs.augment_model_with_bundled_inputs to add.");
125
126 c10::IValue bundled_inputs = module_->run_method("get_all_bundled_inputs");
127 return ivalue_to_bundled_inputs(bundled_inputs);
128 }
129
130 std::unordered_map<std::string, std::vector<std::vector<at::IValue>>>
get_many_functions_bundled_inputs()131 MobileModelRunner::get_many_functions_bundled_inputs() {
132 auto has_bundled_input =
133 module_->find_method("get_bundled_inputs_functions_and_info");
134 CAFFE_ENFORCE(
135 has_bundled_input,
136 "Model does not have bundled inputs. ",
137 "Use torch.utils.bundled_inputs.augment_many_model_functions_with_bundled_inputs to add.");
138
139 auto ival_bundled_inputs_mapping =
140 module_->run_method("get_bundled_inputs_functions_and_info");
141 auto bundled_inputs_mapping =
142 ivalue_to_bundled_inputs_map(ival_bundled_inputs_mapping);
143
144 std::unordered_map<std::string, std::vector<std::vector<at::IValue>>> ret;
145
146 for (auto& entry : bundled_inputs_mapping) {
147 std::string function_name = entry.first;
148 std::string function_to_call = entry.second;
149
150 auto has_func_to_call = module_->find_method(function_to_call);
151 CAFFE_ENFORCE(
152 has_func_to_call,
153 "Model does not have ",
154 function_to_call,
155 "Use torch.utils.bundled_inputs.augment_many_model_functions_with_bundled_inputs to add.");
156
157 c10::IValue bundled_inputs = module_->run_method(function_to_call);
158 ret[function_name] = ivalue_to_bundled_inputs(bundled_inputs);
159 }
160 return ret;
161 }
162
run_with_inputs(std::vector<std::vector<at::IValue>> const & bundled_inputs)163 std::vector<at::IValue> MobileModelRunner::run_with_inputs(
164 std::vector<std::vector<at::IValue>> const& bundled_inputs) {
165 std::vector<at::IValue> ret;
166 ret.reserve(bundled_inputs.size());
167 for (std::vector<at::IValue> const& input : bundled_inputs) {
168 ret.emplace_back(module_->forward(input));
169 }
170 return ret;
171 }
172
run_with_inputs(const std::string & function_name,std::vector<std::vector<at::IValue>> const & bundled_inputs) const173 std::vector<at::IValue> MobileModelRunner::run_with_inputs(
174 const std::string& function_name,
175 std::vector<std::vector<at::IValue>> const& bundled_inputs) const {
176 std::vector<at::IValue> ret;
177 ret.reserve(bundled_inputs.size());
178 auto has_bundled_input = module_->find_method(function_name);
179 CAFFE_ENFORCE(
180 has_bundled_input,
181 "Model does not have the method named ",
182 function_name,
183 "Please ensure that it was exported correctly");
184 for (std::vector<at::IValue> const& input : bundled_inputs) {
185 auto func = module_->get_method(function_name);
186 ret.emplace_back(func(input));
187 }
188 return ret;
189 }
190
run_argless_functions(const std::vector<std::string> & functions)191 void MobileModelRunner::run_argless_functions(
192 const std::vector<std::string>& functions) {
193 for (auto& function_name : functions) {
194 if (module_->find_method(function_name)) {
195 module_->run_method(function_name);
196 }
197 }
198 }
199
set_has_metal_gpu_operators(std::set<std::string> const & op_list)200 bool MobileModelRunner::set_has_metal_gpu_operators(
201 std::set<std::string> const& op_list) {
202 for (std::string const& op : op_list) {
203 if (op.find("metal::") == 0 || op.find("metal_prepack::") == 0 ||
204 op.find("metal_prepack_unet::") == 0) {
205 return true;
206 }
207 }
208 return false;
209 }
210
for_each_tensor_in_bundled_inputs(std::function<void (const::at::Tensor &)> const & func)211 void MobileModelRunner::for_each_tensor_in_bundled_inputs(
212 std::function<void(const ::at::Tensor&)> const& func) {
213 if (has_new_style_bundled_inputs()) {
214 // Get the bundled inputs and access the arg level ivalues stored within
215 auto bundled_inputs_mapping = this->get_many_functions_bundled_inputs();
216
217 // Loop over functions
218 for (auto& entry : bundled_inputs_mapping) {
219 std::vector<std::vector<at::IValue>> bundled_inputs = entry.second;
220 // Loop through inputs
221 for (const std::vector<at::IValue>& input : bundled_inputs) {
222 // Loop through values in an input
223 for (const at::IValue& iv : input) {
224 for_each_tensor_in_ivalue(iv, func);
225 }
226 }
227 }
228 } else {
229 c10::IValue iv = module_->run_method("get_all_bundled_inputs");
230 for_each_tensor_in_ivalue(iv, func);
231 }
232 }
233 } // namespace torch::jit::mobile
234