xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/model_tracer/MobileModelRunner.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/mobile/model_tracer/MobileModelRunner.h>
2 #include <torch/csrc/jit/mobile/model_tracer/TensorUtils.h>
3 
4 namespace torch::jit::mobile {
5 
6 std::vector<std::vector<at::IValue>> MobileModelRunner::
ivalue_to_bundled_inputs(const c10::IValue & bundled_inputs)7     ivalue_to_bundled_inputs(const c10::IValue& bundled_inputs) {
8   CAFFE_ENFORCE(
9       bundled_inputs.isList(),
10       "Expected get_all_bundled_inputs to ",
11       "return a list but got a ",
12       bundled_inputs.tagKind(),
13       " instead");
14 
15   c10::List<at::IValue> all_inputs = bundled_inputs.toList();
16   CAFFE_ENFORCE(
17       !all_inputs.empty(),
18       "Expected at least 1 bundled input, ",
19       "but found none. Please use ",
20       "torch.utils.bundled_inputs.augment_model_with_bundled_inputs to add.");
21 
22   std::vector<std::vector<at::IValue>> ret;
23   for (at::IValue input : all_inputs) {
24     CAFFE_ENFORCE(
25         input.isTuple(),
26         "Expected list element to be a tuple ",
27         "but got a ",
28         input.tagKind(),
29         " instead");
30     ret.push_back(input.toTupleRef().elements());
31   }
32 
33   return ret;
34 }
35 
36 std::unordered_map<std::string, std::string> MobileModelRunner::
ivalue_to_bundled_inputs_map(const c10::IValue & bundled_inputs)37     ivalue_to_bundled_inputs_map(const c10::IValue& bundled_inputs) {
38   CAFFE_ENFORCE(
39       bundled_inputs.isGenericDict(),
40       "Expected get_bundled_inputs_functions_and_info to ",
41       "return a dict but got a ",
42       bundled_inputs.tagKind(),
43       " instead");
44 
45   c10::Dict<at::IValue, at::IValue> all_inputs = bundled_inputs.toGenericDict();
46   CAFFE_ENFORCE(
47       !all_inputs.empty(),
48       "Expected at least 1 function with bundled inputs, ",
49       "but found none. Please use ",
50       "torch.utils.bundled_inputs.augment_model_with_bundled_inputs to add.");
51 
52   std::unordered_map<std::string, std::string> ret;
53   for (auto& input : all_inputs) {
54     const at::IValue& function_name = input.key();
55     const at::IValue& nested_dict = input.value();
56     CAFFE_ENFORCE(
57         function_name.isString(),
58         "Expected function with inputs to be a string ",
59         "but got a ",
60         function_name.tagKind(),
61         " instead");
62     CAFFE_ENFORCE(
63         nested_dict.isGenericDict(),
64         "Expected function name to map to dictionary ",
65         "but got a ",
66         nested_dict.tagKind(),
67         " instead");
68 
69     // Got the nested dict now need to convert that into std types
70     c10::Dict<at::IValue, at::IValue> function_and_info_ival_dict =
71         nested_dict.toGenericDict();
72     std::unordered_map<std::string, std::vector<std::string>>
73         function_and_info_dict;
74     for (auto& entry : function_and_info_ival_dict) {
75       const at::IValue& key = entry.key();
76       const at::IValue& value = entry.value();
77       CAFFE_ENFORCE(
78           key.isString(),
79           "Expected extra information key to be a string ",
80           "but got a ",
81           value.tagKind(),
82           " instead");
83       CAFFE_ENFORCE(
84           value.isList(),
85           "Expected extra information values to be a list ",
86           "but got a ",
87           value.tagKind(),
88           " instead");
89 
90       // Got the value of the nested dict entry now need to convert it to std
91       // types
92       std::vector<std::string> data_list;
93       c10::List<at::IValue> ival_data = value.toList();
94       for (at::IValue data : ival_data) {
95         CAFFE_ENFORCE(
96             data.isString(),
97             "Expected list element of nested dict entries to be a string ",
98             "but got a ",
99             data.tagKind(),
100             " instead");
101         data_list.push_back(data.toStringRef());
102       }
103 
104       // Add entry into std type mapping
105       function_and_info_dict[key.toStringRef()] = data_list;
106     }
107 
108     // Could store the full mapping of std types, but the 'info' section isnt
109     // needed here
110     std::string input_function =
111         function_and_info_dict["get_inputs_function_name"][0];
112     ret[function_name.toStringRef()] = input_function;
113   }
114 
115   return ret;
116 }
117 
118 std::vector<std::vector<at::IValue>> MobileModelRunner::
get_all_bundled_inputs()119     get_all_bundled_inputs() {
120   auto has_bundled_input = module_->find_method("get_all_bundled_inputs");
121   CAFFE_ENFORCE(
122       has_bundled_input,
123       "Model does not have bundled inputs. ",
124       "Use torch.utils.bundled_inputs.augment_model_with_bundled_inputs to add.");
125 
126   c10::IValue bundled_inputs = module_->run_method("get_all_bundled_inputs");
127   return ivalue_to_bundled_inputs(bundled_inputs);
128 }
129 
130 std::unordered_map<std::string, std::vector<std::vector<at::IValue>>>
get_many_functions_bundled_inputs()131 MobileModelRunner::get_many_functions_bundled_inputs() {
132   auto has_bundled_input =
133       module_->find_method("get_bundled_inputs_functions_and_info");
134   CAFFE_ENFORCE(
135       has_bundled_input,
136       "Model does not have bundled inputs. ",
137       "Use torch.utils.bundled_inputs.augment_many_model_functions_with_bundled_inputs to add.");
138 
139   auto ival_bundled_inputs_mapping =
140       module_->run_method("get_bundled_inputs_functions_and_info");
141   auto bundled_inputs_mapping =
142       ivalue_to_bundled_inputs_map(ival_bundled_inputs_mapping);
143 
144   std::unordered_map<std::string, std::vector<std::vector<at::IValue>>> ret;
145 
146   for (auto& entry : bundled_inputs_mapping) {
147     std::string function_name = entry.first;
148     std::string function_to_call = entry.second;
149 
150     auto has_func_to_call = module_->find_method(function_to_call);
151     CAFFE_ENFORCE(
152         has_func_to_call,
153         "Model does not have ",
154         function_to_call,
155         "Use torch.utils.bundled_inputs.augment_many_model_functions_with_bundled_inputs to add.");
156 
157     c10::IValue bundled_inputs = module_->run_method(function_to_call);
158     ret[function_name] = ivalue_to_bundled_inputs(bundled_inputs);
159   }
160   return ret;
161 }
162 
run_with_inputs(std::vector<std::vector<at::IValue>> const & bundled_inputs)163 std::vector<at::IValue> MobileModelRunner::run_with_inputs(
164     std::vector<std::vector<at::IValue>> const& bundled_inputs) {
165   std::vector<at::IValue> ret;
166   ret.reserve(bundled_inputs.size());
167   for (std::vector<at::IValue> const& input : bundled_inputs) {
168     ret.emplace_back(module_->forward(input));
169   }
170   return ret;
171 }
172 
run_with_inputs(const std::string & function_name,std::vector<std::vector<at::IValue>> const & bundled_inputs) const173 std::vector<at::IValue> MobileModelRunner::run_with_inputs(
174     const std::string& function_name,
175     std::vector<std::vector<at::IValue>> const& bundled_inputs) const {
176   std::vector<at::IValue> ret;
177   ret.reserve(bundled_inputs.size());
178   auto has_bundled_input = module_->find_method(function_name);
179   CAFFE_ENFORCE(
180       has_bundled_input,
181       "Model does not have the method named ",
182       function_name,
183       "Please ensure that it was exported correctly");
184   for (std::vector<at::IValue> const& input : bundled_inputs) {
185     auto func = module_->get_method(function_name);
186     ret.emplace_back(func(input));
187   }
188   return ret;
189 }
190 
run_argless_functions(const std::vector<std::string> & functions)191 void MobileModelRunner::run_argless_functions(
192     const std::vector<std::string>& functions) {
193   for (auto& function_name : functions) {
194     if (module_->find_method(function_name)) {
195       module_->run_method(function_name);
196     }
197   }
198 }
199 
set_has_metal_gpu_operators(std::set<std::string> const & op_list)200 bool MobileModelRunner::set_has_metal_gpu_operators(
201     std::set<std::string> const& op_list) {
202   for (std::string const& op : op_list) {
203     if (op.find("metal::") == 0 || op.find("metal_prepack::") == 0 ||
204         op.find("metal_prepack_unet::") == 0) {
205       return true;
206     }
207   }
208   return false;
209 }
210 
for_each_tensor_in_bundled_inputs(std::function<void (const::at::Tensor &)> const & func)211 void MobileModelRunner::for_each_tensor_in_bundled_inputs(
212     std::function<void(const ::at::Tensor&)> const& func) {
213   if (has_new_style_bundled_inputs()) {
214     // Get the bundled inputs and access the arg level ivalues stored within
215     auto bundled_inputs_mapping = this->get_many_functions_bundled_inputs();
216 
217     // Loop over functions
218     for (auto& entry : bundled_inputs_mapping) {
219       std::vector<std::vector<at::IValue>> bundled_inputs = entry.second;
220       // Loop through inputs
221       for (const std::vector<at::IValue>& input : bundled_inputs) {
222         // Loop through values in an input
223         for (const at::IValue& iv : input) {
224           for_each_tensor_in_ivalue(iv, func);
225         }
226       }
227     }
228   } else {
229     c10::IValue iv = module_->run_method("get_all_bundled_inputs");
230     for_each_tensor_in_ivalue(iv, func);
231   }
232 }
233 } // namespace torch::jit::mobile
234